diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index e3699d540..3efa9cd2d 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,11 +1,6 @@ # -*- coding: utf-8 -*- # -*- coding: utf-8 -*- """ -@software: -@file: utils_model.py -@time: 2024/7/28 上午1:09 -@author: Mai ū -@contact: 2496955113@qq.com @desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。 它被设计为一个高度容错和可扩展的系统,包含以下主要组件: @@ -190,9 +185,21 @@ class _ModelSelector: return model_info, api_provider, client def update_usage_penalty(self, model_name: str, increase: bool): - """更新模型的使用惩罚值,用于负载均衡。""" + """ + 更新模型的使用惩罚值。 + + 在模型被选中时增加惩罚值,请求完成后减少惩罚值。 + 这有助于在短期内将请求分散到不同的模型,实现更动态的负载均衡。 + + Args: + model_name (str): 要更新惩罚值的模型名称。 + increase (bool): True表示增加惩罚值,False表示减少。 + """ + # 获取当前模型的统计数据 total_tokens, penalty, usage_penalty = self.model_usage[model_name] + # 根据操作是增加还是减少来确定调整量 adjustment = 1 if increase else -1 + # 更新模型的惩罚值 self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment) def update_failure_penalty(self, model_name: str, e: Exception): @@ -253,11 +260,29 @@ class _PromptProcessor: """ def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str: - """为请求准备最终的提示词,应用内容混淆和反截断指令。""" + """ + 为请求准备最终的提示词。 + + 此方法会根据API提供商和模型配置,对原始提示词应用内容混淆和反截断指令, + 生成最终发送给模型的完整提示内容。 + + Args: + prompt (str): 原始的用户提示词。 + model_info (ModelInfo): 目标模型的信息。 + api_provider (APIProvider): API提供商的配置。 + task_name (str): 当前任务的名称,用于日志记录。 + + Returns: + str: 处理后的、可以直接发送给模型的完整提示词。 + """ + # 步骤1: 根据API提供商的配置应用内容混淆 processed_prompt = self._apply_content_obfuscation(prompt, api_provider) + + # 步骤2: 检查模型是否需要注入反截断指令 if getattr(model_info, "use_anti_truncation", False): processed_prompt += self.anti_truncation_instruction logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") + return processed_prompt def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]: @@ -277,33 +302,73 @@ class _PromptProcessor: return content, reasoning, is_truncated def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str: - """根据API提供商配置对文本进行混淆处理。""" + """ + 根据API提供商的配置对文本进行内容混淆。 + + 如果提供商配置中启用了内容混淆,此方法会在文本前部加入抗审查指令, + 并在文本中注入随机噪音,以降低内容被审查或修改的风险。 + + Args: + text (str): 原始文本内容。 + api_provider (APIProvider): API提供商的配置。 + + Returns: + str: 经过混淆处理的文本。 + """ + # 检查当前API提供商是否启用了内容混淆功能 if not getattr(api_provider, "enable_content_obfuscation", False): return text + # 获取混淆强度,默认为1 intensity = getattr(api_provider, "obfuscation_intensity", 1) logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") + + # 将抗审查指令和原始文本拼接 processed_text = self.noise_instruction + "\n\n" + text + + # 在拼接后的文本中注入随机噪音 return self._inject_random_noise(processed_text, intensity) @staticmethod def _inject_random_noise(text: str, intensity: int) -> str: - """在文本中注入随机乱码。""" + """ + 在文本中按指定强度注入随机噪音字符串。 + + 该方法通过在文本的单词之间随机插入无意义的字符串(噪音)来实现内容混淆。 + 强度越高,插入噪音的概率和长度就越大。 + + Args: + text (str): 待处理的文本。 + intensity (int): 混淆强度 (1-3),决定噪音的概率和长度。 + + Returns: + str: 注入噪音后的文本。 + """ + # 定义不同强度级别的噪音参数:概率和长度范围 params = { - 1: {"probability": 15, "length": (3, 6)}, - 2: {"probability": 25, "length": (5, 10)}, - 3: {"probability": 35, "length": (8, 15)}, + 1: {"probability": 15, "length": (3, 6)}, # 低强度 + 2: {"probability": 25, "length": (5, 10)}, # 中强度 + 3: {"probability": 35, "length": (8, 15)}, # 高强度 } + # 根据传入的强度选择配置,如果强度无效则使用默认值 config = params.get(intensity, params[1]) + words = text.split() result = [] + # 遍历每个单词 for word in words: result.append(word) + # 根据概率决定是否在此单词后注入噪音 if random.randint(1, 100) <= config["probability"]: + # 确定噪音的长度 noise_length = random.randint(*config["length"]) + # 定义噪音字符集 chars = string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?" + # 生成噪音字符串 noise = "".join(random.choice(chars) for _ in range(noise_length)) result.append(noise) + + # 将处理后的单词列表重新组合成字符串 return " ".join(result) @staticmethod @@ -448,30 +513,68 @@ class _RequestExecutor: def _handle_resp_not_ok( self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info ) -> Tuple[int, Optional[List[Message]]]: - """处理非200的HTTP响应异常。""" + """ + 处理非200的HTTP响应异常。 + + 根据不同的HTTP状态码决定下一步操作: + - 4xx 客户端错误:通常不可重试,直接放弃。 + - 413 (Payload Too Large): 尝试压缩消息体后重试一次。 + - 429 (Too Many Requests) / 5xx 服务器错误:可重试。 + + Args: + e (RespNotOkException): 捕获到的响应异常。 + model_info (ModelInfo): 当前模型信息。 + api_provider (APIProvider): API提供商配置。 + remain_try (int): 剩余重试次数。 + messages_info (tuple): 包含消息列表和是否已压缩的标志。 + + Returns: + Tuple[int, Optional[List[Message]]]: (等待间隔, 新的消息列表)。 + 等待间隔为-1表示不再重试。新的消息列表用于压缩后重试。 + """ model_name = model_info.name + # 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试 if e.status_code in [400, 401, 402, 403, 404]: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。") return -1, None + # 处理请求体过大的情况 elif e.status_code == 413: messages, is_compressed = messages_info + # 如果消息存在且尚未被压缩,则尝试压缩后立即重试 if messages and not is_compressed: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试。") return 0, compress_messages(messages) + # 如果已经压缩过或没有消息体,则放弃 logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大且无法压缩,放弃请求。") return -1, None + # 处理请求频繁或服务器端错误,这些情况适合重试 elif e.status_code == 429 or e.status_code >= 500: reason = "请求过于频繁" if e.status_code == 429 else "服务器错误" return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name) + # 处理其他未知的HTTP错误 else: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}") return -1, None def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]: - """辅助函数:检查是否可以重试。""" - if remain_try > 1: # 剩余次数大于1才重试 + """ + 辅助函数,根据剩余次数决定是否进行下一次重试。 + + Args: + remain_try (int): 剩余的重试次数。 + interval (int): 重试前的等待间隔(秒)。 + reason (str): 本次失败的原因。 + model_name (str): 失败的模型名称。 + + Returns: + Tuple[int, None]: (等待间隔, None)。如果等待间隔为-1,表示不应再重试。 + """ + # 只有在剩余重试次数大于1时才进行下一次重试(因为当前这次失败已经消耗掉一次) + if remain_try > 1: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。") return interval, None + + # 如果已无剩余重试次数,则记录错误并返回-1表示放弃 logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。") return -1, None @@ -822,16 +925,28 @@ class LLMRequest: return response.embedding, model_info.name def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): - """异步记录用量到数据库。""" + """ + 记录模型使用情况。 + + 此方法首先在内存中更新模型的累计token使用量,然后创建一个异步任务, + 将详细的用量数据(包括模型信息、token数、耗时等)写入数据库。 + + Args: + model_info (ModelInfo): 使用的模型信息。 + usage (Optional[UsageRecord]): API返回的用量记录。 + time_cost (float): 本次请求的总耗时。 + endpoint (str): 请求的API端点 (e.g., "/chat/completions")。 + """ if usage: - # 更新内存中的token计数 + # 步骤1: 更新内存中的token计数,用于负载均衡 total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty) + # 步骤2: 创建一个后台任务,将用量数据异步写入数据库 asyncio.create_task(llm_usage_recorder.record_usage_to_database( model_info=model_info, model_usage=usage, - user_id="system", + user_id="system", # 此处可根据业务需求修改 time_cost=time_cost, request_type=self.task_name, endpoint=endpoint, @@ -839,15 +954,34 @@ class LLMRequest: @staticmethod def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: - """构建工具选项列表。""" + """ + 根据输入的字典列表构建并验证 `ToolOption` 对象列表。 + + 此方法将标准化的工具定义(字典格式)转换为内部使用的 `ToolOption` 对象, + 同时会验证参数格式的正确性。 + + Args: + tools (Optional[List[Dict[str, Any]]]): 工具定义的列表。 + 每个工具是一个字典,包含 "name", "description", 和 "parameters"。 + "parameters" 是一个元组列表,每个元组包含 (name, type, desc, required, enum)。 + + Returns: + Optional[List[ToolOption]]: 构建好的 `ToolOption` 对象列表,如果输入为空则返回 None。 + """ + # 如果没有提供工具,直接返回 None if not tools: return None + tool_options: List[ToolOption] = [] + # 遍历每个工具定义 for tool in tools: try: + # 使用建造者模式创建 ToolOption builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", "")) + + # 遍历工具的参数 for param in tool.get("parameters", []): - # 参数格式验证 + # 严格验证参数格式是否为包含5个元素的元组 assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" builder.add_param( name=param[0], @@ -856,7 +990,11 @@ class LLMRequest: required=param[3], enum_values=param[4], ) + # 将构建好的 ToolOption 添加到列表中 tool_options.append(builder.build()) except (KeyError, IndexError, TypeError, AssertionError) as e: + # 如果构建过程中出现任何错误,记录日志并跳过该工具 logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}") + + # 如果列表非空则返回列表,否则返回 None return tool_options or None