diff --git a/bot.py b/bot.py index d2b9f4b3e..b9ebd5057 100644 --- a/bot.py +++ b/bot.py @@ -111,9 +111,9 @@ async def graceful_shutdown(main_system_instance): try: from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() - if hasattr(chat_manager, "_stop_auto_save"): + if hasattr(chat_manager, "stop_auto_save"): logger.info("正在停止聊天管理器...") - chat_manager._stop_auto_save() + chat_manager.stop_auto_save() except Exception as e: logger.warning(f"停止聊天管理器时出错: {e}") diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 88f8601d6..cd017b6a1 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -4,12 +4,15 @@ from collections.abc import Callable from dataclasses import dataclass from typing import Any +from src.common.logger import get_logger from src.config.api_ada_configs import APIProvider, ModelInfo from ..payload_content.message import Message from ..payload_content.resp_format import RespFormat from ..payload_content.tool_option import ToolCall, ToolOption +logger = get_logger("model_client.base_client") + @dataclass class UsageRecord: @@ -144,6 +147,10 @@ class ClientRegistry: """APIProvider.type -> BaseClient的映射表""" self.client_instance_cache: dict[str, BaseClient] = {} """APIProvider.name -> BaseClient的映射表""" + self._event_loop_cache: dict[str, int | None] = {} + """APIProvider.name -> event loop id的映射表,用于检测事件循环变化""" + self._loop_change_count: int = 0 + """事件循环变化导致缓存失效的次数""" def register_client_class(self, client_type: str): """ @@ -160,29 +167,91 @@ class ClientRegistry: return decorator + def _get_current_loop_id(self) -> int | None: + """ + 获取当前事件循环的ID + Returns: + int | None: 事件循环ID,如果没有运行中的循环则返回None + """ + try: + loop = asyncio.get_running_loop() + return id(loop) + except RuntimeError: + # 没有运行中的事件循环 + return None + + def _is_event_loop_changed(self, provider_name: str) -> bool: + """ + 检查事件循环是否发生变化 + Args: + provider_name: Provider名称 + Returns: + bool: 事件循环是否变化 + """ + current_loop_id = self._get_current_loop_id() + + # 如果没有缓存的循环ID,说明是首次创建 + if provider_name not in self._event_loop_cache: + return False + + # 比较当前循环ID与缓存的循环ID + cached_loop_id = self._event_loop_cache[provider_name] + return current_loop_id != cached_loop_id + def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient: """ - 获取注册的API客户端实例 + 获取注册的API客户端实例(带事件循环检测) Args: api_provider: APIProvider实例 - force_new: 是否强制创建新实例(用于解决事件循环问题) + force_new: 是否强制创建新实例(通常不需要,会自动检测事件循环变化) Returns: BaseClient: 注册的API客户端实例 """ + provider_name = api_provider.name + # 如果强制创建新实例,直接创建不使用缓存 if force_new: if client_class := self.client_registry.get(api_provider.client_type): - return client_class(api_provider) + new_instance = client_class(api_provider) + # 更新事件循环缓存 + self._event_loop_cache[provider_name] = self._get_current_loop_id() + return new_instance else: raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") + # 检查事件循环是否变化 + if self._is_event_loop_changed(provider_name): + # 事件循环已变化,需要重新创建实例 + logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例") + self._loop_change_count += 1 + + # 移除旧实例 + if provider_name in self.client_instance_cache: + del self.client_instance_cache[provider_name] + # 正常的缓存逻辑 - if api_provider.name not in self.client_instance_cache: + if provider_name not in self.client_instance_cache: if client_class := self.client_registry.get(api_provider.client_type): - self.client_instance_cache[api_provider.name] = client_class(api_provider) + self.client_instance_cache[provider_name] = client_class(api_provider) + # 缓存当前事件循环ID + self._event_loop_cache[provider_name] = self._get_current_loop_id() else: raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") - return self.client_instance_cache[api_provider.name] + + return self.client_instance_cache[provider_name] + + def get_cache_stats(self) -> dict: + """ + 获取缓存统计信息 + Returns: + dict: 包含缓存统计的字典 + """ + return { + "cached_instances": len(self.client_instance_cache), + "tracked_loops": len(self._event_loop_cache), + "loop_change_count": self._loop_change_count, + "cached_providers": list(self.client_instance_cache.keys()), + } client_registry = ClientRegistry() diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 509268a33..093a2051c 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -383,17 +383,61 @@ def _default_normal_response_parser( @client_registry.register_client_class("openai") class OpenaiClient(BaseClient): + # 类级别的全局缓存:所有 OpenaiClient 实例共享 + _global_client_cache: dict[int, AsyncOpenAI] = {} + """全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例""" + def __init__(self, api_provider: APIProvider): super().__init__(api_provider) + self._config_hash = self._calculate_config_hash() + """当前 provider 的配置哈希值""" + + def _calculate_config_hash(self) -> int: + """计算当前配置的哈希值""" + config_tuple = ( + self.api_provider.base_url, + self.api_provider.get_api_key(), + self.api_provider.timeout, + ) + return hash(config_tuple) def _create_client(self) -> AsyncOpenAI: - """动态创建OpenAI客户端""" - return AsyncOpenAI( + """ + 获取或创建 OpenAI 客户端实例(全局缓存) + + 多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout), + 将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。 + """ + # 检查全局缓存 + if self._config_hash in self._global_client_cache: + return self._global_client_cache[self._config_hash] + + # 创建新的 AsyncOpenAI 实例 + logger.debug( + f"创建新的 AsyncOpenAI 客户端实例 " + f"(base_url={self.api_provider.base_url}, " + f"config_hash={self._config_hash})" + ) + + client = AsyncOpenAI( base_url=self.api_provider.base_url, api_key=self.api_provider.get_api_key(), max_retries=0, timeout=self.api_provider.timeout, ) + + # 存入全局缓存 + self._global_client_cache[self._config_hash] = client + + return client + + @classmethod + def get_cache_stats(cls) -> dict: + """获取全局缓存统计信息""" + return { + "cached_openai_clients": len(cls._global_client_cache), + "config_hashes": list(cls._global_client_cache.keys()), + } async def get_response( self, diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index afb2f13ed..387a9ba20 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -48,7 +48,7 @@ logger = get_logger("model_utils") # ============================================================================== -def _normalize_image_format(image_format: str) -> str: +async def _normalize_image_format(image_format: str) -> str: """ 标准化图片格式名称,确保与各种API的兼容性 @@ -152,7 +152,7 @@ class _ModelSelector: self.model_list = model_list self.model_usage = model_usage - def select_best_available_model( + async def select_best_available_model( self, failed_models_in_this_request: set, request_type: str ) -> tuple[ModelInfo, APIProvider, BaseClient] | None: """ @@ -190,17 +190,16 @@ class _ModelSelector: model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) - # 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。 - # 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。 - force_new_client = request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + # 自动事件循环检测:ClientRegistry 会自动检测事件循环变化并处理缓存失效 + # 无需手动指定 force_new,embedding 请求也能享受缓存优势 + client = client_registry.get_client_class_instance(api_provider) logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") # 增加所选模型的请求使用惩罚值,以实现动态负载均衡。 - self.update_usage_penalty(model_info.name, increase=True) + await self.update_usage_penalty(model_info.name, increase=True) return model_info, api_provider, client - def update_usage_penalty(self, model_name: str, increase: bool): + async def update_usage_penalty(self, model_name: str, increase: bool): """ 更新模型的使用惩罚值。 @@ -218,7 +217,7 @@ class _ModelSelector: # 更新模型的惩罚值 self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment) - def update_failure_penalty(self, model_name: str, e: Exception): + async def update_failure_penalty(self, model_name: str, e: Exception): """ 根据异常类型动态调整模型的失败惩罚值。 关键错误(如网络连接、服务器错误)会获得更高的惩罚, @@ -281,7 +280,7 @@ class _PromptProcessor: 这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 """ - def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str: + async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str: """ 为请求准备最终的提示词。 @@ -298,7 +297,7 @@ class _PromptProcessor: str: 处理后的、可以直接发送给模型的完整提示词。 """ # 步骤1: 根据API提供商的配置应用内容混淆 - processed_prompt = self._apply_content_obfuscation(prompt, api_provider) + processed_prompt = await self._apply_content_obfuscation(prompt, api_provider) # 步骤2: 检查模型是否需要注入反截断指令 if getattr(model_info, "use_anti_truncation", False): @@ -307,14 +306,14 @@ class _PromptProcessor: return processed_prompt - def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]: + async def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]: """ 处理响应内容,提取思维链并检查截断。 Returns: Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断) """ - content, reasoning = self._extract_reasoning(content) + content, reasoning = await self._extract_reasoning(content) is_truncated = False if use_anti_truncation: if content.endswith(self.end_marker): @@ -323,7 +322,7 @@ class _PromptProcessor: is_truncated = True return content, reasoning, is_truncated - def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str: + async def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str: """ 根据API提供商的配置对文本进行内容混淆。 @@ -349,10 +348,10 @@ class _PromptProcessor: processed_text = self.noise_instruction + "\n\n" + text # 在拼接后的文本中注入随机噪音 - return self._inject_random_noise(processed_text, intensity) + return await self._inject_random_noise(processed_text, intensity) @staticmethod - def _inject_random_noise(text: str, intensity: int) -> str: + async def _inject_random_noise(text: str, intensity: int) -> str: """ 在文本中按指定强度注入随机噪音字符串。 @@ -394,7 +393,7 @@ class _PromptProcessor: return " ".join(result) @staticmethod - def _extract_reasoning(content: str) -> tuple[str, str]: + async def _extract_reasoning(content: str) -> tuple[str, str]: """ 从模型返回的完整内容中提取被...标签包裹的思考过程, 并返回清理后的内容和思考过程。 @@ -490,10 +489,10 @@ class _RequestExecutor: except Exception as e: logger.debug(f"请求失败: {e!s}") # 记录失败并更新模型的惩罚值 - self.model_selector.update_failure_penalty(model_info.name, e) + await self.model_selector.update_failure_penalty(model_info.name, e) # 处理异常,决定是否重试以及等待多久 - wait_interval, new_compressed_messages = self._handle_exception( + wait_interval, new_compressed_messages = await self._handle_exception( e, model_info, api_provider, @@ -513,7 +512,7 @@ class _RequestExecutor: logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") raise RuntimeError("请求失败,已达到最大重试次数") - def _handle_exception( + async def _handle_exception( self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info ) -> tuple[int, list[Message] | None]: """ @@ -526,9 +525,9 @@ class _RequestExecutor: retry_interval = api_provider.retry_interval if isinstance(e, (NetworkConnectionError, ReqAbortException)): - return self._check_retry(remain_try, retry_interval, "连接异常", model_name) + return await self._check_retry(remain_try, retry_interval, "连接异常", model_name) elif isinstance(e, RespNotOkException): - return self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) + return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) elif isinstance(e, RespParseException): logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}") return -1, None @@ -536,7 +535,7 @@ class _RequestExecutor: logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}") return -1, None - def _handle_resp_not_ok( + async def _handle_resp_not_ok( self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info ) -> tuple[int, list[Message] | None]: """ @@ -578,13 +577,13 @@ class _RequestExecutor: # 处理请求频繁或服务器端错误,这些情况适合重试 elif e.status_code == 429 or e.status_code >= 500: reason = "请求过于频繁" if e.status_code == 429 else "服务器错误" - return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name) + return await self._check_retry(remain_try, api_provider.retry_interval, reason, model_name) # 处理其他未知的HTTP错误 else: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}") return -1, None - def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]: + async def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]: """ 辅助函数,根据剩余次数决定是否进行下一次重试。 @@ -654,7 +653,7 @@ class _RequestStrategy: last_exception: Exception | None = None for attempt in range(max_attempts): - selection_result = self.model_selector.select_best_available_model( + selection_result = await self.model_selector.select_best_available_model( failed_models_in_this_request, str(request_type.value) ) if selection_result is None: @@ -669,7 +668,7 @@ class _RequestStrategy: request_kwargs = kwargs.copy() if request_type == RequestType.RESPONSE and "prompt" in request_kwargs: prompt = request_kwargs.pop("prompt") - processed_prompt = self.prompt_processor.prepare_prompt( + processed_prompt = await self.prompt_processor.prepare_prompt( prompt, model_info, api_provider, self.task_name ) message = MessageBuilder().add_text_content(processed_prompt).build() @@ -688,7 +687,7 @@ class _RequestStrategy: # 成功,立即返回 logger.debug(f"模型 '{model_info.name}' 成功生成了回复。") - self.model_selector.update_usage_penalty(model_info.name, increase=False) + await self.model_selector.update_usage_penalty(model_info.name, increase=False) return response, model_info except Exception as e: @@ -738,7 +737,7 @@ class _RequestStrategy: # --- 响应内容处理和空回复/截断检查 --- content = response.content or "" use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_content, reasoning, is_truncated = self.prompt_processor.process_response( + processed_content, reasoning, is_truncated = await self.prompt_processor.process_response( content, use_anti_truncation ) @@ -821,12 +820,12 @@ class LLMRequest: start_time = time.time() # 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行 - selection_result = self._model_selector.select_best_available_model(set(), "response") + selection_result = await self._model_selector.select_best_available_model(set(), "response") if not selection_result: raise RuntimeError("无法为图像响应选择可用模型。") model_info, api_provider, client = selection_result - normalized_format = _normalize_image_format(image_format) + normalized_format = await _normalize_image_format(image_format) message = ( MessageBuilder() .add_text_content(prompt) @@ -849,7 +848,7 @@ class LLMRequest: ) await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") - content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False) + content, reasoning, _ = await self._prompt_processor.process_response(response.content or "", False) reasoning = response.reasoning_content or reasoning return content, (reasoning, model_info.name, response.tool_calls) @@ -935,7 +934,7 @@ class LLMRequest: (响应内容, (推理过程, 模型名称, 工具调用)) """ start_time = time.time() - tool_options = self._build_tool_options(tools) + tool_options = await self._build_tool_options(tools) response, model_info = await self._strategy.execute_with_failover( RequestType.RESPONSE, @@ -1008,7 +1007,7 @@ class LLMRequest: ) @staticmethod - def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None: + async def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None: """ 根据输入的字典列表构建并验证 `ToolOption` 对象列表。