docs(llm): 为 utils_model 模块补充详细文档和注释
为 `utils_model.py` 中的关键类和方法添加了全面的文档字符串和内联注释,以提升代码的可读性和可维护性。 主要变更包括: - 为 `_ModelSelector`, `_PromptProcessor`, `_RequestExecutor`, 和 `LLMRequest` 类中的核心方法扩充了详细的文档,解释其功能、参数和返回值。 - 在复杂的逻辑块(如重试机制、错误处理、内容混淆)中增加了内联注释,以阐明其实现细节。 - 移除了文件中旧的、多余的作者信息头。
This commit is contained in:
@@ -1,11 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@software:
|
||||
@file: utils_model.py
|
||||
@time: 2024/7/28 上午1:09
|
||||
@author: Mai ū
|
||||
@contact: 2496955113@qq.com
|
||||
@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。
|
||||
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
|
||||
|
||||
@@ -190,9 +185,21 @@ class _ModelSelector:
|
||||
return model_info, api_provider, client
|
||||
|
||||
def update_usage_penalty(self, model_name: str, increase: bool):
|
||||
"""更新模型的使用惩罚值,用于负载均衡。"""
|
||||
"""
|
||||
更新模型的使用惩罚值。
|
||||
|
||||
在模型被选中时增加惩罚值,请求完成后减少惩罚值。
|
||||
这有助于在短期内将请求分散到不同的模型,实现更动态的负载均衡。
|
||||
|
||||
Args:
|
||||
model_name (str): 要更新惩罚值的模型名称。
|
||||
increase (bool): True表示增加惩罚值,False表示减少。
|
||||
"""
|
||||
# 获取当前模型的统计数据
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_name]
|
||||
# 根据操作是增加还是减少来确定调整量
|
||||
adjustment = 1 if increase else -1
|
||||
# 更新模型的惩罚值
|
||||
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment)
|
||||
|
||||
def update_failure_penalty(self, model_name: str, e: Exception):
|
||||
@@ -253,11 +260,29 @@ class _PromptProcessor:
|
||||
"""
|
||||
|
||||
def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
|
||||
"""为请求准备最终的提示词,应用内容混淆和反截断指令。"""
|
||||
"""
|
||||
为请求准备最终的提示词。
|
||||
|
||||
此方法会根据API提供商和模型配置,对原始提示词应用内容混淆和反截断指令,
|
||||
生成最终发送给模型的完整提示内容。
|
||||
|
||||
Args:
|
||||
prompt (str): 原始的用户提示词。
|
||||
model_info (ModelInfo): 目标模型的信息。
|
||||
api_provider (APIProvider): API提供商的配置。
|
||||
task_name (str): 当前任务的名称,用于日志记录。
|
||||
|
||||
Returns:
|
||||
str: 处理后的、可以直接发送给模型的完整提示词。
|
||||
"""
|
||||
# 步骤1: 根据API提供商的配置应用内容混淆
|
||||
processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
|
||||
|
||||
# 步骤2: 检查模型是否需要注入反截断指令
|
||||
if getattr(model_info, "use_anti_truncation", False):
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
|
||||
|
||||
return processed_prompt
|
||||
|
||||
def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]:
|
||||
@@ -277,33 +302,73 @@ class _PromptProcessor:
|
||||
return content, reasoning, is_truncated
|
||||
|
||||
def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str:
|
||||
"""根据API提供商配置对文本进行混淆处理。"""
|
||||
"""
|
||||
根据API提供商的配置对文本进行内容混淆。
|
||||
|
||||
如果提供商配置中启用了内容混淆,此方法会在文本前部加入抗审查指令,
|
||||
并在文本中注入随机噪音,以降低内容被审查或修改的风险。
|
||||
|
||||
Args:
|
||||
text (str): 原始文本内容。
|
||||
api_provider (APIProvider): API提供商的配置。
|
||||
|
||||
Returns:
|
||||
str: 经过混淆处理的文本。
|
||||
"""
|
||||
# 检查当前API提供商是否启用了内容混淆功能
|
||||
if not getattr(api_provider, "enable_content_obfuscation", False):
|
||||
return text
|
||||
|
||||
# 获取混淆强度,默认为1
|
||||
intensity = getattr(api_provider, "obfuscation_intensity", 1)
|
||||
logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
|
||||
|
||||
# 将抗审查指令和原始文本拼接
|
||||
processed_text = self.noise_instruction + "\n\n" + text
|
||||
|
||||
# 在拼接后的文本中注入随机噪音
|
||||
return self._inject_random_noise(processed_text, intensity)
|
||||
|
||||
@staticmethod
|
||||
def _inject_random_noise(text: str, intensity: int) -> str:
|
||||
"""在文本中注入随机乱码。"""
|
||||
"""
|
||||
在文本中按指定强度注入随机噪音字符串。
|
||||
|
||||
该方法通过在文本的单词之间随机插入无意义的字符串(噪音)来实现内容混淆。
|
||||
强度越高,插入噪音的概率和长度就越大。
|
||||
|
||||
Args:
|
||||
text (str): 待处理的文本。
|
||||
intensity (int): 混淆强度 (1-3),决定噪音的概率和长度。
|
||||
|
||||
Returns:
|
||||
str: 注入噪音后的文本。
|
||||
"""
|
||||
# 定义不同强度级别的噪音参数:概率和长度范围
|
||||
params = {
|
||||
1: {"probability": 15, "length": (3, 6)},
|
||||
2: {"probability": 25, "length": (5, 10)},
|
||||
3: {"probability": 35, "length": (8, 15)},
|
||||
1: {"probability": 15, "length": (3, 6)}, # 低强度
|
||||
2: {"probability": 25, "length": (5, 10)}, # 中强度
|
||||
3: {"probability": 35, "length": (8, 15)}, # 高强度
|
||||
}
|
||||
# 根据传入的强度选择配置,如果强度无效则使用默认值
|
||||
config = params.get(intensity, params[1])
|
||||
|
||||
words = text.split()
|
||||
result = []
|
||||
# 遍历每个单词
|
||||
for word in words:
|
||||
result.append(word)
|
||||
# 根据概率决定是否在此单词后注入噪音
|
||||
if random.randint(1, 100) <= config["probability"]:
|
||||
# 确定噪音的长度
|
||||
noise_length = random.randint(*config["length"])
|
||||
# 定义噪音字符集
|
||||
chars = string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
# 生成噪音字符串
|
||||
noise = "".join(random.choice(chars) for _ in range(noise_length))
|
||||
result.append(noise)
|
||||
|
||||
# 将处理后的单词列表重新组合成字符串
|
||||
return " ".join(result)
|
||||
|
||||
@staticmethod
|
||||
@@ -448,30 +513,68 @@ class _RequestExecutor:
|
||||
def _handle_resp_not_ok(
|
||||
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||
) -> Tuple[int, Optional[List[Message]]]:
|
||||
"""处理非200的HTTP响应异常。"""
|
||||
"""
|
||||
处理非200的HTTP响应异常。
|
||||
|
||||
根据不同的HTTP状态码决定下一步操作:
|
||||
- 4xx 客户端错误:通常不可重试,直接放弃。
|
||||
- 413 (Payload Too Large): 尝试压缩消息体后重试一次。
|
||||
- 429 (Too Many Requests) / 5xx 服务器错误:可重试。
|
||||
|
||||
Args:
|
||||
e (RespNotOkException): 捕获到的响应异常。
|
||||
model_info (ModelInfo): 当前模型信息。
|
||||
api_provider (APIProvider): API提供商配置。
|
||||
remain_try (int): 剩余重试次数。
|
||||
messages_info (tuple): 包含消息列表和是否已压缩的标志。
|
||||
|
||||
Returns:
|
||||
Tuple[int, Optional[List[Message]]]: (等待间隔, 新的消息列表)。
|
||||
等待间隔为-1表示不再重试。新的消息列表用于压缩后重试。
|
||||
"""
|
||||
model_name = model_info.name
|
||||
# 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试
|
||||
if e.status_code in [400, 401, 402, 403, 404]:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。")
|
||||
return -1, None
|
||||
# 处理请求体过大的情况
|
||||
elif e.status_code == 413:
|
||||
messages, is_compressed = messages_info
|
||||
# 如果消息存在且尚未被压缩,则尝试压缩后立即重试
|
||||
if messages and not is_compressed:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试。")
|
||||
return 0, compress_messages(messages)
|
||||
# 如果已经压缩过或没有消息体,则放弃
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大且无法压缩,放弃请求。")
|
||||
return -1, None
|
||||
# 处理请求频繁或服务器端错误,这些情况适合重试
|
||||
elif e.status_code == 429 or e.status_code >= 500:
|
||||
reason = "请求过于频繁" if e.status_code == 429 else "服务器错误"
|
||||
return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name)
|
||||
# 处理其他未知的HTTP错误
|
||||
else:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
|
||||
return -1, None
|
||||
|
||||
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]:
|
||||
"""辅助函数:检查是否可以重试。"""
|
||||
if remain_try > 1: # 剩余次数大于1才重试
|
||||
"""
|
||||
辅助函数,根据剩余次数决定是否进行下一次重试。
|
||||
|
||||
Args:
|
||||
remain_try (int): 剩余的重试次数。
|
||||
interval (int): 重试前的等待间隔(秒)。
|
||||
reason (str): 本次失败的原因。
|
||||
model_name (str): 失败的模型名称。
|
||||
|
||||
Returns:
|
||||
Tuple[int, None]: (等待间隔, None)。如果等待间隔为-1,表示不应再重试。
|
||||
"""
|
||||
# 只有在剩余重试次数大于1时才进行下一次重试(因为当前这次失败已经消耗掉一次)
|
||||
if remain_try > 1:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。")
|
||||
return interval, None
|
||||
|
||||
# 如果已无剩余重试次数,则记录错误并返回-1表示放弃
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。")
|
||||
return -1, None
|
||||
|
||||
@@ -822,16 +925,28 @@ class LLMRequest:
|
||||
return response.embedding, model_info.name
|
||||
|
||||
def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
||||
"""异步记录用量到数据库。"""
|
||||
"""
|
||||
记录模型使用情况。
|
||||
|
||||
此方法首先在内存中更新模型的累计token使用量,然后创建一个异步任务,
|
||||
将详细的用量数据(包括模型信息、token数、耗时等)写入数据库。
|
||||
|
||||
Args:
|
||||
model_info (ModelInfo): 使用的模型信息。
|
||||
usage (Optional[UsageRecord]): API返回的用量记录。
|
||||
time_cost (float): 本次请求的总耗时。
|
||||
endpoint (str): 请求的API端点 (e.g., "/chat/completions")。
|
||||
"""
|
||||
if usage:
|
||||
# 更新内存中的token计数
|
||||
# 步骤1: 更新内存中的token计数,用于负载均衡
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
|
||||
|
||||
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库
|
||||
asyncio.create_task(llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
user_id="system", # 此处可根据业务需求修改
|
||||
time_cost=time_cost,
|
||||
request_type=self.task_name,
|
||||
endpoint=endpoint,
|
||||
@@ -839,15 +954,34 @@ class LLMRequest:
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||
"""构建工具选项列表。"""
|
||||
"""
|
||||
根据输入的字典列表构建并验证 `ToolOption` 对象列表。
|
||||
|
||||
此方法将标准化的工具定义(字典格式)转换为内部使用的 `ToolOption` 对象,
|
||||
同时会验证参数格式的正确性。
|
||||
|
||||
Args:
|
||||
tools (Optional[List[Dict[str, Any]]]): 工具定义的列表。
|
||||
每个工具是一个字典,包含 "name", "description", 和 "parameters"。
|
||||
"parameters" 是一个元组列表,每个元组包含 (name, type, desc, required, enum)。
|
||||
|
||||
Returns:
|
||||
Optional[List[ToolOption]]: 构建好的 `ToolOption` 对象列表,如果输入为空则返回 None。
|
||||
"""
|
||||
# 如果没有提供工具,直接返回 None
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
tool_options: List[ToolOption] = []
|
||||
# 遍历每个工具定义
|
||||
for tool in tools:
|
||||
try:
|
||||
# 使用建造者模式创建 ToolOption
|
||||
builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", ""))
|
||||
|
||||
# 遍历工具的参数
|
||||
for param in tool.get("parameters", []):
|
||||
# 参数格式验证
|
||||
# 严格验证参数格式是否为包含5个元素的元组
|
||||
assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组"
|
||||
builder.add_param(
|
||||
name=param[0],
|
||||
@@ -856,7 +990,11 @@ class LLMRequest:
|
||||
required=param[3],
|
||||
enum_values=param[4],
|
||||
)
|
||||
# 将构建好的 ToolOption 添加到列表中
|
||||
tool_options.append(builder.build())
|
||||
except (KeyError, IndexError, TypeError, AssertionError) as e:
|
||||
# 如果构建过程中出现任何错误,记录日志并跳过该工具
|
||||
logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}")
|
||||
|
||||
# 如果列表非空则返回列表,否则返回 None
|
||||
return tool_options or None
|
||||
|
||||
Reference in New Issue
Block a user