feat(model): 优化客户端缓存和事件循环检测机制

- 在 ClientRegistry 中添加事件循环变化检测,自动处理缓存失效
- 为 OpenaiClient 实现全局 AsyncOpenAI 客户端缓存,提升连接池复用效率
- 将 utils_model 中的同步方法改为异步,确保与事件循环兼容
- 移除 embedding 请求的特殊处理,现在所有请求都能享受缓存优势
- 添加缓存统计功能,便于监控和调试
This commit is contained in:
Windpicker-owo
2025-10-06 01:05:50 +08:00
parent a72012bf78
commit f59a31865c
4 changed files with 156 additions and 44 deletions

4
bot.py
View File

@@ -111,9 +111,9 @@ async def graceful_shutdown(main_system_instance):
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
if hasattr(chat_manager, "_stop_auto_save"): if hasattr(chat_manager, "stop_auto_save"):
logger.info("正在停止聊天管理器...") logger.info("正在停止聊天管理器...")
chat_manager._stop_auto_save() chat_manager.stop_auto_save()
except Exception as e: except Exception as e:
logger.warning(f"停止聊天管理器时出错: {e}") logger.warning(f"停止聊天管理器时出错: {e}")

View File

@@ -4,12 +4,15 @@ from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from src.common.logger import get_logger
from src.config.api_ada_configs import APIProvider, ModelInfo from src.config.api_ada_configs import APIProvider, ModelInfo
from ..payload_content.message import Message from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolCall, ToolOption from ..payload_content.tool_option import ToolCall, ToolOption
logger = get_logger("model_client.base_client")
@dataclass @dataclass
class UsageRecord: class UsageRecord:
@@ -144,6 +147,10 @@ class ClientRegistry:
"""APIProvider.type -> BaseClient的映射表""" """APIProvider.type -> BaseClient的映射表"""
self.client_instance_cache: dict[str, BaseClient] = {} self.client_instance_cache: dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表""" """APIProvider.name -> BaseClient的映射表"""
self._event_loop_cache: dict[str, int | None] = {}
"""APIProvider.name -> event loop id的映射表用于检测事件循环变化"""
self._loop_change_count: int = 0
"""事件循环变化导致缓存失效的次数"""
def register_client_class(self, client_type: str): def register_client_class(self, client_type: str):
""" """
@@ -160,29 +167,91 @@ class ClientRegistry:
return decorator return decorator
def _get_current_loop_id(self) -> int | None:
"""
获取当前事件循环的ID
Returns:
int | None: 事件循环ID如果没有运行中的循环则返回None
"""
try:
loop = asyncio.get_running_loop()
return id(loop)
except RuntimeError:
# 没有运行中的事件循环
return None
def _is_event_loop_changed(self, provider_name: str) -> bool:
"""
检查事件循环是否发生变化
Args:
provider_name: Provider名称
Returns:
bool: 事件循环是否变化
"""
current_loop_id = self._get_current_loop_id()
# 如果没有缓存的循环ID说明是首次创建
if provider_name not in self._event_loop_cache:
return False
# 比较当前循环ID与缓存的循环ID
cached_loop_id = self._event_loop_cache[provider_name]
return current_loop_id != cached_loop_id
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient: def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
""" """
获取注册的API客户端实例 获取注册的API客户端实例(带事件循环检测)
Args: Args:
api_provider: APIProvider实例 api_provider: APIProvider实例
force_new: 是否强制创建新实例(用于解决事件循环问题 force_new: 是否强制创建新实例(通常不需要,会自动检测事件循环变化
Returns: Returns:
BaseClient: 注册的API客户端实例 BaseClient: 注册的API客户端实例
""" """
provider_name = api_provider.name
# 如果强制创建新实例,直接创建不使用缓存 # 如果强制创建新实例,直接创建不使用缓存
if force_new: if force_new:
if client_class := self.client_registry.get(api_provider.client_type): if client_class := self.client_registry.get(api_provider.client_type):
return client_class(api_provider) new_instance = client_class(api_provider)
# 更新事件循环缓存
self._event_loop_cache[provider_name] = self._get_current_loop_id()
return new_instance
else: else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 检查事件循环是否变化
if self._is_event_loop_changed(provider_name):
# 事件循环已变化,需要重新创建实例
logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例")
self._loop_change_count += 1
# 移除旧实例
if provider_name in self.client_instance_cache:
del self.client_instance_cache[provider_name]
# 正常的缓存逻辑 # 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache: if provider_name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type): if client_class := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = client_class(api_provider) self.client_instance_cache[provider_name] = client_class(api_provider)
# 缓存当前事件循环ID
self._event_loop_cache[provider_name] = self._get_current_loop_id()
else: else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[api_provider.name]
return self.client_instance_cache[provider_name]
def get_cache_stats(self) -> dict:
"""
获取缓存统计信息
Returns:
dict: 包含缓存统计的字典
"""
return {
"cached_instances": len(self.client_instance_cache),
"tracked_loops": len(self._event_loop_cache),
"loop_change_count": self._loop_change_count,
"cached_providers": list(self.client_instance_cache.keys()),
}
client_registry = ClientRegistry() client_registry = ClientRegistry()

View File

@@ -383,18 +383,62 @@ def _default_normal_response_parser(
@client_registry.register_client_class("openai") @client_registry.register_client_class("openai")
class OpenaiClient(BaseClient): class OpenaiClient(BaseClient):
# 类级别的全局缓存:所有 OpenaiClient 实例共享
_global_client_cache: dict[int, AsyncOpenAI] = {}
"""全局 AsyncOpenAI 客户端缓存config_hash -> AsyncOpenAI 实例"""
def __init__(self, api_provider: APIProvider): def __init__(self, api_provider: APIProvider):
super().__init__(api_provider) super().__init__(api_provider)
self._config_hash = self._calculate_config_hash()
"""当前 provider 的配置哈希值"""
def _calculate_config_hash(self) -> int:
"""计算当前配置的哈希值"""
config_tuple = (
self.api_provider.base_url,
self.api_provider.get_api_key(),
self.api_provider.timeout,
)
return hash(config_tuple)
def _create_client(self) -> AsyncOpenAI: def _create_client(self) -> AsyncOpenAI:
"""动态创建OpenAI客户端""" """
return AsyncOpenAI( 获取或创建 OpenAI 客户端实例(全局缓存)
多个 OpenaiClient 实例如果配置相同base_url + api_key + timeout
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
"""
# 检查全局缓存
if self._config_hash in self._global_client_cache:
return self._global_client_cache[self._config_hash]
# 创建新的 AsyncOpenAI 实例
logger.debug(
f"创建新的 AsyncOpenAI 客户端实例 "
f"(base_url={self.api_provider.base_url}, "
f"config_hash={self._config_hash})"
)
client = AsyncOpenAI(
base_url=self.api_provider.base_url, base_url=self.api_provider.base_url,
api_key=self.api_provider.get_api_key(), api_key=self.api_provider.get_api_key(),
max_retries=0, max_retries=0,
timeout=self.api_provider.timeout, timeout=self.api_provider.timeout,
) )
# 存入全局缓存
self._global_client_cache[self._config_hash] = client
return client
@classmethod
def get_cache_stats(cls) -> dict:
"""获取全局缓存统计信息"""
return {
"cached_openai_clients": len(cls._global_client_cache),
"config_hashes": list(cls._global_client_cache.keys()),
}
async def get_response( async def get_response(
self, self,
model_info: ModelInfo, model_info: ModelInfo,

View File

@@ -48,7 +48,7 @@ logger = get_logger("model_utils")
# ============================================================================== # ==============================================================================
def _normalize_image_format(image_format: str) -> str: async def _normalize_image_format(image_format: str) -> str:
""" """
标准化图片格式名称确保与各种API的兼容性 标准化图片格式名称确保与各种API的兼容性
@@ -152,7 +152,7 @@ class _ModelSelector:
self.model_list = model_list self.model_list = model_list
self.model_usage = model_usage self.model_usage = model_usage
def select_best_available_model( async def select_best_available_model(
self, failed_models_in_this_request: set, request_type: str self, failed_models_in_this_request: set, request_type: str
) -> tuple[ModelInfo, APIProvider, BaseClient] | None: ) -> tuple[ModelInfo, APIProvider, BaseClient] | None:
""" """
@@ -190,17 +190,16 @@ class _ModelSelector:
model_info = model_config.get_model_info(least_used_model_name) model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider) api_provider = model_config.get_provider(model_info.api_provider)
# 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。 # 自动事件循环检测ClientRegistry 会自动检测事件循环变化并处理缓存失效
# 这是为了避免在某些高并发场景下共享的ClientSession可能引发的事件循环相关问题。 # 无需手动指定 force_newembedding 请求也能享受缓存优势
force_new_client = request_type == "embedding" client = client_registry.get_client_class_instance(api_provider)
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
# 增加所选模型的请求使用惩罚值,以实现动态负载均衡。 # 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
self.update_usage_penalty(model_info.name, increase=True) await self.update_usage_penalty(model_info.name, increase=True)
return model_info, api_provider, client return model_info, api_provider, client
def update_usage_penalty(self, model_name: str, increase: bool): async def update_usage_penalty(self, model_name: str, increase: bool):
""" """
更新模型的使用惩罚值。 更新模型的使用惩罚值。
@@ -218,7 +217,7 @@ class _ModelSelector:
# 更新模型的惩罚值 # 更新模型的惩罚值
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment) self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment)
def update_failure_penalty(self, model_name: str, e: Exception): async def update_failure_penalty(self, model_name: str, e: Exception):
""" """
根据异常类型动态调整模型的失败惩罚值。 根据异常类型动态调整模型的失败惩罚值。
关键错误(如网络连接、服务器错误)会获得更高的惩罚, 关键错误(如网络连接、服务器错误)会获得更高的惩罚,
@@ -281,7 +280,7 @@ class _PromptProcessor:
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
""" """
def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str: async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
""" """
为请求准备最终的提示词。 为请求准备最终的提示词。
@@ -298,7 +297,7 @@ class _PromptProcessor:
str: 处理后的、可以直接发送给模型的完整提示词。 str: 处理后的、可以直接发送给模型的完整提示词。
""" """
# 步骤1: 根据API提供商的配置应用内容混淆 # 步骤1: 根据API提供商的配置应用内容混淆
processed_prompt = self._apply_content_obfuscation(prompt, api_provider) processed_prompt = await self._apply_content_obfuscation(prompt, api_provider)
# 步骤2: 检查模型是否需要注入反截断指令 # 步骤2: 检查模型是否需要注入反截断指令
if getattr(model_info, "use_anti_truncation", False): if getattr(model_info, "use_anti_truncation", False):
@@ -307,14 +306,14 @@ class _PromptProcessor:
return processed_prompt return processed_prompt
def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]: async def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
""" """
处理响应内容,提取思维链并检查截断。 处理响应内容,提取思维链并检查截断。
Returns: Returns:
Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断) Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
""" """
content, reasoning = self._extract_reasoning(content) content, reasoning = await self._extract_reasoning(content)
is_truncated = False is_truncated = False
if use_anti_truncation: if use_anti_truncation:
if content.endswith(self.end_marker): if content.endswith(self.end_marker):
@@ -323,7 +322,7 @@ class _PromptProcessor:
is_truncated = True is_truncated = True
return content, reasoning, is_truncated return content, reasoning, is_truncated
def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str: async def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str:
""" """
根据API提供商的配置对文本进行内容混淆。 根据API提供商的配置对文本进行内容混淆。
@@ -349,10 +348,10 @@ class _PromptProcessor:
processed_text = self.noise_instruction + "\n\n" + text processed_text = self.noise_instruction + "\n\n" + text
# 在拼接后的文本中注入随机噪音 # 在拼接后的文本中注入随机噪音
return self._inject_random_noise(processed_text, intensity) return await self._inject_random_noise(processed_text, intensity)
@staticmethod @staticmethod
def _inject_random_noise(text: str, intensity: int) -> str: async def _inject_random_noise(text: str, intensity: int) -> str:
""" """
在文本中按指定强度注入随机噪音字符串。 在文本中按指定强度注入随机噪音字符串。
@@ -394,7 +393,7 @@ class _PromptProcessor:
return " ".join(result) return " ".join(result)
@staticmethod @staticmethod
def _extract_reasoning(content: str) -> tuple[str, str]: async def _extract_reasoning(content: str) -> tuple[str, str]:
""" """
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程, 从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
并返回清理后的内容和思考过程。 并返回清理后的内容和思考过程。
@@ -490,10 +489,10 @@ class _RequestExecutor:
except Exception as e: except Exception as e:
logger.debug(f"请求失败: {e!s}") logger.debug(f"请求失败: {e!s}")
# 记录失败并更新模型的惩罚值 # 记录失败并更新模型的惩罚值
self.model_selector.update_failure_penalty(model_info.name, e) await self.model_selector.update_failure_penalty(model_info.name, e)
# 处理异常,决定是否重试以及等待多久 # 处理异常,决定是否重试以及等待多久
wait_interval, new_compressed_messages = self._handle_exception( wait_interval, new_compressed_messages = await self._handle_exception(
e, e,
model_info, model_info,
api_provider, api_provider,
@@ -513,7 +512,7 @@ class _RequestExecutor:
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}") logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数") raise RuntimeError("请求失败,已达到最大重试次数")
def _handle_exception( async def _handle_exception(
self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
) -> tuple[int, list[Message] | None]: ) -> tuple[int, list[Message] | None]:
""" """
@@ -526,9 +525,9 @@ class _RequestExecutor:
retry_interval = api_provider.retry_interval retry_interval = api_provider.retry_interval
if isinstance(e, (NetworkConnectionError, ReqAbortException)): if isinstance(e, (NetworkConnectionError, ReqAbortException)):
return self._check_retry(remain_try, retry_interval, "连接异常", model_name) return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
elif isinstance(e, RespNotOkException): elif isinstance(e, RespNotOkException):
return self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)
elif isinstance(e, RespParseException): elif isinstance(e, RespParseException):
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}") logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
return -1, None return -1, None
@@ -536,7 +535,7 @@ class _RequestExecutor:
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}") logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}")
return -1, None return -1, None
def _handle_resp_not_ok( async def _handle_resp_not_ok(
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
) -> tuple[int, list[Message] | None]: ) -> tuple[int, list[Message] | None]:
""" """
@@ -578,13 +577,13 @@ class _RequestExecutor:
# 处理请求频繁或服务器端错误,这些情况适合重试 # 处理请求频繁或服务器端错误,这些情况适合重试
elif e.status_code == 429 or e.status_code >= 500: elif e.status_code == 429 or e.status_code >= 500:
reason = "请求过于频繁" if e.status_code == 429 else "服务器错误" reason = "请求过于频繁" if e.status_code == 429 else "服务器错误"
return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name) return await self._check_retry(remain_try, api_provider.retry_interval, reason, model_name)
# 处理其他未知的HTTP错误 # 处理其他未知的HTTP错误
else: else:
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}") logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
return -1, None return -1, None
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]: async def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
""" """
辅助函数,根据剩余次数决定是否进行下一次重试。 辅助函数,根据剩余次数决定是否进行下一次重试。
@@ -654,7 +653,7 @@ class _RequestStrategy:
last_exception: Exception | None = None last_exception: Exception | None = None
for attempt in range(max_attempts): for attempt in range(max_attempts):
selection_result = self.model_selector.select_best_available_model( selection_result = await self.model_selector.select_best_available_model(
failed_models_in_this_request, str(request_type.value) failed_models_in_this_request, str(request_type.value)
) )
if selection_result is None: if selection_result is None:
@@ -669,7 +668,7 @@ class _RequestStrategy:
request_kwargs = kwargs.copy() request_kwargs = kwargs.copy()
if request_type == RequestType.RESPONSE and "prompt" in request_kwargs: if request_type == RequestType.RESPONSE and "prompt" in request_kwargs:
prompt = request_kwargs.pop("prompt") prompt = request_kwargs.pop("prompt")
processed_prompt = self.prompt_processor.prepare_prompt( processed_prompt = await self.prompt_processor.prepare_prompt(
prompt, model_info, api_provider, self.task_name prompt, model_info, api_provider, self.task_name
) )
message = MessageBuilder().add_text_content(processed_prompt).build() message = MessageBuilder().add_text_content(processed_prompt).build()
@@ -688,7 +687,7 @@ class _RequestStrategy:
# 成功,立即返回 # 成功,立即返回
logger.debug(f"模型 '{model_info.name}' 成功生成了回复。") logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
self.model_selector.update_usage_penalty(model_info.name, increase=False) await self.model_selector.update_usage_penalty(model_info.name, increase=False)
return response, model_info return response, model_info
except Exception as e: except Exception as e:
@@ -738,7 +737,7 @@ class _RequestStrategy:
# --- 响应内容处理和空回复/截断检查 --- # --- 响应内容处理和空回复/截断检查 ---
content = response.content or "" content = response.content or ""
use_anti_truncation = getattr(model_info, "use_anti_truncation", False) use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
processed_content, reasoning, is_truncated = self.prompt_processor.process_response( processed_content, reasoning, is_truncated = await self.prompt_processor.process_response(
content, use_anti_truncation content, use_anti_truncation
) )
@@ -821,12 +820,12 @@ class LLMRequest:
start_time = time.time() start_time = time.time()
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行 # 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
selection_result = self._model_selector.select_best_available_model(set(), "response") selection_result = await self._model_selector.select_best_available_model(set(), "response")
if not selection_result: if not selection_result:
raise RuntimeError("无法为图像响应选择可用模型。") raise RuntimeError("无法为图像响应选择可用模型。")
model_info, api_provider, client = selection_result model_info, api_provider, client = selection_result
normalized_format = _normalize_image_format(image_format) normalized_format = await _normalize_image_format(image_format)
message = ( message = (
MessageBuilder() MessageBuilder()
.add_text_content(prompt) .add_text_content(prompt)
@@ -849,7 +848,7 @@ class LLMRequest:
) )
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False) content, reasoning, _ = await self._prompt_processor.process_response(response.content or "", False)
reasoning = response.reasoning_content or reasoning reasoning = response.reasoning_content or reasoning
return content, (reasoning, model_info.name, response.tool_calls) return content, (reasoning, model_info.name, response.tool_calls)
@@ -935,7 +934,7 @@ class LLMRequest:
(响应内容, (推理过程, 模型名称, 工具调用)) (响应内容, (推理过程, 模型名称, 工具调用))
""" """
start_time = time.time() start_time = time.time()
tool_options = self._build_tool_options(tools) tool_options = await self._build_tool_options(tools)
response, model_info = await self._strategy.execute_with_failover( response, model_info = await self._strategy.execute_with_failover(
RequestType.RESPONSE, RequestType.RESPONSE,
@@ -1008,7 +1007,7 @@ class LLMRequest:
) )
@staticmethod @staticmethod
def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None: async def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
""" """
根据输入的字典列表构建并验证 `ToolOption` 对象列表。 根据输入的字典列表构建并验证 `ToolOption` 对象列表。