style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 19:38:39 +08:00
parent d5627b0661
commit ecb02cae31
111 changed files with 2344 additions and 2316 deletions

View File

@@ -18,6 +18,7 @@
- **LLMRequest (主接口)**:
作为模块的统一入口Facade为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。
"""
import re
import asyncio
import time
@@ -26,14 +27,13 @@ import string
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder
from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord
from .utils import compress_messages, llm_usage_recorder
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
@@ -46,6 +46,7 @@ logger = get_logger("model_utils")
# Standalone Utility Functions
# ==============================================================================
def _normalize_image_format(image_format: str) -> str:
"""
标准化图片格式名称确保与各种API的兼容性
@@ -57,17 +58,26 @@ def _normalize_image_format(image_format: str) -> str:
str: 标准化后的图片格式
"""
format_mapping = {
"jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg",
"png": "png", "PNG": "png",
"webp": "webp", "WEBP": "webp",
"gif": "gif", "GIF": "gif",
"heic": "heic", "HEIC": "heic",
"heif": "heif", "HEIF": "heif",
"jpg": "jpeg",
"JPG": "jpeg",
"JPEG": "jpeg",
"jpeg": "jpeg",
"png": "png",
"PNG": "png",
"webp": "webp",
"WEBP": "webp",
"gif": "gif",
"GIF": "gif",
"heic": "heic",
"HEIC": "heic",
"heif": "heif",
"HEIF": "heif",
}
normalized = format_mapping.get(image_format, image_format.lower())
logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
return normalized
async def execute_concurrently(
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
concurrency_count: int,
@@ -103,25 +113,29 @@ async def execute_concurrently(
for i, res in enumerate(results):
if isinstance(res, Exception):
logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
first_exception = next((res for res in results if isinstance(res, Exception)), None)
if first_exception:
raise first_exception
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
class RequestType(Enum):
"""请求类型枚举"""
RESPONSE = "response"
EMBEDDING = "embedding"
AUDIO = "audio"
# ==============================================================================
# Helper Classes for LLMRequest Refactoring
# ==============================================================================
class _ModelSelector:
"""负责模型选择、负载均衡和动态故障切换的策略。"""
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
@@ -168,16 +182,18 @@ class _ModelSelector:
# - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。
least_used_model_name = min(
candidate_models_usage,
key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000,
key=lambda k: candidate_models_usage[k][0]
+ candidate_models_usage[k][1] * 300
+ candidate_models_usage[k][2] * 1000,
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
# 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。
# 这是为了避免在某些高并发场景下共享的ClientSession可能引发的事件循环相关问题。
force_new_client = request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
# 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
self.update_usage_penalty(model_info.name, increase=True)
@@ -214,26 +230,32 @@ class _ModelSelector:
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
# 网络连接错误或请求被中断,通常是基础设施问题,应重罚
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}")
logger.warning(
f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}"
)
elif isinstance(e, RespNotOkException):
# 对于HTTP响应错误重点关注服务器端错误
if e.status_code >= 500:
# 5xx 错误表明服务器端出现问题,应重罚
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}")
logger.warning(
f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}"
)
else:
# 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚
logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
logger.warning(
f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}"
)
else:
# 其他未知异常,给予基础惩罚
logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
class _PromptProcessor:
"""封装所有与提示词和响应内容的预处理和后处理逻辑。"""
def __init__(self):
"""
初始化提示处理器。
@@ -276,18 +298,18 @@ class _PromptProcessor:
"""
# 步骤1: 根据API提供商的配置应用内容混淆
processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
# 步骤2: 检查模型是否需要注入反截断指令
if getattr(model_info, "use_anti_truncation", False):
processed_prompt += self.anti_truncation_instruction
logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
return processed_prompt
def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]:
"""
处理响应内容,提取思维链并检查截断。
Returns:
Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
"""
@@ -317,14 +339,14 @@ class _PromptProcessor:
# 检查当前API提供商是否启用了内容混淆功能
if not getattr(api_provider, "enable_content_obfuscation", False):
return text
# 获取混淆强度默认为1
intensity = getattr(api_provider, "obfuscation_intensity", 1)
logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
# 将抗审查指令和原始文本拼接
processed_text = self.noise_instruction + "\n\n" + text
# 在拼接后的文本中注入随机噪音
return self._inject_random_noise(processed_text, intensity)
@@ -346,12 +368,12 @@ class _PromptProcessor:
# 定义不同强度级别的噪音参数:概率和长度范围
params = {
1: {"probability": 15, "length": (3, 6)}, # 低强度
2: {"probability": 25, "length": (5, 10)}, # 中强度
3: {"probability": 35, "length": (8, 15)}, # 高强度
2: {"probability": 25, "length": (5, 10)}, # 中强度
3: {"probability": 35, "length": (8, 15)}, # 高强度
}
# 根据传入的强度选择配置,如果强度无效则使用默认值
config = params.get(intensity, params[1])
words = text.split()
result = []
# 遍历每个单词
@@ -366,7 +388,7 @@ class _PromptProcessor:
# 生成噪音字符串
noise = "".join(random.choice(chars) for _ in range(noise_length))
result.append(noise)
# 将处理后的单词列表重新组合成字符串
return " ".join(result)
@@ -396,7 +418,7 @@ class _PromptProcessor:
else:
reasoning = ""
clean_content = content.strip()
return clean_content, reasoning
@@ -441,7 +463,7 @@ class _RequestExecutor:
"""
retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
while retry_remain > 0:
try:
# 优先使用压缩后的消息列表
@@ -451,11 +473,11 @@ class _RequestExecutor:
# 根据请求类型调用不同的客户端方法
if request_type == RequestType.RESPONSE:
assert current_messages is not None, "message_list cannot be None for response requests"
# 修复: 防止 'message_list' 在 kwargs 中重复传递
request_params = kwargs.copy()
request_params.pop("message_list", None)
return await client.get_response(
model_info=model_info, message_list=current_messages, **request_params
)
@@ -463,15 +485,19 @@ class _RequestExecutor:
return await client.get_embedding(model_info=model_info, **kwargs)
elif request_type == RequestType.AUDIO:
return await client.get_audio_transcriptions(model_info=model_info, **kwargs)
except Exception as e:
logger.debug(f"请求失败: {str(e)}")
# 记录失败并更新模型的惩罚值
self.model_selector.update_failure_penalty(model_info.name, e)
# 处理异常,决定是否重试以及等待多久
wait_interval, new_compressed_messages = self._handle_exception(
e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None)
e,
model_info,
api_provider,
retry_remain,
(kwargs.get("message_list"), compressed_messages is not None),
)
if new_compressed_messages:
compressed_messages = new_compressed_messages # 更新为压缩后的消息
@@ -482,7 +508,7 @@ class _RequestExecutor:
await asyncio.sleep(wait_interval) # 等待指定时间后重试
finally:
retry_remain -= 1
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
@@ -491,7 +517,7 @@ class _RequestExecutor:
) -> Tuple[int, Optional[List[Message]]]:
"""
默认异常处理函数,决定是否重试。
Returns:
(等待间隔(-1表示不再重试, 新的消息列表(适用于压缩消息))
"""
@@ -534,7 +560,9 @@ class _RequestExecutor:
model_name = model_info.name
# 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试
if e.status_code in [400, 401, 402, 403, 404]:
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。")
logger.warning(
f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。"
)
return -1, None
# 处理请求体过大的情况
elif e.status_code == 413:
@@ -570,9 +598,11 @@ class _RequestExecutor:
"""
# 只有在剩余重试次数大于1时才进行下一次重试因为当前这次失败已经消耗掉一次
if remain_try > 1:
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。")
logger.warning(
f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。"
)
return interval, None
# 如果已无剩余重试次数,则记录错误并返回-1表示放弃
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。")
return -1, None
@@ -585,7 +615,14 @@ class _RequestStrategy:
即使在单个模型或API端点失败的情况下也能正常工作。
"""
def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str):
def __init__(
self,
model_selector: _ModelSelector,
prompt_processor: _PromptProcessor,
executor: _RequestExecutor,
model_list: List[str],
task_name: str,
):
"""
初始化请求策略。
@@ -616,11 +653,13 @@ class _RequestStrategy:
last_exception: Optional[Exception] = None
for attempt in range(max_attempts):
selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value))
selection_result = self.model_selector.select_best_available_model(
failed_models_in_this_request, str(request_type.value)
)
if selection_result is None:
logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
break
model_info, api_provider, client = selection_result
logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...")
@@ -637,32 +676,36 @@ class _RequestStrategy:
# 合并模型特定的额外参数
if model_info.extra_params:
request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})}
request_kwargs["extra_params"] = {
**model_info.extra_params,
**request_kwargs.get("extra_params", {}),
}
response = await self._try_model_request(
model_info, api_provider, client, request_type, **request_kwargs
)
response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs)
# 成功,立即返回
logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
self.model_selector.update_usage_penalty(model_info.name, increase=False)
return response, model_info
except Exception as e:
logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
failed_models_in_this_request.add(model_info.name)
last_exception = e
# 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率
logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
if raise_when_empty:
if last_exception:
raise RuntimeError("所有模型均未能生成响应。") from last_exception
raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
# 如果不抛出异常,返回一个备用响应
fallback_model_info = model_config.get_model_info(self.model_list[0])
return APIResponse(content="所有模型都请求失败"), fallback_model_info
async def _try_model_request(
self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs
) -> APIResponse:
@@ -684,46 +727,49 @@ class _RequestStrategy:
RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。
"""
max_empty_retry = api_provider.max_retry
for i in range(max_empty_retry + 1):
response = await self.executor.execute_request(
api_provider, client, request_type, model_info, **kwargs
)
response = await self.executor.execute_request(api_provider, client, request_type, model_info, **kwargs)
if request_type != RequestType.RESPONSE:
return response # 对于非响应类型,直接返回
return response # 对于非响应类型,直接返回
# --- 响应内容处理和空回复/截断检查 ---
content = response.content or ""
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation)
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(
content, use_anti_truncation
)
# 更新响应对象
response.content = processed_content
response.reasoning_content = response.reasoning_content or reasoning
is_empty_reply = not response.tool_calls and not (response.content and response.content.strip())
if not is_empty_reply and not is_truncated:
return response # 成功获取有效响应
return response # 成功获取有效响应
if i < max_empty_retry:
reason = "空回复" if is_empty_reply else "截断"
logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...")
logger.warning(
f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})..."
)
if api_provider.retry_interval > 0:
await asyncio.sleep(api_provider.retry_interval)
else:
reason = "空回复" if is_empty_reply else "截断"
logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。")
raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。")
raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
# ==============================================================================
# Main Facade Class
# ==============================================================================
class LLMRequest:
"""
LLM请求协调器。
@@ -745,7 +791,7 @@ class LLMRequest:
model: (0, 0, 0) for model in self.model_for_task.model_list
}
"""模型使用量记录,(total_tokens, penalty, usage_penalty)"""
# 初始化辅助类
self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
self._prompt_processor = _PromptProcessor()
@@ -769,36 +815,44 @@ class LLMRequest:
prompt (str): 提示词
image_base64 (str): 图像的Base64编码字符串
image_format (str): 图像格式(如 'png', 'jpeg' 等)
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
start_time = time.time()
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
selection_result = self._model_selector.select_best_available_model(set(), "response")
if not selection_result:
raise RuntimeError("无法为图像响应选择可用模型。")
model_info, api_provider, client = selection_result
normalized_format = _normalize_image_format(image_format)
message = MessageBuilder().add_text_content(prompt).add_image_content(
image_base64=image_base64,
image_format=normalized_format,
support_formats=client.get_support_image_formats(),
).build()
message = (
MessageBuilder()
.add_text_content(prompt)
.add_image_content(
image_base64=image_base64,
image_format=normalized_format,
support_formats=client.get_support_image_formats(),
)
.build()
)
response = await self._executor.execute_request(
api_provider, client, RequestType.RESPONSE, model_info,
api_provider,
client,
RequestType.RESPONSE,
model_info,
message_list=[message],
temperature=temperature,
max_tokens=max_tokens,
)
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
reasoning = response.reasoning_content or reasoning
return content, (reasoning, model_info.name, response.tool_calls)
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
@@ -812,9 +866,7 @@ class LLMRequest:
Returns:
Optional[str]: 语音转换后的文本内容如果所有模型都失败则返回None。
"""
response, _ = await self._strategy.execute_with_failover(
RequestType.AUDIO, audio_base64=voice_base64
)
response, _ = await self._strategy.execute_with_failover(RequestType.AUDIO, audio_base64=voice_base64)
return response.content or None
async def generate_response_async(
@@ -834,7 +886,7 @@ class LLMRequest:
max_tokens (int, optional): 最大token数
tools: 工具配置
raise_when_empty (bool): 是否在空回复时抛出异常
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
@@ -842,12 +894,16 @@ class LLMRequest:
if concurrency_count <= 1:
return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty)
try:
return await execute_concurrently(
self._execute_single_text_request,
concurrency_count,
prompt, temperature, max_tokens, tools, raise_when_empty=False
prompt,
temperature,
max_tokens,
tools,
raise_when_empty=False,
)
except Exception as e:
logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
@@ -885,7 +941,7 @@ class LLMRequest:
response, model_info = await self._strategy.execute_with_failover(
RequestType.RESPONSE,
raise_when_empty=raise_when_empty,
prompt=prompt, # 传递原始prompt由strategy处理
prompt=prompt, # 传递原始prompt由strategy处理
tool_options=tool_options,
temperature=self.model_for_task.temperature if temperature is None else temperature,
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
@@ -906,21 +962,20 @@ class LLMRequest:
Args:
embedding_input (str): 获取嵌入的目标
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
start_time = time.time()
response, model_info = await self._strategy.execute_with_failover(
RequestType.EMBEDDING,
embedding_input=embedding_input
RequestType.EMBEDDING, embedding_input=embedding_input
)
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
if not response.embedding:
raise RuntimeError("获取embedding失败")
return response.embedding, model_info.name
async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
@@ -940,16 +995,18 @@ class LLMRequest:
# 步骤1: 更新内存中的token计数用于负载均衡
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库
asyncio.create_task(llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system", # 此处可根据业务需求修改
time_cost=time_cost,
request_type=self.task_name,
endpoint=endpoint,
))
asyncio.create_task(
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system", # 此处可根据业务需求修改
time_cost=time_cost,
request_type=self.task_name,
endpoint=endpoint,
)
)
@staticmethod
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
@@ -970,14 +1027,14 @@ class LLMRequest:
# 如果没有提供工具,直接返回 None
if not tools:
return None
tool_options: List[ToolOption] = []
# 遍历每个工具定义
for tool in tools:
try:
# 使用建造者模式创建 ToolOption
builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", ""))
# 遍历工具的参数
for param in tool.get("parameters", []):
# 严格验证参数格式是否为包含5个元素的元组
@@ -994,6 +1051,6 @@ class LLMRequest:
except (KeyError, IndexError, TypeError, AssertionError) as e:
# 如果构建过程中出现任何错误,记录日志并跳过该工具
logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}")
# 如果列表非空则返回列表,否则返回 None
return tool_options or None