feat(model): 优化客户端缓存和事件循环检测机制
- 在 ClientRegistry 中添加事件循环变化检测,自动处理缓存失效 - 为 OpenaiClient 实现全局 AsyncOpenAI 客户端缓存,提升连接池复用效率 - 将 utils_model 中的同步方法改为异步,确保与事件循环兼容 - 移除 embedding 请求的特殊处理,现在所有请求都能享受缓存优势 - 添加缓存统计功能,便于监控和调试
This commit is contained in:
@@ -48,7 +48,7 @@ logger = get_logger("model_utils")
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def _normalize_image_format(image_format: str) -> str:
|
||||
async def _normalize_image_format(image_format: str) -> str:
|
||||
"""
|
||||
标准化图片格式名称,确保与各种API的兼容性
|
||||
|
||||
@@ -152,7 +152,7 @@ class _ModelSelector:
|
||||
self.model_list = model_list
|
||||
self.model_usage = model_usage
|
||||
|
||||
def select_best_available_model(
|
||||
async def select_best_available_model(
|
||||
self, failed_models_in_this_request: set, request_type: str
|
||||
) -> tuple[ModelInfo, APIProvider, BaseClient] | None:
|
||||
"""
|
||||
@@ -190,17 +190,16 @@ class _ModelSelector:
|
||||
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
# 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。
|
||||
# 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。
|
||||
force_new_client = request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
# 自动事件循环检测:ClientRegistry 会自动检测事件循环变化并处理缓存失效
|
||||
# 无需手动指定 force_new,embedding 请求也能享受缓存优势
|
||||
client = client_registry.get_client_class_instance(api_provider)
|
||||
|
||||
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
|
||||
# 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
|
||||
self.update_usage_penalty(model_info.name, increase=True)
|
||||
await self.update_usage_penalty(model_info.name, increase=True)
|
||||
return model_info, api_provider, client
|
||||
|
||||
def update_usage_penalty(self, model_name: str, increase: bool):
|
||||
async def update_usage_penalty(self, model_name: str, increase: bool):
|
||||
"""
|
||||
更新模型的使用惩罚值。
|
||||
|
||||
@@ -218,7 +217,7 @@ class _ModelSelector:
|
||||
# 更新模型的惩罚值
|
||||
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment)
|
||||
|
||||
def update_failure_penalty(self, model_name: str, e: Exception):
|
||||
async def update_failure_penalty(self, model_name: str, e: Exception):
|
||||
"""
|
||||
根据异常类型动态调整模型的失败惩罚值。
|
||||
关键错误(如网络连接、服务器错误)会获得更高的惩罚,
|
||||
@@ -281,7 +280,7 @@ class _PromptProcessor:
|
||||
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
|
||||
"""
|
||||
|
||||
def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
|
||||
async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
|
||||
"""
|
||||
为请求准备最终的提示词。
|
||||
|
||||
@@ -298,7 +297,7 @@ class _PromptProcessor:
|
||||
str: 处理后的、可以直接发送给模型的完整提示词。
|
||||
"""
|
||||
# 步骤1: 根据API提供商的配置应用内容混淆
|
||||
processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
|
||||
processed_prompt = await self._apply_content_obfuscation(prompt, api_provider)
|
||||
|
||||
# 步骤2: 检查模型是否需要注入反截断指令
|
||||
if getattr(model_info, "use_anti_truncation", False):
|
||||
@@ -307,14 +306,14 @@ class _PromptProcessor:
|
||||
|
||||
return processed_prompt
|
||||
|
||||
def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
|
||||
async def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
|
||||
"""
|
||||
处理响应内容,提取思维链并检查截断。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
|
||||
"""
|
||||
content, reasoning = self._extract_reasoning(content)
|
||||
content, reasoning = await self._extract_reasoning(content)
|
||||
is_truncated = False
|
||||
if use_anti_truncation:
|
||||
if content.endswith(self.end_marker):
|
||||
@@ -323,7 +322,7 @@ class _PromptProcessor:
|
||||
is_truncated = True
|
||||
return content, reasoning, is_truncated
|
||||
|
||||
def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str:
|
||||
async def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str:
|
||||
"""
|
||||
根据API提供商的配置对文本进行内容混淆。
|
||||
|
||||
@@ -349,10 +348,10 @@ class _PromptProcessor:
|
||||
processed_text = self.noise_instruction + "\n\n" + text
|
||||
|
||||
# 在拼接后的文本中注入随机噪音
|
||||
return self._inject_random_noise(processed_text, intensity)
|
||||
return await self._inject_random_noise(processed_text, intensity)
|
||||
|
||||
@staticmethod
|
||||
def _inject_random_noise(text: str, intensity: int) -> str:
|
||||
async def _inject_random_noise(text: str, intensity: int) -> str:
|
||||
"""
|
||||
在文本中按指定强度注入随机噪音字符串。
|
||||
|
||||
@@ -394,7 +393,7 @@ class _PromptProcessor:
|
||||
return " ".join(result)
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning(content: str) -> tuple[str, str]:
|
||||
async def _extract_reasoning(content: str) -> tuple[str, str]:
|
||||
"""
|
||||
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
|
||||
并返回清理后的内容和思考过程。
|
||||
@@ -490,10 +489,10 @@ class _RequestExecutor:
|
||||
except Exception as e:
|
||||
logger.debug(f"请求失败: {e!s}")
|
||||
# 记录失败并更新模型的惩罚值
|
||||
self.model_selector.update_failure_penalty(model_info.name, e)
|
||||
await self.model_selector.update_failure_penalty(model_info.name, e)
|
||||
|
||||
# 处理异常,决定是否重试以及等待多久
|
||||
wait_interval, new_compressed_messages = self._handle_exception(
|
||||
wait_interval, new_compressed_messages = await self._handle_exception(
|
||||
e,
|
||||
model_info,
|
||||
api_provider,
|
||||
@@ -513,7 +512,7 @@ class _RequestExecutor:
|
||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||
|
||||
def _handle_exception(
|
||||
async def _handle_exception(
|
||||
self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||
) -> tuple[int, list[Message] | None]:
|
||||
"""
|
||||
@@ -526,9 +525,9 @@ class _RequestExecutor:
|
||||
retry_interval = api_provider.retry_interval
|
||||
|
||||
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||
return self._check_retry(remain_try, retry_interval, "连接异常", model_name)
|
||||
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
|
||||
elif isinstance(e, RespNotOkException):
|
||||
return self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)
|
||||
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)
|
||||
elif isinstance(e, RespParseException):
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
|
||||
return -1, None
|
||||
@@ -536,7 +535,7 @@ class _RequestExecutor:
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}")
|
||||
return -1, None
|
||||
|
||||
def _handle_resp_not_ok(
|
||||
async def _handle_resp_not_ok(
|
||||
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||
) -> tuple[int, list[Message] | None]:
|
||||
"""
|
||||
@@ -578,13 +577,13 @@ class _RequestExecutor:
|
||||
# 处理请求频繁或服务器端错误,这些情况适合重试
|
||||
elif e.status_code == 429 or e.status_code >= 500:
|
||||
reason = "请求过于频繁" if e.status_code == 429 else "服务器错误"
|
||||
return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name)
|
||||
return await self._check_retry(remain_try, api_provider.retry_interval, reason, model_name)
|
||||
# 处理其他未知的HTTP错误
|
||||
else:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
|
||||
return -1, None
|
||||
|
||||
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
|
||||
async def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
|
||||
"""
|
||||
辅助函数,根据剩余次数决定是否进行下一次重试。
|
||||
|
||||
@@ -654,7 +653,7 @@ class _RequestStrategy:
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
selection_result = self.model_selector.select_best_available_model(
|
||||
selection_result = await self.model_selector.select_best_available_model(
|
||||
failed_models_in_this_request, str(request_type.value)
|
||||
)
|
||||
if selection_result is None:
|
||||
@@ -669,7 +668,7 @@ class _RequestStrategy:
|
||||
request_kwargs = kwargs.copy()
|
||||
if request_type == RequestType.RESPONSE and "prompt" in request_kwargs:
|
||||
prompt = request_kwargs.pop("prompt")
|
||||
processed_prompt = self.prompt_processor.prepare_prompt(
|
||||
processed_prompt = await self.prompt_processor.prepare_prompt(
|
||||
prompt, model_info, api_provider, self.task_name
|
||||
)
|
||||
message = MessageBuilder().add_text_content(processed_prompt).build()
|
||||
@@ -688,7 +687,7 @@ class _RequestStrategy:
|
||||
|
||||
# 成功,立即返回
|
||||
logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
|
||||
self.model_selector.update_usage_penalty(model_info.name, increase=False)
|
||||
await self.model_selector.update_usage_penalty(model_info.name, increase=False)
|
||||
return response, model_info
|
||||
|
||||
except Exception as e:
|
||||
@@ -738,7 +737,7 @@ class _RequestStrategy:
|
||||
# --- 响应内容处理和空回复/截断检查 ---
|
||||
content = response.content or ""
|
||||
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
|
||||
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(
|
||||
processed_content, reasoning, is_truncated = await self.prompt_processor.process_response(
|
||||
content, use_anti_truncation
|
||||
)
|
||||
|
||||
@@ -821,12 +820,12 @@ class LLMRequest:
|
||||
start_time = time.time()
|
||||
|
||||
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
|
||||
selection_result = self._model_selector.select_best_available_model(set(), "response")
|
||||
selection_result = await self._model_selector.select_best_available_model(set(), "response")
|
||||
if not selection_result:
|
||||
raise RuntimeError("无法为图像响应选择可用模型。")
|
||||
model_info, api_provider, client = selection_result
|
||||
|
||||
normalized_format = _normalize_image_format(image_format)
|
||||
normalized_format = await _normalize_image_format(image_format)
|
||||
message = (
|
||||
MessageBuilder()
|
||||
.add_text_content(prompt)
|
||||
@@ -849,7 +848,7 @@ class LLMRequest:
|
||||
)
|
||||
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
|
||||
content, reasoning, _ = await self._prompt_processor.process_response(response.content or "", False)
|
||||
reasoning = response.reasoning_content or reasoning
|
||||
|
||||
return content, (reasoning, model_info.name, response.tool_calls)
|
||||
@@ -935,7 +934,7 @@ class LLMRequest:
|
||||
(响应内容, (推理过程, 模型名称, 工具调用))
|
||||
"""
|
||||
start_time = time.time()
|
||||
tool_options = self._build_tool_options(tools)
|
||||
tool_options = await self._build_tool_options(tools)
|
||||
|
||||
response, model_info = await self._strategy.execute_with_failover(
|
||||
RequestType.RESPONSE,
|
||||
@@ -1008,7 +1007,7 @@ class LLMRequest:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
|
||||
async def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
|
||||
"""
|
||||
根据输入的字典列表构建并验证 `ToolOption` 对象列表。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user