feat(model): 优化客户端缓存和事件循环检测机制
- 在 ClientRegistry 中添加事件循环变化检测,自动处理缓存失效 - 为 OpenaiClient 实现全局 AsyncOpenAI 客户端缓存,提升连接池复用效率 - 将 utils_model 中的同步方法改为异步,确保与事件循环兼容 - 移除 embedding 请求的特殊处理,现在所有请求都能享受缓存优势 - 添加缓存统计功能,便于监控和调试
This commit is contained in:
4
bot.py
4
bot.py
@@ -111,9 +111,9 @@ async def graceful_shutdown(main_system_instance):
|
|||||||
try:
|
try:
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
if hasattr(chat_manager, "_stop_auto_save"):
|
if hasattr(chat_manager, "stop_auto_save"):
|
||||||
logger.info("正在停止聊天管理器...")
|
logger.info("正在停止聊天管理器...")
|
||||||
chat_manager._stop_auto_save()
|
chat_manager.stop_auto_save()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"停止聊天管理器时出错: {e}")
|
logger.warning(f"停止聊天管理器时出错: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ from collections.abc import Callable
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||||
|
|
||||||
from ..payload_content.message import Message
|
from ..payload_content.message import Message
|
||||||
from ..payload_content.resp_format import RespFormat
|
from ..payload_content.resp_format import RespFormat
|
||||||
from ..payload_content.tool_option import ToolCall, ToolOption
|
from ..payload_content.tool_option import ToolCall, ToolOption
|
||||||
|
|
||||||
|
logger = get_logger("model_client.base_client")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UsageRecord:
|
class UsageRecord:
|
||||||
@@ -144,6 +147,10 @@ class ClientRegistry:
|
|||||||
"""APIProvider.type -> BaseClient的映射表"""
|
"""APIProvider.type -> BaseClient的映射表"""
|
||||||
self.client_instance_cache: dict[str, BaseClient] = {}
|
self.client_instance_cache: dict[str, BaseClient] = {}
|
||||||
"""APIProvider.name -> BaseClient的映射表"""
|
"""APIProvider.name -> BaseClient的映射表"""
|
||||||
|
self._event_loop_cache: dict[str, int | None] = {}
|
||||||
|
"""APIProvider.name -> event loop id的映射表,用于检测事件循环变化"""
|
||||||
|
self._loop_change_count: int = 0
|
||||||
|
"""事件循环变化导致缓存失效的次数"""
|
||||||
|
|
||||||
def register_client_class(self, client_type: str):
|
def register_client_class(self, client_type: str):
|
||||||
"""
|
"""
|
||||||
@@ -160,29 +167,91 @@ class ClientRegistry:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def _get_current_loop_id(self) -> int | None:
|
||||||
|
"""
|
||||||
|
获取当前事件循环的ID
|
||||||
|
Returns:
|
||||||
|
int | None: 事件循环ID,如果没有运行中的循环则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return id(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
# 没有运行中的事件循环
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_event_loop_changed(self, provider_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查事件循环是否发生变化
|
||||||
|
Args:
|
||||||
|
provider_name: Provider名称
|
||||||
|
Returns:
|
||||||
|
bool: 事件循环是否变化
|
||||||
|
"""
|
||||||
|
current_loop_id = self._get_current_loop_id()
|
||||||
|
|
||||||
|
# 如果没有缓存的循环ID,说明是首次创建
|
||||||
|
if provider_name not in self._event_loop_cache:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 比较当前循环ID与缓存的循环ID
|
||||||
|
cached_loop_id = self._event_loop_cache[provider_name]
|
||||||
|
return current_loop_id != cached_loop_id
|
||||||
|
|
||||||
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
||||||
"""
|
"""
|
||||||
获取注册的API客户端实例
|
获取注册的API客户端实例(带事件循环检测)
|
||||||
Args:
|
Args:
|
||||||
api_provider: APIProvider实例
|
api_provider: APIProvider实例
|
||||||
force_new: 是否强制创建新实例(用于解决事件循环问题)
|
force_new: 是否强制创建新实例(通常不需要,会自动检测事件循环变化)
|
||||||
Returns:
|
Returns:
|
||||||
BaseClient: 注册的API客户端实例
|
BaseClient: 注册的API客户端实例
|
||||||
"""
|
"""
|
||||||
|
provider_name = api_provider.name
|
||||||
|
|
||||||
# 如果强制创建新实例,直接创建不使用缓存
|
# 如果强制创建新实例,直接创建不使用缓存
|
||||||
if force_new:
|
if force_new:
|
||||||
if client_class := self.client_registry.get(api_provider.client_type):
|
if client_class := self.client_registry.get(api_provider.client_type):
|
||||||
return client_class(api_provider)
|
new_instance = client_class(api_provider)
|
||||||
|
# 更新事件循环缓存
|
||||||
|
self._event_loop_cache[provider_name] = self._get_current_loop_id()
|
||||||
|
return new_instance
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||||
|
|
||||||
|
# 检查事件循环是否变化
|
||||||
|
if self._is_event_loop_changed(provider_name):
|
||||||
|
# 事件循环已变化,需要重新创建实例
|
||||||
|
logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例")
|
||||||
|
self._loop_change_count += 1
|
||||||
|
|
||||||
|
# 移除旧实例
|
||||||
|
if provider_name in self.client_instance_cache:
|
||||||
|
del self.client_instance_cache[provider_name]
|
||||||
|
|
||||||
# 正常的缓存逻辑
|
# 正常的缓存逻辑
|
||||||
if api_provider.name not in self.client_instance_cache:
|
if provider_name not in self.client_instance_cache:
|
||||||
if client_class := self.client_registry.get(api_provider.client_type):
|
if client_class := self.client_registry.get(api_provider.client_type):
|
||||||
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
self.client_instance_cache[provider_name] = client_class(api_provider)
|
||||||
|
# 缓存当前事件循环ID
|
||||||
|
self._event_loop_cache[provider_name] = self._get_current_loop_id()
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||||
return self.client_instance_cache[api_provider.name]
|
|
||||||
|
return self.client_instance_cache[provider_name]
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> dict:
|
||||||
|
"""
|
||||||
|
获取缓存统计信息
|
||||||
|
Returns:
|
||||||
|
dict: 包含缓存统计的字典
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"cached_instances": len(self.client_instance_cache),
|
||||||
|
"tracked_loops": len(self._event_loop_cache),
|
||||||
|
"loop_change_count": self._loop_change_count,
|
||||||
|
"cached_providers": list(self.client_instance_cache.keys()),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
client_registry = ClientRegistry()
|
client_registry = ClientRegistry()
|
||||||
|
|||||||
@@ -383,18 +383,62 @@ def _default_normal_response_parser(
|
|||||||
|
|
||||||
@client_registry.register_client_class("openai")
|
@client_registry.register_client_class("openai")
|
||||||
class OpenaiClient(BaseClient):
|
class OpenaiClient(BaseClient):
|
||||||
|
# 类级别的全局缓存:所有 OpenaiClient 实例共享
|
||||||
|
_global_client_cache: dict[int, AsyncOpenAI] = {}
|
||||||
|
"""全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例"""
|
||||||
|
|
||||||
def __init__(self, api_provider: APIProvider):
|
def __init__(self, api_provider: APIProvider):
|
||||||
super().__init__(api_provider)
|
super().__init__(api_provider)
|
||||||
|
self._config_hash = self._calculate_config_hash()
|
||||||
|
"""当前 provider 的配置哈希值"""
|
||||||
|
|
||||||
|
def _calculate_config_hash(self) -> int:
|
||||||
|
"""计算当前配置的哈希值"""
|
||||||
|
config_tuple = (
|
||||||
|
self.api_provider.base_url,
|
||||||
|
self.api_provider.get_api_key(),
|
||||||
|
self.api_provider.timeout,
|
||||||
|
)
|
||||||
|
return hash(config_tuple)
|
||||||
|
|
||||||
def _create_client(self) -> AsyncOpenAI:
|
def _create_client(self) -> AsyncOpenAI:
|
||||||
"""动态创建OpenAI客户端"""
|
"""
|
||||||
return AsyncOpenAI(
|
获取或创建 OpenAI 客户端实例(全局缓存)
|
||||||
|
|
||||||
|
多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout),
|
||||||
|
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
|
||||||
|
"""
|
||||||
|
# 检查全局缓存
|
||||||
|
if self._config_hash in self._global_client_cache:
|
||||||
|
return self._global_client_cache[self._config_hash]
|
||||||
|
|
||||||
|
# 创建新的 AsyncOpenAI 实例
|
||||||
|
logger.debug(
|
||||||
|
f"创建新的 AsyncOpenAI 客户端实例 "
|
||||||
|
f"(base_url={self.api_provider.base_url}, "
|
||||||
|
f"config_hash={self._config_hash})"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = AsyncOpenAI(
|
||||||
base_url=self.api_provider.base_url,
|
base_url=self.api_provider.base_url,
|
||||||
api_key=self.api_provider.get_api_key(),
|
api_key=self.api_provider.get_api_key(),
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
timeout=self.api_provider.timeout,
|
timeout=self.api_provider.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 存入全局缓存
|
||||||
|
self._global_client_cache[self._config_hash] = client
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cache_stats(cls) -> dict:
|
||||||
|
"""获取全局缓存统计信息"""
|
||||||
|
return {
|
||||||
|
"cached_openai_clients": len(cls._global_client_cache),
|
||||||
|
"config_hashes": list(cls._global_client_cache.keys()),
|
||||||
|
}
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
self,
|
self,
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ logger = get_logger("model_utils")
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
def _normalize_image_format(image_format: str) -> str:
|
async def _normalize_image_format(image_format: str) -> str:
|
||||||
"""
|
"""
|
||||||
标准化图片格式名称,确保与各种API的兼容性
|
标准化图片格式名称,确保与各种API的兼容性
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ class _ModelSelector:
|
|||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
self.model_usage = model_usage
|
self.model_usage = model_usage
|
||||||
|
|
||||||
def select_best_available_model(
|
async def select_best_available_model(
|
||||||
self, failed_models_in_this_request: set, request_type: str
|
self, failed_models_in_this_request: set, request_type: str
|
||||||
) -> tuple[ModelInfo, APIProvider, BaseClient] | None:
|
) -> tuple[ModelInfo, APIProvider, BaseClient] | None:
|
||||||
"""
|
"""
|
||||||
@@ -190,17 +190,16 @@ class _ModelSelector:
|
|||||||
|
|
||||||
model_info = model_config.get_model_info(least_used_model_name)
|
model_info = model_config.get_model_info(least_used_model_name)
|
||||||
api_provider = model_config.get_provider(model_info.api_provider)
|
api_provider = model_config.get_provider(model_info.api_provider)
|
||||||
# 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。
|
# 自动事件循环检测:ClientRegistry 会自动检测事件循环变化并处理缓存失效
|
||||||
# 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。
|
# 无需手动指定 force_new,embedding 请求也能享受缓存优势
|
||||||
force_new_client = request_type == "embedding"
|
client = client_registry.get_client_class_instance(api_provider)
|
||||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
|
||||||
|
|
||||||
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
|
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
|
||||||
# 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
|
# 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
|
||||||
self.update_usage_penalty(model_info.name, increase=True)
|
await self.update_usage_penalty(model_info.name, increase=True)
|
||||||
return model_info, api_provider, client
|
return model_info, api_provider, client
|
||||||
|
|
||||||
def update_usage_penalty(self, model_name: str, increase: bool):
|
async def update_usage_penalty(self, model_name: str, increase: bool):
|
||||||
"""
|
"""
|
||||||
更新模型的使用惩罚值。
|
更新模型的使用惩罚值。
|
||||||
|
|
||||||
@@ -218,7 +217,7 @@ class _ModelSelector:
|
|||||||
# 更新模型的惩罚值
|
# 更新模型的惩罚值
|
||||||
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment)
|
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment)
|
||||||
|
|
||||||
def update_failure_penalty(self, model_name: str, e: Exception):
|
async def update_failure_penalty(self, model_name: str, e: Exception):
|
||||||
"""
|
"""
|
||||||
根据异常类型动态调整模型的失败惩罚值。
|
根据异常类型动态调整模型的失败惩罚值。
|
||||||
关键错误(如网络连接、服务器错误)会获得更高的惩罚,
|
关键错误(如网络连接、服务器错误)会获得更高的惩罚,
|
||||||
@@ -281,7 +280,7 @@ class _PromptProcessor:
|
|||||||
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
|
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
|
async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
为请求准备最终的提示词。
|
为请求准备最终的提示词。
|
||||||
|
|
||||||
@@ -298,7 +297,7 @@ class _PromptProcessor:
|
|||||||
str: 处理后的、可以直接发送给模型的完整提示词。
|
str: 处理后的、可以直接发送给模型的完整提示词。
|
||||||
"""
|
"""
|
||||||
# 步骤1: 根据API提供商的配置应用内容混淆
|
# 步骤1: 根据API提供商的配置应用内容混淆
|
||||||
processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
|
processed_prompt = await self._apply_content_obfuscation(prompt, api_provider)
|
||||||
|
|
||||||
# 步骤2: 检查模型是否需要注入反截断指令
|
# 步骤2: 检查模型是否需要注入反截断指令
|
||||||
if getattr(model_info, "use_anti_truncation", False):
|
if getattr(model_info, "use_anti_truncation", False):
|
||||||
@@ -307,14 +306,14 @@ class _PromptProcessor:
|
|||||||
|
|
||||||
return processed_prompt
|
return processed_prompt
|
||||||
|
|
||||||
def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
|
async def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
|
||||||
"""
|
"""
|
||||||
处理响应内容,提取思维链并检查截断。
|
处理响应内容,提取思维链并检查截断。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
|
Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
|
||||||
"""
|
"""
|
||||||
content, reasoning = self._extract_reasoning(content)
|
content, reasoning = await self._extract_reasoning(content)
|
||||||
is_truncated = False
|
is_truncated = False
|
||||||
if use_anti_truncation:
|
if use_anti_truncation:
|
||||||
if content.endswith(self.end_marker):
|
if content.endswith(self.end_marker):
|
||||||
@@ -323,7 +322,7 @@ class _PromptProcessor:
|
|||||||
is_truncated = True
|
is_truncated = True
|
||||||
return content, reasoning, is_truncated
|
return content, reasoning, is_truncated
|
||||||
|
|
||||||
def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str:
|
async def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str:
|
||||||
"""
|
"""
|
||||||
根据API提供商的配置对文本进行内容混淆。
|
根据API提供商的配置对文本进行内容混淆。
|
||||||
|
|
||||||
@@ -349,10 +348,10 @@ class _PromptProcessor:
|
|||||||
processed_text = self.noise_instruction + "\n\n" + text
|
processed_text = self.noise_instruction + "\n\n" + text
|
||||||
|
|
||||||
# 在拼接后的文本中注入随机噪音
|
# 在拼接后的文本中注入随机噪音
|
||||||
return self._inject_random_noise(processed_text, intensity)
|
return await self._inject_random_noise(processed_text, intensity)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _inject_random_noise(text: str, intensity: int) -> str:
|
async def _inject_random_noise(text: str, intensity: int) -> str:
|
||||||
"""
|
"""
|
||||||
在文本中按指定强度注入随机噪音字符串。
|
在文本中按指定强度注入随机噪音字符串。
|
||||||
|
|
||||||
@@ -394,7 +393,7 @@ class _PromptProcessor:
|
|||||||
return " ".join(result)
|
return " ".join(result)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_reasoning(content: str) -> tuple[str, str]:
|
async def _extract_reasoning(content: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
|
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
|
||||||
并返回清理后的内容和思考过程。
|
并返回清理后的内容和思考过程。
|
||||||
@@ -490,10 +489,10 @@ class _RequestExecutor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"请求失败: {e!s}")
|
logger.debug(f"请求失败: {e!s}")
|
||||||
# 记录失败并更新模型的惩罚值
|
# 记录失败并更新模型的惩罚值
|
||||||
self.model_selector.update_failure_penalty(model_info.name, e)
|
await self.model_selector.update_failure_penalty(model_info.name, e)
|
||||||
|
|
||||||
# 处理异常,决定是否重试以及等待多久
|
# 处理异常,决定是否重试以及等待多久
|
||||||
wait_interval, new_compressed_messages = self._handle_exception(
|
wait_interval, new_compressed_messages = await self._handle_exception(
|
||||||
e,
|
e,
|
||||||
model_info,
|
model_info,
|
||||||
api_provider,
|
api_provider,
|
||||||
@@ -513,7 +512,7 @@ class _RequestExecutor:
|
|||||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||||
|
|
||||||
def _handle_exception(
|
async def _handle_exception(
|
||||||
self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||||
) -> tuple[int, list[Message] | None]:
|
) -> tuple[int, list[Message] | None]:
|
||||||
"""
|
"""
|
||||||
@@ -526,9 +525,9 @@ class _RequestExecutor:
|
|||||||
retry_interval = api_provider.retry_interval
|
retry_interval = api_provider.retry_interval
|
||||||
|
|
||||||
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||||
return self._check_retry(remain_try, retry_interval, "连接异常", model_name)
|
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
|
||||||
elif isinstance(e, RespNotOkException):
|
elif isinstance(e, RespNotOkException):
|
||||||
return self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)
|
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)
|
||||||
elif isinstance(e, RespParseException):
|
elif isinstance(e, RespParseException):
|
||||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
|
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
|
||||||
return -1, None
|
return -1, None
|
||||||
@@ -536,7 +535,7 @@ class _RequestExecutor:
|
|||||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}")
|
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}")
|
||||||
return -1, None
|
return -1, None
|
||||||
|
|
||||||
def _handle_resp_not_ok(
|
async def _handle_resp_not_ok(
|
||||||
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||||
) -> tuple[int, list[Message] | None]:
|
) -> tuple[int, list[Message] | None]:
|
||||||
"""
|
"""
|
||||||
@@ -578,13 +577,13 @@ class _RequestExecutor:
|
|||||||
# 处理请求频繁或服务器端错误,这些情况适合重试
|
# 处理请求频繁或服务器端错误,这些情况适合重试
|
||||||
elif e.status_code == 429 or e.status_code >= 500:
|
elif e.status_code == 429 or e.status_code >= 500:
|
||||||
reason = "请求过于频繁" if e.status_code == 429 else "服务器错误"
|
reason = "请求过于频繁" if e.status_code == 429 else "服务器错误"
|
||||||
return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name)
|
return await self._check_retry(remain_try, api_provider.retry_interval, reason, model_name)
|
||||||
# 处理其他未知的HTTP错误
|
# 处理其他未知的HTTP错误
|
||||||
else:
|
else:
|
||||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
|
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
|
||||||
return -1, None
|
return -1, None
|
||||||
|
|
||||||
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
|
async def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
|
||||||
"""
|
"""
|
||||||
辅助函数,根据剩余次数决定是否进行下一次重试。
|
辅助函数,根据剩余次数决定是否进行下一次重试。
|
||||||
|
|
||||||
@@ -654,7 +653,7 @@ class _RequestStrategy:
|
|||||||
last_exception: Exception | None = None
|
last_exception: Exception | None = None
|
||||||
|
|
||||||
for attempt in range(max_attempts):
|
for attempt in range(max_attempts):
|
||||||
selection_result = self.model_selector.select_best_available_model(
|
selection_result = await self.model_selector.select_best_available_model(
|
||||||
failed_models_in_this_request, str(request_type.value)
|
failed_models_in_this_request, str(request_type.value)
|
||||||
)
|
)
|
||||||
if selection_result is None:
|
if selection_result is None:
|
||||||
@@ -669,7 +668,7 @@ class _RequestStrategy:
|
|||||||
request_kwargs = kwargs.copy()
|
request_kwargs = kwargs.copy()
|
||||||
if request_type == RequestType.RESPONSE and "prompt" in request_kwargs:
|
if request_type == RequestType.RESPONSE and "prompt" in request_kwargs:
|
||||||
prompt = request_kwargs.pop("prompt")
|
prompt = request_kwargs.pop("prompt")
|
||||||
processed_prompt = self.prompt_processor.prepare_prompt(
|
processed_prompt = await self.prompt_processor.prepare_prompt(
|
||||||
prompt, model_info, api_provider, self.task_name
|
prompt, model_info, api_provider, self.task_name
|
||||||
)
|
)
|
||||||
message = MessageBuilder().add_text_content(processed_prompt).build()
|
message = MessageBuilder().add_text_content(processed_prompt).build()
|
||||||
@@ -688,7 +687,7 @@ class _RequestStrategy:
|
|||||||
|
|
||||||
# 成功,立即返回
|
# 成功,立即返回
|
||||||
logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
|
logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
|
||||||
self.model_selector.update_usage_penalty(model_info.name, increase=False)
|
await self.model_selector.update_usage_penalty(model_info.name, increase=False)
|
||||||
return response, model_info
|
return response, model_info
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -738,7 +737,7 @@ class _RequestStrategy:
|
|||||||
# --- 响应内容处理和空回复/截断检查 ---
|
# --- 响应内容处理和空回复/截断检查 ---
|
||||||
content = response.content or ""
|
content = response.content or ""
|
||||||
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
|
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
|
||||||
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(
|
processed_content, reasoning, is_truncated = await self.prompt_processor.process_response(
|
||||||
content, use_anti_truncation
|
content, use_anti_truncation
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -821,12 +820,12 @@ class LLMRequest:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
|
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
|
||||||
selection_result = self._model_selector.select_best_available_model(set(), "response")
|
selection_result = await self._model_selector.select_best_available_model(set(), "response")
|
||||||
if not selection_result:
|
if not selection_result:
|
||||||
raise RuntimeError("无法为图像响应选择可用模型。")
|
raise RuntimeError("无法为图像响应选择可用模型。")
|
||||||
model_info, api_provider, client = selection_result
|
model_info, api_provider, client = selection_result
|
||||||
|
|
||||||
normalized_format = _normalize_image_format(image_format)
|
normalized_format = await _normalize_image_format(image_format)
|
||||||
message = (
|
message = (
|
||||||
MessageBuilder()
|
MessageBuilder()
|
||||||
.add_text_content(prompt)
|
.add_text_content(prompt)
|
||||||
@@ -849,7 +848,7 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||||
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
|
content, reasoning, _ = await self._prompt_processor.process_response(response.content or "", False)
|
||||||
reasoning = response.reasoning_content or reasoning
|
reasoning = response.reasoning_content or reasoning
|
||||||
|
|
||||||
return content, (reasoning, model_info.name, response.tool_calls)
|
return content, (reasoning, model_info.name, response.tool_calls)
|
||||||
@@ -935,7 +934,7 @@ class LLMRequest:
|
|||||||
(响应内容, (推理过程, 模型名称, 工具调用))
|
(响应内容, (推理过程, 模型名称, 工具调用))
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
tool_options = self._build_tool_options(tools)
|
tool_options = await self._build_tool_options(tools)
|
||||||
|
|
||||||
response, model_info = await self._strategy.execute_with_failover(
|
response, model_info = await self._strategy.execute_with_failover(
|
||||||
RequestType.RESPONSE,
|
RequestType.RESPONSE,
|
||||||
@@ -1008,7 +1007,7 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
|
async def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
|
||||||
"""
|
"""
|
||||||
根据输入的字典列表构建并验证 `ToolOption` 对象列表。
|
根据输入的字典列表构建并验证 `ToolOption` 对象列表。
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user