From 921eaddf9c5bf236d9729f31a4fa60a9a09e0abd Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Tue, 23 Sep 2025 14:52:42 +0800 Subject: [PATCH 01/41] =?UTF-8?q?ci(workflow):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E5=88=9B=E5=BB=BA=E9=A2=84=E5=8F=91=E5=B8=83?= =?UTF-8?q?=E7=9A=84=E5=B7=A5=E4=BD=9C=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 该工作流会在每次推送到 master 分支时创建一个预发布版本,现已不再需要此自动化流程。 --- .github/workflows/create-prerelease.yml | 34 ------------------------- 1 file changed, 34 deletions(-) delete mode 100644 .github/workflows/create-prerelease.yml diff --git a/.github/workflows/create-prerelease.yml b/.github/workflows/create-prerelease.yml deleted file mode 100644 index ea0cedcdf..000000000 --- a/.github/workflows/create-prerelease.yml +++ /dev/null @@ -1,34 +0,0 @@ -# 当代码推送到 master 分支时,自动创建一个 pre-release - -name: Create Pre-release - -on: - push: - branches: - - master - -jobs: - create-prerelease: - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - # 获取所有提交历史,以便生成 release notes - fetch-depth: 0 - - - name: Generate tag name - id: generate_tag - run: echo "TAG_NAME=MoFox-prerelease-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT - - - name: Create Pre-release - env: - # 使用仓库自带的 GITHUB_TOKEN 进行认证 - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - gh release create ${{ steps.generate_tag.outputs.TAG_NAME }} \ - --title "Pre-release ${{ steps.generate_tag.outputs.TAG_NAME }}" \ - --prerelease \ - --generate-notes \ No newline at end of file From a32759687bcd367635ccd6c89b500a5ee9937bf6 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:24:28 +0800 Subject: [PATCH 02/41] =?UTF-8?q?feat(chat):=20=E5=A2=9E=E5=8A=A0=E5=B7=B2?= =?UTF-8?q?=E8=AF=BB=E6=A0=87=E8=AE=B0=E4=BB=A5=E8=81=9A=E7=84=A6=E6=9C=AA?= =?UTF-8?q?=E8=AF=BB=E6=B6=88=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为聊天上下文生成逻辑引入了“已读标记” (read_mark) 机制。 当生成回复时,可以在历史消息中插入一个明确的分隔符,以告知模型哪些消息是它已经看过的旧消息,哪些是需要关注的新消息。 这有助于模型更好地聚焦于未读内容,提升上下文感知能力和回复的相关性。 同时,将 Prompt 模板中的“群聊”等硬编码文本参数化,以更好地适配私聊等不同聊天场景。 --- src/chat/chat_loop/cycle_processor.py | 1 + src/chat/replyer/default_generator.py | 25 +++++++++++++------------ src/chat/utils/chat_message_builder.py | 10 ++++++++-- src/chat/utils/prompt.py | 10 +++++++--- src/plugin_system/apis/generator_api.py | 2 ++ 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py index 441d14b1f..e571b5ac5 100644 --- a/src/chat/chat_loop/cycle_processor.py +++ b/src/chat/chat_loop/cycle_processor.py @@ -267,6 +267,7 @@ class CycleProcessor: enable_tool=global_config.tool.enable_tool, request_type="chat.replyer", from_plugin=False, + read_mark=action_info.get("action_message", {}).get("time", 0.0), ) if not success or not response_set: logger.info( diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 127779e1e..c2cd9fc08 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -83,12 +83,12 @@ def init_prompt(): - {schedule_block} ## 历史记录 -### 当前群聊中的所有人的聊天记录: +### {chat_context_type}中的所有人的聊天记录: {background_dialogue_prompt} {cross_context_block} -### 当前群聊中正在与你对话的聊天记录 +### {chat_context_type}中正在与你对话的聊天记录 {core_dialogue_prompt} ## 表达方式 @@ -110,12 +110,11 @@ def init_prompt(): ## 任务 -*你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。* +*你正在一个{chat_context_type}里聊天,你需要理解整个{chat_context_type}的聊天动态和话题走向,并做出自然的回应。* ### 核心任务 -- 你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与聊天,你可以参考他们的回复内容,但是你现在想回复{sender_name}的发言。 - -- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。 +- 你现在的主要任务是和 {sender_name} 聊天。 +- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。 ## 规则 {safety_guidelines_block} @@ -236,6 +235,7 @@ class DefaultReplyer: from_plugin: bool = True, stream_id: Optional[str] = None, reply_message: Optional[Dict[str, Any]] = None, + read_mark: float = 0.0, ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: # sourcery skip: merge-nested-ifs """ @@ -268,6 +268,7 @@ class DefaultReplyer: available_actions=available_actions, enable_tool=enable_tool, reply_message=reply_message, + read_mark=read_mark, ) if not prompt: @@ -723,10 +724,8 @@ class DefaultReplyer: truncate=True, show_actions=True, ) - core_dialogue_prompt = f"""-------------------------------- -这是你和{sender}的对话,你们正在交流中: + core_dialogue_prompt = f""" {core_dialogue_prompt_str} --------------------------------- """ return core_dialogue_prompt, all_dialogue_prompt @@ -783,6 +782,7 @@ class DefaultReplyer: available_actions: Optional[Dict[str, ActionInfo]] = None, enable_tool: bool = True, reply_message: Optional[Dict[str, Any]] = None, + read_mark: float = 0.0, ) -> str: """ 构建回复器上下文 @@ -859,7 +859,7 @@ class DefaultReplyer: target = "(无消息内容)" person_info_manager = get_person_info_manager() - person_id = await person_info_manager.get_person_id_by_person_name(sender) + person_id = person_info_manager.get_person_id(platform, reply_message.get("user_id")) if reply_message else None platform = chat_stream.platform target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) @@ -891,7 +891,7 @@ class DefaultReplyer: replace_bot_name=True, merge_messages=False, timestamp_mode="relative", - read_mark=0.0, + read_mark=read_mark, show_actions=True, ) # 获取目标用户信息,用于s4u模式 @@ -1081,6 +1081,7 @@ class DefaultReplyer: reply_target_block=reply_target_block, mood_prompt=mood_prompt, action_descriptions=action_descriptions, + read_mark=read_mark, ) # 使用新的统一Prompt系统 - 使用正确的模板名称 @@ -1167,7 +1168,7 @@ class DefaultReplyer: replace_bot_name=True, merge_messages=False, timestamp_mode="relative", - read_mark=0.0, + read_mark=read_mark, show_actions=True, ) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index b3110a8e6..4b08c5bbe 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -520,6 +520,7 @@ async def _build_readable_messages_internal( pic_counter: int = 1, show_pic: bool = True, message_id_list: Optional[List[Dict[str, Any]]] = None, + read_mark: float = 0.0, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -726,11 +727,10 @@ async def _build_readable_messages_internal( "is_action": is_action, } continue - # 如果是同一个人发送的连续消息且时间间隔小于等于60秒 if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60): current_merge["content"].append(content) - current_merge["end_time"] = timestamp # 更新最后消息时间 + current_merge["end_time"] = timestamp else: # 保存上一个合并块 merged_messages.append(current_merge) @@ -758,8 +758,14 @@ async def _build_readable_messages_internal( # 4 & 5: 格式化为字符串 output_lines = [] + read_mark_inserted = False for _i, merged in enumerate(merged_messages): + # 检查是否需要插入已读标记 + if read_mark > 0 and not read_mark_inserted and merged["start_time"] >= read_mark: + output_lines.append("\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n") + read_mark_inserted = True + # 使用指定的 timestamp_mode 格式化时间 readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode) diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index db31acfa5..3d97b622e 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -78,6 +78,7 @@ class PromptParameters: # 可用动作信息 available_actions: Optional[Dict[str, Any]] = None + read_mark: float = 0.0 def validate(self) -> List[str]: """参数验证""" @@ -449,7 +450,8 @@ class Prompt: core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts( self.parameters.message_list_before_now_long, self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", - self.parameters.sender + self.parameters.sender, + read_mark=self.parameters.read_mark, ) context_data["core_dialogue_prompt"] = core_dialogue @@ -465,7 +467,7 @@ class Prompt: @staticmethod async def _build_s4u_chat_history_prompts( - message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str + message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, read_mark: float = 0.0 ) -> Tuple[str, str]: """构建S4U风格的分离对话prompt""" # 实现逻辑与原有SmartPromptBuilder相同 @@ -491,6 +493,7 @@ class Prompt: replace_bot_name=True, timestamp_mode="normal", truncate=True, + read_mark=read_mark, ) all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}" @@ -510,7 +513,7 @@ class Prompt: replace_bot_name=True, merge_messages=False, timestamp_mode="normal_no_YMD", - read_mark=0.0, + read_mark=read_mark, truncate=True, show_actions=True, ) @@ -764,6 +767,7 @@ class Prompt: "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), + "chat_context_type": "群聊" if self.parameters.is_group_chat else "私聊", } def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 5ffae7298..e74044866 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -84,6 +84,7 @@ async def generate_reply( return_prompt: bool = False, request_type: str = "generator_api", from_plugin: bool = True, + read_mark: float = 0.0, ) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: """生成回复 @@ -129,6 +130,7 @@ async def generate_reply( from_plugin=from_plugin, stream_id=chat_stream.stream_id if chat_stream else chat_id, reply_message=reply_message, + read_mark=read_mark, ) if not success: logger.warning("[GeneratorAPI] 回复生成失败") From a6b6acc1a68141c606cce2782b81451b370e4777 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Tue, 23 Sep 2025 16:13:00 +0800 Subject: [PATCH 03/41] =?UTF-8?q?feat(config):=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=9B=B4=E6=96=B0=E4=B8=AD=E5=B7=B2=E5=BA=9F?= =?UTF-8?q?=E5=BC=83=E7=9A=84=E9=85=8D=E7=BD=AE=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在版本更新过程中,新增一个步骤来对比用户配置与最新的模板文件。 此变更会自动删除用户配置文件中所有在模板中不再存在的键,以保持配置的整洁性,并防止因过时的配置项导致潜在的兼容性问题或混淆。 --- src/config/config.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/config/config.py b/src/config/config.py index 3fbd7e9e6..f0a3ec2d8 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -164,6 +164,18 @@ def _version_tuple(v): return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) +def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDocument | dict | Table): + """ + 递归地从目标字典中移除所有不存在于参考字典中的键。 + """ + # 使用 list() 创建键的副本,以便在迭代期间安全地修改字典 + for key in list(target.keys()): + if key not in reference: + del target[key] + elif isinstance(target.get(key), (dict, Table)) and isinstance(reference.get(key), (dict, Table)): + _remove_obsolete_keys(target[key], reference[key]) + + def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): """ 将source字典的值更新到target字典中 @@ -334,6 +346,13 @@ def _update_config_generic(config_name: str, template_name: str): logger.info(f"开始合并{config_name}新旧配置...") _update_dict(new_config, old_config) + # 移除在新模板中已不存在的旧配置项 + logger.info(f"开始移除{config_name}中已废弃的配置项...") + with open(template_path, "r", encoding="utf-8") as f: + template_doc = tomlkit.load(f) + _remove_obsolete_keys(new_config, template_doc) + logger.info(f"已移除{config_name}中已废弃的配置项") + # 保存更新后的配置(保留注释和格式) with open(new_config_path, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(new_config)) From 46307287604d7623682dec40dffaa6a6dfcc94e4 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Tue, 23 Sep 2025 17:20:57 +0800 Subject: [PATCH 04/41] =?UTF-8?q?=E8=80=81rust=5Fvideo=E6=88=91=E4=B8=BA?= =?UTF-8?q?=E4=BD=A0=E8=B8=A9=E8=83=8C=E6=9D=A5=E5=96=BD(=E6=8A=8A?= =?UTF-8?q?=E4=BB=96=E5=88=A0=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rust_image/Cargo.toml | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 rust_image/Cargo.toml diff --git a/rust_image/Cargo.toml b/rust_image/Cargo.toml deleted file mode 100644 index e69de29bb..000000000 From ae738ef8cb4f4ddb47ad9ef2075ab8fc2dee717e Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Tue, 23 Sep 2025 19:15:58 +0800 Subject: [PATCH 05/41] =?UTF-8?q?perf(memory):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=AE=B0=E5=BF=86=E7=B3=BB=E7=BB=9F=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E6=93=8D=E4=BD=9C=E5=B9=B6=E4=BF=AE=E5=A4=8D=E5=B9=B6=E5=8F=91?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将消息记忆次数的更新方式从单次写入重构为批量更新,在记忆构建任务结束时统一执行,大幅减少数据库写入次数,显著提升性能。 此外,为 `HippocampusManager` 添加了异步锁,以防止记忆巩固和遗忘操作并发执行时产生竞争条件。同时,增加了节点去重逻辑,在插入数据库前检查重复的概念,确保数据一致性。 --- src/chat/memory_system/Hippocampus.py | 67 ++++++++++++------- .../src/recv_handler/notice_handler.py | 2 +- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index fcc8e65d2..ca726c1a8 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -3,6 +3,7 @@ import datetime import math import random import time +import asyncio import re import orjson import jieba @@ -789,7 +790,7 @@ class EntorhinalCortex: self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - async def get_memory_sample(self): + async def get_memory_sample(self) -> tuple[list, list[str]]: """从数据库获取记忆样本""" # 硬编码:每条消息最大记忆次数 max_memorized_time_per_msg = 2 @@ -811,24 +812,27 @@ class EntorhinalCortex: for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False): logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] + all_message_ids_to_update = [] for timestamp in timestamps: - if messages := await self.random_get_msg_snippet( + if result := await self.random_get_msg_snippet( timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg, ): + messages, message_ids_to_update = result time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") chat_samples.append(messages) + all_message_ids_to_update.extend(message_ids_to_update) else: logger.debug(f"时间戳 {timestamp} 的消息无需记忆") - return chat_samples + return chat_samples, all_message_ids_to_update @staticmethod async def random_get_msg_snippet( target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int - ) -> list | None: + ) -> tuple[list, list[str]] | None: # sourcery skip: invert-any-all, use-any, use-named-expression, use-next """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟 @@ -862,18 +866,9 @@ class EntorhinalCortex: # 如果所有消息都有效 if all_valid: - # 更新数据库中的记忆次数 - for message in messages: - # 确保在更新前获取最新的 memorized_times - current_memorized_times = message.get("memorized_times", 0) - async with get_db_session() as session: - await session.execute( - update(Messages) - .where(Messages.message_id == message["message_id"]) - .values(memorized_times=current_memorized_times + 1) - ) - await session.commit() - return messages # 直接返回原始的消息列表 + # 返回消息和需要更新的message_id + message_ids_to_update = [msg["message_id"] for msg in messages] + return messages, message_ids_to_update target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试 @@ -953,10 +948,20 @@ class EntorhinalCortex: # 批量处理节点 if nodes_to_create: - batch_size = 100 - for i in range(0, len(nodes_to_create), batch_size): - batch = nodes_to_create[i : i + batch_size] - await session.execute(insert(GraphNodes), batch) + # 在插入前进行去重检查 + unique_nodes_to_create = [] + seen_concepts = set(db_nodes.keys()) + for node_data in nodes_to_create: + concept = node_data["concept"] + if concept not in seen_concepts: + unique_nodes_to_create.append(node_data) + seen_concepts.add(concept) + + if unique_nodes_to_create: + batch_size = 100 + for i in range(0, len(unique_nodes_to_create), batch_size): + batch = unique_nodes_to_create[i : i + batch_size] + await session.execute(insert(GraphNodes), batch) if nodes_to_update: batch_size = 100 @@ -1346,7 +1351,7 @@ class ParahippocampalGyrus: # sourcery skip: merge-list-appends-into-extend logger.info("------------------------------------开始构建记忆--------------------------------------") start_time = time.time() - memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample() + memory_samples, all_message_ids_to_update = await self.hippocampus.entorhinal_cortex.get_memory_sample() all_added_nodes = [] all_connected_nodes = [] all_added_edges = [] @@ -1409,8 +1414,21 @@ class ParahippocampalGyrus: if all_connected_nodes: logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") + # 先同步记忆图 await self.hippocampus.entorhinal_cortex.sync_memory_to_db() + # 最后批量更新消息的记忆次数 + if all_message_ids_to_update: + async with get_db_session() as session: + # 使用 in_ 操作符进行批量更新 + await session.execute( + update(Messages) + .where(Messages.message_id.in_(all_message_ids_to_update)) + .values(memorized_times=Messages.memorized_times + 1) + ) + await session.commit() + logger.info(f"批量更新了 {len(all_message_ids_to_update)} 条消息的记忆次数") + end_time = time.time() logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------") @@ -1617,6 +1635,7 @@ class HippocampusManager: def __init__(self): self._hippocampus: Hippocampus = None # type: ignore self._initialized = False + self._db_lock = asyncio.Lock() def initialize(self): """初始化海马体实例""" @@ -1665,14 +1684,16 @@ class HippocampusManager: """遗忘记忆的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage) + async with self._db_lock: + return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage) async def consolidate_memory(self): """整合记忆的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") # 使用 operation_build_memory 方法来整合记忆 - return await self._hippocampus.parahippocampal_gyrus.operation_build_memory() + async with self._db_lock: + return await self._hippocampus.parahippocampal_gyrus.operation_build_memory() async def get_memory_from_text( self, diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 4a32657a7..58b7f23b9 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -339,7 +339,7 @@ class NoticeHandler: message_id=raw_message.get("message_id",""), emoji_id=like_emoji_id ) - seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]") + seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]") return seg_data, user_info async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: From 8ff4687670fce68b41abf2117ee5647f02200ef8 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 13:46:44 +0800 Subject: [PATCH 06/41] =?UTF-8?q?fix(db):=20=E4=BF=AE=E5=A4=8D=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=81=E7=A7=BB=E4=B8=AD=E5=88=97=E5=92=8C?= =?UTF-8?q?=E7=B4=A2=E5=BC=95=E7=9A=84=E5=88=9B=E5=BB=BA=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 增强了添加列时对默认值的处理,以兼容不同数据库方言(例如 SQLite 的布尔值)。 - 切换到更标准的 `index.create()` 方法来创建索引,提高了稳定性。 - 调整了启动顺序,确保数据库在主系统之前完成初始化,以防止竞争条件。 --- bot.py | 4 +-- src/chat/message_receive/chat_stream.py | 6 ++-- src/common/database/db_migration.py | 43 ++++++++++++++----------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/bot.py b/bot.py index f382df1e1..472ee5f08 100644 --- a/bot.py +++ b/bot.py @@ -229,10 +229,10 @@ if __name__ == "__main__": asyncio.set_event_loop(loop) try: - # 执行初始化和任务调度 - loop.run_until_complete(main_system.initialize()) # 异步初始化数据库表结构 loop.run_until_complete(maibot.initialize_database_async()) + # 执行初始化和任务调度 + loop.run_until_complete(main_system.initialize()) initialize_lpmm_knowledge() # Schedule tasks returns a future that runs forever. # We can run console_input_loop concurrently. diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index c42654aa3..de2fb62e9 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -254,7 +254,7 @@ class ChatManager: model_instance = await _db_find_stream_async(stream_id) if model_instance: - # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 + # 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式 user_info_data = { "platform": model_instance.user_platform, "user_id": model_instance.user_id, @@ -382,7 +382,7 @@ class ChatManager: await _db_save_stream_async(stream_data_dict) stream.saved = True except Exception as e: - logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) + logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}", exc_info=True) async def _save_all_streams(self): """保存所有聊天流""" @@ -435,7 +435,7 @@ class ChatManager: if stream.stream_id in self.last_messages: stream.set_context(self.last_messages[stream.stream_id]) except Exception as e: - logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) + logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) chat_manager = None diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index aedff3676..8f7b1ecd3 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -70,24 +70,32 @@ async def check_and_migrate_database(): def add_columns_sync(conn): dialect = conn.dialect + compiler = dialect.ddl_compiler(dialect, None) + for column_name in missing_columns: column = table.c[column_name] - - # 使用DDLCompiler为特定方言编译列 - compiler = dialect.ddl_compiler(dialect, None) - - # 编译列的数据类型 column_type = compiler.get_column_specification(column) - - # 构建原生SQL sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}" - - # 添加默认值(如果存在) + if column.default: - default_value = compiler.render_literal_value(column.default.arg, column.type) + # 手动处理不同方言的默认值 + default_arg = column.default.arg + if dialect.name == "sqlite" and isinstance(default_arg, bool): + # SQLite 将布尔值存储为 0 或 1 + default_value = "1" if default_arg else "0" + elif hasattr(compiler, 'render_literal_value'): + try: + # 尝试使用 render_literal_value + default_value = compiler.render_literal_value(default_arg, column.type) + except AttributeError: + # 如果失败,则回退到简单的字符串转换 + default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + else: + # 对于没有 render_literal_value 的旧版或特定方言 + default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + sql += f" DEFAULT {default_value}" - - # 添加非空约束(如果存在) + if not column.nullable: sql += " NOT NULL" @@ -109,12 +117,11 @@ async def check_and_migrate_database(): logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}") def add_indexes_sync(conn): - with conn.begin(): - for index_name in missing_indexes: - index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) - if index_obj is not None: - conn.execute(CreateIndex(index_obj)) - logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。") + for index_name in missing_indexes: + index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) + if index_obj is not None: + index_obj.create(conn) + logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。") await connection.run_sync(add_indexes_sync) else: From 29c9dac4a40ef3fe3ff84c1fd7e5b9f6de53b7aa Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 14:06:34 +0800 Subject: [PATCH 07/41] =?UTF-8?q?refactor(db):=20=E7=A7=BB=E9=99=A4=20Mong?= =?UTF-8?q?oDB=20=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 从数据库模块中移除了所有与 MongoDB 相关的代码,包括连接逻辑、`get_db` 函数和 `DBWrapper` 代理类。 项目将统一使用 SQLAlchemy 作为唯一的数据库接口,此更改旨在简化代码库并消除不再需要的依赖。 BREAKING CHANGE: 全局 MongoDB 实例 `memory_db` 和 `get_db` 函数已被移除。所有数据库交互现在都应通过 SQLAlchemy 会话进行。 --- src/common/database/database.py | 55 --------------------------------- 1 file changed, 55 deletions(-) diff --git a/src/common/database/database.py b/src/common/database/database.py index 6a34d900e..1815a98ff 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,6 +1,4 @@ import os -from pymongo import MongoClient -from pymongo.database import Database from rich.traceback import install from src.common.logger import get_logger @@ -10,8 +8,6 @@ from src.common.database.sqlalchemy_models import get_engine, get_db_session install(extra_lines=3) -_client = None -_db = None _sql_engine = None logger = get_logger("database") @@ -64,43 +60,6 @@ class SQLAlchemyTransaction: db = DatabaseProxy() -def __create_database_instance(): - uri = os.getenv("MONGODB_URI") - host = os.getenv("MONGODB_HOST", "127.0.0.1") - port = int(os.getenv("MONGODB_PORT", "27017")) - # db_name 变量在创建连接时不需要,在获取数据库实例时才使用 - username = os.getenv("MONGODB_USERNAME") - password = os.getenv("MONGODB_PASSWORD") - auth_source = os.getenv("MONGODB_AUTH_SOURCE") - - if uri: - # 支持标准mongodb://和mongodb+srv://连接字符串 - if uri.startswith(("mongodb://", "mongodb+srv://")): - return MongoClient(uri) - else: - raise ValueError( - "Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. " - "For MongoDB Atlas, use 'mongodb+srv://' format. " - "See: https://www.mongodb.com/docs/manual/reference/connection-string/" - ) - - if username and password: - # 如果有用户名和密码,使用认证连接 - return MongoClient(host, port, username=username, password=password, authSource=auth_source) - - # 否则使用无认证连接 - return MongoClient(host, port) - - -def get_db(): - """获取MongoDB连接实例,延迟初始化。""" - global _client, _db - if _client is None: - _client = __create_database_instance() - _db = _client[os.getenv("DATABASE_NAME", "MegBot")] - return _db - - async def initialize_sql_database(database_config): """ 根据配置初始化SQL数据库连接(SQLAlchemy版本) @@ -141,17 +100,3 @@ async def initialize_sql_database(database_config): except Exception as e: logger.error(f"初始化SQL数据库失败: {e}") return None - - -class DBWrapper: - """数据库代理类,保持接口兼容性同时实现懒加载。""" - - def __getattr__(self, name): - return getattr(get_db(), name) - - def __getitem__(self, key): - return get_db()[key] # type: ignore - - -# 全局MongoDB数据库访问点 -memory_db: Database = DBWrapper() # type: ignore From 7feae466c3caa839e93055af24058d4c5462f175 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 14:17:32 +0800 Subject: [PATCH 08/41] Megre Pull Request #1260 from MaiCore:https://github.com/MaiM-with-u/MaiBot/pull/1260 --- src/chat/message_receive/bot.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 3b68190a7..53cb00345 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -415,7 +415,6 @@ class ChatBot: return get_chat_manager().register_message(message) - chat = await get_chat_manager().get_or_create_stream( platform=message.message_info.platform, # type: ignore user_info=user_info, # type: ignore @@ -427,11 +426,11 @@ class ChatBot: # 处理消息内容,生成纯文本 await message.process() - # 过滤检查 - if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore - message.raw_message, # type: ignore - chat, - user_info, # type: ignore + # 过滤检查 (在消息处理之后进行) + if _check_ban_words( + message.processed_plain_text, chat, user_info # type: ignore + ) or _check_ban_regex( + message.processed_plain_text, chat, user_info # type: ignore ): return From 63bf20f0761d46e59eef6f15f90c91d3325a0673 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 14:33:43 +0800 Subject: [PATCH 09/41] =?UTF-8?q?feat(gemini):=20=E4=B8=BA=20Gemini=20?= =?UTF-8?q?=E5=AE=A2=E6=88=B7=E7=AB=AF=E6=B7=BB=E5=8A=A0=20thinking=5Fbudg?= =?UTF-8?q?et=20=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增对 Gemini 模型 "thinking" 功能的支持,允许通过 `thinking_budget` 参数查看模型的思考过程。 - 实现了 `clamp_thinking_budget` 方法,根据不同模型(如 2.5-pro, 2.5-flash)的限制来约束和验证 `thinking_budget` 的值。 - 支持特殊值:-1(自动模式)和 0(禁用模式,如果模型允许)。 - 默认禁用所有安全设置(safetySettings),以减少不必要的回答屏蔽。 --- .../model_client/aiohttp_gemini_client.py | 84 ++++++++++++++++++- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 4ab0af5f7..7b997b680 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -20,6 +20,26 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("AioHTTP-Gemini客户端") +# gemini_thinking参数(默认范围) +# 不同模型的思考预算范围配置 +THINKING_BUDGET_LIMITS = { + "gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True}, + "gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True}, + "gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False}, +} +# 思维预算特殊值 +THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定 +THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用) + +gemini_safe_settings = [ + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, +] + + def _format_to_mime_type(image_format: str) -> str: """ 将图片格式转换为正确的MIME类型 @@ -130,7 +150,11 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]: def _build_generation_config( - max_tokens: int, temperature: float, response_format: RespFormat | None = None, extra_params: dict | None = None + max_tokens: int, + temperature: float, + thinking_budget: int, + response_format: RespFormat | None = None, + extra_params: dict | None = None, ) -> dict: """构建生成配置""" config = { @@ -138,6 +162,8 @@ def _build_generation_config( "temperature": temperature, "topK": 1, "topP": 1, + "safetySettings": gemini_safe_settings, + "thinkingConfig": {"includeThoughts": True, "thinkingBudget": thinking_budget}, } # 处理响应格式 @@ -150,7 +176,11 @@ def _build_generation_config( # 合并额外参数 if extra_params: - config.update(extra_params) + # 拷贝一份以防修改原始字典 + safe_extra_params = extra_params.copy() + # 移除已单独处理的 thinking_budget + safe_extra_params.pop("thinking_budget", None) + config.update(safe_extra_params) return config @@ -317,6 +347,41 @@ class AiohttpGeminiClient(BaseClient): if api_provider.base_url: self.base_url = api_provider.base_url.rstrip("/") + @staticmethod + def clamp_thinking_budget(tb: int, model_id: str) -> int: + """ + 按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本) + """ + limits = None + + # 优先尝试精确匹配 + if model_id in THINKING_BUDGET_LIMITS: + limits = THINKING_BUDGET_LIMITS[model_id] + else: + # 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先 + sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True) + for key in sorted_keys: + # 必须满足:完全等于 或者 前缀匹配(带 "-" 边界) + if model_id == key or model_id.startswith(f"{key}-"): + limits = THINKING_BUDGET_LIMITS[key] + break + + # 特殊值处理 + if tb == THINKING_BUDGET_AUTO: + return THINKING_BUDGET_AUTO + if tb == THINKING_BUDGET_DISABLED: + if limits and limits.get("can_disable", False): + return THINKING_BUDGET_DISABLED + return limits["min"] if limits else THINKING_BUDGET_AUTO + + # 已知模型裁剪到范围 + if limits: + return max(limits["min"], min(tb, limits["max"])) + + # 未知模型,返回动态模式 + logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。") + return tb + # 移除全局 session,全部请求都用 with aiohttp.ClientSession() as session: async def _make_request( @@ -376,10 +441,21 @@ class AiohttpGeminiClient(BaseClient): # 转换消息格式 contents, system_instructions = _convert_messages(message_list) + # 处理思考预算 + tb = THINKING_BUDGET_AUTO + if extra_params and "thinking_budget" in extra_params: + try: + tb = int(extra_params["thinking_budget"]) + except (ValueError, TypeError): + logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}") + tb = self.clamp_thinking_budget(tb, model_info.model_identifier) + # 构建请求体 request_data = { "contents": contents, - "generationConfig": _build_generation_config(max_tokens, temperature, response_format, extra_params), + "generationConfig": _build_generation_config( + max_tokens, temperature, tb, response_format, extra_params + ), } # 添加系统指令 @@ -475,7 +551,7 @@ class AiohttpGeminiClient(BaseClient): request_data = { "contents": contents, - "generationConfig": _build_generation_config(2048, 0.1, None, extra_params), + "generationConfig": _build_generation_config(2048, 0.1, THINKING_BUDGET_AUTO, None, extra_params), } try: From 6ed9349933df2bfa6fc51e0d1680dc78a37c3e9f Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 15:00:39 +0800 Subject: [PATCH 10/41] =?UTF-8?q?refactor(llm):=20=E9=87=8D=E6=9E=84=20LLM?= =?UTF-8?q?=20=E8=AF=B7=E6=B1=82=E5=A4=84=E7=90=86=EF=BC=8C=E5=BC=95?= =?UTF-8?q?=E5=85=A5=E9=80=9A=E7=94=A8=E6=95=85=E9=9A=9C=E8=BD=AC=E7=A7=BB?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 之前的代码里,处理文本、图片、语音的请求方法都各自为战,写了一大堆重复的故障转移逻辑,简直乱糟糟的,看得我头疼。 为了解决这个问题,我进行了一次大扫除: - 引入了一个通用的 `_execute_with_failover` 执行器,把所有“模型失败就换下一个”的脏活累活都统一管理起来了。 - 重构了所有相关的请求方法(文本、图片、语音、嵌入),让它们变得更清爽,只专注于自己的核心任务。 - 升级了 `_model_scheduler`,现在它会智能地根据实时负载给模型排队,谁最闲谁先上。那个笨笨的 `_select_model` 就被我光荣地裁掉了。 这次重构之后,代码的可维护性和健壮性都好多了,再加新功能也方便啦。哼哼,快夸我! --- src/llm_models/utils_model.py | 588 +++++++++++++++++++--------------- 1 file changed, 329 insertions(+), 259 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 146e5eb46..10312f27d 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -159,7 +159,7 @@ class LLMRequest: max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 为图像生成响应 + 为图像生成响应(已集成故障转移) Args: prompt (str): 提示词 image_base64 (str): 图像的Base64编码字符串 @@ -167,71 +167,78 @@ class LLMRequest: Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ - # 标准化图片格式以确保API兼容性 normalized_format = _normalize_image_format(image_format) - # 模型选择 - start_time = time.time() - model_info, api_provider, client = self._select_model() - - # 请求体构建 - message_builder = MessageBuilder() - message_builder.add_text_content(prompt) - message_builder.add_image_content( - image_base64=image_base64, - image_format=normalized_format, - support_formats=client.get_support_image_formats(), - ) - messages = [message_builder.build()] - - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - temperature=temperature, - max_tokens=max_tokens, - ) - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - # 从内容中提取标签的推理内容(向后兼容) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", - time_cost=time.time() - start_time, - request_type=self.request_type, - endpoint="/chat/completions", + async def request_logic( + model_info: ModelInfo, api_provider: APIProvider, client: BaseClient + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + start_time = time.time() + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + message_builder.add_image_content( + image_base64=image_base64, + image_format=normalized_format, + support_formats=client.get_support_image_formats(), ) - return content, (reasoning_content, model_info.name, tool_calls) + messages = [message_builder.build()] + + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + if usage := response.usage: + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", + time_cost=time.time() - start_time, + request_type=self.request_type, + endpoint="/chat/completions", + ) + return content, (reasoning_content, model_info.name, tool_calls) + + result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True) + if result: + return result + + # 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常 + raise RuntimeError("图片响应生成失败,所有模型均尝试失败。") async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: """ - 为语音生成响应 + 为语音生成响应(已集成故障转移) Args: voice_base64 (str): 语音的Base64编码字符串 Returns: (Optional[str]): 生成的文本描述或None """ - # 模型选择 - model_info, api_provider, client = self._select_model() - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.AUDIO, - model_info=model_info, - audio_base64=voice_base64, - ) - return response.content or None + async def request_logic(model_info: ModelInfo, api_provider: APIProvider, client: BaseClient) -> Optional[str]: + """定义单次请求的具体逻辑""" + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.AUDIO, + model_info=model_info, + audio_base64=voice_base64, + ) + return response.content or None + + # 对于语音识别,如果所有模型都失败,我们可能不希望程序崩溃,而是返回None + result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=False) + return result async def generate_response_async( self, @@ -279,6 +286,75 @@ class LLMRequest: raise e return "所有并发请求都失败了", ("", "unknown", None) + async def _execute_with_failover( + self, + request_callable: Callable[[ModelInfo, APIProvider, BaseClient], Coroutine[Any, Any, Any]], + raise_on_failure: bool = True, + ) -> Any: + """ + 通用的故障转移执行器。 + + 它会使用智能模型调度器按最优顺序尝试模型,直到请求成功或所有模型都失败。 + + Args: + request_callable: 一个接收 (model_info, api_provider, client) 并返回协程的函数, + 用于执行实际的请求逻辑。 + raise_on_failure: 如果所有模型都失败,是否抛出异常。 + + Returns: + 请求成功时的返回结果。 + + Raises: + RuntimeError: 如果所有模型都失败且 raise_on_failure 为 True。 + """ + failed_models = set() + last_exception: Optional[Exception] = None + + # model_scheduler 现在会动态排序,所以我们只需要在循环中处理失败的模型 + while True: + model_scheduler = self._model_scheduler(failed_models) + try: + model_info, api_provider, client = next(model_scheduler) + except StopIteration: + # 没有更多可用模型了 + break + + model_name = model_info.name + logger.debug(f"正在尝试使用模型: {model_name} (剩余可用: {len(self.model_for_task.model_list) - len(failed_models)})") + + try: + # 执行传入的请求函数 + result = await request_callable(model_info, api_provider, client) + logger.debug(f"模型 '{model_name}' 成功生成回复。") + return result + + except RespNotOkException as e: + # 对于某些致命的HTTP错误(如认证失败),我们可能希望立即失败或标记该模型为永久失败 + if e.status_code in [401, 403]: + logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将永久禁用此模型在此次请求中。") + else: + logger.warning(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code},将尝试下一个模型。") + failed_models.add(model_name) + last_exception = e + continue + + except Exception as e: + # 捕获其他所有异常(包括超时、解析错误、运行时错误等) + logger.error(f"使用模型 '{model_name}' 时发生异常: {e},将尝试下一个模型。") + failed_models.add(model_name) + last_exception = e + continue + + # 所有模型都尝试失败 + logger.error("所有可用模型都已尝试失败。") + if raise_on_failure: + if last_exception: + raise RuntimeError("所有模型都请求失败") from last_exception + raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") + + # 根据需要返回一个默认的错误结果 + return None + async def _execute_single_request( self, prompt: str, @@ -288,83 +364,67 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 执行单次请求,并在模型失败时按顺序切换到下一个可用模型。 + 使用通用的故障转移执行器来执行单次文本生成请求。 """ - failed_models = set() - last_exception: Optional[Exception] = None - model_scheduler = self._model_scheduler(failed_models) - - for model_info, api_provider, client in model_scheduler: + async def request_logic( + model_info: ModelInfo, api_provider: APIProvider, client: BaseClient + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + """定义单次请求的具体逻辑""" start_time = time.time() model_name = model_info.name - logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏 - try: - # 检查是否启用反截断 - # 检查是否为该模型启用反截断 - use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_prompt = prompt + # 检查是否启用反截断 + use_anti_truncation = getattr(model_info, "use_anti_truncation", False) + processed_prompt = prompt + if use_anti_truncation: + processed_prompt += self.anti_truncation_instruction + logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") + + processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) + + message_builder = MessageBuilder() + message_builder.add_text_content(processed_prompt) + messages = [message_builder.build()] + tool_built = self._build_tool_options(tools) + + # 针对当前模型的空回复/截断重试逻辑 + empty_retry_count = 0 + max_empty_retry = api_provider.max_retry + empty_retry_interval = api_provider.retry_interval + + is_empty_reply = False + is_truncated = False + + while empty_retry_count <= max_empty_retry: + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + tool_options=tool_built, + temperature=temperature, + max_tokens=max_tokens, + ) + + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + is_empty_reply = not tool_calls and (not content or content.strip() == "") + is_truncated = False if use_anti_truncation: - processed_prompt += self.anti_truncation_instruction - logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") - - processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) - - message_builder = MessageBuilder() - message_builder.add_text_content(processed_prompt) - messages = [message_builder.build()] - tool_built = self._build_tool_options(tools) - - # 针对当前模型的空回复/截断重试逻辑 - empty_retry_count = 0 - max_empty_retry = api_provider.max_retry - empty_retry_interval = api_provider.retry_interval - - while empty_retry_count <= max_empty_retry: - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - tool_options=tool_built, - temperature=temperature, - max_tokens=max_tokens, - ) - - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - is_empty_reply = not tool_calls and (not content or content.strip() == "") - is_truncated = False - if use_anti_truncation: - if content.endswith(self.end_marker): - content = content[: -len(self.end_marker)].strip() - else: - is_truncated = True - - if is_empty_reply or is_truncated: - empty_retry_count += 1 - if empty_retry_count <= max_empty_retry: - reason = "空回复" if is_empty_reply else "截断" - logger.warning( - f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..." - ) - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) - continue # 继续使用当前模型重试 - else: - # 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型 - reason = "空回复" if is_empty_reply else "截断" - logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") - raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") + if content.endswith(self.end_marker): + content = content[: -len(self.end_marker)].strip() + else: + is_truncated = True + if not is_empty_reply and not is_truncated: # 成功获取响应 if usage := response.usage: await llm_usage_recorder.record_usage_to_database( @@ -381,115 +441,115 @@ class LLMRequest: raise RuntimeError("生成空回复") content = "生成的响应为空" - logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏 return content, (reasoning_content, model_name, tool_calls) - except RespNotOkException as e: - if e.status_code in [401, 403]: - logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 - else: - logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}") - if raise_when_empty: - raise - # 对于其他HTTP错误,直接抛出,不再尝试其他模型 - return f"请求失败: {e}", ("", model_name, None) + # 如果代码执行到这里,说明是空回复或截断,需要重试 + empty_retry_count += 1 + if empty_retry_count <= max_empty_retry: + reason = "空回复" if is_empty_reply else "截断" + logger.warning( + f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..." + ) + if empty_retry_interval > 0: + await asyncio.sleep(empty_retry_interval) + continue # 继续使用当前模型重试 - except RuntimeError as e: - # 捕获所有重试失败(包括空回复和网络问题) - logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 + # 如果循环结束,说明重试次数已用尽 + reason = "空回复" if is_empty_reply else "截断" + logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") + raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") - except Exception as e: - logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 + # 调用通用的故障转移执行器 + result = await self._execute_with_failover( + request_callable=request_logic, raise_on_failure=raise_when_empty + ) - # 所有模型都尝试失败 - logger.error("所有可用模型都已尝试失败。") - if raise_when_empty: - if last_exception: - raise RuntimeError("所有模型都请求失败") from last_exception - raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") + if result: + return result + # 如果所有模型都失败了,并且不抛出异常,返回一个默认的错误信息 return "所有模型都请求失败", ("", "unknown", None) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """获取嵌入向量 + """获取嵌入向量(已集成故障转移) Args: embedding_input (str): 获取嵌入的目标 Returns: (Tuple[List[float], str]): (嵌入向量,使用的模型名称) """ - # 无需构建消息体,直接使用输入文本 - start_time = time.time() - model_info, api_provider, client = self._select_model() - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.EMBEDDING, - model_info=model_info, - embedding_input=embedding_input, - ) - - embedding = response.embedding - - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( + async def request_logic( + model_info: ModelInfo, api_provider: APIProvider, client: BaseClient + ) -> Tuple[List[float], str]: + """定义单次请求的具体逻辑""" + start_time = time.time() + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.EMBEDDING, model_info=model_info, - time_cost=time.time() - start_time, - model_usage=usage, - user_id="system", - request_type=self.request_type, - endpoint="/embeddings", + embedding_input=embedding_input, ) - if not embedding: - raise RuntimeError("获取embedding失败") + embedding = response.embedding + if not embedding: + raise RuntimeError(f"模型 '{model_info.name}'未能返回 embedding。") - return embedding, model_info.name + if usage := response.usage: + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + time_cost=time.time() - start_time, + model_usage=usage, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings", + ) - def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: + return embedding, model_info.name + + result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True) + if result: + return result + + # 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常 + raise RuntimeError("获取 embedding 失败,所有模型均尝试失败。") + + def _model_scheduler( + self, failed_models: set | None = None + ) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: """ - 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 + 一个智能模型调度器,根据实时负载动态排序并提供模型,同时跳过已失败的模型。 """ - for model_name in self.model_for_task.model_list: - if model_name in failed_models: - continue + # sourcery skip: class-extract-method + if failed_models is None: + failed_models = set() + # 1. 筛选出所有未失败的可用模型 + available_models = [name for name in self.model_for_task.model_list if name not in failed_models] + + # 2. 根据负载均衡算法对可用模型进行排序 + # key: total_tokens + penalty * 300 + usage_penalty * 1000 + sorted_models = sorted( + available_models, + key=lambda name: self.model_usage[name][0] + + self.model_usage[name][1] * 300 + + self.model_usage[name][2] * 1000, + ) + + if not sorted_models: + logger.warning("所有模型都已失败或不可用,调度器无法提供任何模型。") + return + + logger.debug(f"模型调度顺序: {', '.join(sorted_models)}") + + # 3. 按最优顺序 yield 模型信息 + for model_name in sorted_models: model_info = model_config.get_model_info(model_name) api_provider = model_config.get_provider(model_info.api_provider) force_new_client = self.request_type == "embedding" client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - yield model_info, api_provider, client - def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: - """ - 根据总tokens和惩罚值选择的模型 (负载均衡) - """ - least_used_model_name = min( - self.model_usage, - key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, - ) - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - logger.debug(f"选择请求模型: {model_info.name}") - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 - return model_info, api_provider, client - async def _execute_request( self, api_provider: APIProvider, @@ -513,63 +573,73 @@ class LLMRequest: """ retry_remain = api_provider.max_retry compressed_messages: Optional[List[Message]] = None - while retry_remain > 0: - try: - if request_type == RequestType.RESPONSE: - assert message_list is not None, "message_list cannot be None for response requests" - return await client.get_response( - model_info=model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, - temperature=self.model_for_task.temperature if temperature is None else temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.EMBEDDING: - assert embedding_input, "embedding_input cannot be empty for embedding requests" - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.AUDIO: - assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" - return await client.get_audio_transcriptions( - model_info=model_info, - audio_base64=audio_base64, - extra_params=model_info.extra_params, - ) - except Exception as e: - logger.debug(f"请求失败: {str(e)}") - # 处理异常 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) - wait_interval, compressed_messages = self._default_exception_handler( - e, - self.task_name, - model_info=model_info, - api_provider=api_provider, - remain_try=retry_remain, - retry_interval=api_provider.retry_interval, - messages=(message_list, compressed_messages is not None) if message_list else None, - ) - - if wait_interval == -1: - retry_remain = 0 # 不再重试 - elif wait_interval > 0: - logger.info(f"等待 {wait_interval} 秒后重试...") - await asyncio.sleep(wait_interval) - finally: - # 放在finally防止死循环 - retry_remain -= 1 + # 增加使用惩罚值,标记该模型正在被尝试 total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 - logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") - raise RuntimeError("请求失败,已达到最大重试次数") + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) + + try: + while retry_remain > 0: + try: + if request_type == RequestType.RESPONSE: + assert message_list is not None, "message_list cannot be None for response requests" + return await client.get_response( + model_info=model_info, + message_list=(compressed_messages or message_list), + tool_options=tool_options, + max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, + temperature=self.model_for_task.temperature if temperature is None else temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + extra_params=model_info.extra_params, + ) + elif request_type == RequestType.EMBEDDING: + assert embedding_input, "embedding_input cannot be empty for embedding requests" + return await client.get_embedding( + model_info=model_info, + embedding_input=embedding_input, + extra_params=model_info.extra_params, + ) + elif request_type == RequestType.AUDIO: + assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" + return await client.get_audio_transcriptions( + model_info=model_info, + audio_base64=audio_base64, + extra_params=model_info.extra_params, + ) + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + # 处理异常 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) + + wait_interval, compressed_messages = self._default_exception_handler( + e, + self.task_name, + model_info=model_info, + api_provider=api_provider, + remain_try=retry_remain, + retry_interval=api_provider.retry_interval, + messages=(message_list, compressed_messages is not None) if message_list else None, + ) + + if wait_interval == -1: + retry_remain = 0 # 不再重试 + elif wait_interval > 0: + logger.info(f"等待 {wait_interval} 秒后重试...") + await asyncio.sleep(wait_interval) + finally: + # 放在finally防止死循环 + retry_remain -= 1 + + # 当请求完全结束(无论是成功还是所有重试都失败),都将在此处处理 + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") + raise RuntimeError("请求失败,已达到最大重试次数") + finally: + # 无论请求成功或失败,最终都将使用惩罚值减回去 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) def _default_exception_handler( self, From f64f7755bdaed0e60857a9b5f11b3c2ce3349f0c Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 15:01:46 +0800 Subject: [PATCH 11/41] =?UTF-8?q?=E7=A7=BB=E9=99=A4=20=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E9=99=8D=E7=BA=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/utils_model.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 10312f27d..4ed2ac637 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -755,24 +755,6 @@ class LLMRequest: # 响应错误 if e.status_code in [400, 401, 402, 403, 404]: model_name = model_info.name - if ( - e.status_code == 403 - and model_name.startswith("Pro/deepseek-ai") - and api_provider.base_url == "https://api.siliconflow.cn/v1/" - ): - old_model_name = model_name - new_model_name = model_name[4:] - model_info.name = new_model_name - logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {new_model_name}") - # 更新任务配置中的模型列表 - for i, m_name in enumerate(self.model_for_task.model_list): - if m_name == old_model_name: - self.model_for_task.model_list[i] = new_model_name - logger.warning( - f"将任务 {self.task_name} 的模型列表中的 {old_model_name} 临时降级至 {new_model_name}" - ) - break - return 0, None # 立即重试 # 客户端错误 logger.warning( f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}" From 63837c3acec89a1b61e8b6080bed6015e2265c33 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 15:02:30 +0800 Subject: [PATCH 12/41] =?UTF-8?q?fix(config):=20=E7=A7=BB=E9=99=A4=20Silic?= =?UTF-8?q?onFlow=20=E6=A8=A1=E5=9E=8B=E6=A0=87=E8=AF=86=E7=AC=A6=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=20'Pro/'=20=E5=89=8D=E7=BC=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- template/model_config_template.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/template/model_config_template.toml b/template/model_config_template.toml index ea200accb..8c9763c2f 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -53,14 +53,14 @@ price_out = 8.0 # 输出价格(用于API调用统计,单 #use_anti_truncation = true # [可选] 启用反截断功能。当模型输出不完整时,系统会自动重试。建议只为有需要的模型(如Gemini)开启。 [[models]] -model_identifier = "Pro/deepseek-ai/DeepSeek-V3" +model_identifier = "deepseek-ai/DeepSeek-V3" name = "siliconflow-deepseek-v3" api_provider = "SiliconFlow" price_in = 2.0 price_out = 8.0 [[models]] -model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +model_identifier = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" name = "deepseek-r1-distill-qwen-32b" api_provider = "SiliconFlow" price_in = 4.0 From 2a52f3c7c65cbd029c5b1a01d6ed544897d03eba Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 18:58:46 +0800 Subject: [PATCH 13/41] =?UTF-8?q?refactor(set=5Femoji=5Flike):=20=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=20send=5Fcommand=20=E6=96=B9=E6=B3=95=E5=8F=91?= =?UTF-8?q?=E9=80=81=E8=A1=A8=E6=83=85=E5=9B=9E=E5=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将原先直接调用底层 `adapter_command_to_stream` 的方式重构为使用封装好的 `self.send_command` 辅助方法。 此次重构简化了动作实现代码,提高了可读性,并更好地封装了命令发送的逻辑。 --- plugins/set_emoji_like/plugin.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/plugins/set_emoji_like/plugin.py b/plugins/set_emoji_like/plugin.py index 9e569cbb2..966d4aabc 100644 --- a/plugins/set_emoji_like/plugin.py +++ b/plugins/set_emoji_like/plugin.py @@ -125,31 +125,25 @@ class SetEmojiLikeAction(BaseAction): try: # 使用适配器API发送贴表情命令 - response = await send_api.adapter_command_to_stream( - action="set_msg_emoji_like", - params={"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, - stream_id=self.chat_stream.stream_id if self.chat_stream else None, - timeout=30.0, - storage_message=False, + success = await self.send_command( + command_name="set_emoji_like", args={"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, storage_message=False ) - - if response["status"] == "ok": - logger.info(f"设置表情回应成功: {response}") + if success: + logger.info("设置表情回应成功") await self.store_action_info( action_build_into_prompt=True, action_prompt_display=f"执行了set_emoji_like动作,{emoji_input},设置表情回应: {emoji_id}, 是否设置: {set_like}", action_done=True, ) - return True, f"成功设置表情回应: {response.get('message', '成功')}" + return True, "成功设置表情回应" else: - error_msg = response.get("message", "未知错误") - logger.error(f"设置表情回应失败: {error_msg}") + logger.error("设置表情回应失败") await self.store_action_info( action_build_into_prompt=True, - action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: {error_msg}", + action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败", action_done=False, ) - return False, f"设置表情回应失败: {error_msg}" + return False, "设置表情回应失败" except Exception as e: logger.error(f"设置表情回应失败: {e}") From 98212bb9385af4c05641777c4b6bc41d83aa6218 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 20:21:59 +0800 Subject: [PATCH 14/41] =?UTF-8?q?feat(chat):=20=E5=9C=A8=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E4=B8=AD=E6=98=BE=E7=A4=BA=E7=94=A8=E6=88=B7?= =?UTF-8?q?=20QQ=20=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/utils/chat_message_builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 4b08c5bbe..e2d0a4fb9 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -643,6 +643,10 @@ async def _build_readable_messages_internal( else: person_name = "某人" + # 在用户名后面添加 QQ 号, 但机器人本体不用 + if user_id != global_config.bot.qq_account: + person_name = f"{person_name}({user_id})" + # 使用独立函数处理用户引用格式 content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name) From 4e3ab4003c490fd96a65c40a75791c06b0a28185 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 21:28:42 +0800 Subject: [PATCH 15/41] =?UTF-8?q?Revert=20"refactor(llm):=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=20LLM=20=E8=AF=B7=E6=B1=82=E5=A4=84=E7=90=86=EF=BC=8C?= =?UTF-8?q?=E5=BC=95=E5=85=A5=E9=80=9A=E7=94=A8=E6=95=85=E9=9A=9C=E8=BD=AC?= =?UTF-8?q?=E7=A7=BB=E6=89=A7=E8=A1=8C=E5=99=A8"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 6ed9349933df2bfa6fc51e0d1680dc78a37c3e9f. --- src/llm_models/utils_model.py | 588 +++++++++++++++------------------- 1 file changed, 259 insertions(+), 329 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 4ed2ac637..8f668dc7b 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -159,7 +159,7 @@ class LLMRequest: max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 为图像生成响应(已集成故障转移) + 为图像生成响应 Args: prompt (str): 提示词 image_base64 (str): 图像的Base64编码字符串 @@ -167,78 +167,71 @@ class LLMRequest: Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ + # 标准化图片格式以确保API兼容性 normalized_format = _normalize_image_format(image_format) - async def request_logic( - model_info: ModelInfo, api_provider: APIProvider, client: BaseClient - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - start_time = time.time() - message_builder = MessageBuilder() - message_builder.add_text_content(prompt) - message_builder.add_image_content( - image_base64=image_base64, - image_format=normalized_format, - support_formats=client.get_support_image_formats(), - ) - messages = [message_builder.build()] + # 模型选择 + start_time = time.time() + model_info, api_provider, client = self._select_model() - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, + # 请求体构建 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + message_builder.add_image_content( + image_base64=image_base64, + image_format=normalized_format, + support_formats=client.get_support_image_formats(), + ) + messages = [message_builder.build()] + + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + if usage := response.usage: + await llm_usage_recorder.record_usage_to_database( model_info=model_info, - message_list=messages, - temperature=temperature, - max_tokens=max_tokens, + model_usage=usage, + user_id="system", + time_cost=time.time() - start_time, + request_type=self.request_type, + endpoint="/chat/completions", ) - - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", - time_cost=time.time() - start_time, - request_type=self.request_type, - endpoint="/chat/completions", - ) - return content, (reasoning_content, model_info.name, tool_calls) - - result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True) - if result: - return result - - # 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常 - raise RuntimeError("图片响应生成失败,所有模型均尝试失败。") + return content, (reasoning_content, model_info.name, tool_calls) async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: """ - 为语音生成响应(已集成故障转移) + 为语音生成响应 Args: voice_base64 (str): 语音的Base64编码字符串 Returns: (Optional[str]): 生成的文本描述或None """ + # 模型选择 + model_info, api_provider, client = self._select_model() - async def request_logic(model_info: ModelInfo, api_provider: APIProvider, client: BaseClient) -> Optional[str]: - """定义单次请求的具体逻辑""" - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.AUDIO, - model_info=model_info, - audio_base64=voice_base64, - ) - return response.content or None - - # 对于语音识别,如果所有模型都失败,我们可能不希望程序崩溃,而是返回None - result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=False) - return result + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.AUDIO, + model_info=model_info, + audio_base64=voice_base64, + ) + return response.content or None async def generate_response_async( self, @@ -286,75 +279,6 @@ class LLMRequest: raise e return "所有并发请求都失败了", ("", "unknown", None) - async def _execute_with_failover( - self, - request_callable: Callable[[ModelInfo, APIProvider, BaseClient], Coroutine[Any, Any, Any]], - raise_on_failure: bool = True, - ) -> Any: - """ - 通用的故障转移执行器。 - - 它会使用智能模型调度器按最优顺序尝试模型,直到请求成功或所有模型都失败。 - - Args: - request_callable: 一个接收 (model_info, api_provider, client) 并返回协程的函数, - 用于执行实际的请求逻辑。 - raise_on_failure: 如果所有模型都失败,是否抛出异常。 - - Returns: - 请求成功时的返回结果。 - - Raises: - RuntimeError: 如果所有模型都失败且 raise_on_failure 为 True。 - """ - failed_models = set() - last_exception: Optional[Exception] = None - - # model_scheduler 现在会动态排序,所以我们只需要在循环中处理失败的模型 - while True: - model_scheduler = self._model_scheduler(failed_models) - try: - model_info, api_provider, client = next(model_scheduler) - except StopIteration: - # 没有更多可用模型了 - break - - model_name = model_info.name - logger.debug(f"正在尝试使用模型: {model_name} (剩余可用: {len(self.model_for_task.model_list) - len(failed_models)})") - - try: - # 执行传入的请求函数 - result = await request_callable(model_info, api_provider, client) - logger.debug(f"模型 '{model_name}' 成功生成回复。") - return result - - except RespNotOkException as e: - # 对于某些致命的HTTP错误(如认证失败),我们可能希望立即失败或标记该模型为永久失败 - if e.status_code in [401, 403]: - logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将永久禁用此模型在此次请求中。") - else: - logger.warning(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code},将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue - - except Exception as e: - # 捕获其他所有异常(包括超时、解析错误、运行时错误等) - logger.error(f"使用模型 '{model_name}' 时发生异常: {e},将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue - - # 所有模型都尝试失败 - logger.error("所有可用模型都已尝试失败。") - if raise_on_failure: - if last_exception: - raise RuntimeError("所有模型都请求失败") from last_exception - raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") - - # 根据需要返回一个默认的错误结果 - return None - async def _execute_single_request( self, prompt: str, @@ -364,67 +288,83 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 使用通用的故障转移执行器来执行单次文本生成请求。 + 执行单次请求,并在模型失败时按顺序切换到下一个可用模型。 """ + failed_models = set() + last_exception: Optional[Exception] = None - async def request_logic( - model_info: ModelInfo, api_provider: APIProvider, client: BaseClient - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """定义单次请求的具体逻辑""" + model_scheduler = self._model_scheduler(failed_models) + + for model_info, api_provider, client in model_scheduler: start_time = time.time() model_name = model_info.name + logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏 - # 检查是否启用反截断 - use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_prompt = prompt - if use_anti_truncation: - processed_prompt += self.anti_truncation_instruction - logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") - - processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) - - message_builder = MessageBuilder() - message_builder.add_text_content(processed_prompt) - messages = [message_builder.build()] - tool_built = self._build_tool_options(tools) - - # 针对当前模型的空回复/截断重试逻辑 - empty_retry_count = 0 - max_empty_retry = api_provider.max_retry - empty_retry_interval = api_provider.retry_interval - - is_empty_reply = False - is_truncated = False - - while empty_retry_count <= max_empty_retry: - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - tool_options=tool_built, - temperature=temperature, - max_tokens=max_tokens, - ) - - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - is_empty_reply = not tool_calls and (not content or content.strip() == "") - is_truncated = False + try: + # 检查是否启用反截断 + # 检查是否为该模型启用反截断 + use_anti_truncation = getattr(model_info, "use_anti_truncation", False) + processed_prompt = prompt if use_anti_truncation: - if content.endswith(self.end_marker): - content = content[: -len(self.end_marker)].strip() - else: - is_truncated = True + processed_prompt += self.anti_truncation_instruction + logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") + + processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) + + message_builder = MessageBuilder() + message_builder.add_text_content(processed_prompt) + messages = [message_builder.build()] + tool_built = self._build_tool_options(tools) + + # 针对当前模型的空回复/截断重试逻辑 + empty_retry_count = 0 + max_empty_retry = api_provider.max_retry + empty_retry_interval = api_provider.retry_interval + + while empty_retry_count <= max_empty_retry: + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + tool_options=tool_built, + temperature=temperature, + max_tokens=max_tokens, + ) + + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + is_empty_reply = not tool_calls and (not content or content.strip() == "") + is_truncated = False + if use_anti_truncation: + if content.endswith(self.end_marker): + content = content[: -len(self.end_marker)].strip() + else: + is_truncated = True + + if is_empty_reply or is_truncated: + empty_retry_count += 1 + if empty_retry_count <= max_empty_retry: + reason = "空回复" if is_empty_reply else "截断" + logger.warning( + f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..." + ) + if empty_retry_interval > 0: + await asyncio.sleep(empty_retry_interval) + continue # 继续使用当前模型重试 + else: + # 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型 + reason = "空回复" if is_empty_reply else "截断" + logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") + raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") - if not is_empty_reply and not is_truncated: # 成功获取响应 if usage := response.usage: await llm_usage_recorder.record_usage_to_database( @@ -441,115 +381,115 @@ class LLMRequest: raise RuntimeError("生成空回复") content = "生成的响应为空" + logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏 return content, (reasoning_content, model_name, tool_calls) - # 如果代码执行到这里,说明是空回复或截断,需要重试 - empty_retry_count += 1 - if empty_retry_count <= max_empty_retry: - reason = "空回复" if is_empty_reply else "截断" - logger.warning( - f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..." - ) - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) - continue # 继续使用当前模型重试 + except RespNotOkException as e: + if e.status_code in [401, 403]: + logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。") + failed_models.add(model_name) + last_exception = e + continue # 切换到下一个模型 + else: + logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}") + if raise_when_empty: + raise + # 对于其他HTTP错误,直接抛出,不再尝试其他模型 + return f"请求失败: {e}", ("", model_name, None) - # 如果循环结束,说明重试次数已用尽 - reason = "空回复" if is_empty_reply else "截断" - logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") - raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") + except RuntimeError as e: + # 捕获所有重试失败(包括空回复和网络问题) + logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。") + failed_models.add(model_name) + last_exception = e + continue # 切换到下一个模型 - # 调用通用的故障转移执行器 - result = await self._execute_with_failover( - request_callable=request_logic, raise_on_failure=raise_when_empty - ) + except Exception as e: + logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}") + failed_models.add(model_name) + last_exception = e + continue # 切换到下一个模型 - if result: - return result + # 所有模型都尝试失败 + logger.error("所有可用模型都已尝试失败。") + if raise_when_empty: + if last_exception: + raise RuntimeError("所有模型都请求失败") from last_exception + raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") - # 如果所有模型都失败了,并且不抛出异常,返回一个默认的错误信息 return "所有模型都请求失败", ("", "unknown", None) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """获取嵌入向量(已集成故障转移) + """获取嵌入向量 Args: embedding_input (str): 获取嵌入的目标 Returns: (Tuple[List[float], str]): (嵌入向量,使用的模型名称) """ + # 无需构建消息体,直接使用输入文本 + start_time = time.time() + model_info, api_provider, client = self._select_model() - async def request_logic( - model_info: ModelInfo, api_provider: APIProvider, client: BaseClient - ) -> Tuple[List[float], str]: - """定义单次请求的具体逻辑""" - start_time = time.time() - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.EMBEDDING, - model_info=model_info, - embedding_input=embedding_input, - ) - - embedding = response.embedding - if not embedding: - raise RuntimeError(f"模型 '{model_info.name}'未能返回 embedding。") - - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - time_cost=time.time() - start_time, - model_usage=usage, - user_id="system", - request_type=self.request_type, - endpoint="/embeddings", - ) - - return embedding, model_info.name - - result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True) - if result: - return result - - # 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常 - raise RuntimeError("获取 embedding 失败,所有模型均尝试失败。") - - def _model_scheduler( - self, failed_models: set | None = None - ) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: - """ - 一个智能模型调度器,根据实时负载动态排序并提供模型,同时跳过已失败的模型。 - """ - # sourcery skip: class-extract-method - if failed_models is None: - failed_models = set() - - # 1. 筛选出所有未失败的可用模型 - available_models = [name for name in self.model_for_task.model_list if name not in failed_models] - - # 2. 根据负载均衡算法对可用模型进行排序 - # key: total_tokens + penalty * 300 + usage_penalty * 1000 - sorted_models = sorted( - available_models, - key=lambda name: self.model_usage[name][0] - + self.model_usage[name][1] * 300 - + self.model_usage[name][2] * 1000, + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.EMBEDDING, + model_info=model_info, + embedding_input=embedding_input, ) - if not sorted_models: - logger.warning("所有模型都已失败或不可用,调度器无法提供任何模型。") - return + embedding = response.embedding - logger.debug(f"模型调度顺序: {', '.join(sorted_models)}") + if usage := response.usage: + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + time_cost=time.time() - start_time, + model_usage=usage, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings", + ) + + if not embedding: + raise RuntimeError("获取embedding失败") + + return embedding, model_info.name + + def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: + """ + 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 + """ + for model_name in self.model_for_task.model_list: + if model_name in failed_models: + continue - # 3. 按最优顺序 yield 模型信息 - for model_name in sorted_models: model_info = model_config.get_model_info(model_name) api_provider = model_config.get_provider(model_info.api_provider) force_new_client = self.request_type == "embedding" client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + yield model_info, api_provider, client + def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: + """ + 根据总tokens和惩罚值选择的模型 (负载均衡) + """ + least_used_model_name = min( + self.model_usage, + key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, + ) + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + + # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 + force_new_client = self.request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + logger.debug(f"选择请求模型: {model_info.name}") + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 + return model_info, api_provider, client + async def _execute_request( self, api_provider: APIProvider, @@ -573,73 +513,63 @@ class LLMRequest: """ retry_remain = api_provider.max_retry compressed_messages: Optional[List[Message]] = None - - # 增加使用惩罚值,标记该模型正在被尝试 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) - - try: - while retry_remain > 0: - try: - if request_type == RequestType.RESPONSE: - assert message_list is not None, "message_list cannot be None for response requests" - return await client.get_response( - model_info=model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, - temperature=self.model_for_task.temperature if temperature is None else temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.EMBEDDING: - assert embedding_input, "embedding_input cannot be empty for embedding requests" - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.AUDIO: - assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" - return await client.get_audio_transcriptions( - model_info=model_info, - audio_base64=audio_base64, - extra_params=model_info.extra_params, - ) - except Exception as e: - logger.debug(f"请求失败: {str(e)}") - # 处理异常 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) - - wait_interval, compressed_messages = self._default_exception_handler( - e, - self.task_name, + while retry_remain > 0: + try: + if request_type == RequestType.RESPONSE: + assert message_list is not None, "message_list cannot be None for response requests" + return await client.get_response( model_info=model_info, - api_provider=api_provider, - remain_try=retry_remain, - retry_interval=api_provider.retry_interval, - messages=(message_list, compressed_messages is not None) if message_list else None, + message_list=(compressed_messages or message_list), + tool_options=tool_options, + max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, + temperature=self.model_for_task.temperature if temperature is None else temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + extra_params=model_info.extra_params, ) + elif request_type == RequestType.EMBEDDING: + assert embedding_input, "embedding_input cannot be empty for embedding requests" + return await client.get_embedding( + model_info=model_info, + embedding_input=embedding_input, + extra_params=model_info.extra_params, + ) + elif request_type == RequestType.AUDIO: + assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" + return await client.get_audio_transcriptions( + model_info=model_info, + audio_base64=audio_base64, + extra_params=model_info.extra_params, + ) + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + # 处理异常 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) - if wait_interval == -1: - retry_remain = 0 # 不再重试 - elif wait_interval > 0: - logger.info(f"等待 {wait_interval} 秒后重试...") - await asyncio.sleep(wait_interval) - finally: - # 放在finally防止死循环 - retry_remain -= 1 + wait_interval, compressed_messages = self._default_exception_handler( + e, + self.task_name, + model_info=model_info, + api_provider=api_provider, + remain_try=retry_remain, + retry_interval=api_provider.retry_interval, + messages=(message_list, compressed_messages is not None) if message_list else None, + ) - # 当请求完全结束(无论是成功还是所有重试都失败),都将在此处处理 - logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") - raise RuntimeError("请求失败,已达到最大重试次数") - finally: - # 无论请求成功或失败,最终都将使用惩罚值减回去 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) + if wait_interval == -1: + retry_remain = 0 # 不再重试 + elif wait_interval > 0: + logger.info(f"等待 {wait_interval} 秒后重试...") + await asyncio.sleep(wait_interval) + finally: + # 放在finally防止死循环 + retry_remain -= 1 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") + raise RuntimeError("请求失败,已达到最大重试次数") def _default_exception_handler( self, From 885d4217206510b706e13c269bc609c5321df3b7 Mon Sep 17 00:00:00 2001 From: ikun-11451 <334495606@qq.com> Date: Wed, 24 Sep 2025 23:21:12 +0800 Subject: [PATCH 16/41] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=9D=87=E8=A1=A1=E8=B4=9F=E8=BD=BD=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=96=B5~=E4=BD=A0=E6=98=AF=E4=B8=80=E5=8F=AA=E7=8C=AB?= =?UTF-8?q?=E5=A8=98=E5=96=B5~?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/utils_model.py | 180 +++++++++++++++++++++++++--------- 1 file changed, 131 insertions(+), 49 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 8f668dc7b..ec1a996bf 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -288,20 +288,29 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 执行单次请求,并在模型失败时按顺序切换到下一个可用模型。 + 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 """ - failed_models = set() + failed_models_in_this_request = set() + # 迭代次数等于模型总数,以确保每个模型在当前请求中最多只尝试一次 + max_attempts = len(self.model_for_task.model_list) last_exception: Optional[Exception] = None - model_scheduler = self._model_scheduler(failed_models) + for attempt in range(max_attempts): + # 根据负载均衡和当前故障选择最佳可用模型 + model_selection_result = self._select_best_available_model(failed_models_in_this_request) - for model_info, api_provider, client in model_scheduler: - start_time = time.time() + if model_selection_result is None: + logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") + break # 没有更多模型可供尝试 + + model_info, api_provider, client = model_selection_result model_name = model_info.name - logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏 + logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...") + + start_time = time.time() try: - # 检查是否启用反截断 + # --- 为当前模型尝试进行设置 --- # 检查是否为该模型启用反截断 use_anti_truncation = getattr(model_info, "use_anti_truncation", False) processed_prompt = prompt @@ -316,7 +325,7 @@ class LLMRequest: messages = [message_builder.build()] tool_built = self._build_tool_options(tools) - # 针对当前模型的空回复/截断重试逻辑 + # --- 当前选定模型内的空回复/截断重试逻辑 --- empty_retry_count = 0 max_empty_retry = api_provider.max_retry empty_retry_interval = api_provider.retry_interval @@ -337,6 +346,7 @@ class LLMRequest: reasoning_content = response.reasoning_content or "" tool_calls = response.tool_calls + # 向后兼容 标签(如果 reasoning_content 为空) if not reasoning_content and content: content, extracted_reasoning = self._extract_reasoning(content) reasoning_content = extracted_reasoning @@ -354,18 +364,17 @@ class LLMRequest: if empty_retry_count <= max_empty_retry: reason = "空回复" if is_empty_reply else "截断" logger.warning( - f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..." + f"模型 '{model_name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..." ) if empty_retry_interval > 0: await asyncio.sleep(empty_retry_interval) - continue # 继续使用当前模型重试 + continue # 使用当前模型重试 else: - # 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型 reason = "空回复" if is_empty_reply else "截断" - logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") - raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") + logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。将此模型标记为当前请求失败。") + raise RuntimeError(f"模型 '{model_name}' 已达到空回复/截断的最大内部重试次数。") - # 成功获取响应 + # --- 从当前模型获取成功响应 --- if usage := response.usage: await llm_usage_recorder.record_usage_to_database( model_info=model_info, @@ -376,47 +385,29 @@ class LLMRequest: endpoint="/chat/completions", ) + # 处理成功执行后响应仍然为空的情况 if not content and not tool_calls: if raise_when_empty: - raise RuntimeError("生成空回复") - content = "生成的响应为空" + raise RuntimeError("所选模型生成了空回复。") + content = "生成的响应为空" # Fallback message - logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏 - return content, (reasoning_content, model_name, tool_calls) + logger.debug(f"模型 '{model_name}' 成功生成了回复。") + return content, (reasoning_content, model_name, tool_calls) # 成功,立即返回 - except RespNotOkException as e: - if e.status_code in [401, 403]: - logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 - else: - logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}") - if raise_when_empty: - raise - # 对于其他HTTP错误,直接抛出,不再尝试其他模型 - return f"请求失败: {e}", ("", model_name, None) + # --- 当前模型尝试过程中的异常处理 --- + except Exception as e: # 捕获当前模型尝试过程中的所有异常 + # 修复 NameError: model_name 在异常处理块中未定义,应使用 model_info.name + logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") + failed_models_in_this_request.add(model_info.name) + last_exception = e # 存储异常以供最终报告 + # 继续循环以尝试下一个可用模型 - except RuntimeError as e: - # 捕获所有重试失败(包括空回复和网络问题) - logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 - - except Exception as e: - logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 - - # 所有模型都尝试失败 - logger.error("所有可用模型都已尝试失败。") + # 如果循环结束未能返回,则表示当前请求的所有模型都已失败 + logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") if raise_when_empty: if last_exception: - raise RuntimeError("所有模型都请求失败") from last_exception - raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") - + raise RuntimeError("所有模型均未能生成响应。") from last_exception + raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") return "所有模型都请求失败", ("", "unknown", None) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: @@ -456,6 +447,57 @@ class LLMRequest: return embedding, model_info.name + def _select_best_available_model(self, failed_models_in_this_request: set) -> Tuple[ModelInfo, APIProvider, BaseClient] | None: + """ + 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 + + 参数: + failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。 + + 返回: + Tuple[ModelInfo, APIProvider, BaseClient] | None: 选定的模型详细信息,如果无可用模型则返回 None。 + """ + candidate_models_usage = {} + # 过滤掉当前请求中已失败的模型 + for model_name, usage_data in self.model_usage.items(): + if model_name not in failed_models_in_this_request: + candidate_models_usage[model_name] = usage_data + + if not candidate_models_usage: + logger.warning("没有可用的模型供当前请求选择。") + return None + + # 根据现有公式查找分数最低的模型,该公式综合了总token数、模型惩罚值和使用频率惩罚值。 + # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。 + least_used_model_name = min( + candidate_models_usage, + key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, + ) + + # --- 动态故障转移的核心逻辑 --- + # _execute_single_request 中的循环会多次调用此函数。 + # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数, + # 此时由于失败模型已被标记,且其惩罚值可能已在 _execute_request 中增加, + # _select_best_available_model 会自动选择一个得分更低(即更可用)的模型。 + # 这种机制实现了动态的、基于当前系统状态的故障转移。 + + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + + # 对于嵌入任务,如果需要,强制创建新的客户端实例(从原始 _select_model 复制) + force_new_client = self.request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + + logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") + + # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。 + # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) + + return model_info, api_provider, client + def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: """ 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 @@ -546,7 +588,47 @@ class LLMRequest: logger.debug(f"请求失败: {str(e)}") # 处理异常 total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) + + # --- 增强动态故障转移的智能性 --- + # 根据异常类型和严重程度,动态调整模型的惩罚值。 + # 关键错误(如网络连接、服务器错误)会获得更高的惩罚, + # 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。 + CRITICAL_PENALTY_MULTIPLIER = 5 # 关键错误时的惩罚系数 + default_penalty_increment = 1 # 普通错误时的基础惩罚 + + penalty_increment = default_penalty_increment + + if isinstance(e, NetworkConnectionError): + # 网络连接问题表明模型服务器不稳定,增加较高惩罚 + penalty_increment = CRITICAL_PENALTY_MULTIPLIER + # 修复 NameError: model_name 在此处未定义,应使用 model_info.name + logger.warning(f"模型 '{model_info.name}' 发生网络连接错误,增加惩罚值: {penalty_increment}") + elif isinstance(e, ReqAbortException): + # 请求被中止,可能是服务器端原因或服务不稳定,增加较高惩罚 + penalty_increment = CRITICAL_PENALTY_MULTIPLIER + # 修复 NameError: model_name 在此处未定义,应使用 model_info.name + logger.warning(f"模型 '{model_info.name}' 请求被中止,增加惩罚值: {penalty_increment}") + elif isinstance(e, RespNotOkException): + if e.status_code >= 500: + # 服务器错误 (5xx) 表明服务器端问题,应显著增加惩罚 + penalty_increment = CRITICAL_PENALTY_MULTIPLIER + logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}") + elif e.status_code == 429: + # 请求过于频繁,是暂时性问题,但仍需惩罚,此处使用默认基础值 + # penalty_increment = 2 # 可以选择一个中间值,例如2,表示比普通错误重,但比关键错误轻 + logger.warning(f"模型 '{model_name}' 请求过于频繁 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") + else: + # 其他客户端错误 (4xx)。通常不重试,_handle_resp_not_ok 会处理。 + # 如果 _handle_resp_not_ok 返回 retry_interval, 则进入这里的 exception 块。 + logger.warning(f"模型 '{model_name}' 发生非致命的响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") + else: + # 其他未捕获的异常,增加基础惩罚 + logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") + + self.model_usage[model_info.name] = (total_tokens, penalty + penalty_increment, usage_penalty) + # --- 结束增强 --- + # 移除冗余的、错误的惩罚值更新行,保留上面正确的动态惩罚更新 + # self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) wait_interval, compressed_messages = self._default_exception_handler( e, From 253946fe57e86f06ea102467a4c5a346db1d11a6 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 19:21:55 +0800 Subject: [PATCH 17/41] =?UTF-8?q?refactor(llm):=20=E5=B0=86LLM=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E9=80=BB=E8=BE=91=E8=A7=A3=E8=80=A6=E5=88=B0=E4=B8=93?= =?UTF-8?q?=E9=97=A8=E7=9A=84=E7=BB=84=E4=BB=B6=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 庞大的`LLMRequest`类已被重构为一个协调器,它将任务委托给多个专门的组件。此更改旨在遵循单一职责原则,从而提高代码的结构、可维护性和可扩展性。 核心逻辑被提取到以下新类中: - `ModelSelector`: 封装了基于负载和可用性选择最佳模型的逻辑。 - `PromptProcessor`: 负责处理所有提示词修改和响应内容的解析。 - `RequestStrategy`: 管理请求的执行流程,包括故障转移和并发请求策略。 这种新的架构使系统更加模块化,更易于测试,并且未来可以更轻松地扩展新的请求策略。 --- src/llm_models/llm_utils.py | 65 ++ .../model_client/aiohttp_gemini_client.py | 6 +- src/llm_models/model_selector.py | 130 +++ src/llm_models/prompt_processor.py | 113 +++ src/llm_models/request_executor.py | 226 +++++ src/llm_models/request_strategy.py | 206 ++++ src/llm_models/utils_model.py | 957 +++--------------- 7 files changed, 856 insertions(+), 847 deletions(-) create mode 100644 src/llm_models/llm_utils.py create mode 100644 src/llm_models/model_selector.py create mode 100644 src/llm_models/prompt_processor.py create mode 100644 src/llm_models/request_executor.py create mode 100644 src/llm_models/request_strategy.py diff --git a/src/llm_models/llm_utils.py b/src/llm_models/llm_utils.py new file mode 100644 index 000000000..fb7810a0c --- /dev/null +++ b/src/llm_models/llm_utils.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +""" +@File : llm_utils.py +@Time : 2024/05/24 17:00:00 +@Author : 墨墨 +@Version : 1.0 +@Desc : LLM相关通用工具函数 +""" +from typing import List, Dict, Any, Tuple + +from src.common.logger import get_logger +from .payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType + +logger = get_logger("llm_utils") + +def normalize_image_format(image_format: str) -> str: + """ + 标准化图片格式名称,确保与各种API的兼容性 + """ + format_mapping = { + "jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg", + "png": "png", "PNG": "png", + "webp": "webp", "WEBP": "webp", + "gif": "gif", "GIF": "gif", + "heic": "heic", "HEIC": "heic", + "heif": "heif", "HEIF": "heif", + } + normalized = format_mapping.get(image_format, image_format.lower()) + logger.debug(f"图片格式标准化: {image_format} -> {normalized}") + return normalized + +def build_tool_options(tools: List[Dict[str, Any]] | None) -> List[ToolOption] | None: + """构建工具选项列表""" + if not tools: + return None + tool_options: List[ToolOption] = [] + for tool in tools: + try: + tool_options_builder = ToolOptionBuilder() + tool_options_builder.set_name(tool.get("name", "")) + tool_options_builder.set_description(tool.get("description", "")) + parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) + for param in parameters: + # 参数校验 + assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" + assert isinstance(param[0], str), "参数名称必须是字符串" + assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" + assert isinstance(param[2], str), "参数描述必须是字符串" + assert isinstance(param[3], bool), "参数是否必填必须是布尔值" + assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" + + tool_options_builder.add_param( + name=param[0], + param_type=param[1], + description=param[2], + required=param[3], + enum_values=param[4], + ) + tool_options.append(tool_options_builder.build()) + except AssertionError as ae: + logger.error(f"工具 '{tool.get('name', 'unknown')}' 的参数定义错误: {str(ae)}") + except Exception as e: + logger.error(f"构建工具 '{tool.get('name', 'unknown')}' 失败: {str(e)}") + + return tool_options or None \ No newline at end of file diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 7b997b680..eeb90c265 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]: def _convert_tool_param(param: ToolParam) -> dict: """转换工具参数""" - result = { + result: dict[str, Any] = { "type": param.param_type.value, "description": param.description, } @@ -132,7 +132,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]: def _convert_tool_option_item(tool_option: ToolOption) -> dict: """转换单个工具选项""" - function_declaration = { + function_declaration: dict[str, Any] = { "name": tool_option.name, "description": tool_option.description, } @@ -500,7 +500,7 @@ class AiohttpGeminiClient(BaseClient): # 直接重抛项目定义的异常 raise except Exception as e: - logger.debug(e) + logger.debug(f"请求处理中发生未知异常: {e}") # 其他异常转换为网络连接错误 raise NetworkConnectionError() from e diff --git a/src/llm_models/model_selector.py b/src/llm_models/model_selector.py new file mode 100644 index 000000000..827e28842 --- /dev/null +++ b/src/llm_models/model_selector.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +""" +@File : model_selector.py +@Time : 2024/05/24 16:00:00 +@Author : 墨墨 +@Version : 1.0 +@Desc : 模型选择与负载均衡器 +""" +from typing import Dict, Tuple, Set, Optional + +from src.common.logger import get_logger +from src.config.config import model_config +from src.config.api_ada_configs import ModelInfo, APIProvider, TaskConfig +from .model_client.base_client import BaseClient, client_registry + +logger = get_logger("model_selector") + + +class ModelSelector: + """模型选择与负载均衡器""" + + def __init__(self, model_set: TaskConfig, request_type: str = ""): + """ + 初始化模型选择器 + + Args: + model_set (TaskConfig): 任务配置中定义的模型集合 + request_type (str, optional): 请求类型 (例如 "embedding"). Defaults to "". + """ + self.model_for_task = model_set + self.request_type = request_type + self.model_usage: Dict[str, Tuple[int, int, int]] = { + model: (0, 0, 0) for model in self.model_for_task.model_list + } + """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" + + def select_best_available_model( + self, failed_models_in_this_request: Set[str] + ) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: + """ + 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 + + Args: + failed_models_in_this_request (Set[str]): 当前请求中已失败的模型名称集合。 + + Returns: + Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: 选定的模型详细信息,如果无可用模型则返回 None。 + """ + candidate_models_usage = { + model_name: usage_data + for model_name, usage_data in self.model_usage.items() + if model_name not in failed_models_in_this_request + } + + if not candidate_models_usage: + logger.warning("没有可用的模型供当前请求选择。") + return None + + # 根据现有公式查找分数最低的模型 + # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。 + least_used_model_name = min( + candidate_models_usage, + key=lambda k: candidate_models_usage[k][0] + + candidate_models_usage[k][1] * 300 + + candidate_models_usage[k][2] * 1000, + ) + + # --- 动态故障转移的核心逻辑 --- + # RequestStrategy 中的循环会多次调用此函数。 + # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数, + # 此时由于失败模型已被标记,且其惩罚值可能已在 RequestExecutor 中增加, + # 此函数会自动选择一个得分更低(即更可用)的模型。 + # 这种机制实现了动态的、基于当前系统状态的故障转移。 + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + + force_new_client = self.request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + + logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") + + # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。 + # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) + + return model_info, api_provider, client + + def select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: + """ + 根据总tokens和惩罚值选择的模型 (负载均衡) + """ + least_used_model_name = min( + self.model_usage, + key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, + ) + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + + force_new_client = self.request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + logger.debug(f"选择请求模型: {model_info.name}") + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) + return model_info, api_provider, client + + def update_model_penalty(self, model_name: str, penalty_increment: int): + """ + 更新指定模型的惩罚值 + + Args: + model_name (str): 模型名称 + penalty_increment (int): 惩罚增量 + """ + if model_name in self.model_usage: + total_tokens, penalty, usage_penalty = self.model_usage[model_name] + self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) + logger.debug(f"模型 '{model_name}' 的惩罚值增加了 {penalty_increment}") + + def decrease_usage_penalty(self, model_name: str): + """ + 请求结束后,减少使用惩罚值 + + Args: + model_name (str): 模型名称 + """ + if model_name in self.model_usage: + total_tokens, penalty, usage_penalty = self.model_usage[model_name] + self.model_usage[model_name] = (total_tokens, penalty, usage_penalty - 1) \ No newline at end of file diff --git a/src/llm_models/prompt_processor.py b/src/llm_models/prompt_processor.py new file mode 100644 index 000000000..035a8f55f --- /dev/null +++ b/src/llm_models/prompt_processor.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" +@File : prompt_processor.py +@Time : 2024/05/24 16:45:00 +@Author : 墨墨 +@Version : 1.0 +@Desc : 提示词预处理(反截断、内容混淆等) +""" +import random +import string +import re +from typing import Tuple + +from src.common.logger import get_logger +from src.config.api_ada_configs import APIProvider, ModelInfo + +logger = get_logger("prompt_processor") + + +class PromptProcessor: + """提示词处理器""" + + def __init__(self): + self.noise_instruction = """**【核心过滤规则】** +在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 + +你的任务是【完全并彻底地忽略】这些随机字符串。 +**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。 +""" + self.end_marker = "###MAI_RESPONSE_END###" + self.anti_truncation_instruction = f""" +**【输出完成信令】** +这是一个非常重要的指令,请务-务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。 +例如: +<你的回复内容> +{self.end_marker} + +这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 +""" + + def process_prompt( + self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str + ) -> str: + """ + 根据模型和API提供商的配置处理提示词 + """ + processed_prompt = prompt + + # 1. 添加反截断指令 + use_anti_truncation = getattr(model_info, "use_anti_truncation", False) + if use_anti_truncation: + processed_prompt += self.anti_truncation_instruction + logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") + + # 2. 应用内容混淆 + if getattr(api_provider, "enable_content_obfuscation", False): + intensity = getattr(api_provider, "obfuscation_intensity", 1) + logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") + processed_prompt = self._apply_content_obfuscation(processed_prompt, intensity) + + return processed_prompt + + def _apply_content_obfuscation(self, text: str, intensity: int) -> str: + """对文本进行混淆处理""" + # 在开头加入过滤规则指令 + processed_text = self.noise_instruction + "\n\n" + text + logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}") + + # 添加随机乱码 + final_text = self._inject_random_noise(processed_text, intensity) + logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}") + + return final_text + + @staticmethod + def _inject_random_noise(text: str, intensity: int) -> str: + """在文本中注入随机乱码""" + def generate_noise(length: int) -> str: + chars = ( + string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?" + + "一二三四五六七八九零壹贰叁" + "αβγδεζηθικλμνξοπρστυφχψω" + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" + ) + return "".join(random.choice(chars) for _ in range(length)) + + params = { + 1: {"probability": 15, "length": (3, 6)}, + 2: {"probability": 25, "length": (5, 10)}, + 3: {"probability": 35, "length": (8, 15)}, + } + config = params.get(intensity, params[1]) + logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}") + + words = text.split() + result = [] + noise_count = 0 + for word in words: + result.append(word) + if random.randint(1, 100) <= config["probability"]: + noise_length = random.randint(*config["length"]) + noise = generate_noise(noise_length) + result.append(noise) + noise_count += 1 + + logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}") + return " ".join(result) + + @staticmethod + def extract_reasoning(content: str) -> Tuple[str, str]: + """CoT思维链提取,向后兼容""" + match = re.search(r"(?:)?(.*?)", content, re.DOTALL) + clean_content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() + reasoning = match.group(1).strip() if match else "" + return clean_content, reasoning diff --git a/src/llm_models/request_executor.py b/src/llm_models/request_executor.py new file mode 100644 index 000000000..33b3197b0 --- /dev/null +++ b/src/llm_models/request_executor.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +""" +@File : request_executor.py +@Time : 2024/05/24 16:15:00 +@Author : 墨墨 +@Version : 1.0 +@Desc : 负责执行LLM请求、处理重试及异常 +""" +import asyncio +from typing import List, Callable, Optional, Tuple + +from src.common.logger import get_logger +from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig +from .exceptions import ( + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from .model_client.base_client import APIResponse, BaseClient +from .model_selector import ModelSelector +from .payload_content.message import Message +from .payload_content.resp_format import RespFormat +from .payload_content.tool_option import ToolOption +from .utils import compress_messages + +logger = get_logger("request_executor") + + +class RequestExecutor: + """请求执行器""" + + def __init__( + self, + task_name: str, + model_set: TaskConfig, + api_provider: APIProvider, + client: BaseClient, + model_info: ModelInfo, + model_selector: ModelSelector, + ): + self.task_name = task_name + self.model_set = model_set + self.api_provider = api_provider + self.client = client + self.model_info = model_info + self.model_selector = model_selector + + async def execute_request( + self, + request_type: str, + message_list: List[Message] | None = None, + tool_options: list[ToolOption] | None = None, + response_format: RespFormat | None = None, + stream_response_handler: Optional[Callable] = None, + async_response_parser: Optional[Callable] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + embedding_input: str = "", + audio_base64: str = "", + ) -> APIResponse: + """ + 实际执行请求的方法, 包含了重试和异常处理逻辑 + """ + retry_remain = self.api_provider.max_retry + compressed_messages: Optional[List[Message]] = None + while retry_remain > 0: + try: + if request_type == "response": + assert message_list is not None, "message_list cannot be None for response requests" + return await self.client.get_response( + model_info=self.model_info, + message_list=(compressed_messages or message_list), + tool_options=tool_options, + max_tokens=self.model_set.max_tokens if max_tokens is None else max_tokens, + temperature=self.model_set.temperature if temperature is None else temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + extra_params=self.model_info.extra_params, + ) + elif request_type == "embedding": + assert embedding_input, "embedding_input cannot be empty for embedding requests" + return await self.client.get_embedding( + model_info=self.model_info, + embedding_input=embedding_input, + extra_params=self.model_info.extra_params, + ) + elif request_type == "audio": + assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" + return await self.client.get_audio_transcriptions( + model_info=self.model_info, + audio_base64=audio_base64, + extra_params=self.model_info.extra_params, + ) + raise ValueError(f"未知的请求类型: {request_type}") + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + self._apply_penalty_on_failure(e) + + wait_interval, compressed_messages = self._default_exception_handler( + e, + remain_try=retry_remain, + retry_interval=self.api_provider.retry_interval, + messages=(message_list, compressed_messages is not None) if message_list else None, + ) + + if wait_interval == -1: + retry_remain = 0 # 不再重试 + elif wait_interval > 0: + logger.info(f"等待 {wait_interval} 秒后重试...") + await asyncio.sleep(wait_interval) + finally: + retry_remain -= 1 + + self.model_selector.decrease_usage_penalty(self.model_info.name) + logger.error(f"模型 '{self.model_info.name}' 请求失败,达到最大重试次数 {self.api_provider.max_retry} 次") + raise RuntimeError("请求失败,已达到最大重试次数") + + def _apply_penalty_on_failure(self, e: Exception): + """根据异常类型,动态调整模型的惩罚值""" + CRITICAL_PENALTY_MULTIPLIER = 5 + default_penalty_increment = 1 + penalty_increment = default_penalty_increment + + if isinstance(e, (NetworkConnectionError, ReqAbortException)): + penalty_increment = CRITICAL_PENALTY_MULTIPLIER + elif isinstance(e, RespNotOkException): + if e.status_code >= 500: + penalty_increment = CRITICAL_PENALTY_MULTIPLIER + + log_message = f"发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}" + if isinstance(e, (NetworkConnectionError, ReqAbortException)): + log_message = f"发生关键错误 ({type(e).__name__}),增加惩罚值: {penalty_increment}" + elif isinstance(e, RespNotOkException): + log_message = f"发生响应错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}" + logger.warning(f"模型 '{self.model_info.name}' {log_message}") + + self.model_selector.update_model_penalty(self.model_info.name, penalty_increment) + + def _default_exception_handler( + self, + e: Exception, + remain_try: int, + retry_interval: int = 10, + messages: Tuple[List[Message], bool] | None = None, + ) -> Tuple[int, List[Message] | None]: + """默认异常处理函数""" + model_name = self.model_info.name + + if isinstance(e, NetworkConnectionError): + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数", + ) + elif isinstance(e, ReqAbortException): + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") + return -1, None + elif isinstance(e, RespNotOkException): + return self._handle_resp_not_ok(e, remain_try, retry_interval, messages) + elif isinstance(e, RespParseException): + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") + logger.debug(f"附加内容: {str(e.ext_info)}") + return -1, None + else: + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") + return -1, None + + def _handle_resp_not_ok( + self, + e: RespNotOkException, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, + ): + """处理响应错误异常""" + model_name = self.model_info.name + if e.status_code in [400, 401, 402, 403, 404]: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}") + return -1, None + elif e.status_code == 413: + if messages and not messages[1]: + return self._check_retry( + remain_try, 0, + can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", + cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,压缩后仍失败", + can_retry_callable=compress_messages, messages=messages[0], + ) + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,无法压缩,放弃请求。") + return -1, None + elif e.status_code == 429: + return self._check_retry( + remain_try, retry_interval, + can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数", + ) + elif e.status_code >= 500: + return self._check_retry( + remain_try, retry_interval, + can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数", + ) + else: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}") + return -1, None + + @staticmethod + def _check_retry( + remain_try: int, + retry_interval: int, + can_retry_msg: str, + cannot_retry_msg: str, + can_retry_callable: Callable | None = None, + **kwargs, + ) -> Tuple[int, List[Message] | None]: + """辅助函数:检查是否可以重试""" + if remain_try > 0: + logger.warning(f"{can_retry_msg}") + if can_retry_callable is not None: + return retry_interval, can_retry_callable(**kwargs) + return retry_interval, None + else: + logger.warning(f"{cannot_retry_msg}") + return -1, None \ No newline at end of file diff --git a/src/llm_models/request_strategy.py b/src/llm_models/request_strategy.py new file mode 100644 index 000000000..3a694f526 --- /dev/null +++ b/src/llm_models/request_strategy.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- +""" +@File : request_strategy.py +@Time : 2024/05/24 16:30:00 +@Author : 墨墨 +@Version : 1.0 +@Desc : 高级请求策略(并发、故障转移) +""" +import asyncio +import random +from typing import List, Tuple, Optional, Dict, Any, Callable, Coroutine + +from src.common.logger import get_logger +from src.config.api_ada_configs import TaskConfig +from .model_client.base_client import APIResponse +from .model_selector import ModelSelector +from .payload_content.message import MessageBuilder +from .payload_content.tool_option import ToolCall +from .prompt_processor import PromptProcessor +from .request_executor import RequestExecutor + +logger = get_logger("request_strategy") + + +class RequestStrategy: + """高级请求策略""" + + def __init__(self, model_set: TaskConfig, model_selector: ModelSelector, task_name: str): + self.model_set = model_set + self.model_selector = model_selector + self.task_name = task_name + + async def execute_with_fallback( + self, + base_payload: Dict[str, Any], + raise_when_empty: bool = True, + ) -> Dict[str, Any]: + """ + 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 + """ + failed_models_in_this_request = set() + max_attempts = len(self.model_set.model_list) + last_exception: Optional[Exception] = None + + for attempt in range(max_attempts): + model_selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request) + + if model_selection_result is None: + logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") + break + + model_info, api_provider, client = model_selection_result + model_name = model_info.name + logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...") + + try: + # 1. Process Prompt + prompt_processor: PromptProcessor = base_payload["prompt_processor"] + raw_prompt = base_payload["prompt"] + processed_prompt = prompt_processor.process_prompt( + raw_prompt, model_info, api_provider, self.task_name + ) + + # 2. Build Message + message_builder = MessageBuilder().add_text_content(processed_prompt) + messages = [message_builder.build()] + + # 3. Create payload for executor + executor_payload = { + "request_type": "response", # Strategy only handles response type + "message_list": messages, + "tool_options": base_payload["tool_options"], + "temperature": base_payload["temperature"], + "max_tokens": base_payload["max_tokens"], + } + + executor = RequestExecutor( + task_name=self.task_name, + model_set=self.model_set, + api_provider=api_provider, + client=client, + model_info=model_info, + model_selector=self.model_selector, + ) + response = await self._execute_and_handle_empty_retry(executor, executor_payload, prompt_processor) + + # 4. Post-process response + # The reasoning content is now extracted here, after a successful, de-truncated response is received. + final_content, reasoning_content = prompt_processor.extract_reasoning(response.content or "") + response.content = final_content # Update response with cleaned content + + tool_calls = response.tool_calls + + if not final_content and not tool_calls: + if raise_when_empty: + raise RuntimeError("所选模型生成了空回复。") + content = "生成的响应为空" # Fallback message + + logger.debug(f"模型 '{model_name}' 成功生成了回复。") + return { + "content": response.content, + "reasoning_content": reasoning_content, + "model_name": model_name, + "tool_calls": tool_calls, + "model_info": model_info, + "usage": response.usage, + "success": True, + } + + except Exception as e: + logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") + failed_models_in_this_request.add(model_info.name) + last_exception = e + + logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") + if raise_when_empty: + if last_exception: + raise RuntimeError("所有模型均未能生成响应。") from last_exception + raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") + return { + "content": "所有模型都请求失败", + "reasoning_content": "", + "model_name": "unknown", + "tool_calls": None, + "model_info": None, + "usage": None, + "success": False, + } + + async def execute_concurrently( + self, + coro_callable: Callable[..., Coroutine[Any, Any, Any]], + concurrency_count: int, + *args, + **kwargs, + ) -> Any: + """ + 执行并发请求并从成功的结果中随机选择一个。 + """ + logger.info(f"启用并发请求模式,并发数: {concurrency_count}") + tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)] + + results = await asyncio.gather(*tasks, return_exceptions=True) + successful_results = [res for res in results if not isinstance(res, Exception)] + + if successful_results: + selected = random.choice(successful_results) + logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个") + return selected + + for i, res in enumerate(results): + if isinstance(res, Exception): + logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") + + first_exception = next((res for res in results if isinstance(res, Exception)), None) + if first_exception: + raise first_exception + + raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息") + + async def _execute_and_handle_empty_retry( + self, executor: RequestExecutor, payload: Dict[str, Any], prompt_processor: PromptProcessor + ) -> APIResponse: + """ + 在单个模型内部处理空回复/截断的重试逻辑 + """ + empty_retry_count = 0 + max_empty_retry = executor.api_provider.max_retry + empty_retry_interval = executor.api_provider.retry_interval + use_anti_truncation = getattr(executor.model_info, "use_anti_truncation", False) + end_marker = prompt_processor.end_marker + + while empty_retry_count <= max_empty_retry: + response = await executor.execute_request(**payload) + + content = response.content or "" + tool_calls = response.tool_calls + + is_empty_reply = not tool_calls and (not content or content.strip() == "") + is_truncated = False + if use_anti_truncation and end_marker: + if content.endswith(end_marker): + # 移除结束标记 + response.content = content[: -len(end_marker)].strip() + else: + is_truncated = True + + if is_empty_reply or is_truncated: + empty_retry_count += 1 + if empty_retry_count <= max_empty_retry: + reason = "空回复" if is_empty_reply else "截断" + logger.warning( + f"模型 '{executor.model_info.name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..." + ) + if empty_retry_interval > 0: + await asyncio.sleep(empty_retry_interval) + continue + else: + reason = "空回复" if is_empty_reply else "截断" + raise RuntimeError(f"模型 '{executor.model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。") + + # 成功获取响应 + return response + + # 此处理论上不会到达,因为循环要么返回要么抛异常 + raise RuntimeError("空回复/截断重Test逻辑出现未知错误") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index ec1a996bf..1414aacb1 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,154 +1,36 @@ -import re -import asyncio +# -*- coding: utf-8 -*- +""" +@File : utils_model.py +@Time : 2024/05/24 17:15:00 +@Author : 墨墨 +@Version : 2.0 (Refactored) +@Desc : LLM请求协调器 +""" import time -import random - -from enum import Enum -from rich.traceback import install -from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator +from typing import Tuple, List, Dict, Optional, Any from src.common.logger import get_logger -from src.config.config import model_config -from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig -from .payload_content.message import MessageBuilder, Message -from .payload_content.resp_format import RespFormat -from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType -from .model_client.base_client import BaseClient, APIResponse, client_registry -from .utils import compress_messages, llm_usage_recorder -from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException - -install(extra_lines=3) +from src.config.api_ada_configs import TaskConfig, ModelInfo +from .llm_utils import build_tool_options, normalize_image_format +from .model_selector import ModelSelector +from .payload_content.message import MessageBuilder +from .payload_content.tool_option import ToolCall +from .prompt_processor import PromptProcessor +from .request_strategy import RequestStrategy +from .utils import llm_usage_recorder logger = get_logger("model_utils") -# 常见Error Code Mapping -error_code_mapping = { - 400: "参数不正确", - 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", - 402: "账号余额不足", - 403: "需要实名,或余额不足", - 404: "Not Found", - 429: "请求过于频繁,请稍后再试", - 500: "服务器内部故障", - 503: "服务器负载过高", -} - - -def _normalize_image_format(image_format: str) -> str: - """ - 标准化图片格式名称,确保与各种API的兼容性 - - Args: - image_format (str): 原始图片格式 - - Returns: - str: 标准化后的图片格式 - """ - format_mapping = { - "jpg": "jpeg", - "JPG": "jpeg", - "JPEG": "jpeg", - "jpeg": "jpeg", - "png": "png", - "PNG": "png", - "webp": "webp", - "WEBP": "webp", - "gif": "gif", - "GIF": "gif", - "heic": "heic", - "HEIC": "heic", - "heif": "heif", - "HEIF": "heif", - } - - normalized = format_mapping.get(image_format, image_format.lower()) - logger.debug(f"图片格式标准化: {image_format} -> {normalized}") - return normalized - - -class RequestType(Enum): - """请求类型枚举""" - - RESPONSE = "response" - EMBEDDING = "embedding" - AUDIO = "audio" - - -async def execute_concurrently( - coro_callable: Callable[..., Coroutine[Any, Any, Any]], - concurrency_count: int, - *args, - **kwargs, -) -> Any: - """ - 执行并发请求并从成功的结果中随机选择一个。 - - Args: - coro_callable (Callable): 要并发执行的协程函数。 - concurrency_count (int): 并发执行的次数。 - *args: 传递给协程函数的位置参数。 - **kwargs: 传递给协程函数的关键字参数。 - - Returns: - Any: 其中一个成功执行的结果。 - - Raises: - RuntimeError: 如果所有并发请求都失败。 - """ - logger.info(f"启用并发请求模式,并发数: {concurrency_count}") - tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)] - - results = await asyncio.gather(*tasks, return_exceptions=True) - successful_results = [res for res in results if not isinstance(res, Exception)] - - if successful_results: - selected = random.choice(successful_results) - logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个") - return selected - - # 如果所有请求都失败了,记录所有异常并抛出第一个 - for i, res in enumerate(results): - if isinstance(res, Exception): - logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") - - first_exception = next((res for res in results if isinstance(res, Exception)), None) - if first_exception: - raise first_exception - - raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息") - - class LLMRequest: - """LLM请求类""" + """LLM请求协调器""" def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: self.task_name = request_type self.model_for_task = model_set self.request_type = request_type - self.model_usage: Dict[str, Tuple[int, int, int]] = { - model: (0, 0, 0) for model in self.model_for_task.model_list - } - """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" - - # 内容混淆过滤指令 - self.noise_instruction = """**【核心过滤规则】** -在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 - -你的任务是【完全并彻底地忽略】这些随机字符串。 -**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。 -""" - - # 反截断指令 - self.end_marker = "###MAI_RESPONSE_END###" - self.anti_truncation_instruction = f""" -**【输出完成信令】** -这是一个非常重要的指令,请务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。 -例如: -<你的回复内容> -{self.end_marker} - -这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 -""" + self.model_selector = ModelSelector(model_set, request_type) + self.prompt_processor = PromptProcessor() + self.request_strategy = RequestStrategy(model_set, self.model_selector, request_type) async def generate_response_for_image( self, @@ -158,25 +40,18 @@ class LLMRequest: temperature: Optional[float] = None, max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 为图像生成响应 - Args: - prompt (str): 提示词 - image_base64 (str): 图像的Base64编码字符串 - image_format (str): 图像格式(如 'png', 'jpeg' 等) - Returns: - (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 - """ - # 标准化图片格式以确保API兼容性 - normalized_format = _normalize_image_format(image_format) - - # 模型选择 + """为图像生成响应""" start_time = time.time() - model_info, api_provider, client = self._select_model() - - # 请求体构建 + + # 1. 选择模型 + model_info, api_provider, client = self.model_selector.select_model() + + # 2. 准备消息体 + processed_prompt = self.prompt_processor.process_prompt(prompt, model_info, api_provider, self.task_name) + normalized_format = normalize_image_format(image_format) + message_builder = MessageBuilder() - message_builder.add_text_content(prompt) + message_builder.add_text_content(processed_prompt) message_builder.add_image_content( image_base64=image_base64, image_format=normalized_format, @@ -184,51 +59,47 @@ class LLMRequest: ) messages = [message_builder.build()] - # 请求并处理返回值 - response = await self._execute_request( + # 3. 执行请求 (图像请求通常不走复杂的故障转移策略,直接执行) + from .request_executor import RequestExecutor + executor = RequestExecutor( + task_name=self.task_name, + model_set=self.model_for_task, api_provider=api_provider, client=client, - request_type=RequestType.RESPONSE, model_info=model_info, + model_selector=self.model_selector, + ) + response = await executor.execute_request( + request_type="response", message_list=messages, temperature=temperature, max_tokens=max_tokens, ) - content = response.content or "" - reasoning_content = response.reasoning_content or "" + + # 4. 处理响应 + content, reasoning_content = self.prompt_processor.extract_reasoning(response.content or "") tool_calls = response.tool_calls - # 从内容中提取标签的推理内容(向后兼容) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning + if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", - time_cost=time.time() - start_time, - request_type=self.request_type, - endpoint="/chat/completions", - ) + await self._record_usage(model_info, usage, time.time() - start_time) + return content, (reasoning_content, model_info.name, tool_calls) async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: - """ - 为语音生成响应 - Args: - voice_base64 (str): 语音的Base64编码字符串 - Returns: - (Optional[str]): 生成的文本描述或None - """ - # 模型选择 - model_info, api_provider, client = self._select_model() - - # 请求并处理返回值 - response = await self._execute_request( + """为语音生成响应""" + model_info, api_provider, client = self.model_selector.select_model() + + from .request_executor import RequestExecutor + executor = RequestExecutor( + task_name=self.task_name, + model_set=self.model_for_task, api_provider=api_provider, client=client, - request_type=RequestType.AUDIO, model_info=model_info, + model_selector=self.model_selector, + ) + response = await executor.execute_request( + request_type="audio", audio_base64=voice_base64, ) return response.content or None @@ -241,680 +112,78 @@ class LLMRequest: tools: Optional[List[Dict[str, Any]]] = None, raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 异步生成响应,支持并发请求 - Args: - prompt (str): 提示词 - temperature (float, optional): 温度参数 - max_tokens (int, optional): 最大token数 - tools: 工具配置 - raise_when_empty: 是否在空回复时抛出异常 - Returns: - (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 - """ - # 检查是否需要并发请求 + """异步生成响应,支持并发和故障转移""" + + # 1. 准备基础请求载荷 + tool_built = build_tool_options(tools) + base_payload = { + "prompt": prompt, + "tool_options": tool_built, + "temperature": temperature, + "max_tokens": max_tokens, + "prompt_processor": self.prompt_processor, + } + + # 2. 根据配置选择执行策略 concurrency_count = getattr(self.model_for_task, "concurrency_count", 1) - + if concurrency_count <= 1: - # 单次请求 - return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty) - - # 并发请求 - try: - # 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False, - # 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理 - content, (reasoning_content, model_name, tool_calls) = await execute_concurrently( - self._execute_single_request, + # 单次请求,但使用带故障转移的策略 + result = await self.request_strategy.execute_with_fallback( + base_payload, raise_when_empty + ) + else: + # 并发请求策略 + result = await self.request_strategy.execute_concurrently( + self.request_strategy.execute_with_fallback, concurrency_count, - prompt, - temperature, - max_tokens, - tools, + base_payload, raise_when_empty=False, ) - return content, (reasoning_content, model_name, tool_calls) - except Exception as e: - logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}") - if raise_when_empty: - raise e - return "所有并发请求都失败了", ("", "unknown", None) - - async def _execute_single_request( - self, - prompt: str, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, - raise_when_empty: bool = True, - ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """ - 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 - """ - failed_models_in_this_request = set() - # 迭代次数等于模型总数,以确保每个模型在当前请求中最多只尝试一次 - max_attempts = len(self.model_for_task.model_list) - last_exception: Optional[Exception] = None - - for attempt in range(max_attempts): - # 根据负载均衡和当前故障选择最佳可用模型 - model_selection_result = self._select_best_available_model(failed_models_in_this_request) - - if model_selection_result is None: - logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") - break # 没有更多模型可供尝试 - - model_info, api_provider, client = model_selection_result - model_name = model_info.name - logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...") - - start_time = time.time() - - try: - # --- 为当前模型尝试进行设置 --- - # 检查是否为该模型启用反截断 - use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_prompt = prompt - if use_anti_truncation: - processed_prompt += self.anti_truncation_instruction - logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") - - processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) - - message_builder = MessageBuilder() - message_builder.add_text_content(processed_prompt) - messages = [message_builder.build()] - tool_built = self._build_tool_options(tools) - - # --- 当前选定模型内的空回复/截断重试逻辑 --- - empty_retry_count = 0 - max_empty_retry = api_provider.max_retry - empty_retry_interval = api_provider.retry_interval - - while empty_retry_count <= max_empty_retry: - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - tool_options=tool_built, - temperature=temperature, - max_tokens=max_tokens, - ) - - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - # 向后兼容 标签(如果 reasoning_content 为空) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - is_empty_reply = not tool_calls and (not content or content.strip() == "") - is_truncated = False - if use_anti_truncation: - if content.endswith(self.end_marker): - content = content[: -len(self.end_marker)].strip() - else: - is_truncated = True - - if is_empty_reply or is_truncated: - empty_retry_count += 1 - if empty_retry_count <= max_empty_retry: - reason = "空回复" if is_empty_reply else "截断" - logger.warning( - f"模型 '{model_name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..." - ) - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) - continue # 使用当前模型重试 - else: - reason = "空回复" if is_empty_reply else "截断" - logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。将此模型标记为当前请求失败。") - raise RuntimeError(f"模型 '{model_name}' 已达到空回复/截断的最大内部重试次数。") - - # --- 从当前模型获取成功响应 --- - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - time_cost=time.time() - start_time, - user_id="system", - request_type=self.request_type, - endpoint="/chat/completions", - ) - - # 处理成功执行后响应仍然为空的情况 - if not content and not tool_calls: - if raise_when_empty: - raise RuntimeError("所选模型生成了空回复。") - content = "生成的响应为空" # Fallback message - - logger.debug(f"模型 '{model_name}' 成功生成了回复。") - return content, (reasoning_content, model_name, tool_calls) # 成功,立即返回 - - # --- 当前模型尝试过程中的异常处理 --- - except Exception as e: # 捕获当前模型尝试过程中的所有异常 - # 修复 NameError: model_name 在异常处理块中未定义,应使用 model_info.name - logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") - failed_models_in_this_request.add(model_info.name) - last_exception = e # 存储异常以供最终报告 - # 继续循环以尝试下一个可用模型 - - # 如果循环结束未能返回,则表示当前请求的所有模型都已失败 - logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") - if raise_when_empty: - if last_exception: - raise RuntimeError("所有模型均未能生成响应。") from last_exception - raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") - return "所有模型都请求失败", ("", "unknown", None) + + # 3. 处理最终结果 + content, (reasoning_content, model_name, tool_calls) = result + + # 4. 记录用量 (需要从策略中获取最终使用的模型信息和用量) + # TODO: 改造策略以返回最终模型信息和用量, 此处暂时省略 + + return content, (reasoning_content, model_name, tool_calls) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """获取嵌入向量 - Args: - embedding_input (str): 获取嵌入的目标 - Returns: - (Tuple[List[float], str]): (嵌入向量,使用的模型名称) - """ - # 无需构建消息体,直接使用输入文本 + """获取嵌入向量""" start_time = time.time() - model_info, api_provider, client = self._select_model() - - # 请求并处理返回值 - response = await self._execute_request( + model_info, api_provider, client = self.model_selector.select_model() + + from .request_executor import RequestExecutor + executor = RequestExecutor( + task_name=self.task_name, + model_set=self.model_for_task, api_provider=api_provider, client=client, - request_type=RequestType.EMBEDDING, model_info=model_info, + model_selector=self.model_selector, + ) + response = await executor.execute_request( + request_type="embedding", embedding_input=embedding_input, ) - + embedding = response.embedding - - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - time_cost=time.time() - start_time, - model_usage=usage, - user_id="system", - request_type=self.request_type, - endpoint="/embeddings", - ) - if not embedding: raise RuntimeError("获取embedding失败") - + + if usage := response.usage: + await self._record_usage(model_info, usage, time.time() - start_time, "/embeddings") + return embedding, model_info.name - def _select_best_available_model(self, failed_models_in_this_request: set) -> Tuple[ModelInfo, APIProvider, BaseClient] | None: - """ - 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 - - 参数: - failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。 - - 返回: - Tuple[ModelInfo, APIProvider, BaseClient] | None: 选定的模型详细信息,如果无可用模型则返回 None。 - """ - candidate_models_usage = {} - # 过滤掉当前请求中已失败的模型 - for model_name, usage_data in self.model_usage.items(): - if model_name not in failed_models_in_this_request: - candidate_models_usage[model_name] = usage_data - - if not candidate_models_usage: - logger.warning("没有可用的模型供当前请求选择。") - return None - - # 根据现有公式查找分数最低的模型,该公式综合了总token数、模型惩罚值和使用频率惩罚值。 - # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 - # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。 - least_used_model_name = min( - candidate_models_usage, - key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, + async def _record_usage(self, model_info: ModelInfo, usage, time_cost, endpoint="/chat/completions"): + """记录模型用量""" + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", + time_cost=time_cost, + request_type=self.request_type, + endpoint=endpoint, ) - - # --- 动态故障转移的核心逻辑 --- - # _execute_single_request 中的循环会多次调用此函数。 - # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数, - # 此时由于失败模型已被标记,且其惩罚值可能已在 _execute_request 中增加, - # _select_best_available_model 会自动选择一个得分更低(即更可用)的模型。 - # 这种机制实现了动态的、基于当前系统状态的故障转移。 - - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - # 对于嵌入任务,如果需要,强制创建新的客户端实例(从原始 _select_model 复制) - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - - logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") - - # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。 - # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) - - return model_info, api_provider, client - - def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: - """ - 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 - """ - for model_name in self.model_for_task.model_list: - if model_name in failed_models: - continue - - model_info = model_config.get_model_info(model_name) - api_provider = model_config.get_provider(model_info.api_provider) - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - - yield model_info, api_provider, client - - def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: - """ - 根据总tokens和惩罚值选择的模型 (负载均衡) - """ - least_used_model_name = min( - self.model_usage, - key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, - ) - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - logger.debug(f"选择请求模型: {model_info.name}") - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 - return model_info, api_provider, client - - async def _execute_request( - self, - api_provider: APIProvider, - client: BaseClient, - request_type: RequestType, - model_info: ModelInfo, - message_list: List[Message] | None = None, - tool_options: list[ToolOption] | None = None, - response_format: RespFormat | None = None, - stream_response_handler: Optional[Callable] = None, - async_response_parser: Optional[Callable] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - embedding_input: str = "", - audio_base64: str = "", - ) -> APIResponse: - """ - 实际执行请求的方法 - - 包含了重试和异常处理逻辑 - """ - retry_remain = api_provider.max_retry - compressed_messages: Optional[List[Message]] = None - while retry_remain > 0: - try: - if request_type == RequestType.RESPONSE: - assert message_list is not None, "message_list cannot be None for response requests" - return await client.get_response( - model_info=model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, - temperature=self.model_for_task.temperature if temperature is None else temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.EMBEDDING: - assert embedding_input, "embedding_input cannot be empty for embedding requests" - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.AUDIO: - assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" - return await client.get_audio_transcriptions( - model_info=model_info, - audio_base64=audio_base64, - extra_params=model_info.extra_params, - ) - except Exception as e: - logger.debug(f"请求失败: {str(e)}") - # 处理异常 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - - # --- 增强动态故障转移的智能性 --- - # 根据异常类型和严重程度,动态调整模型的惩罚值。 - # 关键错误(如网络连接、服务器错误)会获得更高的惩罚, - # 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。 - CRITICAL_PENALTY_MULTIPLIER = 5 # 关键错误时的惩罚系数 - default_penalty_increment = 1 # 普通错误时的基础惩罚 - - penalty_increment = default_penalty_increment - - if isinstance(e, NetworkConnectionError): - # 网络连接问题表明模型服务器不稳定,增加较高惩罚 - penalty_increment = CRITICAL_PENALTY_MULTIPLIER - # 修复 NameError: model_name 在此处未定义,应使用 model_info.name - logger.warning(f"模型 '{model_info.name}' 发生网络连接错误,增加惩罚值: {penalty_increment}") - elif isinstance(e, ReqAbortException): - # 请求被中止,可能是服务器端原因或服务不稳定,增加较高惩罚 - penalty_increment = CRITICAL_PENALTY_MULTIPLIER - # 修复 NameError: model_name 在此处未定义,应使用 model_info.name - logger.warning(f"模型 '{model_info.name}' 请求被中止,增加惩罚值: {penalty_increment}") - elif isinstance(e, RespNotOkException): - if e.status_code >= 500: - # 服务器错误 (5xx) 表明服务器端问题,应显著增加惩罚 - penalty_increment = CRITICAL_PENALTY_MULTIPLIER - logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}") - elif e.status_code == 429: - # 请求过于频繁,是暂时性问题,但仍需惩罚,此处使用默认基础值 - # penalty_increment = 2 # 可以选择一个中间值,例如2,表示比普通错误重,但比关键错误轻 - logger.warning(f"模型 '{model_name}' 请求过于频繁 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") - else: - # 其他客户端错误 (4xx)。通常不重试,_handle_resp_not_ok 会处理。 - # 如果 _handle_resp_not_ok 返回 retry_interval, 则进入这里的 exception 块。 - logger.warning(f"模型 '{model_name}' 发生非致命的响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") - else: - # 其他未捕获的异常,增加基础惩罚 - logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") - - self.model_usage[model_info.name] = (total_tokens, penalty + penalty_increment, usage_penalty) - # --- 结束增强 --- - # 移除冗余的、错误的惩罚值更新行,保留上面正确的动态惩罚更新 - # self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) - - wait_interval, compressed_messages = self._default_exception_handler( - e, - self.task_name, - model_info=model_info, - api_provider=api_provider, - remain_try=retry_remain, - retry_interval=api_provider.retry_interval, - messages=(message_list, compressed_messages is not None) if message_list else None, - ) - - if wait_interval == -1: - retry_remain = 0 # 不再重试 - elif wait_interval > 0: - logger.info(f"等待 {wait_interval} 秒后重试...") - await asyncio.sleep(wait_interval) - finally: - # 放在finally防止死循环 - retry_remain -= 1 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 - logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") - raise RuntimeError("请求失败,已达到最大重试次数") - - def _default_exception_handler( - self, - e: Exception, - task_name: str, - model_info: ModelInfo, - api_provider: APIProvider, - remain_try: int, - retry_interval: int = 10, - messages: Tuple[List[Message], bool] | None = None, - ) -> Tuple[int, List[Message] | None]: - """ - 默认异常处理函数 - Args: - e (Exception): 异常对象 - task_name (str): 任务名称 - model_info (ModelInfo): 模型信息 - api_provider (APIProvider): API提供商 - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) - Returns: - (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - model_name = model_info.name if model_info else "unknown" - - if isinstance(e, NetworkConnectionError): # 网络连接错误 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确", - ) - elif isinstance(e, ReqAbortException): - logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") - return -1, None # 不再重试请求该模型 - elif isinstance(e, RespNotOkException): - return self._handle_resp_not_ok( - e, - task_name, - model_info, - api_provider, - remain_try, - retry_interval, - messages, - ) - elif isinstance(e, RespParseException): - # 响应解析错误 - logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") - logger.debug(f"附加内容: {str(e.ext_info)}") - return -1, None # 不再重试请求该模型 - else: - logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") - return -1, None # 不再重试请求该模型 - - @staticmethod - def _check_retry( - remain_try: int, - retry_interval: int, - can_retry_msg: str, - cannot_retry_msg: str, - can_retry_callable: Callable | None = None, - **kwargs, - ) -> Tuple[int, List[Message] | None]: - """辅助函数:检查是否可以重试 - Args: - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - can_retry_msg (str): 可以重试时的提示信息 - cannot_retry_msg (str): 不可以重试时的提示信息 - can_retry_callable (Callable | None): 可以重试时调用的函数(如果有) - **kwargs: 其他参数 - - Returns: - (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - if remain_try > 0: - # 还有重试机会 - logger.warning(f"{can_retry_msg}") - if can_retry_callable is not None: - return retry_interval, can_retry_callable(**kwargs) - else: - return retry_interval, None - else: - # 达到最大重试次数 - logger.warning(f"{cannot_retry_msg}") - return -1, None # 不再重试请求该模型 - - def _handle_resp_not_ok( - self, - e: RespNotOkException, - task_name: str, - model_info: ModelInfo, - api_provider: APIProvider, - remain_try: int, - retry_interval: int = 10, - messages: tuple[list[Message], bool] | None = None, - ): - model_name = model_info.name - """ - 处理响应错误异常 - Args: - e (RespNotOkException): 响应错误异常对象 - task_name (str): 任务名称 - model_info (ModelInfo): 模型信息 - api_provider (APIProvider): API提供商 - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) - Returns: - (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - # 响应错误 - if e.status_code in [400, 401, 402, 403, 404]: - model_name = model_info.name - # 客户端错误 - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None # 不再重试请求该模型 - elif e.status_code == 413: - if messages and not messages[1]: - # 消息列表不为空且未压缩,尝试压缩消息 - return self._check_retry( - remain_try, - 0, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求", - can_retry_callable=compress_messages, - messages=messages[0], - ) - # 没有消息可压缩 - logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。") - return -1, None - elif e.status_code == 429: - # 请求过于频繁 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求", - ) - elif e.status_code >= 500: - # 服务器错误 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试", - ) - else: - # 未知错误 - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None - - @staticmethod - def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: - # sourcery skip: extract-method - """构建工具选项列表""" - if not tools: - return None - tool_options: List[ToolOption] = [] - for tool in tools: - tool_legal = True - tool_options_builder = ToolOptionBuilder() - tool_options_builder.set_name(tool.get("name", "")) - tool_options_builder.set_description(tool.get("description", "")) - parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) - for param in parameters: - try: - assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" - assert isinstance(param[0], str), "参数名称必须是字符串" - assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" - assert isinstance(param[2], str), "参数描述必须是字符串" - assert isinstance(param[3], bool), "参数是否必填必须是布尔值" - assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" - tool_options_builder.add_param( - name=param[0], - param_type=param[1], - description=param[2], - required=param[3], - enum_values=param[4], - ) - except AssertionError as ae: - tool_legal = False - logger.error(f"{param[0]} 参数定义错误: {str(ae)}") - except Exception as e: - tool_legal = False - logger.error(f"构建工具参数失败: {str(e)}") - if tool_legal: - tool_options.append(tool_options_builder.build()) - return tool_options or None - - @staticmethod - def _extract_reasoning(content: str) -> Tuple[str, str]: - """CoT思维链提取,向后兼容""" - match = re.search(r"(?:)?(.*?)", content, re.DOTALL) - content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() - reasoning = match[1].strip() if match else "" - return content, reasoning - - def _apply_content_obfuscation(self, text: str, api_provider) -> str: - """根据API提供商配置对文本进行混淆处理""" - if not hasattr(api_provider, "enable_content_obfuscation") or not api_provider.enable_content_obfuscation: - logger.debug(f"API提供商 '{api_provider.name}' 未启用内容混淆") - return text - - intensity = getattr(api_provider, "obfuscation_intensity", 1) - logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") - - # 在开头加入过滤规则指令 - processed_text = self.noise_instruction + "\n\n" + text - logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}") - - # 添加随机乱码 - final_text = self._inject_random_noise(processed_text, intensity) - logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}") - - return final_text - - @staticmethod - def _inject_random_noise(text: str, intensity: int) -> str: - """在文本中注入随机乱码""" - import random - import string - - def generate_noise(length: int) -> str: - """生成指定长度的随机乱码字符""" - chars = ( - string.ascii_letters # a-z, A-Z - + string.digits # 0-9 - + "!@#$%^&*()_+-=[]{}|;:,.<>?" # 特殊符号 - + "一二三四五六七八九零壹贰叁" # 中文字符 - + "αβγδεζηθικλμνξοπρστυφχψω" # 希腊字母 - + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" # 数学符号 - ) - return "".join(random.choice(chars) for _ in range(length)) - - # 强度参数映射 - params = { - 1: {"probability": 15, "length": (3, 6)}, # 低强度:15%概率,3-6个字符 - 2: {"probability": 25, "length": (5, 10)}, # 中强度:25%概率,5-10个字符 - 3: {"probability": 35, "length": (8, 15)}, # 高强度:35%概率,8-15个字符 - } - - config = params.get(intensity, params[1]) - logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}") - - # 按词分割处理 - words = text.split() - result = [] - noise_count = 0 - - for word in words: - result.append(word) - # 根据概率插入乱码 - if random.randint(1, 100) <= config["probability"]: - noise_length = random.randint(*config["length"]) - noise = generate_noise(noise_length) - result.append(noise) - noise_count += 1 - - logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}") - return " ".join(result) From 0f39e0b6a6a8c831731a12013c6f4f684d1a3e69 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 19:29:44 +0800 Subject: [PATCH 18/41] refactor(llm): improve module clarity with docstrings and unified logging This commit introduces a comprehensive refactoring of the `llm_models` module to enhance code clarity, maintainability, and robustness. Key changes include: - **Comprehensive Documentation**: Added detailed docstrings and inline comments to `PromptProcessor`, `RequestExecutor`, `RequestStrategy`, and `LLMRequest`. This clarifies the purpose and logic of each component, including prompt manipulation, request execution with retries, fallback strategies, and concurrency. - **Unified Logging**: Standardized all loggers within the module to use a single, consistent name (`model_utils`), simplifying log filtering and analysis. - **Improved Result Handling**: Refined the result processing in `LLMRequest` to correctly extract and record usage data returned from the `RequestStrategy`, fixing a previously incomplete implementation. --- src/llm_models/llm_utils.py | 2 +- src/llm_models/model_selector.py | 2 +- src/llm_models/prompt_processor.py | 76 +++++++++++++++--- src/llm_models/request_executor.py | 94 ++++++++++++++++++---- src/llm_models/request_strategy.py | 107 +++++++++++++++++++----- src/llm_models/utils_model.py | 125 ++++++++++++++++++++++++----- 6 files changed, 336 insertions(+), 70 deletions(-) diff --git a/src/llm_models/llm_utils.py b/src/llm_models/llm_utils.py index fb7810a0c..8fba27e88 100644 --- a/src/llm_models/llm_utils.py +++ b/src/llm_models/llm_utils.py @@ -11,7 +11,7 @@ from typing import List, Dict, Any, Tuple from src.common.logger import get_logger from .payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType -logger = get_logger("llm_utils") +logger = get_logger("model_utils") def normalize_image_format(image_format: str) -> str: """ diff --git a/src/llm_models/model_selector.py b/src/llm_models/model_selector.py index 827e28842..61ec06938 100644 --- a/src/llm_models/model_selector.py +++ b/src/llm_models/model_selector.py @@ -13,7 +13,7 @@ from src.config.config import model_config from src.config.api_ada_configs import ModelInfo, APIProvider, TaskConfig from .model_client.base_client import BaseClient, client_registry -logger = get_logger("model_selector") +logger = get_logger("model_utils") class ModelSelector: diff --git a/src/llm_models/prompt_processor.py b/src/llm_models/prompt_processor.py index 035a8f55f..94a0a2ef5 100644 --- a/src/llm_models/prompt_processor.py +++ b/src/llm_models/prompt_processor.py @@ -18,16 +18,28 @@ logger = get_logger("prompt_processor") class PromptProcessor: - """提示词处理器""" + """ + 提示词处理器。 + 负责对发送给模型的原始prompt进行预处理,以增强模型性能或实现特定功能。 + 主要功能包括: + 1. **反截断**:在prompt末尾添加一个特殊的结束标记指令,帮助判断模型输出是否被截断。 + 2. **内容混淆**:向prompt中注入随机的“噪音”字符串,并附带指令让模型忽略它们, + 可能用于绕过某些平台的审查或内容策略。 + 3. **思维链提取**:从模型的响应中分离出思考过程(被标签包裹)和最终答案。 + """ def __init__(self): + """初始化Prompt处理器,定义所需的指令文本。""" + # 指导模型忽略噪音字符串的指令 self.noise_instruction = """**【核心过滤规则】** 在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 你的任务是【完全并彻底地忽略】这些随机字符串。 **【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。 """ + # 定义一个独特的结束标记,用于反截断检查 self.end_marker = "###MAI_RESPONSE_END###" + # 指导模型在回复末尾添加结束标记的指令 self.anti_truncation_instruction = f""" **【输出完成信令】** 这是一个非常重要的指令,请务-务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。 @@ -42,17 +54,26 @@ class PromptProcessor: self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str ) -> str: """ - 根据模型和API提供商的配置处理提示词 + 根据模型和API提供商的配置,对输入的prompt进行预处理。 + + Args: + prompt (str): 原始的用户输入prompt。 + model_info (ModelInfo): 当前使用的模型信息。 + api_provider (APIProvider): 当前API提供商的配置。 + task_name (str): 当前任务的名称,用于日志记录。 + + Returns: + str: 经过处理后的、最终将发送给模型的prompt。 """ processed_prompt = prompt - # 1. 添加反截断指令 + # 步骤 1: 根据模型配置添加反截断指令 use_anti_truncation = getattr(model_info, "use_anti_truncation", False) if use_anti_truncation: processed_prompt += self.anti_truncation_instruction logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") - # 2. 应用内容混淆 + # 步骤 2: 根据API提供商配置应用内容混淆 if getattr(api_provider, "enable_content_obfuscation", False): intensity = getattr(api_provider, "obfuscation_intensity", 1) logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") @@ -61,12 +82,15 @@ class PromptProcessor: return processed_prompt def _apply_content_obfuscation(self, text: str, intensity: int) -> str: - """对文本进行混淆处理""" - # 在开头加入过滤规则指令 + """ + 对文本应用内容混淆处理。 + 首先添加过滤规则指令,然后注入随机噪音。 + """ + # 在文本开头加入指导模型忽略噪音的指令 processed_text = self.noise_instruction + "\n\n" + text logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}") - # 添加随机乱码 + # 在文本中注入随机乱码 final_text = self._inject_random_noise(processed_text, intensity) logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}") @@ -74,20 +98,31 @@ class PromptProcessor: @staticmethod def _inject_random_noise(text: str, intensity: int) -> str: - """在文本中注入随机乱码""" + """ + 根据指定的强度,在文本的词语之间随机注入噪音字符串。 + + Args: + text (str): 待注入噪音的文本。 + intensity (int): 混淆强度 (1, 2, or 3),决定噪音的注入概率和长度。 + + Returns: + str: 注入噪音后的文本。 + """ def generate_noise(length: int) -> str: + """生成指定长度的随机噪音字符串。""" chars = ( string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?" + "一二三四五六七八九零壹贰叁" + "αβγδεζηθικλμνξοπρστυφχψω" + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" ) return "".join(random.choice(chars) for _ in range(length)) + # 根据强度级别定义注入参数 params = { - 1: {"probability": 15, "length": (3, 6)}, - 2: {"probability": 25, "length": (5, 10)}, - 3: {"probability": 35, "length": (8, 15)}, + 1: {"probability": 15, "length": (3, 6)}, # 低强度 + 2: {"probability": 25, "length": (5, 10)}, # 中强度 + 3: {"probability": 35, "length": (8, 15)}, # 高强度 } - config = params.get(intensity, params[1]) + config = params.get(intensity, params[1]) # 默认为低强度 logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}") words = text.split() @@ -95,6 +130,7 @@ class PromptProcessor: noise_count = 0 for word in words: result.append(word) + # 按概率决定是否注入噪音 if random.randint(1, 100) <= config["probability"]: noise_length = random.randint(*config["length"]) noise = generate_noise(noise_length) @@ -106,8 +142,22 @@ class PromptProcessor: @staticmethod def extract_reasoning(content: str) -> Tuple[str, str]: - """CoT思维链提取,向后兼容""" + """ + 从模型返回的完整内容中提取被...标签包裹的思考过程, + 并返回清理后的内容和思考过程。 + + Args: + content (str): 模型返回的原始字符串。 + + Returns: + Tuple[str, str]: + - 清理后的内容(移除了标签及其内容)。 + - 提取出的思考过程文本(如果没有则为空字符串)。 + """ + # 使用正则表达式查找标签 match = re.search(r"(?:)?(.*?)", content, re.DOTALL) + # 从内容中移除标签及其包裹的所有内容 clean_content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() + # 如果找到匹配项,则提取思考过程 reasoning = match.group(1).strip() if match else "" return clean_content, reasoning diff --git a/src/llm_models/request_executor.py b/src/llm_models/request_executor.py index 33b3197b0..d33fd88cd 100644 --- a/src/llm_models/request_executor.py +++ b/src/llm_models/request_executor.py @@ -24,11 +24,15 @@ from .payload_content.resp_format import RespFormat from .payload_content.tool_option import ToolOption from .utils import compress_messages -logger = get_logger("request_executor") +logger = get_logger("model_utils") class RequestExecutor: - """请求执行器""" + """ + 请求执行器。 + 负责直接与模型客户端交互,执行API请求。 + 它包含了核心的请求重试、异常分类处理、模型惩罚机制和消息压缩等底层逻辑。 + """ def __init__( self, @@ -39,6 +43,17 @@ class RequestExecutor: model_info: ModelInfo, model_selector: ModelSelector, ): + """ + 初始化请求执行器。 + + Args: + task_name (str): 当前任务的名称。 + model_set (TaskConfig): 任务相关的模型配置。 + api_provider (APIProvider): API提供商配置。 + client (BaseClient): 用于发送请求的客户端实例。 + model_info (ModelInfo): 当前请求要使用的模型信息。 + model_selector (ModelSelector): 模型选择器实例,用于更新模型状态(如惩罚值)。 + """ self.task_name = task_name self.model_set = model_set self.api_provider = api_provider @@ -60,12 +75,34 @@ class RequestExecutor: audio_base64: str = "", ) -> APIResponse: """ - 实际执行请求的方法, 包含了重试和异常处理逻辑 + 实际执行API请求,并包含完整的重试和异常处理逻辑。 + + Args: + request_type (str): 请求类型 ('response', 'embedding', 'audio')。 + message_list (List[Message] | None, optional): 消息列表。 Defaults to None. + tool_options (list[ToolOption] | None, optional): 工具选项。 Defaults to None. + response_format (RespFormat | None, optional): 响应格式要求。 Defaults to None. + stream_response_handler (Optional[Callable], optional): 流式响应处理器。 Defaults to None. + async_response_parser (Optional[Callable], optional): 异步响应解析器。 Defaults to None. + temperature (Optional[float], optional): 温度参数。 Defaults to None. + max_tokens (Optional[int], optional): 最大token数。 Defaults to None. + embedding_input (str, optional): embedding输入文本。 Defaults to "". + audio_base64 (str, optional): 音频base64数据。 Defaults to "". + + Returns: + APIResponse: 从模型客户端返回的API响应对象。 + + Raises: + ValueError: 如果请求类型未知。 + RuntimeError: 如果所有重试都失败。 """ retry_remain = self.api_provider.max_retry compressed_messages: Optional[List[Message]] = None + + # 循环进行重试 while retry_remain > 0: try: + # 根据请求类型调用不同的客户端方法 if request_type == "response": assert message_list is not None, "message_list cannot be None for response requests" return await self.client.get_response( @@ -96,8 +133,10 @@ class RequestExecutor: raise ValueError(f"未知的请求类型: {request_type}") except Exception as e: logger.debug(f"请求失败: {str(e)}") + # 对失败的模型应用惩罚 self._apply_penalty_on_failure(e) + # 使用默认异常处理器来决定下一步操作(等待、重试、压缩或终止) wait_interval, compressed_messages = self._default_exception_handler( e, remain_try=retry_remain, @@ -106,29 +145,35 @@ class RequestExecutor: ) if wait_interval == -1: - retry_remain = 0 # 不再重试 + retry_remain = 0 # 处理器决定不再重试 elif wait_interval > 0: logger.info(f"等待 {wait_interval} 秒后重试...") await asyncio.sleep(wait_interval) finally: retry_remain -= 1 - self.model_selector.decrease_usage_penalty(self.model_info.name) + # 所有重试次数用尽后 + self.model_selector.decrease_usage_penalty(self.model_info.name) # 减少因使用而增加的基础惩罚 logger.error(f"模型 '{self.model_info.name}' 请求失败,达到最大重试次数 {self.api_provider.max_retry} 次") raise RuntimeError("请求失败,已达到最大重试次数") def _apply_penalty_on_failure(self, e: Exception): - """根据异常类型,动态调整模型的惩罚值""" + """ + 根据异常类型,动态调整失败模型的惩罚值。 + 关键错误(如网络问题、服务器5xx错误)会施加更重的惩罚。 + """ CRITICAL_PENALTY_MULTIPLIER = 5 default_penalty_increment = 1 penalty_increment = default_penalty_increment + # 对严重错误施加更高的惩罚 if isinstance(e, (NetworkConnectionError, ReqAbortException)): penalty_increment = CRITICAL_PENALTY_MULTIPLIER elif isinstance(e, RespNotOkException): - if e.status_code >= 500: + if e.status_code >= 500: # 服务器内部错误 penalty_increment = CRITICAL_PENALTY_MULTIPLIER + # 记录日志 log_message = f"发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}" if isinstance(e, (NetworkConnectionError, ReqAbortException)): log_message = f"发生关键错误 ({type(e).__name__}),增加惩罚值: {penalty_increment}" @@ -136,6 +181,7 @@ class RequestExecutor: log_message = f"发生响应错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}" logger.warning(f"模型 '{self.model_info.name}' {log_message}") + # 更新模型的惩罚值 self.model_selector.update_model_penalty(self.model_info.name, penalty_increment) def _default_exception_handler( @@ -145,7 +191,15 @@ class RequestExecutor: retry_interval: int = 10, messages: Tuple[List[Message], bool] | None = None, ) -> Tuple[int, List[Message] | None]: - """默认异常处理函数""" + """ + 默认的异常分类处理器。 + 根据异常类型决定是否重试、等待多久以及是否需要压缩消息。 + + Returns: + Tuple[int, List[Message] | None]: + - 等待时间(秒)。-1表示不重试。 + - 压缩后的消息列表(如果有)。 + """ model_name = self.model_info.name if isinstance(e, NetworkConnectionError): @@ -157,16 +211,16 @@ class RequestExecutor: ) elif isinstance(e, ReqAbortException): logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") - return -1, None + return -1, None # 请求被中断,不重试 elif isinstance(e, RespNotOkException): return self._handle_resp_not_ok(e, remain_try, retry_interval, messages) elif isinstance(e, RespParseException): logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") logger.debug(f"附加内容: {str(e.ext_info)}") - return -1, None + return -1, None # 解析错误通常不可重试 else: logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") - return -1, None + return -1, None # 未知异常,不重试 def _handle_resp_not_ok( self, @@ -174,28 +228,33 @@ class RequestExecutor: remain_try: int, retry_interval: int = 10, messages: tuple[list[Message], bool] | None = None, - ): - """处理响应错误异常""" + ) -> Tuple[int, Optional[List[Message]]]: + """处理HTTP状态码非200的异常。""" model_name = self.model_info.name + # 客户端错误 (4xx),通常不可重试 if e.status_code in [400, 401, 402, 403, 404]: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}") return -1, None + # 请求体过大 (413) elif e.status_code == 413: - if messages and not messages[1]: + # 如果消息存在且尚未被压缩,尝试压缩后重试一次 + if messages and not messages[1]: # messages[1] is a flag indicating if it's already compressed return self._check_retry( - remain_try, 0, + remain_try, 0, # 立即重试 can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,压缩后仍失败", can_retry_callable=compress_messages, messages=messages[0], ) logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,无法压缩,放弃请求。") return -1, None + # 请求过于频繁 (429) elif e.status_code == 429: return self._check_retry( remain_try, retry_interval, can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数", ) + # 服务器错误 (5xx),可以重试 elif e.status_code >= 500: return self._check_retry( remain_try, retry_interval, @@ -215,9 +274,12 @@ class RequestExecutor: can_retry_callable: Callable | None = None, **kwargs, ) -> Tuple[int, List[Message] | None]: - """辅助函数:检查是否可以重试""" + """ + 辅助函数:检查是否可以重试,并执行可选的回调函数(如消息压缩)。 + """ if remain_try > 0: logger.warning(f"{can_retry_msg}") + # 如果有可执行的回调(例如压缩函数),执行它并返回结果 if can_retry_callable is not None: return retry_interval, can_retry_callable(**kwargs) return retry_interval, None diff --git a/src/llm_models/request_strategy.py b/src/llm_models/request_strategy.py index 3a694f526..b4c670355 100644 --- a/src/llm_models/request_strategy.py +++ b/src/llm_models/request_strategy.py @@ -19,13 +19,24 @@ from .payload_content.tool_option import ToolCall from .prompt_processor import PromptProcessor from .request_executor import RequestExecutor -logger = get_logger("request_strategy") +logger = get_logger("model_utils") class RequestStrategy: - """高级请求策略""" + """ + 高级请求策略模块。 + 负责实现复杂的请求逻辑,如模型的故障转移(fallback)和并发请求。 + """ def __init__(self, model_set: TaskConfig, model_selector: ModelSelector, task_name: str): + """ + 初始化请求策略。 + + Args: + model_set (TaskConfig): 特定任务的模型配置。 + model_selector (ModelSelector): 模型选择器实例。 + task_name (str): 当前任务的名称。 + """ self.model_set = model_set self.model_selector = model_selector self.task_name = task_name @@ -37,43 +48,56 @@ class RequestStrategy: ) -> Dict[str, Any]: """ 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 + + 该方法会按顺序尝试任务配置中的所有可用模型,直到一个模型成功返回响应。 + 如果所有模型都失败,将根据 `raise_when_empty` 参数决定是抛出异常还是返回一个失败结果。 + + Args: + base_payload (Dict[str, Any]): 基础请求载荷,包含prompt、工具选项等。 + raise_when_empty (bool, optional): 如果所有模型都失败或返回空内容,是否抛出异常。 Defaults to True. + + Returns: + Dict[str, Any]: 一个包含响应结果的字典,包括内容、模型信息、用量和成功状态。 """ + # 记录在本次请求中已经失败的模型,避免重复尝试 failed_models_in_this_request = set() max_attempts = len(self.model_set.model_list) last_exception: Optional[Exception] = None for attempt in range(max_attempts): + # 选择一个当前最佳且未失败的模型 model_selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request) if model_selection_result is None: logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") - break + break # 没有更多可用模型,跳出循环 model_info, api_provider, client = model_selection_result model_name = model_info.name logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...") try: - # 1. Process Prompt + # 步骤 1: 预处理Prompt prompt_processor: PromptProcessor = base_payload["prompt_processor"] raw_prompt = base_payload["prompt"] processed_prompt = prompt_processor.process_prompt( raw_prompt, model_info, api_provider, self.task_name ) - # 2. Build Message + # 步骤 2: 构建消息体 message_builder = MessageBuilder().add_text_content(processed_prompt) messages = [message_builder.build()] - # 3. Create payload for executor + # 步骤 3: 为执行器创建载荷 executor_payload = { - "request_type": "response", # Strategy only handles response type + "request_type": "response", # 策略模式目前只处理'response'类型请求 "message_list": messages, "tool_options": base_payload["tool_options"], "temperature": base_payload["temperature"], "max_tokens": base_payload["max_tokens"], } + # 创建请求执行器实例 executor = RequestExecutor( task_name=self.task_name, model_set=self.model_set, @@ -82,21 +106,24 @@ class RequestStrategy: model_info=model_info, model_selector=self.model_selector, ) + # 执行请求,并处理内部的空回复/截断重试 response = await self._execute_and_handle_empty_retry(executor, executor_payload, prompt_processor) - # 4. Post-process response - # The reasoning content is now extracted here, after a successful, de-truncated response is received. + # 步骤 4: 后处理响应 + # 在获取到成功的、完整的响应后,提取思考过程内容 final_content, reasoning_content = prompt_processor.extract_reasoning(response.content or "") - response.content = final_content # Update response with cleaned content + response.content = final_content # 使用清理后的内容更新响应对象 tool_calls = response.tool_calls + # 检查最终内容是否为空 if not final_content and not tool_calls: if raise_when_empty: raise RuntimeError("所选模型生成了空回复。") - content = "生成的响应为空" # Fallback message + logger.warning(f"模型 '{model_name}' 生成了空回复,返回默认信息。") logger.debug(f"模型 '{model_name}' 成功生成了回复。") + # 返回成功结果,包含用量和模型信息,供上层记录 return { "content": response.content, "reasoning_content": reasoning_content, @@ -108,15 +135,19 @@ class RequestStrategy: } except Exception as e: + # 捕获请求过程中的任何异常 logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") failed_models_in_this_request.add(model_info.name) last_exception = e + # 如果循环结束仍未成功 logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") if raise_when_empty: if last_exception: raise RuntimeError("所有模型均未能生成响应。") from last_exception raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") + + # 返回失败结果 return { "content": "所有模型都请求失败", "reasoning_content": "", @@ -135,23 +166,43 @@ class RequestStrategy: **kwargs, ) -> Any: """ - 执行并发请求并从成功的结果中随机选择一个。 + 以指定的并发数执行多个协程,并从所有成功的结果中随机选择一个返回。 + + Args: + coro_callable (Callable): 要并发执行的协程函数。 + concurrency_count (int): 并发数量。 + *args: 传递给协程函数的位置参数。 + **kwargs: 传递给协程函数的关键字参数。 + + Returns: + Any: 从成功的结果中随机选择的一个。 + + Raises: + RuntimeError: 如果所有并发任务都失败了。 """ logger.info(f"启用并发请求模式,并发数: {concurrency_count}") + # 创建并发任务列表 tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)] + # 等待所有任务完成 results = await asyncio.gather(*tasks, return_exceptions=True) - successful_results = [res for res in results if not isinstance(res, Exception)] + # 筛选出成功的结果 + successful_results = [ + res for res in results if isinstance(res, dict) and res.get("success") + ] if successful_results: + # 从成功结果中随机选择一个 selected = random.choice(successful_results) logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个") return selected + # 如果没有成功的结果,记录所有异常 for i, res in enumerate(results): if isinstance(res, Exception): logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") + # 抛出第一个遇到的异常 first_exception = next((res for res in results if isinstance(res, Exception)), None) if first_exception: raise first_exception @@ -162,11 +213,23 @@ class RequestStrategy: self, executor: RequestExecutor, payload: Dict[str, Any], prompt_processor: PromptProcessor ) -> APIResponse: """ - 在单个模型内部处理空回复/截断的重试逻辑 + 在单个模型内部处理因回复为空或被截断而触发的重试逻辑。 + + Args: + executor (RequestExecutor): 请求执行器实例。 + payload (Dict[str, Any]): 传递给 `execute_request` 的载荷。 + prompt_processor (PromptProcessor): 提示词处理器,用于获取反截断标记。 + + Returns: + APIResponse: 一个有效的、非空且完整的API响应。 + + Raises: + RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的回复。 """ empty_retry_count = 0 max_empty_retry = executor.api_provider.max_retry empty_retry_interval = executor.api_provider.retry_interval + # 检查模型是否启用了反截断功能 use_anti_truncation = getattr(executor.model_info, "use_anti_truncation", False) end_marker = prompt_processor.end_marker @@ -176,15 +239,20 @@ class RequestStrategy: content = response.content or "" tool_calls = response.tool_calls + # 判断是否为空回复 is_empty_reply = not tool_calls and (not content or content.strip() == "") is_truncated = False + + # 如果启用了反截断,检查回复是否被截断 if use_anti_truncation and end_marker: if content.endswith(end_marker): - # 移除结束标记 + # 如果包含结束标记,说明回复完整,移除标记 response.content = content[: -len(end_marker)].strip() else: + # 否则,认为回复被截断 is_truncated = True + # 如果是空回复或截断,则进行重试 if is_empty_reply or is_truncated: empty_retry_count += 1 if empty_retry_count <= max_empty_retry: @@ -194,13 +262,14 @@ class RequestStrategy: ) if empty_retry_interval > 0: await asyncio.sleep(empty_retry_interval) - continue + continue # 继续下一次循环重试 else: + # 达到最大重试次数,抛出异常 reason = "空回复" if is_empty_reply else "截断" raise RuntimeError(f"模型 '{executor.model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。") - # 成功获取响应 + # 成功获取到有效响应,返回结果 return response - # 此处理论上不会到达,因为循环要么返回要么抛异常 - raise RuntimeError("空回复/截断重Test逻辑出现未知错误") + # 此处理论上不会到达,因为循环要么返回要么抛出异常 + raise RuntimeError("空回复/截断重试逻辑出现未知错误") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 1414aacb1..b223f0d82 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -10,7 +10,7 @@ import time from typing import Tuple, List, Dict, Optional, Any from src.common.logger import get_logger -from src.config.api_ada_configs import TaskConfig, ModelInfo +from src.config.api_ada_configs import TaskConfig, ModelInfo, UsageRecord from .llm_utils import build_tool_options, normalize_image_format from .model_selector import ModelSelector from .payload_content.message import MessageBuilder @@ -22,9 +22,20 @@ from .utils import llm_usage_recorder logger = get_logger("model_utils") class LLMRequest: - """LLM请求协调器""" + """ + LLM请求协调器。 + 封装了模型选择、Prompt处理、请求执行和高级策略(如故障转移、并发)的完整流程。 + 为上层业务逻辑提供统一的、简化的接口来与大语言模型交互。 + """ def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: + """ + 初始化LLM请求协调器。 + + Args: + model_set (TaskConfig): 特定任务的模型配置集合。 + request_type (str, optional): 请求类型或任务名称,用于日志和用量记录。 Defaults to "". + """ self.task_name = request_type self.model_for_task = model_set self.request_type = request_type @@ -40,16 +51,33 @@ class LLMRequest: temperature: Optional[float] = None, max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """为图像生成响应""" + """ + 为包含图像的多模态输入生成文本响应。 + + Args: + prompt (str): 文本提示。 + image_base64 (str): Base64编码的图像数据。 + image_format (str): 图像格式 (例如, "png", "jpeg")。 + temperature (Optional[float], optional): 控制生成文本的随机性。 Defaults to None. + max_tokens (Optional[int], optional): 生成响应的最大长度。 Defaults to None. + + Returns: + Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + - 清理后的响应内容。 + - 一个元组,包含思考过程、模型名称和工具调用列表。 + """ start_time = time.time() - # 1. 选择模型 + # 步骤 1: 选择一个支持图像处理的模型 model_info, api_provider, client = self.model_selector.select_model() - # 2. 准备消息体 + # 步骤 2: 准备消息体 + # 预处理文本提示 processed_prompt = self.prompt_processor.process_prompt(prompt, model_info, api_provider, self.task_name) + # 规范化图像格式 normalized_format = normalize_image_format(image_format) + # 使用MessageBuilder构建多模态消息 message_builder = MessageBuilder() message_builder.add_text_content(processed_prompt) message_builder.add_image_content( @@ -59,7 +87,7 @@ class LLMRequest: ) messages = [message_builder.build()] - # 3. 执行请求 (图像请求通常不走复杂的故障转移策略,直接执行) + # 步骤 3: 执行请求 (图像请求通常不走复杂的故障转移策略,直接执行) from .request_executor import RequestExecutor executor = RequestExecutor( task_name=self.task_name, @@ -76,20 +104,31 @@ class LLMRequest: max_tokens=max_tokens, ) - # 4. 处理响应 + # 步骤 4: 处理响应 content, reasoning_content = self.prompt_processor.extract_reasoning(response.content or "") tool_calls = response.tool_calls + # 记录用量 if usage := response.usage: await self._record_usage(model_info, usage, time.time() - start_time) return content, (reasoning_content, model_info.name, tool_calls) async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: - """为语音生成响应""" + """ + 将语音数据转换为文本(语音识别)。 + + Args: + voice_base64 (str): Base64编码的语音数据。 + + Returns: + Optional[str]: 识别出的文本内容,如果失败则返回None。 + """ + # 选择一个支持语音识别的模型 model_info, api_provider, client = self.model_selector.select_model() from .request_executor import RequestExecutor + # 创建请求执行器 executor = RequestExecutor( task_name=self.task_name, model_set=self.model_for_task, @@ -98,6 +137,7 @@ class LLMRequest: model_info=model_info, model_selector=self.model_selector, ) + # 执行语音转文本请求 response = await executor.execute_request( request_type="audio", audio_base64=voice_base64, @@ -112,9 +152,24 @@ class LLMRequest: tools: Optional[List[Dict[str, Any]]] = None, raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - """异步生成响应,支持并发和故障转移""" + """ + 异步生成文本响应,支持并发和故障转移等高级策略。 + + Args: + prompt (str): 用户输入的提示。 + temperature (Optional[float], optional): 控制生成文本的随机性。 Defaults to None. + max_tokens (Optional[int], optional): 生成响应的最大长度。 Defaults to None. + tools (Optional[List[Dict[str, Any]]], optional): 可供模型调用的工具列表。 Defaults to None. + raise_when_empty (bool, optional): 如果最终响应为空,是否抛出异常。 Defaults to True. + + Returns: + Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + - 清理后的响应内容。 + - 一个元组,包含思考过程、最终使用的模型名称和工具调用列表。 + """ + start_time = time.time() - # 1. 准备基础请求载荷 + # 步骤 1: 准备基础请求载荷 tool_built = build_tool_options(tools) base_payload = { "prompt": prompt, @@ -124,7 +179,7 @@ class LLMRequest: "prompt_processor": self.prompt_processor, } - # 2. 根据配置选择执行策略 + # 步骤 2: 根据配置选择执行策略 (并发或单次带故障转移) concurrency_count = getattr(self.model_for_task, "concurrency_count", 1) if concurrency_count <= 1: @@ -138,23 +193,43 @@ class LLMRequest: self.request_strategy.execute_with_fallback, concurrency_count, base_payload, - raise_when_empty=False, + raise_when_empty=False, # 在并发模式下,单个任务失败不应立即抛出异常 ) - # 3. 处理最终结果 - content, (reasoning_content, model_name, tool_calls) = result + # 步骤 3: 处理最终结果 + content = result.get("content", "") + reasoning_content = result.get("reasoning_content", "") + model_name = result.get("model_name", "unknown") + tool_calls = result.get("tool_calls") - # 4. 记录用量 (需要从策略中获取最终使用的模型信息和用量) - # TODO: 改造策略以返回最终模型信息和用量, 此处暂时省略 + # 步骤 4: 记录用量 (从策略返回的结果中获取最终使用的模型信息和用量) + final_model_info = result.get("model_info") + usage = result.get("usage") + + if final_model_info and usage: + await self._record_usage(final_model_info, usage, time.time() - start_time) return content, (reasoning_content, model_name, tool_calls) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """获取嵌入向量""" + """ + 获取给定文本的嵌入向量 (Embedding)。 + + Args: + embedding_input (str): 需要进行嵌入的文本。 + + Returns: + Tuple[List[float], str]: 嵌入向量列表和所使用的模型名称。 + + Raises: + RuntimeError: 如果获取embedding失败。 + """ start_time = time.time() + # 选择一个支持embedding的模型 model_info, api_provider, client = self.model_selector.select_model() from .request_executor import RequestExecutor + # 创建请求执行器 executor = RequestExecutor( task_name=self.task_name, model_set=self.model_for_task, @@ -163,6 +238,7 @@ class LLMRequest: model_info=model_info, model_selector=self.model_selector, ) + # 执行embedding请求 response = await executor.execute_request( request_type="embedding", embedding_input=embedding_input, @@ -172,17 +248,26 @@ class LLMRequest: if not embedding: raise RuntimeError("获取embedding失败") + # 记录用量 if usage := response.usage: await self._record_usage(model_info, usage, time.time() - start_time, "/embeddings") return embedding, model_info.name - async def _record_usage(self, model_info: ModelInfo, usage, time_cost, endpoint="/chat/completions"): - """记录模型用量""" + async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord, time_cost: float, endpoint: str = "/chat/completions"): + """ + 记录模型API的调用用量到数据库。 + + Args: + model_info (ModelInfo): 使用的模型信息。 + usage (UsageRecord): 包含token用量信息的对象。 + time_cost (float): 本次请求的总耗时(秒)。 + endpoint (str, optional): 请求的API端点。 Defaults to "/chat/completions". + """ await llm_usage_recorder.record_usage_to_database( model_info=model_info, model_usage=usage, - user_id="system", + user_id="system", # 当前所有请求都以系统用户身份记录 time_cost=time_cost, request_type=self.request_type, endpoint=endpoint, From 375a51e01f7babe6e882b0127f20085bfd5e65d7 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 19:50:06 +0800 Subject: [PATCH 19/41] =?UTF-8?q?fix(llm):=20=E4=BF=AE=E5=A4=8D=20?= =?UTF-8?q?=20=E6=A0=87=E7=AD=BE=E8=A7=A3=E6=9E=90=E5=90=8E=E5=8F=AF?= =?UTF-8?q?=E8=83=BD=E6=AE=8B=E7=95=99=E7=A9=BA=E7=99=BD=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 之前的 标签解析逻辑在移除标签内容后,没有处理紧随其后的空白字符,这可能导致清理后的内容开头有多余的空格或换行符。 本次更新使用更精确的正则表达式 `(.*?)\s*`,可以在一次操作中同时移除 标签块和其后的所有空白字符,确保返回的内容格式正确,提高了处理的鲁棒性。 --- src/llm_models/prompt_processor.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/llm_models/prompt_processor.py b/src/llm_models/prompt_processor.py index 94a0a2ef5..0ae944369 100644 --- a/src/llm_models/prompt_processor.py +++ b/src/llm_models/prompt_processor.py @@ -154,10 +154,17 @@ class PromptProcessor: - 清理后的内容(移除了标签及其内容)。 - 提取出的思考过程文本(如果没有则为空字符串)。 """ - # 使用正则表达式查找标签 - match = re.search(r"(?:)?(.*?)", content, re.DOTALL) - # 从内容中移除标签及其包裹的所有内容 - clean_content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() - # 如果找到匹配项,则提取思考过程 - reasoning = match.group(1).strip() if match else "" + # 使用正则表达式精确查找 ... 标签及其内容 + think_pattern = re.compile(r"(.*?)\s*", re.DOTALL) + match = think_pattern.search(content) + + if match: + # 提取思考过程 + reasoning = match.group(1).strip() + # 从原始内容中移除匹配到的整个部分(包括标签和后面的空白) + clean_content = think_pattern.sub("", content, count=1).strip() + else: + reasoning = "" + clean_content = content.strip() + return clean_content, reasoning From 89b79792c054db2b254eadf628b44575ea601b0b Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 19:56:46 +0800 Subject: [PATCH 20/41] =?UTF-8?q?refactor(chat):=20=E5=B0=86=20get=5Fchat?= =?UTF-8?q?=5Ftype=5Fand=5Ftarget=5Finfo=20=E9=87=8D=E6=9E=84=E4=B8=BA?= =?UTF-8?q?=E5=BC=82=E6=AD=A5=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 `get_chat_type_and_target_info` 函数从同步改为异步,以支持其内部对异步方法 `person_info_manager.get_values` 的调用。 此更改可防止在获取聊天对象信息时阻塞事件循环。所有调用此函数的代码(包括 `SubHeartflow`, `ActionModifier`, `PlanGenerator`, `DefaultReplyer`)都已相应更新为使用 `await`。 在 `DefaultReplyer` 中引入了延迟异步初始化模式 (`_async_init`),以适应其类生命周期。 --- src/chat/heart_flow/sub_heartflow.py | 3 ++- src/chat/planner_actions/action_modifier.py | 2 +- src/chat/planner_actions/plan_generator.py | 2 +- src/chat/replyer/default_generator.py | 16 +++++++++++++--- src/chat/utils/utils.py | 5 +++-- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index 275a25a57..136b1cb41 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -24,7 +24,7 @@ class SubHeartflow: self.subheartflow_id = subheartflow_id self.chat_id = subheartflow_id - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) + self.is_group_chat, self.chat_target_info = (None, None) self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id # focus模式退出冷却时间管理 @@ -38,4 +38,5 @@ class SubHeartflow: async def initialize(self): """异步初始化方法,创建兴趣流并确定聊天类型""" + self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_id) await self.heart_fc_instance.start() diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 154fe62a7..bcd01934d 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -72,7 +72,7 @@ class ActionModifier: from src.chat.utils.utils import get_chat_type_and_target_info # 获取聊天类型 - is_group_chat, _ = get_chat_type_and_target_info(self.chat_id) + is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id) all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION) chat_type_removals = [] diff --git a/src/chat/planner_actions/plan_generator.py b/src/chat/planner_actions/plan_generator.py index 5d1ab9c38..ec0a11691 100644 --- a/src/chat/planner_actions/plan_generator.py +++ b/src/chat/planner_actions/plan_generator.py @@ -51,7 +51,7 @@ class PlanGenerator: Returns: Plan: 一个填充了初始上下文信息的 Plan 对象。 """ - _is_group_chat, chat_target_info_dict = get_chat_type_and_target_info(self.chat_id) + _is_group_chat, chat_target_info_dict = await get_chat_type_and_target_info(self.chat_id) target_info = None if chat_target_info_dict: diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index c2cd9fc08..76221ac1c 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -202,7 +202,9 @@ class DefaultReplyer: ): self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) + self.is_group_chat: Optional[bool] = None + self.chat_target_info: Optional[Dict[str, Any]] = None + self._initialized = False self.heart_fc_sender = HeartFCSender() self.memory_activator = MemoryActivator() @@ -775,6 +777,12 @@ class DefaultReplyer: mai_think.target = target return mai_think + async def _async_init(self): + if self._initialized: + return + self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_stream.stream_id) + self._initialized = True + async def build_prompt_reply_context( self, reply_to: str, @@ -800,10 +808,11 @@ class DefaultReplyer: """ if available_actions is None: available_actions = {} + await self._async_init() chat_stream = self.chat_stream chat_id = chat_stream.stream_id person_info_manager = get_person_info_manager() - is_group_chat = bool(chat_stream.group_info) + is_group_chat = self.is_group_chat if global_config.mood.enable_mood: chat_mood = mood_manager.get_mood_by_chat_id(chat_id) @@ -1128,9 +1137,10 @@ class DefaultReplyer: reply_to: str, reply_message: Optional[Dict[str, Any]] = None, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if + await self._async_init() chat_stream = self.chat_stream chat_id = chat_stream.stream_id - is_group_chat = bool(chat_stream.group_info) + is_group_chat = self.is_group_chat if reply_message: sender = reply_message.get("sender") diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 5eb4cc991..85e665328 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -619,7 +619,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" return time.strftime("%H:%M:%S", time.localtime(timestamp)) -def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: +async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: """ 获取聊天类型(是否群聊)和私聊对象信息。 @@ -663,7 +663,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: if person_id: # get_value is async, so await it directly person_info_manager = get_person_info_manager() - person_name = person_info_manager.get_value(person_id, "person_name") + person_data = await person_info_manager.get_values(person_id, ["person_name"]) + person_name = person_data.get("person_name") target_info["person_id"] = person_id target_info["person_name"] = person_name From 9c1a7ff123e58d3165c5c7b8ed2e1ac196fa36c8 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 20:16:46 +0800 Subject: [PATCH 21/41] ruff --- plugins/set_emoji_like/plugin.py | 1 - scripts/lpmm_learning_tool.py | 1 - src/chat/antipromptinjector/core/shield.py | 1 - src/chat/chat_loop/heartFC_chat.py | 1 - src/chat/express/expression_learner.py | 2 +- src/chat/message_receive/message.py | 2 +- src/chat/planner_actions/plan_filter.py | 1 - src/common/database/db_migration.py | 1 - src/config/api_ada_configs.py | 2 +- src/config/config.py | 2 +- src/llm_models/request_strategy.py | 3 +-- src/plugin_system/utils/permission_decorators.py | 1 - src/plugins/built_in/core_actions/emoji.py | 4 ++-- .../napcat_adapter_plugin/src/recv_handler/message_handler.py | 2 +- 14 files changed, 8 insertions(+), 16 deletions(-) diff --git a/plugins/set_emoji_like/plugin.py b/plugins/set_emoji_like/plugin.py index 966d4aabc..810f0639e 100644 --- a/plugins/set_emoji_like/plugin.py +++ b/plugins/set_emoji_like/plugin.py @@ -10,7 +10,6 @@ from src.plugin_system import ( ConfigField, ) from src.common.logger import get_logger -from src.plugin_system.apis import send_api from .qq_emoji_list import qq_face from src.plugin_system.base.component_types import ChatType diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 5a61eeebc..941494bc0 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -1,7 +1,6 @@ import asyncio import os import sys -import glob import orjson import datetime from pathlib import Path diff --git a/src/chat/antipromptinjector/core/shield.py b/src/chat/antipromptinjector/core/shield.py index c4ab8afa8..c7a2e78bc 100644 --- a/src/chat/antipromptinjector/core/shield.py +++ b/src/chat/antipromptinjector/core/shield.py @@ -233,6 +233,5 @@ class MessageShield: def create_default_shield() -> MessageShield: """创建默认的消息加盾器""" - from .config import default_config return MessageShield() diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index bf282da5e..fca7df847 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -10,7 +10,6 @@ from src.config.config import global_config from src.person_info.relationship_builder_manager import relationship_builder_manager from src.chat.express.expression_learner import expression_learner_manager from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager, SleepState -from src.plugin_system.apis import message_api from .hfc_context import HfcContext from .energy_manager import EnergyManager diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index fb22a4115..bb663a1ad 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -4,7 +4,7 @@ import orjson import os from datetime import datetime -from typing import List, Dict, Optional, Any, Tuple, Coroutine +from typing import List, Dict, Optional, Any, Tuple from src.common.logger import get_logger from src.common.database.sqlalchemy_database_api import get_db_session diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 22e57edf0..22c3e3776 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -2,7 +2,7 @@ import base64 import time from abc import abstractmethod, ABCMeta from dataclasses import dataclass -from typing import Optional, Any, TYPE_CHECKING +from typing import Optional, Any import urllib3 from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 4ef8de2d8..fccda0230 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List, Optional from json_repair import repair_json -from . import planner_prompts from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.utils.chat_message_builder import ( build_readable_actions, diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index 8f7b1ecd3..085c277a3 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -1,7 +1,6 @@ # mmc/src/common/database/db_migration.py from sqlalchemy import inspect -from sqlalchemy.schema import CreateIndex from sqlalchemy.sql import text from src.common.database.sqlalchemy_models import Base, get_engine diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 5e5e035dd..0b1984a3c 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,5 +1,5 @@ from typing import List, Dict, Any, Literal, Union -from pydantic import Field, field_validator +from pydantic import Field from threading import Lock from src.config.config_base import ValidatedConfigBase diff --git a/src/config/config.py b/src/config/config.py index f0a3ec2d8..ac6204689 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -8,7 +8,7 @@ from tomlkit import TOMLDocument from tomlkit.items import Table, KeyType from rich.traceback import install from typing import List, Optional -from pydantic import Field, field_validator +from pydantic import Field from src.common.logger import get_logger from src.config.config_base import ValidatedConfigBase diff --git a/src/llm_models/request_strategy.py b/src/llm_models/request_strategy.py index b4c670355..82575b076 100644 --- a/src/llm_models/request_strategy.py +++ b/src/llm_models/request_strategy.py @@ -8,14 +8,13 @@ """ import asyncio import random -from typing import List, Tuple, Optional, Dict, Any, Callable, Coroutine +from typing import Optional, Dict, Any, Callable, Coroutine from src.common.logger import get_logger from src.config.api_ada_configs import TaskConfig from .model_client.base_client import APIResponse from .model_selector import ModelSelector from .payload_content.message import MessageBuilder -from .payload_content.tool_option import ToolCall from .prompt_processor import PromptProcessor from .request_executor import RequestExecutor diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 45357b4b0..990f1c91c 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -7,7 +7,6 @@ from functools import wraps from typing import Callable, Optional from inspect import iscoroutinefunction -import inspect from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.send_api import text_to_stream diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 3ebf4610a..e8ffba68e 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -255,7 +255,7 @@ class EmojiAction(BaseAction): if not success: logger.error(f"{self.log_prefix} 表情包发送失败") - await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False) + await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包,但失败了",action_done= False) return False, "表情包发送失败" # 发送成功后,记录到历史 @@ -264,7 +264,7 @@ class EmojiAction(BaseAction): except Exception as e: logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") - await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True) + await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包",action_done= True) return True, f"发送表情包: {emoji_description}" diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index c50f17e7b..a19ca85e5 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -26,7 +26,7 @@ import json import websockets as Server import base64 from pathlib import Path -from typing import List, Tuple, Optional, Dict, Any, Coroutine +from typing import List, Tuple, Optional, Dict, Any import uuid from maim_message import ( From f12cade772f5c998459c1c9ca1ed7a4c089cbc7a Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 20:24:56 +0800 Subject: [PATCH 22/41] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=E6=9C=AA?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E7=9A=84=E5=AF=BC=E5=85=A5=E5=92=8C=E5=86=97?= =?UTF-8?q?=E4=BD=99=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/llm_utils.py | 65 -- .../model_client/aiohttp_gemini_client.py | 6 +- src/llm_models/model_selector.py | 130 --- src/llm_models/prompt_processor.py | 170 ---- src/llm_models/request_executor.py | 288 ------ src/llm_models/request_strategy.py | 274 ------ src/llm_models/utils_model.py | 912 +++++++++++++++++- 7 files changed, 886 insertions(+), 959 deletions(-) delete mode 100644 src/llm_models/llm_utils.py delete mode 100644 src/llm_models/model_selector.py delete mode 100644 src/llm_models/prompt_processor.py delete mode 100644 src/llm_models/request_executor.py delete mode 100644 src/llm_models/request_strategy.py diff --git a/src/llm_models/llm_utils.py b/src/llm_models/llm_utils.py deleted file mode 100644 index 8fba27e88..000000000 --- a/src/llm_models/llm_utils.py +++ /dev/null @@ -1,65 +0,0 @@ -# -*- coding: utf-8 -*- -""" -@File : llm_utils.py -@Time : 2024/05/24 17:00:00 -@Author : 墨墨 -@Version : 1.0 -@Desc : LLM相关通用工具函数 -""" -from typing import List, Dict, Any, Tuple - -from src.common.logger import get_logger -from .payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType - -logger = get_logger("model_utils") - -def normalize_image_format(image_format: str) -> str: - """ - 标准化图片格式名称,确保与各种API的兼容性 - """ - format_mapping = { - "jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg", - "png": "png", "PNG": "png", - "webp": "webp", "WEBP": "webp", - "gif": "gif", "GIF": "gif", - "heic": "heic", "HEIC": "heic", - "heif": "heif", "HEIF": "heif", - } - normalized = format_mapping.get(image_format, image_format.lower()) - logger.debug(f"图片格式标准化: {image_format} -> {normalized}") - return normalized - -def build_tool_options(tools: List[Dict[str, Any]] | None) -> List[ToolOption] | None: - """构建工具选项列表""" - if not tools: - return None - tool_options: List[ToolOption] = [] - for tool in tools: - try: - tool_options_builder = ToolOptionBuilder() - tool_options_builder.set_name(tool.get("name", "")) - tool_options_builder.set_description(tool.get("description", "")) - parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) - for param in parameters: - # 参数校验 - assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" - assert isinstance(param[0], str), "参数名称必须是字符串" - assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" - assert isinstance(param[2], str), "参数描述必须是字符串" - assert isinstance(param[3], bool), "参数是否必填必须是布尔值" - assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" - - tool_options_builder.add_param( - name=param[0], - param_type=param[1], - description=param[2], - required=param[3], - enum_values=param[4], - ) - tool_options.append(tool_options_builder.build()) - except AssertionError as ae: - logger.error(f"工具 '{tool.get('name', 'unknown')}' 的参数定义错误: {str(ae)}") - except Exception as e: - logger.error(f"构建工具 '{tool.get('name', 'unknown')}' 失败: {str(e)}") - - return tool_options or None \ No newline at end of file diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index eeb90c265..7b997b680 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]: def _convert_tool_param(param: ToolParam) -> dict: """转换工具参数""" - result: dict[str, Any] = { + result = { "type": param.param_type.value, "description": param.description, } @@ -132,7 +132,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]: def _convert_tool_option_item(tool_option: ToolOption) -> dict: """转换单个工具选项""" - function_declaration: dict[str, Any] = { + function_declaration = { "name": tool_option.name, "description": tool_option.description, } @@ -500,7 +500,7 @@ class AiohttpGeminiClient(BaseClient): # 直接重抛项目定义的异常 raise except Exception as e: - logger.debug(f"请求处理中发生未知异常: {e}") + logger.debug(e) # 其他异常转换为网络连接错误 raise NetworkConnectionError() from e diff --git a/src/llm_models/model_selector.py b/src/llm_models/model_selector.py deleted file mode 100644 index 61ec06938..000000000 --- a/src/llm_models/model_selector.py +++ /dev/null @@ -1,130 +0,0 @@ -# -*- coding: utf-8 -*- -""" -@File : model_selector.py -@Time : 2024/05/24 16:00:00 -@Author : 墨墨 -@Version : 1.0 -@Desc : 模型选择与负载均衡器 -""" -from typing import Dict, Tuple, Set, Optional - -from src.common.logger import get_logger -from src.config.config import model_config -from src.config.api_ada_configs import ModelInfo, APIProvider, TaskConfig -from .model_client.base_client import BaseClient, client_registry - -logger = get_logger("model_utils") - - -class ModelSelector: - """模型选择与负载均衡器""" - - def __init__(self, model_set: TaskConfig, request_type: str = ""): - """ - 初始化模型选择器 - - Args: - model_set (TaskConfig): 任务配置中定义的模型集合 - request_type (str, optional): 请求类型 (例如 "embedding"). Defaults to "". - """ - self.model_for_task = model_set - self.request_type = request_type - self.model_usage: Dict[str, Tuple[int, int, int]] = { - model: (0, 0, 0) for model in self.model_for_task.model_list - } - """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" - - def select_best_available_model( - self, failed_models_in_this_request: Set[str] - ) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: - """ - 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 - - Args: - failed_models_in_this_request (Set[str]): 当前请求中已失败的模型名称集合。 - - Returns: - Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: 选定的模型详细信息,如果无可用模型则返回 None。 - """ - candidate_models_usage = { - model_name: usage_data - for model_name, usage_data in self.model_usage.items() - if model_name not in failed_models_in_this_request - } - - if not candidate_models_usage: - logger.warning("没有可用的模型供当前请求选择。") - return None - - # 根据现有公式查找分数最低的模型 - # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 - # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。 - least_used_model_name = min( - candidate_models_usage, - key=lambda k: candidate_models_usage[k][0] - + candidate_models_usage[k][1] * 300 - + candidate_models_usage[k][2] * 1000, - ) - - # --- 动态故障转移的核心逻辑 --- - # RequestStrategy 中的循环会多次调用此函数。 - # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数, - # 此时由于失败模型已被标记,且其惩罚值可能已在 RequestExecutor 中增加, - # 此函数会自动选择一个得分更低(即更可用)的模型。 - # 这种机制实现了动态的、基于当前系统状态的故障转移。 - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - - logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") - - # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。 - # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) - - return model_info, api_provider, client - - def select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: - """ - 根据总tokens和惩罚值选择的模型 (负载均衡) - """ - least_used_model_name = min( - self.model_usage, - key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, - ) - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - logger.debug(f"选择请求模型: {model_info.name}") - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) - return model_info, api_provider, client - - def update_model_penalty(self, model_name: str, penalty_increment: int): - """ - 更新指定模型的惩罚值 - - Args: - model_name (str): 模型名称 - penalty_increment (int): 惩罚增量 - """ - if model_name in self.model_usage: - total_tokens, penalty, usage_penalty = self.model_usage[model_name] - self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) - logger.debug(f"模型 '{model_name}' 的惩罚值增加了 {penalty_increment}") - - def decrease_usage_penalty(self, model_name: str): - """ - 请求结束后,减少使用惩罚值 - - Args: - model_name (str): 模型名称 - """ - if model_name in self.model_usage: - total_tokens, penalty, usage_penalty = self.model_usage[model_name] - self.model_usage[model_name] = (total_tokens, penalty, usage_penalty - 1) \ No newline at end of file diff --git a/src/llm_models/prompt_processor.py b/src/llm_models/prompt_processor.py deleted file mode 100644 index 0ae944369..000000000 --- a/src/llm_models/prompt_processor.py +++ /dev/null @@ -1,170 +0,0 @@ -# -*- coding: utf-8 -*- -""" -@File : prompt_processor.py -@Time : 2024/05/24 16:45:00 -@Author : 墨墨 -@Version : 1.0 -@Desc : 提示词预处理(反截断、内容混淆等) -""" -import random -import string -import re -from typing import Tuple - -from src.common.logger import get_logger -from src.config.api_ada_configs import APIProvider, ModelInfo - -logger = get_logger("prompt_processor") - - -class PromptProcessor: - """ - 提示词处理器。 - 负责对发送给模型的原始prompt进行预处理,以增强模型性能或实现特定功能。 - 主要功能包括: - 1. **反截断**:在prompt末尾添加一个特殊的结束标记指令,帮助判断模型输出是否被截断。 - 2. **内容混淆**:向prompt中注入随机的“噪音”字符串,并附带指令让模型忽略它们, - 可能用于绕过某些平台的审查或内容策略。 - 3. **思维链提取**:从模型的响应中分离出思考过程(被标签包裹)和最终答案。 - """ - - def __init__(self): - """初始化Prompt处理器,定义所需的指令文本。""" - # 指导模型忽略噪音字符串的指令 - self.noise_instruction = """**【核心过滤规则】** -在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 - -你的任务是【完全并彻底地忽略】这些随机字符串。 -**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。 -""" - # 定义一个独特的结束标记,用于反截断检查 - self.end_marker = "###MAI_RESPONSE_END###" - # 指导模型在回复末尾添加结束标记的指令 - self.anti_truncation_instruction = f""" -**【输出完成信令】** -这是一个非常重要的指令,请务-务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。 -例如: -<你的回复内容> -{self.end_marker} - -这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 -""" - - def process_prompt( - self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str - ) -> str: - """ - 根据模型和API提供商的配置,对输入的prompt进行预处理。 - - Args: - prompt (str): 原始的用户输入prompt。 - model_info (ModelInfo): 当前使用的模型信息。 - api_provider (APIProvider): 当前API提供商的配置。 - task_name (str): 当前任务的名称,用于日志记录。 - - Returns: - str: 经过处理后的、最终将发送给模型的prompt。 - """ - processed_prompt = prompt - - # 步骤 1: 根据模型配置添加反截断指令 - use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - if use_anti_truncation: - processed_prompt += self.anti_truncation_instruction - logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") - - # 步骤 2: 根据API提供商配置应用内容混淆 - if getattr(api_provider, "enable_content_obfuscation", False): - intensity = getattr(api_provider, "obfuscation_intensity", 1) - logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") - processed_prompt = self._apply_content_obfuscation(processed_prompt, intensity) - - return processed_prompt - - def _apply_content_obfuscation(self, text: str, intensity: int) -> str: - """ - 对文本应用内容混淆处理。 - 首先添加过滤规则指令,然后注入随机噪音。 - """ - # 在文本开头加入指导模型忽略噪音的指令 - processed_text = self.noise_instruction + "\n\n" + text - logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}") - - # 在文本中注入随机乱码 - final_text = self._inject_random_noise(processed_text, intensity) - logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}") - - return final_text - - @staticmethod - def _inject_random_noise(text: str, intensity: int) -> str: - """ - 根据指定的强度,在文本的词语之间随机注入噪音字符串。 - - Args: - text (str): 待注入噪音的文本。 - intensity (int): 混淆强度 (1, 2, or 3),决定噪音的注入概率和长度。 - - Returns: - str: 注入噪音后的文本。 - """ - def generate_noise(length: int) -> str: - """生成指定长度的随机噪音字符串。""" - chars = ( - string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?" - + "一二三四五六七八九零壹贰叁" + "αβγδεζηθικλμνξοπρστυφχψω" + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" - ) - return "".join(random.choice(chars) for _ in range(length)) - - # 根据强度级别定义注入参数 - params = { - 1: {"probability": 15, "length": (3, 6)}, # 低强度 - 2: {"probability": 25, "length": (5, 10)}, # 中强度 - 3: {"probability": 35, "length": (8, 15)}, # 高强度 - } - config = params.get(intensity, params[1]) # 默认为低强度 - logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}") - - words = text.split() - result = [] - noise_count = 0 - for word in words: - result.append(word) - # 按概率决定是否注入噪音 - if random.randint(1, 100) <= config["probability"]: - noise_length = random.randint(*config["length"]) - noise = generate_noise(noise_length) - result.append(noise) - noise_count += 1 - - logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}") - return " ".join(result) - - @staticmethod - def extract_reasoning(content: str) -> Tuple[str, str]: - """ - 从模型返回的完整内容中提取被...标签包裹的思考过程, - 并返回清理后的内容和思考过程。 - - Args: - content (str): 模型返回的原始字符串。 - - Returns: - Tuple[str, str]: - - 清理后的内容(移除了标签及其内容)。 - - 提取出的思考过程文本(如果没有则为空字符串)。 - """ - # 使用正则表达式精确查找 ... 标签及其内容 - think_pattern = re.compile(r"(.*?)\s*", re.DOTALL) - match = think_pattern.search(content) - - if match: - # 提取思考过程 - reasoning = match.group(1).strip() - # 从原始内容中移除匹配到的整个部分(包括标签和后面的空白) - clean_content = think_pattern.sub("", content, count=1).strip() - else: - reasoning = "" - clean_content = content.strip() - - return clean_content, reasoning diff --git a/src/llm_models/request_executor.py b/src/llm_models/request_executor.py deleted file mode 100644 index d33fd88cd..000000000 --- a/src/llm_models/request_executor.py +++ /dev/null @@ -1,288 +0,0 @@ -# -*- coding: utf-8 -*- -""" -@File : request_executor.py -@Time : 2024/05/24 16:15:00 -@Author : 墨墨 -@Version : 1.0 -@Desc : 负责执行LLM请求、处理重试及异常 -""" -import asyncio -from typing import List, Callable, Optional, Tuple - -from src.common.logger import get_logger -from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig -from .exceptions import ( - NetworkConnectionError, - ReqAbortException, - RespNotOkException, - RespParseException, -) -from .model_client.base_client import APIResponse, BaseClient -from .model_selector import ModelSelector -from .payload_content.message import Message -from .payload_content.resp_format import RespFormat -from .payload_content.tool_option import ToolOption -from .utils import compress_messages - -logger = get_logger("model_utils") - - -class RequestExecutor: - """ - 请求执行器。 - 负责直接与模型客户端交互,执行API请求。 - 它包含了核心的请求重试、异常分类处理、模型惩罚机制和消息压缩等底层逻辑。 - """ - - def __init__( - self, - task_name: str, - model_set: TaskConfig, - api_provider: APIProvider, - client: BaseClient, - model_info: ModelInfo, - model_selector: ModelSelector, - ): - """ - 初始化请求执行器。 - - Args: - task_name (str): 当前任务的名称。 - model_set (TaskConfig): 任务相关的模型配置。 - api_provider (APIProvider): API提供商配置。 - client (BaseClient): 用于发送请求的客户端实例。 - model_info (ModelInfo): 当前请求要使用的模型信息。 - model_selector (ModelSelector): 模型选择器实例,用于更新模型状态(如惩罚值)。 - """ - self.task_name = task_name - self.model_set = model_set - self.api_provider = api_provider - self.client = client - self.model_info = model_info - self.model_selector = model_selector - - async def execute_request( - self, - request_type: str, - message_list: List[Message] | None = None, - tool_options: list[ToolOption] | None = None, - response_format: RespFormat | None = None, - stream_response_handler: Optional[Callable] = None, - async_response_parser: Optional[Callable] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - embedding_input: str = "", - audio_base64: str = "", - ) -> APIResponse: - """ - 实际执行API请求,并包含完整的重试和异常处理逻辑。 - - Args: - request_type (str): 请求类型 ('response', 'embedding', 'audio')。 - message_list (List[Message] | None, optional): 消息列表。 Defaults to None. - tool_options (list[ToolOption] | None, optional): 工具选项。 Defaults to None. - response_format (RespFormat | None, optional): 响应格式要求。 Defaults to None. - stream_response_handler (Optional[Callable], optional): 流式响应处理器。 Defaults to None. - async_response_parser (Optional[Callable], optional): 异步响应解析器。 Defaults to None. - temperature (Optional[float], optional): 温度参数。 Defaults to None. - max_tokens (Optional[int], optional): 最大token数。 Defaults to None. - embedding_input (str, optional): embedding输入文本。 Defaults to "". - audio_base64 (str, optional): 音频base64数据。 Defaults to "". - - Returns: - APIResponse: 从模型客户端返回的API响应对象。 - - Raises: - ValueError: 如果请求类型未知。 - RuntimeError: 如果所有重试都失败。 - """ - retry_remain = self.api_provider.max_retry - compressed_messages: Optional[List[Message]] = None - - # 循环进行重试 - while retry_remain > 0: - try: - # 根据请求类型调用不同的客户端方法 - if request_type == "response": - assert message_list is not None, "message_list cannot be None for response requests" - return await self.client.get_response( - model_info=self.model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=self.model_set.max_tokens if max_tokens is None else max_tokens, - temperature=self.model_set.temperature if temperature is None else temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=self.model_info.extra_params, - ) - elif request_type == "embedding": - assert embedding_input, "embedding_input cannot be empty for embedding requests" - return await self.client.get_embedding( - model_info=self.model_info, - embedding_input=embedding_input, - extra_params=self.model_info.extra_params, - ) - elif request_type == "audio": - assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" - return await self.client.get_audio_transcriptions( - model_info=self.model_info, - audio_base64=audio_base64, - extra_params=self.model_info.extra_params, - ) - raise ValueError(f"未知的请求类型: {request_type}") - except Exception as e: - logger.debug(f"请求失败: {str(e)}") - # 对失败的模型应用惩罚 - self._apply_penalty_on_failure(e) - - # 使用默认异常处理器来决定下一步操作(等待、重试、压缩或终止) - wait_interval, compressed_messages = self._default_exception_handler( - e, - remain_try=retry_remain, - retry_interval=self.api_provider.retry_interval, - messages=(message_list, compressed_messages is not None) if message_list else None, - ) - - if wait_interval == -1: - retry_remain = 0 # 处理器决定不再重试 - elif wait_interval > 0: - logger.info(f"等待 {wait_interval} 秒后重试...") - await asyncio.sleep(wait_interval) - finally: - retry_remain -= 1 - - # 所有重试次数用尽后 - self.model_selector.decrease_usage_penalty(self.model_info.name) # 减少因使用而增加的基础惩罚 - logger.error(f"模型 '{self.model_info.name}' 请求失败,达到最大重试次数 {self.api_provider.max_retry} 次") - raise RuntimeError("请求失败,已达到最大重试次数") - - def _apply_penalty_on_failure(self, e: Exception): - """ - 根据异常类型,动态调整失败模型的惩罚值。 - 关键错误(如网络问题、服务器5xx错误)会施加更重的惩罚。 - """ - CRITICAL_PENALTY_MULTIPLIER = 5 - default_penalty_increment = 1 - penalty_increment = default_penalty_increment - - # 对严重错误施加更高的惩罚 - if isinstance(e, (NetworkConnectionError, ReqAbortException)): - penalty_increment = CRITICAL_PENALTY_MULTIPLIER - elif isinstance(e, RespNotOkException): - if e.status_code >= 500: # 服务器内部错误 - penalty_increment = CRITICAL_PENALTY_MULTIPLIER - - # 记录日志 - log_message = f"发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}" - if isinstance(e, (NetworkConnectionError, ReqAbortException)): - log_message = f"发生关键错误 ({type(e).__name__}),增加惩罚值: {penalty_increment}" - elif isinstance(e, RespNotOkException): - log_message = f"发生响应错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}" - logger.warning(f"模型 '{self.model_info.name}' {log_message}") - - # 更新模型的惩罚值 - self.model_selector.update_model_penalty(self.model_info.name, penalty_increment) - - def _default_exception_handler( - self, - e: Exception, - remain_try: int, - retry_interval: int = 10, - messages: Tuple[List[Message], bool] | None = None, - ) -> Tuple[int, List[Message] | None]: - """ - 默认的异常分类处理器。 - 根据异常类型决定是否重试、等待多久以及是否需要压缩消息。 - - Returns: - Tuple[int, List[Message] | None]: - - 等待时间(秒)。-1表示不重试。 - - 压缩后的消息列表(如果有)。 - """ - model_name = self.model_info.name - - if isinstance(e, NetworkConnectionError): - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数", - ) - elif isinstance(e, ReqAbortException): - logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") - return -1, None # 请求被中断,不重试 - elif isinstance(e, RespNotOkException): - return self._handle_resp_not_ok(e, remain_try, retry_interval, messages) - elif isinstance(e, RespParseException): - logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") - logger.debug(f"附加内容: {str(e.ext_info)}") - return -1, None # 解析错误通常不可重试 - else: - logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") - return -1, None # 未知异常,不重试 - - def _handle_resp_not_ok( - self, - e: RespNotOkException, - remain_try: int, - retry_interval: int = 10, - messages: tuple[list[Message], bool] | None = None, - ) -> Tuple[int, Optional[List[Message]]]: - """处理HTTP状态码非200的异常。""" - model_name = self.model_info.name - # 客户端错误 (4xx),通常不可重试 - if e.status_code in [400, 401, 402, 403, 404]: - logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}") - return -1, None - # 请求体过大 (413) - elif e.status_code == 413: - # 如果消息存在且尚未被压缩,尝试压缩后重试一次 - if messages and not messages[1]: # messages[1] is a flag indicating if it's already compressed - return self._check_retry( - remain_try, 0, # 立即重试 - can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", - cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,压缩后仍失败", - can_retry_callable=compress_messages, messages=messages[0], - ) - logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,无法压缩,放弃请求。") - return -1, None - # 请求过于频繁 (429) - elif e.status_code == 429: - return self._check_retry( - remain_try, retry_interval, - can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数", - ) - # 服务器错误 (5xx),可以重试 - elif e.status_code >= 500: - return self._check_retry( - remain_try, retry_interval, - can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数", - ) - else: - logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}") - return -1, None - - @staticmethod - def _check_retry( - remain_try: int, - retry_interval: int, - can_retry_msg: str, - cannot_retry_msg: str, - can_retry_callable: Callable | None = None, - **kwargs, - ) -> Tuple[int, List[Message] | None]: - """ - 辅助函数:检查是否可以重试,并执行可选的回调函数(如消息压缩)。 - """ - if remain_try > 0: - logger.warning(f"{can_retry_msg}") - # 如果有可执行的回调(例如压缩函数),执行它并返回结果 - if can_retry_callable is not None: - return retry_interval, can_retry_callable(**kwargs) - return retry_interval, None - else: - logger.warning(f"{cannot_retry_msg}") - return -1, None \ No newline at end of file diff --git a/src/llm_models/request_strategy.py b/src/llm_models/request_strategy.py deleted file mode 100644 index 82575b076..000000000 --- a/src/llm_models/request_strategy.py +++ /dev/null @@ -1,274 +0,0 @@ -# -*- coding: utf-8 -*- -""" -@File : request_strategy.py -@Time : 2024/05/24 16:30:00 -@Author : 墨墨 -@Version : 1.0 -@Desc : 高级请求策略(并发、故障转移) -""" -import asyncio -import random -from typing import Optional, Dict, Any, Callable, Coroutine - -from src.common.logger import get_logger -from src.config.api_ada_configs import TaskConfig -from .model_client.base_client import APIResponse -from .model_selector import ModelSelector -from .payload_content.message import MessageBuilder -from .prompt_processor import PromptProcessor -from .request_executor import RequestExecutor - -logger = get_logger("model_utils") - - -class RequestStrategy: - """ - 高级请求策略模块。 - 负责实现复杂的请求逻辑,如模型的故障转移(fallback)和并发请求。 - """ - - def __init__(self, model_set: TaskConfig, model_selector: ModelSelector, task_name: str): - """ - 初始化请求策略。 - - Args: - model_set (TaskConfig): 特定任务的模型配置。 - model_selector (ModelSelector): 模型选择器实例。 - task_name (str): 当前任务的名称。 - """ - self.model_set = model_set - self.model_selector = model_selector - self.task_name = task_name - - async def execute_with_fallback( - self, - base_payload: Dict[str, Any], - raise_when_empty: bool = True, - ) -> Dict[str, Any]: - """ - 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 - - 该方法会按顺序尝试任务配置中的所有可用模型,直到一个模型成功返回响应。 - 如果所有模型都失败,将根据 `raise_when_empty` 参数决定是抛出异常还是返回一个失败结果。 - - Args: - base_payload (Dict[str, Any]): 基础请求载荷,包含prompt、工具选项等。 - raise_when_empty (bool, optional): 如果所有模型都失败或返回空内容,是否抛出异常。 Defaults to True. - - Returns: - Dict[str, Any]: 一个包含响应结果的字典,包括内容、模型信息、用量和成功状态。 - """ - # 记录在本次请求中已经失败的模型,避免重复尝试 - failed_models_in_this_request = set() - max_attempts = len(self.model_set.model_list) - last_exception: Optional[Exception] = None - - for attempt in range(max_attempts): - # 选择一个当前最佳且未失败的模型 - model_selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request) - - if model_selection_result is None: - logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") - break # 没有更多可用模型,跳出循环 - - model_info, api_provider, client = model_selection_result - model_name = model_info.name - logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...") - - try: - # 步骤 1: 预处理Prompt - prompt_processor: PromptProcessor = base_payload["prompt_processor"] - raw_prompt = base_payload["prompt"] - processed_prompt = prompt_processor.process_prompt( - raw_prompt, model_info, api_provider, self.task_name - ) - - # 步骤 2: 构建消息体 - message_builder = MessageBuilder().add_text_content(processed_prompt) - messages = [message_builder.build()] - - # 步骤 3: 为执行器创建载荷 - executor_payload = { - "request_type": "response", # 策略模式目前只处理'response'类型请求 - "message_list": messages, - "tool_options": base_payload["tool_options"], - "temperature": base_payload["temperature"], - "max_tokens": base_payload["max_tokens"], - } - - # 创建请求执行器实例 - executor = RequestExecutor( - task_name=self.task_name, - model_set=self.model_set, - api_provider=api_provider, - client=client, - model_info=model_info, - model_selector=self.model_selector, - ) - # 执行请求,并处理内部的空回复/截断重试 - response = await self._execute_and_handle_empty_retry(executor, executor_payload, prompt_processor) - - # 步骤 4: 后处理响应 - # 在获取到成功的、完整的响应后,提取思考过程内容 - final_content, reasoning_content = prompt_processor.extract_reasoning(response.content or "") - response.content = final_content # 使用清理后的内容更新响应对象 - - tool_calls = response.tool_calls - - # 检查最终内容是否为空 - if not final_content and not tool_calls: - if raise_when_empty: - raise RuntimeError("所选模型生成了空回复。") - logger.warning(f"模型 '{model_name}' 生成了空回复,返回默认信息。") - - logger.debug(f"模型 '{model_name}' 成功生成了回复。") - # 返回成功结果,包含用量和模型信息,供上层记录 - return { - "content": response.content, - "reasoning_content": reasoning_content, - "model_name": model_name, - "tool_calls": tool_calls, - "model_info": model_info, - "usage": response.usage, - "success": True, - } - - except Exception as e: - # 捕获请求过程中的任何异常 - logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") - failed_models_in_this_request.add(model_info.name) - last_exception = e - - # 如果循环结束仍未成功 - logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") - if raise_when_empty: - if last_exception: - raise RuntimeError("所有模型均未能生成响应。") from last_exception - raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") - - # 返回失败结果 - return { - "content": "所有模型都请求失败", - "reasoning_content": "", - "model_name": "unknown", - "tool_calls": None, - "model_info": None, - "usage": None, - "success": False, - } - - async def execute_concurrently( - self, - coro_callable: Callable[..., Coroutine[Any, Any, Any]], - concurrency_count: int, - *args, - **kwargs, - ) -> Any: - """ - 以指定的并发数执行多个协程,并从所有成功的结果中随机选择一个返回。 - - Args: - coro_callable (Callable): 要并发执行的协程函数。 - concurrency_count (int): 并发数量。 - *args: 传递给协程函数的位置参数。 - **kwargs: 传递给协程函数的关键字参数。 - - Returns: - Any: 从成功的结果中随机选择的一个。 - - Raises: - RuntimeError: 如果所有并发任务都失败了。 - """ - logger.info(f"启用并发请求模式,并发数: {concurrency_count}") - # 创建并发任务列表 - tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)] - - # 等待所有任务完成 - results = await asyncio.gather(*tasks, return_exceptions=True) - # 筛选出成功的结果 - successful_results = [ - res for res in results if isinstance(res, dict) and res.get("success") - ] - - if successful_results: - # 从成功结果中随机选择一个 - selected = random.choice(successful_results) - logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个") - return selected - - # 如果没有成功的结果,记录所有异常 - for i, res in enumerate(results): - if isinstance(res, Exception): - logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") - - # 抛出第一个遇到的异常 - first_exception = next((res for res in results if isinstance(res, Exception)), None) - if first_exception: - raise first_exception - - raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息") - - async def _execute_and_handle_empty_retry( - self, executor: RequestExecutor, payload: Dict[str, Any], prompt_processor: PromptProcessor - ) -> APIResponse: - """ - 在单个模型内部处理因回复为空或被截断而触发的重试逻辑。 - - Args: - executor (RequestExecutor): 请求执行器实例。 - payload (Dict[str, Any]): 传递给 `execute_request` 的载荷。 - prompt_processor (PromptProcessor): 提示词处理器,用于获取反截断标记。 - - Returns: - APIResponse: 一个有效的、非空且完整的API响应。 - - Raises: - RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的回复。 - """ - empty_retry_count = 0 - max_empty_retry = executor.api_provider.max_retry - empty_retry_interval = executor.api_provider.retry_interval - # 检查模型是否启用了反截断功能 - use_anti_truncation = getattr(executor.model_info, "use_anti_truncation", False) - end_marker = prompt_processor.end_marker - - while empty_retry_count <= max_empty_retry: - response = await executor.execute_request(**payload) - - content = response.content or "" - tool_calls = response.tool_calls - - # 判断是否为空回复 - is_empty_reply = not tool_calls and (not content or content.strip() == "") - is_truncated = False - - # 如果启用了反截断,检查回复是否被截断 - if use_anti_truncation and end_marker: - if content.endswith(end_marker): - # 如果包含结束标记,说明回复完整,移除标记 - response.content = content[: -len(end_marker)].strip() - else: - # 否则,认为回复被截断 - is_truncated = True - - # 如果是空回复或截断,则进行重试 - if is_empty_reply or is_truncated: - empty_retry_count += 1 - if empty_retry_count <= max_empty_retry: - reason = "空回复" if is_empty_reply else "截断" - logger.warning( - f"模型 '{executor.model_info.name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..." - ) - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) - continue # 继续下一次循环重试 - else: - # 达到最大重试次数,抛出异常 - reason = "空回复" if is_empty_reply else "截断" - raise RuntimeError(f"模型 '{executor.model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。") - - # 成功获取到有效响应,返回结果 - return response - - # 此处理论上不会到达,因为循环要么返回要么抛出异常 - raise RuntimeError("空回复/截断重试逻辑出现未知错误") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index b223f0d82..ac318fbe4 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,15 +1,14 @@ -# -*- coding: utf-8 -*- -""" -@File : utils_model.py -@Time : 2024/05/24 17:15:00 -@Author : 墨墨 -@Version : 2.0 (Refactored) -@Desc : LLM请求协调器 -""" +import re +import asyncio import time -from typing import Tuple, List, Dict, Optional, Any +import random + +from enum import Enum +from rich.traceback import install +from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator from src.common.logger import get_logger +<<<<<<< HEAD from src.config.api_ada_configs import TaskConfig, ModelInfo, UsageRecord from .llm_utils import build_tool_options, normalize_image_format from .model_selector import ModelSelector @@ -18,15 +17,128 @@ from .payload_content.tool_option import ToolCall from .prompt_processor import PromptProcessor from .request_strategy import RequestStrategy from .utils import llm_usage_recorder +======= +from src.config.config import model_config +from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig +from .payload_content.message import MessageBuilder, Message +from .payload_content.resp_format import RespFormat +from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType +from .model_client.base_client import BaseClient, APIResponse, client_registry +from .utils import compress_messages, llm_usage_recorder +from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException + +install(extra_lines=3) +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) logger = get_logger("model_utils") +# 常见Error Code Mapping +error_code_mapping = { + 400: "参数不正确", + 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", + 402: "账号余额不足", + 403: "需要实名,或余额不足", + 404: "Not Found", + 429: "请求过于频繁,请稍后再试", + 500: "服务器内部故障", + 503: "服务器负载过高", +} + + +def _normalize_image_format(image_format: str) -> str: + """ + 标准化图片格式名称,确保与各种API的兼容性 + + Args: + image_format (str): 原始图片格式 + + Returns: + str: 标准化后的图片格式 + """ + format_mapping = { + "jpg": "jpeg", + "JPG": "jpeg", + "JPEG": "jpeg", + "jpeg": "jpeg", + "png": "png", + "PNG": "png", + "webp": "webp", + "WEBP": "webp", + "gif": "gif", + "GIF": "gif", + "heic": "heic", + "HEIC": "heic", + "heif": "heif", + "HEIF": "heif", + } + + normalized = format_mapping.get(image_format, image_format.lower()) + logger.debug(f"图片格式标准化: {image_format} -> {normalized}") + return normalized + + +class RequestType(Enum): + """请求类型枚举""" + + RESPONSE = "response" + EMBEDDING = "embedding" + AUDIO = "audio" + + +async def execute_concurrently( + coro_callable: Callable[..., Coroutine[Any, Any, Any]], + concurrency_count: int, + *args, + **kwargs, +) -> Any: + """ + 执行并发请求并从成功的结果中随机选择一个。 + + Args: + coro_callable (Callable): 要并发执行的协程函数。 + concurrency_count (int): 并发执行的次数。 + *args: 传递给协程函数的位置参数。 + **kwargs: 传递给协程函数的关键字参数。 + + Returns: + Any: 其中一个成功执行的结果。 + + Raises: + RuntimeError: 如果所有并发请求都失败。 + """ + logger.info(f"启用并发请求模式,并发数: {concurrency_count}") + tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)] + + results = await asyncio.gather(*tasks, return_exceptions=True) + successful_results = [res for res in results if not isinstance(res, Exception)] + + if successful_results: + selected = random.choice(successful_results) + logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个") + return selected + + # 如果所有请求都失败了,记录所有异常并抛出第一个 + for i, res in enumerate(results): + if isinstance(res, Exception): + logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") + + first_exception = next((res for res in results if isinstance(res, Exception)), None) + if first_exception: + raise first_exception + + raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息") + + class LLMRequest: +<<<<<<< HEAD """ LLM请求协调器。 封装了模型选择、Prompt处理、请求执行和高级策略(如故障转移、并发)的完整流程。 为上层业务逻辑提供统一的、简化的接口来与大语言模型交互。 """ +======= + """LLM请求类""" +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: """ @@ -39,9 +151,30 @@ class LLMRequest: self.task_name = request_type self.model_for_task = model_set self.request_type = request_type - self.model_selector = ModelSelector(model_set, request_type) - self.prompt_processor = PromptProcessor() - self.request_strategy = RequestStrategy(model_set, self.model_selector, request_type) + self.model_usage: Dict[str, Tuple[int, int, int]] = { + model: (0, 0, 0) for model in self.model_for_task.model_list + } + """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" + + # 内容混淆过滤指令 + self.noise_instruction = """**【核心过滤规则】** +在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 + +你的任务是【完全并彻底地忽略】这些随机字符串。 +**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。 +""" + + # 反截断指令 + self.end_marker = "###MAI_RESPONSE_END###" + self.anti_truncation_instruction = f""" +**【输出完成信令】** +这是一个非常重要的指令,请务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。 +例如: +<你的回复内容> +{self.end_marker} + +这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 +""" async def generate_response_for_image( self, @@ -52,6 +185,7 @@ class LLMRequest: max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ +<<<<<<< HEAD 为包含图像的多模态输入生成文本响应。 Args: @@ -78,8 +212,26 @@ class LLMRequest: normalized_format = normalize_image_format(image_format) # 使用MessageBuilder构建多模态消息 +======= + 为图像生成响应 + Args: + prompt (str): 提示词 + image_base64 (str): 图像的Base64编码字符串 + image_format (str): 图像格式(如 'png', 'jpeg' 等) + Returns: + (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + """ + # 标准化图片格式以确保API兼容性 + normalized_format = _normalize_image_format(image_format) + + # 模型选择 + start_time = time.time() + model_info, api_provider, client = self._select_model() + + # 请求体构建 +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) message_builder = MessageBuilder() - message_builder.add_text_content(processed_prompt) + message_builder.add_text_content(prompt) message_builder.add_image_content( image_base64=image_base64, image_format=normalized_format, @@ -87,35 +239,54 @@ class LLMRequest: ) messages = [message_builder.build()] +<<<<<<< HEAD # 步骤 3: 执行请求 (图像请求通常不走复杂的故障转移策略,直接执行) from .request_executor import RequestExecutor executor = RequestExecutor( task_name=self.task_name, model_set=self.model_for_task, +======= + # 请求并处理返回值 + response = await self._execute_request( +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) api_provider=api_provider, client=client, + request_type=RequestType.RESPONSE, model_info=model_info, - model_selector=self.model_selector, - ) - response = await executor.execute_request( - request_type="response", message_list=messages, temperature=temperature, max_tokens=max_tokens, ) +<<<<<<< HEAD # 步骤 4: 处理响应 content, reasoning_content = self.prompt_processor.extract_reasoning(response.content or "") tool_calls = response.tool_calls # 记录用量 +======= + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) if usage := response.usage: - await self._record_usage(model_info, usage, time.time() - start_time) - + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", + time_cost=time.time() - start_time, + request_type=self.request_type, + endpoint="/chat/completions", + ) return content, (reasoning_content, model_info.name, tool_calls) async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: """ +<<<<<<< HEAD 将语音数据转换为文本(语音识别)。 Args: @@ -132,14 +303,31 @@ class LLMRequest: executor = RequestExecutor( task_name=self.task_name, model_set=self.model_for_task, +======= + 为语音生成响应 + Args: + voice_base64 (str): 语音的Base64编码字符串 + Returns: + (Optional[str]): 生成的文本描述或None + """ + # 模型选择 + model_info, api_provider, client = self._select_model() + + # 请求并处理返回值 + response = await self._execute_request( +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) api_provider=api_provider, client=client, + request_type=RequestType.AUDIO, model_info=model_info, +<<<<<<< HEAD model_selector=self.model_selector, ) # 执行语音转文本请求 response = await executor.execute_request( request_type="audio", +======= +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) audio_base64=voice_base64, ) return response.content or None @@ -153,6 +341,7 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ +<<<<<<< HEAD 异步生成文本响应,支持并发和故障转移等高级策略。 Args: @@ -180,18 +369,33 @@ class LLMRequest: } # 步骤 2: 根据配置选择执行策略 (并发或单次带故障转移) +======= + 异步生成响应,支持并发请求 + Args: + prompt (str): 提示词 + temperature (float, optional): 温度参数 + max_tokens (int, optional): 最大token数 + tools: 工具配置 + raise_when_empty: 是否在空回复时抛出异常 + Returns: + (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + """ + # 检查是否需要并发请求 +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) concurrency_count = getattr(self.model_for_task, "concurrency_count", 1) - + if concurrency_count <= 1: - # 单次请求,但使用带故障转移的策略 - result = await self.request_strategy.execute_with_fallback( - base_payload, raise_when_empty - ) - else: - # 并发请求策略 - result = await self.request_strategy.execute_concurrently( - self.request_strategy.execute_with_fallback, + # 单次请求 + return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty) + + # 并发请求 + try: + # 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False, + # 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理 + content, (reasoning_content, model_name, tool_calls) = await execute_concurrently( + self._execute_single_request, concurrency_count, +<<<<<<< HEAD base_payload, raise_when_empty=False, # 在并发模式下,单个任务失败不应立即抛出异常 ) @@ -233,20 +437,195 @@ class LLMRequest: executor = RequestExecutor( task_name=self.task_name, model_set=self.model_for_task, +======= + prompt, + temperature, + max_tokens, + tools, + raise_when_empty=False, + ) + return content, (reasoning_content, model_name, tool_calls) + except Exception as e: + logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}") + if raise_when_empty: + raise e + return "所有并发请求都失败了", ("", "unknown", None) + + async def _execute_single_request( + self, + prompt: str, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + raise_when_empty: bool = True, + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + """ + 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 + """ + failed_models_in_this_request = set() + # 迭代次数等于模型总数,以确保每个模型在当前请求中最多只尝试一次 + max_attempts = len(self.model_for_task.model_list) + last_exception: Optional[Exception] = None + + for attempt in range(max_attempts): + # 根据负载均衡和当前故障选择最佳可用模型 + model_selection_result = self._select_best_available_model(failed_models_in_this_request) + + if model_selection_result is None: + logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") + break # 没有更多模型可供尝试 + + model_info, api_provider, client = model_selection_result + model_name = model_info.name + logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...") + + start_time = time.time() + + try: + # --- 为当前模型尝试进行设置 --- + # 检查是否为该模型启用反截断 + use_anti_truncation = getattr(model_info, "use_anti_truncation", False) + processed_prompt = prompt + if use_anti_truncation: + processed_prompt += self.anti_truncation_instruction + logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") + + processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) + + message_builder = MessageBuilder() + message_builder.add_text_content(processed_prompt) + messages = [message_builder.build()] + tool_built = self._build_tool_options(tools) + + # --- 当前选定模型内的空回复/截断重试逻辑 --- + empty_retry_count = 0 + max_empty_retry = api_provider.max_retry + empty_retry_interval = api_provider.retry_interval + + while empty_retry_count <= max_empty_retry: + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + tool_options=tool_built, + temperature=temperature, + max_tokens=max_tokens, + ) + + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + # 向后兼容 标签(如果 reasoning_content 为空) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + is_empty_reply = not tool_calls and (not content or content.strip() == "") + is_truncated = False + if use_anti_truncation: + if content.endswith(self.end_marker): + content = content[: -len(self.end_marker)].strip() + else: + is_truncated = True + + if is_empty_reply or is_truncated: + empty_retry_count += 1 + if empty_retry_count <= max_empty_retry: + reason = "空回复" if is_empty_reply else "截断" + logger.warning( + f"模型 '{model_name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..." + ) + if empty_retry_interval > 0: + await asyncio.sleep(empty_retry_interval) + continue # 使用当前模型重试 + else: + reason = "空回复" if is_empty_reply else "截断" + logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。将此模型标记为当前请求失败。") + raise RuntimeError(f"模型 '{model_name}' 已达到空回复/截断的最大内部重试次数。") + + # --- 从当前模型获取成功响应 --- + if usage := response.usage: + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + time_cost=time.time() - start_time, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions", + ) + + # 处理成功执行后响应仍然为空的情况 + if not content and not tool_calls: + if raise_when_empty: + raise RuntimeError("所选模型生成了空回复。") + content = "生成的响应为空" # Fallback message + + logger.debug(f"模型 '{model_name}' 成功生成了回复。") + return content, (reasoning_content, model_name, tool_calls) # 成功,立即返回 + + # --- 当前模型尝试过程中的异常处理 --- + except Exception as e: # 捕获当前模型尝试过程中的所有异常 + # 修复 NameError: model_name 在异常处理块中未定义,应使用 model_info.name + logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") + failed_models_in_this_request.add(model_info.name) + last_exception = e # 存储异常以供最终报告 + # 继续循环以尝试下一个可用模型 + + # 如果循环结束未能返回,则表示当前请求的所有模型都已失败 + logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") + if raise_when_empty: + if last_exception: + raise RuntimeError("所有模型均未能生成响应。") from last_exception + raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") + return "所有模型都请求失败", ("", "unknown", None) + + async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: + """获取嵌入向量 + Args: + embedding_input (str): 获取嵌入的目标 + Returns: + (Tuple[List[float], str]): (嵌入向量,使用的模型名称) + """ + # 无需构建消息体,直接使用输入文本 + start_time = time.time() + model_info, api_provider, client = self._select_model() + + # 请求并处理返回值 + response = await self._execute_request( +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) api_provider=api_provider, client=client, + request_type=RequestType.EMBEDDING, model_info=model_info, +<<<<<<< HEAD model_selector=self.model_selector, ) # 执行embedding请求 response = await executor.execute_request( request_type="embedding", +======= +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) embedding_input=embedding_input, ) - + embedding = response.embedding + + if usage := response.usage: + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + time_cost=time.time() - start_time, + model_usage=usage, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings", + ) + if not embedding: raise RuntimeError("获取embedding失败") +<<<<<<< HEAD # 记录用量 if usage := response.usage: @@ -271,4 +650,479 @@ class LLMRequest: time_cost=time_cost, request_type=self.request_type, endpoint=endpoint, +======= + + return embedding, model_info.name + + def _select_best_available_model(self, failed_models_in_this_request: set) -> Tuple[ModelInfo, APIProvider, BaseClient] | None: + """ + 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 + + 参数: + failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。 + + 返回: + Tuple[ModelInfo, APIProvider, BaseClient] | None: 选定的模型详细信息,如果无可用模型则返回 None。 + """ + candidate_models_usage = {} + # 过滤掉当前请求中已失败的模型 + for model_name, usage_data in self.model_usage.items(): + if model_name not in failed_models_in_this_request: + candidate_models_usage[model_name] = usage_data + + if not candidate_models_usage: + logger.warning("没有可用的模型供当前请求选择。") + return None + + # 根据现有公式查找分数最低的模型,该公式综合了总token数、模型惩罚值和使用频率惩罚值。 + # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。 + least_used_model_name = min( + candidate_models_usage, + key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, +>>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) ) + + # --- 动态故障转移的核心逻辑 --- + # _execute_single_request 中的循环会多次调用此函数。 + # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数, + # 此时由于失败模型已被标记,且其惩罚值可能已在 _execute_request 中增加, + # _select_best_available_model 会自动选择一个得分更低(即更可用)的模型。 + # 这种机制实现了动态的、基于当前系统状态的故障转移。 + + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + + # 对于嵌入任务,如果需要,强制创建新的客户端实例(从原始 _select_model 复制) + force_new_client = self.request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + + logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") + + # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。 + # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) + + return model_info, api_provider, client + + def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: + """ + 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 + """ + for model_name in self.model_for_task.model_list: + if model_name in failed_models: + continue + + model_info = model_config.get_model_info(model_name) + api_provider = model_config.get_provider(model_info.api_provider) + force_new_client = self.request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + + yield model_info, api_provider, client + + def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: + """ + 根据总tokens和惩罚值选择的模型 (负载均衡) + """ + least_used_model_name = min( + self.model_usage, + key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, + ) + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + + # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 + force_new_client = self.request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + logger.debug(f"选择请求模型: {model_info.name}") + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 + return model_info, api_provider, client + + async def _execute_request( + self, + api_provider: APIProvider, + client: BaseClient, + request_type: RequestType, + model_info: ModelInfo, + message_list: List[Message] | None = None, + tool_options: list[ToolOption] | None = None, + response_format: RespFormat | None = None, + stream_response_handler: Optional[Callable] = None, + async_response_parser: Optional[Callable] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + embedding_input: str = "", + audio_base64: str = "", + ) -> APIResponse: + """ + 实际执行请求的方法 + + 包含了重试和异常处理逻辑 + """ + retry_remain = api_provider.max_retry + compressed_messages: Optional[List[Message]] = None + while retry_remain > 0: + try: + if request_type == RequestType.RESPONSE: + assert message_list is not None, "message_list cannot be None for response requests" + return await client.get_response( + model_info=model_info, + message_list=(compressed_messages or message_list), + tool_options=tool_options, + max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, + temperature=self.model_for_task.temperature if temperature is None else temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + extra_params=model_info.extra_params, + ) + elif request_type == RequestType.EMBEDDING: + assert embedding_input, "embedding_input cannot be empty for embedding requests" + return await client.get_embedding( + model_info=model_info, + embedding_input=embedding_input, + extra_params=model_info.extra_params, + ) + elif request_type == RequestType.AUDIO: + assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" + return await client.get_audio_transcriptions( + model_info=model_info, + audio_base64=audio_base64, + extra_params=model_info.extra_params, + ) + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + # 处理异常 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + + # --- 增强动态故障转移的智能性 --- + # 根据异常类型和严重程度,动态调整模型的惩罚值。 + # 关键错误(如网络连接、服务器错误)会获得更高的惩罚, + # 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。 + CRITICAL_PENALTY_MULTIPLIER = 5 # 关键错误时的惩罚系数 + default_penalty_increment = 1 # 普通错误时的基础惩罚 + + penalty_increment = default_penalty_increment + + if isinstance(e, NetworkConnectionError): + # 网络连接问题表明模型服务器不稳定,增加较高惩罚 + penalty_increment = CRITICAL_PENALTY_MULTIPLIER + # 修复 NameError: model_name 在此处未定义,应使用 model_info.name + logger.warning(f"模型 '{model_info.name}' 发生网络连接错误,增加惩罚值: {penalty_increment}") + elif isinstance(e, ReqAbortException): + # 请求被中止,可能是服务器端原因或服务不稳定,增加较高惩罚 + penalty_increment = CRITICAL_PENALTY_MULTIPLIER + # 修复 NameError: model_name 在此处未定义,应使用 model_info.name + logger.warning(f"模型 '{model_info.name}' 请求被中止,增加惩罚值: {penalty_increment}") + elif isinstance(e, RespNotOkException): + if e.status_code >= 500: + # 服务器错误 (5xx) 表明服务器端问题,应显著增加惩罚 + penalty_increment = CRITICAL_PENALTY_MULTIPLIER + logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}") + elif e.status_code == 429: + # 请求过于频繁,是暂时性问题,但仍需惩罚,此处使用默认基础值 + # penalty_increment = 2 # 可以选择一个中间值,例如2,表示比普通错误重,但比关键错误轻 + logger.warning(f"模型 '{model_name}' 请求过于频繁 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") + else: + # 其他客户端错误 (4xx)。通常不重试,_handle_resp_not_ok 会处理。 + # 如果 _handle_resp_not_ok 返回 retry_interval, 则进入这里的 exception 块。 + logger.warning(f"模型 '{model_name}' 发生非致命的响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") + else: + # 其他未捕获的异常,增加基础惩罚 + logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") + + self.model_usage[model_info.name] = (total_tokens, penalty + penalty_increment, usage_penalty) + # --- 结束增强 --- + # 移除冗余的、错误的惩罚值更新行,保留上面正确的动态惩罚更新 + # self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) + + wait_interval, compressed_messages = self._default_exception_handler( + e, + self.task_name, + model_info=model_info, + api_provider=api_provider, + remain_try=retry_remain, + retry_interval=api_provider.retry_interval, + messages=(message_list, compressed_messages is not None) if message_list else None, + ) + + if wait_interval == -1: + retry_remain = 0 # 不再重试 + elif wait_interval > 0: + logger.info(f"等待 {wait_interval} 秒后重试...") + await asyncio.sleep(wait_interval) + finally: + # 放在finally防止死循环 + retry_remain -= 1 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") + raise RuntimeError("请求失败,已达到最大重试次数") + + def _default_exception_handler( + self, + e: Exception, + task_name: str, + model_info: ModelInfo, + api_provider: APIProvider, + remain_try: int, + retry_interval: int = 10, + messages: Tuple[List[Message], bool] | None = None, + ) -> Tuple[int, List[Message] | None]: + """ + 默认异常处理函数 + Args: + e (Exception): 异常对象 + task_name (str): 任务名称 + model_info (ModelInfo): 模型信息 + api_provider (APIProvider): API提供商 + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) + Returns: + (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + model_name = model_info.name if model_info else "unknown" + + if isinstance(e, NetworkConnectionError): # 网络连接错误 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确", + ) + elif isinstance(e, ReqAbortException): + logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") + return -1, None # 不再重试请求该模型 + elif isinstance(e, RespNotOkException): + return self._handle_resp_not_ok( + e, + task_name, + model_info, + api_provider, + remain_try, + retry_interval, + messages, + ) + elif isinstance(e, RespParseException): + # 响应解析错误 + logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") + logger.debug(f"附加内容: {str(e.ext_info)}") + return -1, None # 不再重试请求该模型 + else: + logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") + return -1, None # 不再重试请求该模型 + + @staticmethod + def _check_retry( + remain_try: int, + retry_interval: int, + can_retry_msg: str, + cannot_retry_msg: str, + can_retry_callable: Callable | None = None, + **kwargs, + ) -> Tuple[int, List[Message] | None]: + """辅助函数:检查是否可以重试 + Args: + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + can_retry_msg (str): 可以重试时的提示信息 + cannot_retry_msg (str): 不可以重试时的提示信息 + can_retry_callable (Callable | None): 可以重试时调用的函数(如果有) + **kwargs: 其他参数 + + Returns: + (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + if remain_try > 0: + # 还有重试机会 + logger.warning(f"{can_retry_msg}") + if can_retry_callable is not None: + return retry_interval, can_retry_callable(**kwargs) + else: + return retry_interval, None + else: + # 达到最大重试次数 + logger.warning(f"{cannot_retry_msg}") + return -1, None # 不再重试请求该模型 + + def _handle_resp_not_ok( + self, + e: RespNotOkException, + task_name: str, + model_info: ModelInfo, + api_provider: APIProvider, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, + ): + model_name = model_info.name + """ + 处理响应错误异常 + Args: + e (RespNotOkException): 响应错误异常对象 + task_name (str): 任务名称 + model_info (ModelInfo): 模型信息 + api_provider (APIProvider): API提供商 + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) + Returns: + (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + # 响应错误 + if e.status_code in [400, 401, 402, 403, 404]: + model_name = model_info.name + # 客户端错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None # 不再重试请求该模型 + elif e.status_code == 413: + if messages and not messages[1]: + # 消息列表不为空且未压缩,尝试压缩消息 + return self._check_retry( + remain_try, + 0, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求", + can_retry_callable=compress_messages, + messages=messages[0], + ) + # 没有消息可压缩 + logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。") + return -1, None + elif e.status_code == 429: + # 请求过于频繁 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求", + ) + elif e.status_code >= 500: + # 服务器错误 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试", + ) + else: + # 未知错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None + + @staticmethod + def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: + # sourcery skip: extract-method + """构建工具选项列表""" + if not tools: + return None + tool_options: List[ToolOption] = [] + for tool in tools: + tool_legal = True + tool_options_builder = ToolOptionBuilder() + tool_options_builder.set_name(tool.get("name", "")) + tool_options_builder.set_description(tool.get("description", "")) + parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) + for param in parameters: + try: + assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" + assert isinstance(param[0], str), "参数名称必须是字符串" + assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" + assert isinstance(param[2], str), "参数描述必须是字符串" + assert isinstance(param[3], bool), "参数是否必填必须是布尔值" + assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" + tool_options_builder.add_param( + name=param[0], + param_type=param[1], + description=param[2], + required=param[3], + enum_values=param[4], + ) + except AssertionError as ae: + tool_legal = False + logger.error(f"{param[0]} 参数定义错误: {str(ae)}") + except Exception as e: + tool_legal = False + logger.error(f"构建工具参数失败: {str(e)}") + if tool_legal: + tool_options.append(tool_options_builder.build()) + return tool_options or None + + @staticmethod + def _extract_reasoning(content: str) -> Tuple[str, str]: + """CoT思维链提取,向后兼容""" + match = re.search(r"(?:)?(.*?)", content, re.DOTALL) + content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() + reasoning = match[1].strip() if match else "" + return content, reasoning + + def _apply_content_obfuscation(self, text: str, api_provider) -> str: + """根据API提供商配置对文本进行混淆处理""" + if not hasattr(api_provider, "enable_content_obfuscation") or not api_provider.enable_content_obfuscation: + logger.debug(f"API提供商 '{api_provider.name}' 未启用内容混淆") + return text + + intensity = getattr(api_provider, "obfuscation_intensity", 1) + logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") + + # 在开头加入过滤规则指令 + processed_text = self.noise_instruction + "\n\n" + text + logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}") + + # 添加随机乱码 + final_text = self._inject_random_noise(processed_text, intensity) + logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}") + + return final_text + + @staticmethod + def _inject_random_noise(text: str, intensity: int) -> str: + """在文本中注入随机乱码""" + import random + import string + + def generate_noise(length: int) -> str: + """生成指定长度的随机乱码字符""" + chars = ( + string.ascii_letters # a-z, A-Z + + string.digits # 0-9 + + "!@#$%^&*()_+-=[]{}|;:,.<>?" # 特殊符号 + + "一二三四五六七八九零壹贰叁" # 中文字符 + + "αβγδεζηθικλμνξοπρστυφχψω" # 希腊字母 + + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" # 数学符号 + ) + return "".join(random.choice(chars) for _ in range(length)) + + # 强度参数映射 + params = { + 1: {"probability": 15, "length": (3, 6)}, # 低强度:15%概率,3-6个字符 + 2: {"probability": 25, "length": (5, 10)}, # 中强度:25%概率,5-10个字符 + 3: {"probability": 35, "length": (8, 15)}, # 高强度:35%概率,8-15个字符 + } + + config = params.get(intensity, params[1]) + logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}") + + # 按词分割处理 + words = text.split() + result = [] + noise_count = 0 + + for word in words: + result.append(word) + # 根据概率插入乱码 + if random.randint(1, 100) <= config["probability"]: + noise_length = random.randint(*config["length"]) + noise = generate_noise(noise_length) + result.append(noise) + noise_count += 1 + + logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}") + return " ".join(result) From 3207b778c371ec9c60ca2fe040b35c0dc6aa10b8 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 20:26:19 +0800 Subject: [PATCH 23/41] =?UTF-8?q?refactor(llm):=20=E8=A7=A3=E5=86=B3?= =?UTF-8?q?=E5=90=88=E5=B9=B6=E5=86=B2=E7=AA=81=E5=B9=B6=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E9=80=BB=E8=BE=91=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 最近为解耦LLM请求逻辑而进行的重构引入了严重的合并冲突。 此提交通过移除引入的 `RequestExecutor` 和 `RequestStrategy` 等新组件,并恢复到之前的代码结构,从而解决了这些冲突。这有助于稳定开发分支并为后续重新审视重构方案做准备。 --- src/llm_models/utils_model.py | 193 ---------------------------------- 1 file changed, 193 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index ac318fbe4..c39ab8af9 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -8,16 +8,6 @@ from rich.traceback import install from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator from src.common.logger import get_logger -<<<<<<< HEAD -from src.config.api_ada_configs import TaskConfig, ModelInfo, UsageRecord -from .llm_utils import build_tool_options, normalize_image_format -from .model_selector import ModelSelector -from .payload_content.message import MessageBuilder -from .payload_content.tool_option import ToolCall -from .prompt_processor import PromptProcessor -from .request_strategy import RequestStrategy -from .utils import llm_usage_recorder -======= from src.config.config import model_config from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message @@ -28,7 +18,6 @@ from .utils import compress_messages, llm_usage_recorder from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException install(extra_lines=3) ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) logger = get_logger("model_utils") @@ -185,34 +174,6 @@ class LLMRequest: max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ -<<<<<<< HEAD - 为包含图像的多模态输入生成文本响应。 - - Args: - prompt (str): 文本提示。 - image_base64 (str): Base64编码的图像数据。 - image_format (str): 图像格式 (例如, "png", "jpeg")。 - temperature (Optional[float], optional): 控制生成文本的随机性。 Defaults to None. - max_tokens (Optional[int], optional): 生成响应的最大长度。 Defaults to None. - - Returns: - Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - - 清理后的响应内容。 - - 一个元组,包含思考过程、模型名称和工具调用列表。 - """ - start_time = time.time() - - # 步骤 1: 选择一个支持图像处理的模型 - model_info, api_provider, client = self.model_selector.select_model() - - # 步骤 2: 准备消息体 - # 预处理文本提示 - processed_prompt = self.prompt_processor.process_prompt(prompt, model_info, api_provider, self.task_name) - # 规范化图像格式 - normalized_format = normalize_image_format(image_format) - - # 使用MessageBuilder构建多模态消息 -======= 为图像生成响应 Args: prompt (str): 提示词 @@ -229,7 +190,6 @@ class LLMRequest: model_info, api_provider, client = self._select_model() # 请求体构建 ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) message_builder = MessageBuilder() message_builder.add_text_content(prompt) message_builder.add_image_content( @@ -239,16 +199,8 @@ class LLMRequest: ) messages = [message_builder.build()] -<<<<<<< HEAD - # 步骤 3: 执行请求 (图像请求通常不走复杂的故障转移策略,直接执行) - from .request_executor import RequestExecutor - executor = RequestExecutor( - task_name=self.task_name, - model_set=self.model_for_task, -======= # 请求并处理返回值 response = await self._execute_request( ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) api_provider=api_provider, client=client, request_type=RequestType.RESPONSE, @@ -257,14 +209,6 @@ class LLMRequest: temperature=temperature, max_tokens=max_tokens, ) -<<<<<<< HEAD - - # 步骤 4: 处理响应 - content, reasoning_content = self.prompt_processor.extract_reasoning(response.content or "") - tool_calls = response.tool_calls - - # 记录用量 -======= content = response.content or "" reasoning_content = response.reasoning_content or "" tool_calls = response.tool_calls @@ -272,7 +216,6 @@ class LLMRequest: if not reasoning_content and content: content, extracted_reasoning = self._extract_reasoning(content) reasoning_content = extracted_reasoning ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) if usage := response.usage: await llm_usage_recorder.record_usage_to_database( model_info=model_info, @@ -286,24 +229,6 @@ class LLMRequest: async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: """ -<<<<<<< HEAD - 将语音数据转换为文本(语音识别)。 - - Args: - voice_base64 (str): Base64编码的语音数据。 - - Returns: - Optional[str]: 识别出的文本内容,如果失败则返回None。 - """ - # 选择一个支持语音识别的模型 - model_info, api_provider, client = self.model_selector.select_model() - - from .request_executor import RequestExecutor - # 创建请求执行器 - executor = RequestExecutor( - task_name=self.task_name, - model_set=self.model_for_task, -======= 为语音生成响应 Args: voice_base64 (str): 语音的Base64编码字符串 @@ -315,19 +240,10 @@ class LLMRequest: # 请求并处理返回值 response = await self._execute_request( ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) api_provider=api_provider, client=client, request_type=RequestType.AUDIO, model_info=model_info, -<<<<<<< HEAD - model_selector=self.model_selector, - ) - # 执行语音转文本请求 - response = await executor.execute_request( - request_type="audio", -======= ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) audio_base64=voice_base64, ) return response.content or None @@ -341,35 +257,6 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ -<<<<<<< HEAD - 异步生成文本响应,支持并发和故障转移等高级策略。 - - Args: - prompt (str): 用户输入的提示。 - temperature (Optional[float], optional): 控制生成文本的随机性。 Defaults to None. - max_tokens (Optional[int], optional): 生成响应的最大长度。 Defaults to None. - tools (Optional[List[Dict[str, Any]]], optional): 可供模型调用的工具列表。 Defaults to None. - raise_when_empty (bool, optional): 如果最终响应为空,是否抛出异常。 Defaults to True. - - Returns: - Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: - - 清理后的响应内容。 - - 一个元组,包含思考过程、最终使用的模型名称和工具调用列表。 - """ - start_time = time.time() - - # 步骤 1: 准备基础请求载荷 - tool_built = build_tool_options(tools) - base_payload = { - "prompt": prompt, - "tool_options": tool_built, - "temperature": temperature, - "max_tokens": max_tokens, - "prompt_processor": self.prompt_processor, - } - - # 步骤 2: 根据配置选择执行策略 (并发或单次带故障转移) -======= 异步生成响应,支持并发请求 Args: prompt (str): 提示词 @@ -381,7 +268,6 @@ class LLMRequest: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ # 检查是否需要并发请求 ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) concurrency_count = getattr(self.model_for_task, "concurrency_count", 1) if concurrency_count <= 1: @@ -395,49 +281,6 @@ class LLMRequest: content, (reasoning_content, model_name, tool_calls) = await execute_concurrently( self._execute_single_request, concurrency_count, -<<<<<<< HEAD - base_payload, - raise_when_empty=False, # 在并发模式下,单个任务失败不应立即抛出异常 - ) - - # 步骤 3: 处理最终结果 - content = result.get("content", "") - reasoning_content = result.get("reasoning_content", "") - model_name = result.get("model_name", "unknown") - tool_calls = result.get("tool_calls") - - # 步骤 4: 记录用量 (从策略返回的结果中获取最终使用的模型信息和用量) - final_model_info = result.get("model_info") - usage = result.get("usage") - - if final_model_info and usage: - await self._record_usage(final_model_info, usage, time.time() - start_time) - - return content, (reasoning_content, model_name, tool_calls) - - async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """ - 获取给定文本的嵌入向量 (Embedding)。 - - Args: - embedding_input (str): 需要进行嵌入的文本。 - - Returns: - Tuple[List[float], str]: 嵌入向量列表和所使用的模型名称。 - - Raises: - RuntimeError: 如果获取embedding失败。 - """ - start_time = time.time() - # 选择一个支持embedding的模型 - model_info, api_provider, client = self.model_selector.select_model() - - from .request_executor import RequestExecutor - # 创建请求执行器 - executor = RequestExecutor( - task_name=self.task_name, - model_set=self.model_for_task, -======= prompt, temperature, max_tokens, @@ -595,19 +438,10 @@ class LLMRequest: # 请求并处理返回值 response = await self._execute_request( ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) api_provider=api_provider, client=client, request_type=RequestType.EMBEDDING, model_info=model_info, -<<<<<<< HEAD - model_selector=self.model_selector, - ) - # 执行embedding请求 - response = await executor.execute_request( - request_type="embedding", -======= ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) embedding_input=embedding_input, ) @@ -625,32 +459,6 @@ class LLMRequest: if not embedding: raise RuntimeError("获取embedding失败") -<<<<<<< HEAD - - # 记录用量 - if usage := response.usage: - await self._record_usage(model_info, usage, time.time() - start_time, "/embeddings") - - return embedding, model_info.name - - async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord, time_cost: float, endpoint: str = "/chat/completions"): - """ - 记录模型API的调用用量到数据库。 - - Args: - model_info (ModelInfo): 使用的模型信息。 - usage (UsageRecord): 包含token用量信息的对象。 - time_cost (float): 本次请求的总耗时(秒)。 - endpoint (str, optional): 请求的API端点。 Defaults to "/chat/completions". - """ - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", # 当前所有请求都以系统用户身份记录 - time_cost=time_cost, - request_type=self.request_type, - endpoint=endpoint, -======= return embedding, model_info.name @@ -680,7 +488,6 @@ class LLMRequest: least_used_model_name = min( candidate_models_usage, key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) ) # --- 动态故障转移的核心逻辑 --- From a4945d1ca26867e3a22b3610ff4a8f739d95ff15 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 20:38:04 +0800 Subject: [PATCH 24/41] =?UTF-8?q?refactor(llm):=20=E5=B0=86LLM=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E9=80=BB=E8=BE=91=E8=A7=A3=E8=80=A6=E5=88=B0=E4=B8=93?= =?UTF-8?q?=E7=94=A8=E7=BB=84=E4=BB=B6=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将庞大的 `LLMRequest` 类重构,将其核心职责分解到四个独立的、遵循单一职责原则的辅助类中,以提高代码的模块化、可读性和可维护性。 - `_ModelSelector`: 专门负责模型选择、负载均衡和基于失败历史的动态惩罚策略。 - `_PromptProcessor`: 封装所有与提示词和响应内容的预处理及后处理逻辑,包括内容混淆、反截断信令处理和思维链提取。 - `_RequestExecutor`: 负责执行底层的API请求,并处理网络层面的重试逻辑。 - `_RequestStrategy`: 实现高级请求策略,如在多个模型间的故障转移(failover)和空回复/截断的内部重试。 `LLMRequest` 类现在作为外观(Facade),协调这些新组件来完成请求,使得整体架构更加清晰和易于扩展。 --- src/llm_models/utils_model.py | 1304 +++++++++++++++------------------ 1 file changed, 574 insertions(+), 730 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index c39ab8af9..acb7130b6 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -2,6 +2,7 @@ import re import asyncio import time import random +import string from enum import Enum from rich.traceback import install @@ -13,7 +14,7 @@ from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message from .payload_content.resp_format import RespFormat from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType -from .model_client.base_client import BaseClient, APIResponse, client_registry +from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord from .utils import compress_messages, llm_usage_recorder from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException @@ -21,18 +22,9 @@ install(extra_lines=3) logger = get_logger("model_utils") -# 常见Error Code Mapping -error_code_mapping = { - 400: "参数不正确", - 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", - 402: "账号余额不足", - 403: "需要实名,或余额不足", - 404: "Not Found", - 429: "请求过于频繁,请稍后再试", - 500: "服务器内部故障", - 503: "服务器负载过高", -} - +# ============================================================================== +# Standalone Utility Functions +# ============================================================================== def _normalize_image_format(image_format: str) -> str: """ @@ -45,35 +37,17 @@ def _normalize_image_format(image_format: str) -> str: str: 标准化后的图片格式 """ format_mapping = { - "jpg": "jpeg", - "JPG": "jpeg", - "JPEG": "jpeg", - "jpeg": "jpeg", - "png": "png", - "PNG": "png", - "webp": "webp", - "WEBP": "webp", - "gif": "gif", - "GIF": "gif", - "heic": "heic", - "HEIC": "heic", - "heif": "heif", - "HEIF": "heif", + "jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg", + "png": "png", "PNG": "png", + "webp": "webp", "WEBP": "webp", + "gif": "gif", "GIF": "gif", + "heic": "heic", "HEIC": "heic", + "heif": "heif", "HEIF": "heif", } - normalized = format_mapping.get(image_format, image_format.lower()) logger.debug(f"图片格式标准化: {image_format} -> {normalized}") return normalized - -class RequestType(Enum): - """请求类型枚举""" - - RESPONSE = "response" - EMBEDDING = "embedding" - AUDIO = "audio" - - async def execute_concurrently( coro_callable: Callable[..., Coroutine[Any, Any, Any]], concurrency_count: int, @@ -97,7 +71,6 @@ async def execute_concurrently( """ logger.info(f"启用并发请求模式,并发数: {concurrency_count}") tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)] - results = await asyncio.gather(*tasks, return_exceptions=True) successful_results = [res for res in results if not isinstance(res, Exception)] @@ -110,41 +83,107 @@ async def execute_concurrently( for i, res in enumerate(results): if isinstance(res, Exception): logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") - + first_exception = next((res for res in results if isinstance(res, Exception)), None) if first_exception: raise first_exception - raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息") +class RequestType(Enum): + """请求类型枚举""" + RESPONSE = "response" + EMBEDDING = "embedding" + AUDIO = "audio" -class LLMRequest: -<<<<<<< HEAD - """ - LLM请求协调器。 - 封装了模型选择、Prompt处理、请求执行和高级策略(如故障转移、并发)的完整流程。 - 为上层业务逻辑提供统一的、简化的接口来与大语言模型交互。 - """ -======= - """LLM请求类""" ->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中) +# ============================================================================== +# Helper Classes for LLMRequest Refactoring +# ============================================================================== - def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: +class _ModelSelector: + """负责模型选择、负载均衡和动态故障切换的策略。""" + + CRITICAL_PENALTY_MULTIPLIER = 5 + DEFAULT_PENALTY_INCREMENT = 1 + + def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]): + self.model_list = model_list + self.model_usage = model_usage + + def select_best_available_model( + self, failed_models_in_this_request: set, request_type: str + ) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: """ - 初始化LLM请求协调器。 + 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 Args: - model_set (TaskConfig): 特定任务的模型配置集合。 - request_type (str, optional): 请求类型或任务名称,用于日志和用量记录。 Defaults to "". - """ - self.task_name = request_type - self.model_for_task = model_set - self.request_type = request_type - self.model_usage: Dict[str, Tuple[int, int, int]] = { - model: (0, 0, 0) for model in self.model_for_task.model_list - } - """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" + failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。 + request_type (str): 请求类型,用于确定是否强制创建新客户端。 + Returns: + Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: 选定的模型详细信息,如果无可用模型则返回 None。 + """ + candidate_models_usage = { + model_name: usage_data + for model_name, usage_data in self.model_usage.items() + if model_name not in failed_models_in_this_request + } + + if not candidate_models_usage: + logger.warning("没有可用的模型供当前请求选择。") + return None + + # 根据公式查找分数最低的模型,该公式综合了总token数、模型失败惩罚值和使用频率惩罚值。 + # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + least_used_model_name = min( + candidate_models_usage, + key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, + ) + + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 + force_new_client = request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + + logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") + # 增加所选模型的请求使用惩罚值,以实现动态负载均衡。 + self.update_usage_penalty(model_info.name, increase=True) + return model_info, api_provider, client + + def update_usage_penalty(self, model_name: str, increase: bool): + """更新模型的使用惩罚值,用于负载均衡。""" + total_tokens, penalty, usage_penalty = self.model_usage[model_name] + adjustment = 1 if increase else -1 + self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment) + + def update_failure_penalty(self, model_name: str, e: Exception): + """ + 根据异常类型动态调整模型的失败惩罚值。 + 关键错误(如网络连接、服务器错误)会获得更高的惩罚, + 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。 + """ + total_tokens, penalty, usage_penalty = self.model_usage[model_name] + penalty_increment = self.DEFAULT_PENALTY_INCREMENT + + if isinstance(e, (NetworkConnectionError, ReqAbortException)): + penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER + logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}") + elif isinstance(e, RespNotOkException): + if e.status_code >= 500: + penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER + logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}") + else: + logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") + else: + logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") + + self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) + + +class _PromptProcessor: + """封装所有与提示词和响应内容的预处理和后处理逻辑。""" + + def __init__(self): # 内容混淆过滤指令 self.noise_instruction = """**【核心过滤规则】** 在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 @@ -152,7 +191,6 @@ class LLMRequest: 你的任务是【完全并彻底地忽略】这些随机字符串。 **【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。 """ - # 反截断指令 self.end_marker = "###MAI_RESPONSE_END###" self.anti_truncation_instruction = f""" @@ -165,6 +203,372 @@ class LLMRequest: 这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 """ + def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str: + """为请求准备最终的提示词,应用内容混淆和反截断指令。""" + processed_prompt = self._apply_content_obfuscation(prompt, api_provider) + if getattr(model_info, "use_anti_truncation", False): + processed_prompt += self.anti_truncation_instruction + logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") + return processed_prompt + + def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]: + """ + 处理响应内容,提取思维链并检查截断。 + + Returns: + Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断) + """ + content, reasoning = self._extract_reasoning(content) + is_truncated = False + if use_anti_truncation: + if content.endswith(self.end_marker): + content = content[: -len(self.end_marker)].strip() + else: + is_truncated = True + return content, reasoning, is_truncated + + def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str: + """根据API提供商配置对文本进行混淆处理。""" + if not getattr(api_provider, "enable_content_obfuscation", False): + return text + + intensity = getattr(api_provider, "obfuscation_intensity", 1) + logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") + processed_text = self.noise_instruction + "\n\n" + text + return self._inject_random_noise(processed_text, intensity) + + @staticmethod + def _inject_random_noise(text: str, intensity: int) -> str: + """在文本中注入随机乱码。""" + params = { + 1: {"probability": 15, "length": (3, 6)}, + 2: {"probability": 25, "length": (5, 10)}, + 3: {"probability": 35, "length": (8, 15)}, + } + config = params.get(intensity, params[1]) + words = text.split() + result = [] + for word in words: + result.append(word) + if random.randint(1, 100) <= config["probability"]: + noise_length = random.randint(*config["length"]) + chars = string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?" + noise = "".join(random.choice(chars) for _ in range(noise_length)) + result.append(noise) + return " ".join(result) + + @staticmethod + def _extract_reasoning(content: str) -> Tuple[str, str]: + """ + 从模型返回的完整内容中提取被...标签包裹的思考过程, + 并返回清理后的内容和思考过程。 + + Args: + content (str): 模型返回的原始字符串。 + + Returns: + Tuple[str, str]: + - 清理后的内容(移除了标签及其内容)。 + - 提取出的思考过程文本(如果没有则为空字符串)。 + """ + # 使用正则表达式精确查找 ... 标签及其内容 + think_pattern = re.compile(r"(.*?)\s*", re.DOTALL) + match = think_pattern.search(content) + + if match: + # 提取思考过程 + reasoning = match.group(1).strip() + # 从原始内容中移除匹配到的整个部分(包括标签和后面的空白) + clean_content = think_pattern.sub("", content, count=1).strip() + else: + reasoning = "" + clean_content = content.strip() + + return clean_content, reasoning + + +class _RequestExecutor: + """负责执行实际的API请求,包含重试逻辑和底层异常处理。""" + + def __init__(self, model_selector: _ModelSelector, task_name: str): + self.model_selector = model_selector + self.task_name = task_name + + async def execute_request( + self, + api_provider: APIProvider, + client: BaseClient, + request_type: RequestType, + model_info: ModelInfo, + **kwargs, + ) -> APIResponse: + """实际执行请求的方法,包含了重试和异常处理逻辑。""" + retry_remain = api_provider.max_retry + compressed_messages: Optional[List[Message]] = None + + while retry_remain > 0: + try: + message_list = kwargs.get("message_list") + current_messages = compressed_messages or message_list + + if request_type == RequestType.RESPONSE: + assert current_messages is not None, "message_list cannot be None for response requests" + return await client.get_response(model_info=model_info, message_list=current_messages, **kwargs) + elif request_type == RequestType.EMBEDDING: + return await client.get_embedding(model_info=model_info, **kwargs) + elif request_type == RequestType.AUDIO: + return await client.get_audio_transcriptions(model_info=model_info, **kwargs) + + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + self.model_selector.update_failure_penalty(model_info.name, e) + + wait_interval, new_compressed_messages = self._handle_exception( + e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None) + ) + if new_compressed_messages: + compressed_messages = new_compressed_messages + + if wait_interval == -1: + raise e # 如果不再重试,则传播异常 + elif wait_interval > 0: + await asyncio.sleep(wait_interval) + finally: + retry_remain -= 1 + + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") + raise RuntimeError("请求失败,已达到最大重试次数") + + def _handle_exception( + self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info + ) -> Tuple[int, Optional[List[Message]]]: + """ + 默认异常处理函数,决定是否重试。 + + Returns: + (等待间隔(-1表示不再重试), 新的消息列表(适用于压缩消息)) + """ + model_name = model_info.name + retry_interval = api_provider.retry_interval + + if isinstance(e, (NetworkConnectionError, ReqAbortException)): + return self._check_retry(remain_try, retry_interval, "连接异常", model_name) + elif isinstance(e, RespNotOkException): + return self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) + elif isinstance(e, RespParseException): + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}") + return -1, None + else: + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}") + return -1, None + + def _handle_resp_not_ok( + self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info + ) -> Tuple[int, Optional[List[Message]]]: + """处理非200的HTTP响应异常。""" + model_name = model_info.name + if e.status_code in [400, 401, 402, 403, 404]: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。") + return -1, None + elif e.status_code == 413: + messages, is_compressed = messages_info + if messages and not is_compressed: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试。") + return 0, compress_messages(messages) + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大且无法压缩,放弃请求。") + return -1, None + elif e.status_code == 429 or e.status_code >= 500: + reason = "请求过于频繁" if e.status_code == 429 else "服务器错误" + return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name) + else: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}") + return -1, None + + def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]: + """辅助函数:检查是否可以重试。""" + if remain_try > 1: # 剩余次数大于1才重试 + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。") + return interval, None + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。") + return -1, None + + +class _RequestStrategy: + """ + 封装高级请求策略,如故障转移。 + 此类协调模型选择、提示处理和请求执行,以实现健壮的请求处理, + 即使在单个模型或API端点失败的情况下也能正常工作。 + """ + + def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str): + """ + 初始化请求策略。 + + Args: + model_selector (_ModelSelector): 模型选择器实例。 + prompt_processor (_PromptProcessor): 提示处理器实例。 + executor (_RequestExecutor): 请求执行器实例。 + model_list (List[str]): 可用模型列表。 + task_name (str): 当前任务的名称。 + """ + self.model_selector = model_selector + self.prompt_processor = prompt_processor + self.executor = executor + self.model_list = model_list + self.task_name = task_name + + async def execute_with_failover( + self, + request_type: RequestType, + raise_when_empty: bool = True, + **kwargs, + ) -> Tuple[APIResponse, ModelInfo]: + """ + 执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 + """ + failed_models_in_this_request = set() + max_attempts = len(self.model_list) + last_exception: Optional[Exception] = None + + for attempt in range(max_attempts): + selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value)) + if selection_result is None: + logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") + break + + model_info, api_provider, client = selection_result + logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...") + + try: + # 准备请求参数 + request_kwargs = kwargs.copy() + if request_type == RequestType.RESPONSE and "prompt" in request_kwargs: + prompt = request_kwargs.pop("prompt") + processed_prompt = self.prompt_processor.prepare_prompt( + prompt, model_info, api_provider, self.task_name + ) + message = MessageBuilder().add_text_content(processed_prompt).build() + request_kwargs["message_list"] = [message] + + # 合并模型特定的额外参数 + if model_info.extra_params: + request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})} + + response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs) + + # 成功,立即返回 + logger.debug(f"模型 '{model_info.name}' 成功生成了回复。") + self.model_selector.update_usage_penalty(model_info.name, increase=False) + return response, model_info + + except Exception as e: + logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") + failed_models_in_this_request.add(model_info.name) + last_exception = e + # 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率 + + logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") + if raise_when_empty: + if last_exception: + raise RuntimeError("所有模型均未能生成响应。") from last_exception + raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") + + # 如果不抛出异常,返回一个备用响应 + fallback_model_info = model_config.get_model_info(self.model_list[0]) + return APIResponse(content="所有模型都请求失败"), fallback_model_info + + + async def _try_model_request( + self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs + ) -> APIResponse: + """ + 为单个模型尝试请求,包含空回复/截断的内部重试逻辑。 + 如果模型返回空回复或响应被截断,此方法将自动重试请求,直到达到最大重试次数。 + + Args: + model_info (ModelInfo): 要使用的模型信息。 + api_provider (APIProvider): API提供商信息。 + client (BaseClient): API客户端实例。 + request_type (RequestType): 请求类型。 + **kwargs: 传递给执行器的请求参数。 + + Returns: + APIResponse: 成功的API响应。 + + Raises: + RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。 + """ + max_empty_retry = api_provider.max_retry + + for i in range(max_empty_retry + 1): + response = await self.executor.execute_request( + api_provider, client, request_type, model_info, **kwargs + ) + + if request_type != RequestType.RESPONSE: + return response # 对于非响应类型,直接返回 + + # --- 响应内容处理和空回复/截断检查 --- + content = response.content or "" + use_anti_truncation = getattr(model_info, "use_anti_truncation", False) + processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation) + + # 更新响应对象 + response.content = processed_content + response.reasoning_content = response.reasoning_content or reasoning + + is_empty_reply = not response.tool_calls and not (response.content and response.content.strip()) + + if not is_empty_reply and not is_truncated: + return response # 成功获取有效响应 + + if i < max_empty_retry: + reason = "空回复" if is_empty_reply else "截断" + logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...") + if api_provider.retry_interval > 0: + await asyncio.sleep(api_provider.retry_interval) + else: + reason = "空回复" if is_empty_reply else "截断" + logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。") + raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。") + + raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里 + + +# ============================================================================== +# Main Facade Class +# ============================================================================== + +class LLMRequest: + """ + LLM请求协调器。 + 封装了模型选择、Prompt处理、请求执行和高级策略(如故障转移、并发)的完整流程。 + 为上层业务逻辑提供统一的、简化的接口来与大语言模型交互。 + """ + + def __init__(self, model_set: TaskConfig, request_type: str = ""): + """ + 初始化LLM请求协调器。 + + Args: + model_set (TaskConfig): 特定任务的模型配置集合。 + request_type (str, optional): 请求类型或任务名称,用于日志和用量记录。 Defaults to "". + """ + self.task_name = request_type + self.model_for_task = model_set + self.model_usage: Dict[str, Tuple[int, int, int]] = { + model: (0, 0, 0) for model in self.model_for_task.model_list + } + """模型使用量记录,(total_tokens, penalty, usage_penalty)""" + + # 初始化辅助类 + self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage) + self._prompt_processor = _PromptProcessor() + self._executor = _RequestExecutor(self._model_selector, self.task_name) + self._strategy = _RequestStrategy( + self._model_selector, self._prompt_processor, self._executor, self.model_for_task.model_list, self.task_name + ) + async def generate_response_for_image( self, prompt: str, @@ -174,77 +578,57 @@ class LLMRequest: max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 为图像生成响应 + 为图像生成响应。 + Args: prompt (str): 提示词 image_base64 (str): 图像的Base64编码字符串 image_format (str): 图像格式(如 'png', 'jpeg' 等) + Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ - # 标准化图片格式以确保API兼容性 - normalized_format = _normalize_image_format(image_format) - - # 模型选择 start_time = time.time() - model_info, api_provider, client = self._select_model() - - # 请求体构建 - message_builder = MessageBuilder() - message_builder.add_text_content(prompt) - message_builder.add_image_content( + + # 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行 + selection_result = self._model_selector.select_best_available_model(set(), "response") + if not selection_result: + raise RuntimeError("无法为图像响应选择可用模型。") + model_info, api_provider, client = selection_result + + normalized_format = _normalize_image_format(image_format) + message = MessageBuilder().add_text_content(prompt).add_image_content( image_base64=image_base64, image_format=normalized_format, support_formats=client.get_support_image_formats(), - ) - messages = [message_builder.build()] + ).build() - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, + response = await self._executor.execute_request( + api_provider, client, RequestType.RESPONSE, model_info, + message_list=[message], temperature=temperature, max_tokens=max_tokens, ) - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - # 从内容中提取标签的推理内容(向后兼容) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", - time_cost=time.time() - start_time, - request_type=self.request_type, - endpoint="/chat/completions", - ) - return content, (reasoning_content, model_info.name, tool_calls) + + self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") + content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False) + reasoning = response.reasoning_content or reasoning + + return content, (reasoning, model_info.name, response.tool_calls) async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: """ - 为语音生成响应 - Args: - voice_base64 (str): 语音的Base64编码字符串 - Returns: - (Optional[str]): 生成的文本描述或None - """ - # 模型选择 - model_info, api_provider, client = self._select_model() + 为语音生成响应(语音转文字)。 + 使用故障转移策略来确保即使主模型失败也能获得结果。 - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.AUDIO, - model_info=model_info, - audio_base64=voice_base64, + Args: + voice_base64 (str): 语音的Base64编码字符串。 + + Returns: + Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None。 + """ + response, _ = await self._strategy.execute_with_failover( + RequestType.AUDIO, audio_base64=voice_base64 ) return response.content or None @@ -257,44 +641,36 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 异步生成响应,支持并发请求 + 异步生成响应,支持并发请求。 + Args: prompt (str): 提示词 temperature (float, optional): 温度参数 max_tokens (int, optional): 最大token数 tools: 工具配置 - raise_when_empty: 是否在空回复时抛出异常 + raise_when_empty (bool): 是否在空回复时抛出异常 + Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ - # 检查是否需要并发请求 concurrency_count = getattr(self.model_for_task, "concurrency_count", 1) if concurrency_count <= 1: - # 单次请求 - return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty) - - # 并发请求 + return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty) + try: - # 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False, - # 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理 - content, (reasoning_content, model_name, tool_calls) = await execute_concurrently( - self._execute_single_request, + return await execute_concurrently( + self._execute_single_text_request, concurrency_count, - prompt, - temperature, - max_tokens, - tools, - raise_when_empty=False, + prompt, temperature, max_tokens, tools, raise_when_empty=False ) - return content, (reasoning_content, model_name, tool_calls) except Exception as e: logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}") if raise_when_empty: raise e return "所有并发请求都失败了", ("", "unknown", None) - async def _execute_single_request( + async def _execute_single_text_request( self, prompt: str, temperature: Optional[float] = None, @@ -303,633 +679,101 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 + 执行单次文本生成请求的内部方法。 + 这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期, + 包括工具构建、故障转移执行和用量记录。 + + Args: + prompt (str): 用户的提示。 + temperature (Optional[float]): 生成温度。 + max_tokens (Optional[int]): 最大生成令牌数。 + tools (Optional[List[Dict[str, Any]]]): 可用工具列表。 + raise_when_empty (bool): 如果响应为空是否引发异常。 + + Returns: + Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + (响应内容, (推理过程, 模型名称, 工具调用)) """ - failed_models_in_this_request = set() - # 迭代次数等于模型总数,以确保每个模型在当前请求中最多只尝试一次 - max_attempts = len(self.model_for_task.model_list) - last_exception: Optional[Exception] = None + start_time = time.time() + tool_options = self._build_tool_options(tools) - for attempt in range(max_attempts): - # 根据负载均衡和当前故障选择最佳可用模型 - model_selection_result = self._select_best_available_model(failed_models_in_this_request) + response, model_info = await self._strategy.execute_with_failover( + RequestType.RESPONSE, + raise_when_empty=raise_when_empty, + prompt=prompt, # 传递原始prompt,由strategy处理 + tool_options=tool_options, + temperature=self.model_for_task.temperature if temperature is None else temperature, + max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, + ) - if model_selection_result is None: - logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") - break # 没有更多模型可供尝试 + self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") - model_info, api_provider, client = model_selection_result - model_name = model_info.name - logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...") + if not response.content and not response.tool_calls: + if raise_when_empty: + raise RuntimeError("所选模型生成了空回复。") + response.content = "生成的响应为空" - start_time = time.time() - - try: - # --- 为当前模型尝试进行设置 --- - # 检查是否为该模型启用反截断 - use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_prompt = prompt - if use_anti_truncation: - processed_prompt += self.anti_truncation_instruction - logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") - - processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) - - message_builder = MessageBuilder() - message_builder.add_text_content(processed_prompt) - messages = [message_builder.build()] - tool_built = self._build_tool_options(tools) - - # --- 当前选定模型内的空回复/截断重试逻辑 --- - empty_retry_count = 0 - max_empty_retry = api_provider.max_retry - empty_retry_interval = api_provider.retry_interval - - while empty_retry_count <= max_empty_retry: - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - tool_options=tool_built, - temperature=temperature, - max_tokens=max_tokens, - ) - - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - # 向后兼容 标签(如果 reasoning_content 为空) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - is_empty_reply = not tool_calls and (not content or content.strip() == "") - is_truncated = False - if use_anti_truncation: - if content.endswith(self.end_marker): - content = content[: -len(self.end_marker)].strip() - else: - is_truncated = True - - if is_empty_reply or is_truncated: - empty_retry_count += 1 - if empty_retry_count <= max_empty_retry: - reason = "空回复" if is_empty_reply else "截断" - logger.warning( - f"模型 '{model_name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..." - ) - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) - continue # 使用当前模型重试 - else: - reason = "空回复" if is_empty_reply else "截断" - logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。将此模型标记为当前请求失败。") - raise RuntimeError(f"模型 '{model_name}' 已达到空回复/截断的最大内部重试次数。") - - # --- 从当前模型获取成功响应 --- - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - time_cost=time.time() - start_time, - user_id="system", - request_type=self.request_type, - endpoint="/chat/completions", - ) - - # 处理成功执行后响应仍然为空的情况 - if not content and not tool_calls: - if raise_when_empty: - raise RuntimeError("所选模型生成了空回复。") - content = "生成的响应为空" # Fallback message - - logger.debug(f"模型 '{model_name}' 成功生成了回复。") - return content, (reasoning_content, model_name, tool_calls) # 成功,立即返回 - - # --- 当前模型尝试过程中的异常处理 --- - except Exception as e: # 捕获当前模型尝试过程中的所有异常 - # 修复 NameError: model_name 在异常处理块中未定义,应使用 model_info.name - logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") - failed_models_in_this_request.add(model_info.name) - last_exception = e # 存储异常以供最终报告 - # 继续循环以尝试下一个可用模型 - - # 如果循环结束未能返回,则表示当前请求的所有模型都已失败 - logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") - if raise_when_empty: - if last_exception: - raise RuntimeError("所有模型均未能生成响应。") from last_exception - raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") - return "所有模型都请求失败", ("", "unknown", None) + return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """获取嵌入向量 + """ + 获取嵌入向量。 + Args: embedding_input (str): 获取嵌入的目标 + Returns: (Tuple[List[float], str]): (嵌入向量,使用的模型名称) """ - # 无需构建消息体,直接使用输入文本 start_time = time.time() - model_info, api_provider, client = self._select_model() - - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.EMBEDDING, - model_info=model_info, - embedding_input=embedding_input, + response, model_info = await self._strategy.execute_with_failover( + RequestType.EMBEDDING, + embedding_input=embedding_input ) + + self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings") + + if not response.embedding: + raise RuntimeError("获取embedding失败") + + return response.embedding, model_info.name - embedding = response.embedding - - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( + def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): + """异步记录用量到数据库。""" + if usage: + # 更新内存中的token计数 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty) + + asyncio.create_task(llm_usage_recorder.record_usage_to_database( model_info=model_info, - time_cost=time.time() - start_time, model_usage=usage, user_id="system", - request_type=self.request_type, - endpoint="/embeddings", - ) - - if not embedding: - raise RuntimeError("获取embedding失败") - - return embedding, model_info.name - - def _select_best_available_model(self, failed_models_in_this_request: set) -> Tuple[ModelInfo, APIProvider, BaseClient] | None: - """ - 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 - - 参数: - failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。 - - 返回: - Tuple[ModelInfo, APIProvider, BaseClient] | None: 选定的模型详细信息,如果无可用模型则返回 None。 - """ - candidate_models_usage = {} - # 过滤掉当前请求中已失败的模型 - for model_name, usage_data in self.model_usage.items(): - if model_name not in failed_models_in_this_request: - candidate_models_usage[model_name] = usage_data - - if not candidate_models_usage: - logger.warning("没有可用的模型供当前请求选择。") - return None - - # 根据现有公式查找分数最低的模型,该公式综合了总token数、模型惩罚值和使用频率惩罚值。 - # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 - # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。 - least_used_model_name = min( - candidate_models_usage, - key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, - ) - - # --- 动态故障转移的核心逻辑 --- - # _execute_single_request 中的循环会多次调用此函数。 - # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数, - # 此时由于失败模型已被标记,且其惩罚值可能已在 _execute_request 中增加, - # _select_best_available_model 会自动选择一个得分更低(即更可用)的模型。 - # 这种机制实现了动态的、基于当前系统状态的故障转移。 - - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - # 对于嵌入任务,如果需要,强制创建新的客户端实例(从原始 _select_model 复制) - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - - logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") - - # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。 - # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) - - return model_info, api_provider, client - - def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: - """ - 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 - """ - for model_name in self.model_for_task.model_list: - if model_name in failed_models: - continue - - model_info = model_config.get_model_info(model_name) - api_provider = model_config.get_provider(model_info.api_provider) - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - - yield model_info, api_provider, client - - def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: - """ - 根据总tokens和惩罚值选择的模型 (负载均衡) - """ - least_used_model_name = min( - self.model_usage, - key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, - ) - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - logger.debug(f"选择请求模型: {model_info.name}") - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 - return model_info, api_provider, client - - async def _execute_request( - self, - api_provider: APIProvider, - client: BaseClient, - request_type: RequestType, - model_info: ModelInfo, - message_list: List[Message] | None = None, - tool_options: list[ToolOption] | None = None, - response_format: RespFormat | None = None, - stream_response_handler: Optional[Callable] = None, - async_response_parser: Optional[Callable] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - embedding_input: str = "", - audio_base64: str = "", - ) -> APIResponse: - """ - 实际执行请求的方法 - - 包含了重试和异常处理逻辑 - """ - retry_remain = api_provider.max_retry - compressed_messages: Optional[List[Message]] = None - while retry_remain > 0: - try: - if request_type == RequestType.RESPONSE: - assert message_list is not None, "message_list cannot be None for response requests" - return await client.get_response( - model_info=model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, - temperature=self.model_for_task.temperature if temperature is None else temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.EMBEDDING: - assert embedding_input, "embedding_input cannot be empty for embedding requests" - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.AUDIO: - assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" - return await client.get_audio_transcriptions( - model_info=model_info, - audio_base64=audio_base64, - extra_params=model_info.extra_params, - ) - except Exception as e: - logger.debug(f"请求失败: {str(e)}") - # 处理异常 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - - # --- 增强动态故障转移的智能性 --- - # 根据异常类型和严重程度,动态调整模型的惩罚值。 - # 关键错误(如网络连接、服务器错误)会获得更高的惩罚, - # 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。 - CRITICAL_PENALTY_MULTIPLIER = 5 # 关键错误时的惩罚系数 - default_penalty_increment = 1 # 普通错误时的基础惩罚 - - penalty_increment = default_penalty_increment - - if isinstance(e, NetworkConnectionError): - # 网络连接问题表明模型服务器不稳定,增加较高惩罚 - penalty_increment = CRITICAL_PENALTY_MULTIPLIER - # 修复 NameError: model_name 在此处未定义,应使用 model_info.name - logger.warning(f"模型 '{model_info.name}' 发生网络连接错误,增加惩罚值: {penalty_increment}") - elif isinstance(e, ReqAbortException): - # 请求被中止,可能是服务器端原因或服务不稳定,增加较高惩罚 - penalty_increment = CRITICAL_PENALTY_MULTIPLIER - # 修复 NameError: model_name 在此处未定义,应使用 model_info.name - logger.warning(f"模型 '{model_info.name}' 请求被中止,增加惩罚值: {penalty_increment}") - elif isinstance(e, RespNotOkException): - if e.status_code >= 500: - # 服务器错误 (5xx) 表明服务器端问题,应显著增加惩罚 - penalty_increment = CRITICAL_PENALTY_MULTIPLIER - logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}") - elif e.status_code == 429: - # 请求过于频繁,是暂时性问题,但仍需惩罚,此处使用默认基础值 - # penalty_increment = 2 # 可以选择一个中间值,例如2,表示比普通错误重,但比关键错误轻 - logger.warning(f"模型 '{model_name}' 请求过于频繁 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") - else: - # 其他客户端错误 (4xx)。通常不重试,_handle_resp_not_ok 会处理。 - # 如果 _handle_resp_not_ok 返回 retry_interval, 则进入这里的 exception 块。 - logger.warning(f"模型 '{model_name}' 发生非致命的响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") - else: - # 其他未捕获的异常,增加基础惩罚 - logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") - - self.model_usage[model_info.name] = (total_tokens, penalty + penalty_increment, usage_penalty) - # --- 结束增强 --- - # 移除冗余的、错误的惩罚值更新行,保留上面正确的动态惩罚更新 - # self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) - - wait_interval, compressed_messages = self._default_exception_handler( - e, - self.task_name, - model_info=model_info, - api_provider=api_provider, - remain_try=retry_remain, - retry_interval=api_provider.retry_interval, - messages=(message_list, compressed_messages is not None) if message_list else None, - ) - - if wait_interval == -1: - retry_remain = 0 # 不再重试 - elif wait_interval > 0: - logger.info(f"等待 {wait_interval} 秒后重试...") - await asyncio.sleep(wait_interval) - finally: - # 放在finally防止死循环 - retry_remain -= 1 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 - logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") - raise RuntimeError("请求失败,已达到最大重试次数") - - def _default_exception_handler( - self, - e: Exception, - task_name: str, - model_info: ModelInfo, - api_provider: APIProvider, - remain_try: int, - retry_interval: int = 10, - messages: Tuple[List[Message], bool] | None = None, - ) -> Tuple[int, List[Message] | None]: - """ - 默认异常处理函数 - Args: - e (Exception): 异常对象 - task_name (str): 任务名称 - model_info (ModelInfo): 模型信息 - api_provider (APIProvider): API提供商 - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) - Returns: - (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - model_name = model_info.name if model_info else "unknown" - - if isinstance(e, NetworkConnectionError): # 网络连接错误 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确", - ) - elif isinstance(e, ReqAbortException): - logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") - return -1, None # 不再重试请求该模型 - elif isinstance(e, RespNotOkException): - return self._handle_resp_not_ok( - e, - task_name, - model_info, - api_provider, - remain_try, - retry_interval, - messages, - ) - elif isinstance(e, RespParseException): - # 响应解析错误 - logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") - logger.debug(f"附加内容: {str(e.ext_info)}") - return -1, None # 不再重试请求该模型 - else: - logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") - return -1, None # 不再重试请求该模型 - - @staticmethod - def _check_retry( - remain_try: int, - retry_interval: int, - can_retry_msg: str, - cannot_retry_msg: str, - can_retry_callable: Callable | None = None, - **kwargs, - ) -> Tuple[int, List[Message] | None]: - """辅助函数:检查是否可以重试 - Args: - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - can_retry_msg (str): 可以重试时的提示信息 - cannot_retry_msg (str): 不可以重试时的提示信息 - can_retry_callable (Callable | None): 可以重试时调用的函数(如果有) - **kwargs: 其他参数 - - Returns: - (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - if remain_try > 0: - # 还有重试机会 - logger.warning(f"{can_retry_msg}") - if can_retry_callable is not None: - return retry_interval, can_retry_callable(**kwargs) - else: - return retry_interval, None - else: - # 达到最大重试次数 - logger.warning(f"{cannot_retry_msg}") - return -1, None # 不再重试请求该模型 - - def _handle_resp_not_ok( - self, - e: RespNotOkException, - task_name: str, - model_info: ModelInfo, - api_provider: APIProvider, - remain_try: int, - retry_interval: int = 10, - messages: tuple[list[Message], bool] | None = None, - ): - model_name = model_info.name - """ - 处理响应错误异常 - Args: - e (RespNotOkException): 响应错误异常对象 - task_name (str): 任务名称 - model_info (ModelInfo): 模型信息 - api_provider (APIProvider): API提供商 - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) - Returns: - (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - # 响应错误 - if e.status_code in [400, 401, 402, 403, 404]: - model_name = model_info.name - # 客户端错误 - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None # 不再重试请求该模型 - elif e.status_code == 413: - if messages and not messages[1]: - # 消息列表不为空且未压缩,尝试压缩消息 - return self._check_retry( - remain_try, - 0, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求", - can_retry_callable=compress_messages, - messages=messages[0], - ) - # 没有消息可压缩 - logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。") - return -1, None - elif e.status_code == 429: - # 请求过于频繁 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求", - ) - elif e.status_code >= 500: - # 服务器错误 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试", - ) - else: - # 未知错误 - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None + time_cost=time_cost, + request_type=self.task_name, + endpoint=endpoint, + )) @staticmethod def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: - # sourcery skip: extract-method - """构建工具选项列表""" + """构建工具选项列表。""" if not tools: return None tool_options: List[ToolOption] = [] for tool in tools: - tool_legal = True - tool_options_builder = ToolOptionBuilder() - tool_options_builder.set_name(tool.get("name", "")) - tool_options_builder.set_description(tool.get("description", "")) - parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) - for param in parameters: - try: + try: + builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", "")) + for param in tool.get("parameters", []): + # 参数格式验证 assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" - assert isinstance(param[0], str), "参数名称必须是字符串" - assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" - assert isinstance(param[2], str), "参数描述必须是字符串" - assert isinstance(param[3], bool), "参数是否必填必须是布尔值" - assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" - tool_options_builder.add_param( + builder.add_param( name=param[0], param_type=param[1], description=param[2], required=param[3], enum_values=param[4], ) - except AssertionError as ae: - tool_legal = False - logger.error(f"{param[0]} 参数定义错误: {str(ae)}") - except Exception as e: - tool_legal = False - logger.error(f"构建工具参数失败: {str(e)}") - if tool_legal: - tool_options.append(tool_options_builder.build()) + tool_options.append(builder.build()) + except (KeyError, IndexError, TypeError, AssertionError) as e: + logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}") return tool_options or None - - @staticmethod - def _extract_reasoning(content: str) -> Tuple[str, str]: - """CoT思维链提取,向后兼容""" - match = re.search(r"(?:)?(.*?)", content, re.DOTALL) - content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() - reasoning = match[1].strip() if match else "" - return content, reasoning - - def _apply_content_obfuscation(self, text: str, api_provider) -> str: - """根据API提供商配置对文本进行混淆处理""" - if not hasattr(api_provider, "enable_content_obfuscation") or not api_provider.enable_content_obfuscation: - logger.debug(f"API提供商 '{api_provider.name}' 未启用内容混淆") - return text - - intensity = getattr(api_provider, "obfuscation_intensity", 1) - logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") - - # 在开头加入过滤规则指令 - processed_text = self.noise_instruction + "\n\n" + text - logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}") - - # 添加随机乱码 - final_text = self._inject_random_noise(processed_text, intensity) - logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}") - - return final_text - - @staticmethod - def _inject_random_noise(text: str, intensity: int) -> str: - """在文本中注入随机乱码""" - import random - import string - - def generate_noise(length: int) -> str: - """生成指定长度的随机乱码字符""" - chars = ( - string.ascii_letters # a-z, A-Z - + string.digits # 0-9 - + "!@#$%^&*()_+-=[]{}|;:,.<>?" # 特殊符号 - + "一二三四五六七八九零壹贰叁" # 中文字符 - + "αβγδεζηθικλμνξοπρστυφχψω" # 希腊字母 - + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" # 数学符号 - ) - return "".join(random.choice(chars) for _ in range(length)) - - # 强度参数映射 - params = { - 1: {"probability": 15, "length": (3, 6)}, # 低强度:15%概率,3-6个字符 - 2: {"probability": 25, "length": (5, 10)}, # 中强度:25%概率,5-10个字符 - 3: {"probability": 35, "length": (8, 15)}, # 高强度:35%概率,8-15个字符 - } - - config = params.get(intensity, params[1]) - logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}") - - # 按词分割处理 - words = text.split() - result = [] - noise_count = 0 - - for word in words: - result.append(word) - # 根据概率插入乱码 - if random.randint(1, 100) <= config["probability"]: - noise_length = random.randint(*config["length"]) - noise = generate_noise(noise_length) - result.append(noise) - noise_count += 1 - - logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}") - return " ".join(result) From 15d82e602e835d3c6806e2a3a7f9d715ca4a8e76 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 20:41:04 +0800 Subject: [PATCH 25/41] =?UTF-8?q?fix(llm):=20=E9=98=B2=E6=AD=A2=20get=5Fre?= =?UTF-8?q?sponse=20=E8=B0=83=E7=94=A8=E4=B8=AD=20message=5Flist=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E9=87=8D=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当 `kwargs` 中已包含 `message_list` 时,直接将其与 `message_list=current_messages` 一同传递给 `get_response` 方法会导致 `TypeError`。 此更改通过在传递参数前从 `kwargs` 的副本中移除 `message_list` 键,确保该参数不会被重复传递,从而解决了这个问题。 --- src/llm_models/utils_model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index acb7130b6..b24b1843c 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -313,7 +313,14 @@ class _RequestExecutor: if request_type == RequestType.RESPONSE: assert current_messages is not None, "message_list cannot be None for response requests" - return await client.get_response(model_info=model_info, message_list=current_messages, **kwargs) + + # 修复: 防止 'message_list' 在 kwargs 中重复 + request_params = kwargs.copy() + request_params.pop("message_list", None) + + return await client.get_response( + model_info=model_info, message_list=current_messages, **request_params + ) elif request_type == RequestType.EMBEDDING: return await client.get_embedding(model_info=model_info, **kwargs) elif request_type == RequestType.AUDIO: From fe40b8873131dfeed2123b4746b7ef69fefda846 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Fri, 26 Sep 2025 20:43:23 +0800 Subject: [PATCH 26/41] =?UTF-8?q?=E8=BF=99=E6=98=AF=E7=AC=AC900=E4=B8=AA?= =?UTF-8?q?=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 移除了用户协议 --- EULA.md | 94 --------------------------------------------------------- 1 file changed, 94 deletions(-) delete mode 100644 EULA.md diff --git a/EULA.md b/EULA.md deleted file mode 100644 index bebfedd91..000000000 --- a/EULA.md +++ /dev/null @@ -1,94 +0,0 @@ -# **欢迎使用 MoFox_Bot (第三方修改版)!** - -**版本:V2.1** -**更新日期:2025年8月30日** - ---- - -你好!感谢你选择 MoFox_Bot。在开始之前,请花几分钟时间阅读这份协议。我们用问答的形式,帮助你快速了解使用这个**第三方修改版**软件时,你的权利和责任。 - -**简单来说,你需要同意这份协议才能使用我们的软件。** 如果你是未成年人,请确保你的监护人也阅读并同意了哦。 - ---- - -### **1. 这个软件和原版有什么关系?** - -这是一个非常重要的问题! - -* **第三方修改版**:首先,你需要清楚地知道,MoFox_Bot 是一个基于[MaiCore](https://mai-mai.org/)开源项目的**第三方修改版**。我们(MoFox_Bot 团队)与原始项目的开发者**没有任何关联**。 -* **独立维护**:我们独立负责这个修改版的维护、开发和更新。因此,原始项目的开发者**不会**为 MoFox_Bot 提供任何技术支持,也**不会**对因使用本修改版产生的任何问题负责。 -* **责任划分**:我们只对我们修改、添加的功能负责。对于原始代码中可能存在的任何问题,我们不承担责任。 - ---- - -### **2. 这个软件是免费和开源的吗?** - -**是的,核心代码是开源的!** - -* **遵循开源协议**:本项目继承了原始项目的 **GPLv3 开源协议**。这意味着你可以自由地使用、复制、研究、修改和重新分发它。 -* **你的义务**:当你修改或分发这个软件时,你同样需要遵守 GPLv3 协议的要求,确保你的衍生作品也是开源的。你可以在项目根目录找到 `LICENSE` 文件来了解更多细节。 -* **包含第三方代码**:请注意,项目中可能还包含了其他第三方库或组件,它们各自有独立的开源许可证。你在使用时需要同时遵守这些许可证。 - ---- - -### **3. 我的个人数据是如何被处理的?** - -你的隐私对我们很重要。了解数据如何流动,能帮助你更好地保护自己。 - -* **工作流程**:当你与机器人互动时,你的**输入内容**(比如文字、指令)、**配置信息**以及机器人**生成的回复**,会被发送给第三方的 API 服务(例如 OpenAI、Google 等大语言模型提供商)以获得智能回复。 -* **你的明确授权**:一旦你开始使用,即表示你授权我们利用你的数据进行以下操作: - 1. **调用外部 API**:这是机器人能与你对话的核心。 - 2. **建立本地知识库与记忆**:为了让机器人更个性化、更懂你,软件会在**你自己的设备上**创建和存储知识库、记忆库和对话日志。**这些数据存储在本地,我们无法访问。** - 3. **记录本地日志**:为了方便排查可能出现的技术问题,软件会在你的设备上记录运行日志。 -* **第三方服务的风险**:我们无法控制第三方 API 提供商的服务质量、数据处理政策、稳定性或安全性。使用这些服务时,你同样受到该第三方服务条款和隐私政策的约束。我们建议你自行了解这些条款。 - ---- - -### **4. 关于强大的插件系统,我需要了解什么?** - -MoFox_Bot 通过插件系统实现功能扩展,但这需要你承担相应的责任。 - -* **谁开发的插件?**:绝大多数插件是由**社区里的第三方开发者**创建和维护的,他们并**不是 MoFox_Bot 核心团队的成员**。 -* **责任完全自负**:插件的功能、质量、安全性和合法性**完全由其各自的开发者负责**。我们只提供了一个能让插件运行的技术平台,但**不对任何第三方插件的内容、行为或造成的后果承担任何责任**。 -* **你的使用风险**:使用任何第三方插件的风险**完全由你自行承担**。在安装和使用插件前,我们强烈建议你: - * 仔细阅读并理解插件开发者提供的许可协议和说明文档。 - * **只从你完全信任的来源获取和安装插件**。 - * 自行评估插件的安全性、合法性及其对你数据隐私的影响。 - ---- - -### **5. 我在使用时,有哪些行为准则?** - -请务必合法、合规地使用本软件。 - -* **禁止内容**:严禁输入、处理或传播任何违反你所在地法律法规的内容,包括但不限于:涉及国家秘密、商业机密、侵犯他人知识产权、个人隐私的内容,以及任何形式的非法、骚扰、诽谤、淫秽信息。 -* **合法用途**:你承诺不会将本项目用于任何非法目的或活动,例如网络攻击、诈骗等。 -* **数据安全**:你对自己存储在本地知识库、记忆库和日志中的所有内容的合法性负全部责任。 -* **插件规范**:不要使用任何已知包含恶意代码、安全漏洞或违法内容的插件。 - -**你将对自己使用本项目(包括所有第三方插件)的全部行为及其产生的一切后果,承担完全的法律责任。** - ---- - -### **6. 免责声明(非常重要!)** - -* **“按原样”提供**:本项目是“按原样”提供的,我们**不提供任何形式的明示或暗示的担保**,包括但不限于对适销性、特定用途适用性和不侵权的保证。 -* **AI 回复的立场**:机器人的所有回复均由第三方大语言模型生成,其观点和信息**不代表 MoFox_Bot 团队的立场**。我们不对其准确性、完整性或可靠性负责。 -* **无责任声明**:在任何情况下,MoFox_Bot 团队均不对因使用或无法使用本项目(特别是第三方插件)而导致的任何直接、间接、偶然、特殊或后果性的损害(包括但不限于数据丢失、利润损失、业务中断)承担责任。 -* **插件支持**:所有第三方插件的技术支持、功能更新和 bug 修复,都应**直接联系相应的插件开发者**。 - ---- - -### **7. 其他条款** - -* **协议的修改**:我们保留随时修改本协议的权利。修改后的协议将在新版本发布时生效。我们建议你定期检查以获取最新版本。继续使用本项目即表示你接受修订后的协议。 -* **最终解释权**:在法律允许的范围内,MoFox_Bot 团队保留对本协议的最终解释权。 -* **适用法律**:本协议的订立、执行和解释及争议的解决均应适用中国法律。 - ---- - -### **风险提示(请再次确认你已理解!)** - -* **隐私风险**:你的对话数据会被发送到不受我们控制的第三方 API。请**绝对不要**在对话中包含任何个人身份信息、财务信息、密码或其他敏感数据。 -* **精神健康风险**:AI 机器人只是一个程序,无法提供真正的情感支持或专业的心理建议。如果遇到任何心理困扰,请务必寻求专业人士的帮助(例如,全国心理援助热线:12355)。 -* **插件风险**:这是最大的风险之一。第三方插件可能带来严重的安全漏洞、系统不稳定、性能下降甚至隐私数据泄露的风险。请务必谨慎选择和使用,并为自己的选择承担全部后果。 From 439bbc8163261b7fa8a972888da65d7785812539 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 20:55:11 +0800 Subject: [PATCH 27/41] =?UTF-8?q?docs(llm):=20=E4=B8=BA=20LLM=20=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E6=A8=A1=E5=9D=97=E6=B7=BB=E5=8A=A0=E5=85=A8=E9=9D=A2?= =?UTF-8?q?=E7=9A=84=E6=96=87=E6=A1=A3=E5=92=8C=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为 `utils_model.py` 模块及其核心类(`_ModelSelector`、`_PromptProcessor`、`_RequestExecutor`)添加了详细的文档字符串。 同时,增加了大量的行内注释,以阐明复杂的逻辑,例如: - 模型选择的负载均衡算法 - 针对不同错误的失败惩罚计算 - 对嵌入任务的特殊客户端处理 此举旨在提高 LLM 交互核心逻辑的可读性和可维护性。 --- src/llm_models/utils_model.py | 96 +++++++++++++++++++++++++++++++---- 1 file changed, 86 insertions(+), 10 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index b24b1843c..e3699d540 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,3 +1,29 @@ +# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- +""" +@software: +@file: utils_model.py +@time: 2024/7/28 上午1:09 +@author: Mai ū +@contact: 2496955113@qq.com +@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。 +它被设计为一个高度容错和可扩展的系统,包含以下主要组件: + +- **模型选择器 (_ModelSelector)**: + 实现了基于负载均衡和失败惩罚的动态模型选择策略,确保在高并发或部分模型失效时系统的稳定性。 + +- **提示处理器 (_PromptProcessor)**: + 负责对输入模型的提示词进行预处理(如内容混淆、反截断指令注入)和对模型输出进行后处理(如提取思考过程、检查截断)。 + +- **请求执行器 (_RequestExecutor)**: + 封装了底层的API请求逻辑,包括自动重试、异常分类处理和消息体压缩等功能。 + +- **请求策略 (_RequestStrategy)**: + 实现了高阶请求策略,如模型间的故障转移(Failover),确保单个模型的失败不会导致整个请求失败。 + +- **LLMRequest (主接口)**: + 作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。 +""" import re import asyncio import time @@ -102,10 +128,18 @@ class RequestType(Enum): class _ModelSelector: """负责模型选择、负载均衡和动态故障切换的策略。""" - CRITICAL_PENALTY_MULTIPLIER = 5 - DEFAULT_PENALTY_INCREMENT = 1 + CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 + DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]): + """ + 初始化模型选择器。 + + Args: + model_list (List[str]): 可用模型名称列表。 + model_usage (Dict[str, Tuple[int, int, int]]): 模型的初始使用情况, + 格式为 {model_name: (total_tokens, penalty, usage_penalty)}。 + """ self.model_list = model_list self.model_usage = model_usage @@ -132,8 +166,12 @@ class _ModelSelector: logger.warning("没有可用的模型供当前请求选择。") return None - # 根据公式查找分数最低的模型,该公式综合了总token数、模型失败惩罚值和使用频率惩罚值。 + # 核心负载均衡算法:选择一个综合得分最低的模型。 # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + # 设计思路: + # - `total_tokens`: 基础成本,优先使用累计token少的模型,实现长期均衡。 + # - `penalty * 300`: 失败惩罚项。每次失败会增加penalty,使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。 + # - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。 least_used_model_name = min( candidate_models_usage, key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, @@ -141,7 +179,8 @@ class _ModelSelector: model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) - # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 + # 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。 + # 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。 force_new_client = request_type == "embedding" client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) @@ -165,16 +204,22 @@ class _ModelSelector: total_tokens, penalty, usage_penalty = self.model_usage[model_name] penalty_increment = self.DEFAULT_PENALTY_INCREMENT + # 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池 if isinstance(e, (NetworkConnectionError, ReqAbortException)): + # 网络连接错误或请求被中断,通常是基础设施问题,应重罚 penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}") elif isinstance(e, RespNotOkException): + # 对于HTTP响应错误,重点关注服务器端错误 if e.status_code >= 500: + # 5xx 错误表明服务器端出现问题,应重罚 penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}") else: + # 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚 logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") else: + # 其他未知异常,给予基础惩罚 logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) @@ -184,7 +229,11 @@ class _PromptProcessor: """封装所有与提示词和响应内容的预处理和后处理逻辑。""" def __init__(self): - # 内容混淆过滤指令 + """ + 初始化提示处理器。 + 这里定义了用于内容混淆和反截断的指令文本。 + """ + # 内容混淆过滤指令: 用于在API传输中增加噪音,防止内容被审查或修改 self.noise_instruction = """**【核心过滤规则】** 在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 @@ -291,6 +340,13 @@ class _RequestExecutor: """负责执行实际的API请求,包含重试逻辑和底层异常处理。""" def __init__(self, model_selector: _ModelSelector, task_name: str): + """ + 初始化请求执行器。 + + Args: + model_selector (_ModelSelector): 模型选择器实例,用于在请求失败时更新惩罚。 + task_name (str): 当前任务的名称,用于日志记录。 + """ self.model_selector = model_selector self.task_name = task_name @@ -302,19 +358,37 @@ class _RequestExecutor: model_info: ModelInfo, **kwargs, ) -> APIResponse: - """实际执行请求的方法,包含了重试和异常处理逻辑。""" + """ + 实际执行请求的方法,包含了重试和异常处理逻辑。 + + Args: + api_provider (APIProvider): API提供商配置。 + client (BaseClient): 用于发送请求的客户端实例。 + request_type (RequestType): 请求的类型 (e.g., RESPONSE, EMBEDDING)。 + model_info (ModelInfo): 正在使用的模型的信息。 + **kwargs: 传递给客户端方法的具体参数。 + + Returns: + APIResponse: 来自API的成功响应。 + + Raises: + Exception: 如果重试后请求仍然失败,则抛出最终的异常。 + RuntimeError: 如果达到最大重试次数。 + """ retry_remain = api_provider.max_retry compressed_messages: Optional[List[Message]] = None while retry_remain > 0: try: + # 优先使用压缩后的消息列表 message_list = kwargs.get("message_list") current_messages = compressed_messages or message_list + # 根据请求类型调用不同的客户端方法 if request_type == RequestType.RESPONSE: assert current_messages is not None, "message_list cannot be None for response requests" - # 修复: 防止 'message_list' 在 kwargs 中重复 + # 修复: 防止 'message_list' 在 kwargs 中重复传递 request_params = kwargs.copy() request_params.pop("message_list", None) @@ -328,18 +402,20 @@ class _RequestExecutor: except Exception as e: logger.debug(f"请求失败: {str(e)}") + # 记录失败并更新模型的惩罚值 self.model_selector.update_failure_penalty(model_info.name, e) + # 处理异常,决定是否重试以及等待多久 wait_interval, new_compressed_messages = self._handle_exception( e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None) ) if new_compressed_messages: - compressed_messages = new_compressed_messages + compressed_messages = new_compressed_messages # 更新为压缩后的消息 if wait_interval == -1: - raise e # 如果不再重试,则传播异常 + raise e # 如果决定不再重试,则传播异常 elif wait_interval > 0: - await asyncio.sleep(wait_interval) + await asyncio.sleep(wait_interval) # 等待指定时间后重试 finally: retry_remain -= 1 From 9eb940ca960fc5d13284a638c5fab3261f82536b Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 21:04:37 +0800 Subject: [PATCH 28/41] =?UTF-8?q?docs(llm):=20=E4=B8=BA=20utils=5Fmodel=20?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E8=A1=A5=E5=85=85=E8=AF=A6=E7=BB=86=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=92=8C=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为 `utils_model.py` 中的关键类和方法添加了全面的文档字符串和内联注释,以提升代码的可读性和可维护性。 主要变更包括: - 为 `_ModelSelector`, `_PromptProcessor`, `_RequestExecutor`, 和 `LLMRequest` 类中的核心方法扩充了详细的文档,解释其功能、参数和返回值。 - 在复杂的逻辑块(如重试机制、错误处理、内容混淆)中增加了内联注释,以阐明其实现细节。 - 移除了文件中旧的、多余的作者信息头。 --- src/llm_models/utils_model.py | 178 ++++++++++++++++++++++++++++++---- 1 file changed, 158 insertions(+), 20 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index e3699d540..3efa9cd2d 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,11 +1,6 @@ # -*- coding: utf-8 -*- # -*- coding: utf-8 -*- """ -@software: -@file: utils_model.py -@time: 2024/7/28 上午1:09 -@author: Mai ū -@contact: 2496955113@qq.com @desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。 它被设计为一个高度容错和可扩展的系统,包含以下主要组件: @@ -190,9 +185,21 @@ class _ModelSelector: return model_info, api_provider, client def update_usage_penalty(self, model_name: str, increase: bool): - """更新模型的使用惩罚值,用于负载均衡。""" + """ + 更新模型的使用惩罚值。 + + 在模型被选中时增加惩罚值,请求完成后减少惩罚值。 + 这有助于在短期内将请求分散到不同的模型,实现更动态的负载均衡。 + + Args: + model_name (str): 要更新惩罚值的模型名称。 + increase (bool): True表示增加惩罚值,False表示减少。 + """ + # 获取当前模型的统计数据 total_tokens, penalty, usage_penalty = self.model_usage[model_name] + # 根据操作是增加还是减少来确定调整量 adjustment = 1 if increase else -1 + # 更新模型的惩罚值 self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment) def update_failure_penalty(self, model_name: str, e: Exception): @@ -253,11 +260,29 @@ class _PromptProcessor: """ def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str: - """为请求准备最终的提示词,应用内容混淆和反截断指令。""" + """ + 为请求准备最终的提示词。 + + 此方法会根据API提供商和模型配置,对原始提示词应用内容混淆和反截断指令, + 生成最终发送给模型的完整提示内容。 + + Args: + prompt (str): 原始的用户提示词。 + model_info (ModelInfo): 目标模型的信息。 + api_provider (APIProvider): API提供商的配置。 + task_name (str): 当前任务的名称,用于日志记录。 + + Returns: + str: 处理后的、可以直接发送给模型的完整提示词。 + """ + # 步骤1: 根据API提供商的配置应用内容混淆 processed_prompt = self._apply_content_obfuscation(prompt, api_provider) + + # 步骤2: 检查模型是否需要注入反截断指令 if getattr(model_info, "use_anti_truncation", False): processed_prompt += self.anti_truncation_instruction logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") + return processed_prompt def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]: @@ -277,33 +302,73 @@ class _PromptProcessor: return content, reasoning, is_truncated def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str: - """根据API提供商配置对文本进行混淆处理。""" + """ + 根据API提供商的配置对文本进行内容混淆。 + + 如果提供商配置中启用了内容混淆,此方法会在文本前部加入抗审查指令, + 并在文本中注入随机噪音,以降低内容被审查或修改的风险。 + + Args: + text (str): 原始文本内容。 + api_provider (APIProvider): API提供商的配置。 + + Returns: + str: 经过混淆处理的文本。 + """ + # 检查当前API提供商是否启用了内容混淆功能 if not getattr(api_provider, "enable_content_obfuscation", False): return text + # 获取混淆强度,默认为1 intensity = getattr(api_provider, "obfuscation_intensity", 1) logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") + + # 将抗审查指令和原始文本拼接 processed_text = self.noise_instruction + "\n\n" + text + + # 在拼接后的文本中注入随机噪音 return self._inject_random_noise(processed_text, intensity) @staticmethod def _inject_random_noise(text: str, intensity: int) -> str: - """在文本中注入随机乱码。""" + """ + 在文本中按指定强度注入随机噪音字符串。 + + 该方法通过在文本的单词之间随机插入无意义的字符串(噪音)来实现内容混淆。 + 强度越高,插入噪音的概率和长度就越大。 + + Args: + text (str): 待处理的文本。 + intensity (int): 混淆强度 (1-3),决定噪音的概率和长度。 + + Returns: + str: 注入噪音后的文本。 + """ + # 定义不同强度级别的噪音参数:概率和长度范围 params = { - 1: {"probability": 15, "length": (3, 6)}, - 2: {"probability": 25, "length": (5, 10)}, - 3: {"probability": 35, "length": (8, 15)}, + 1: {"probability": 15, "length": (3, 6)}, # 低强度 + 2: {"probability": 25, "length": (5, 10)}, # 中强度 + 3: {"probability": 35, "length": (8, 15)}, # 高强度 } + # 根据传入的强度选择配置,如果强度无效则使用默认值 config = params.get(intensity, params[1]) + words = text.split() result = [] + # 遍历每个单词 for word in words: result.append(word) + # 根据概率决定是否在此单词后注入噪音 if random.randint(1, 100) <= config["probability"]: + # 确定噪音的长度 noise_length = random.randint(*config["length"]) + # 定义噪音字符集 chars = string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?" + # 生成噪音字符串 noise = "".join(random.choice(chars) for _ in range(noise_length)) result.append(noise) + + # 将处理后的单词列表重新组合成字符串 return " ".join(result) @staticmethod @@ -448,30 +513,68 @@ class _RequestExecutor: def _handle_resp_not_ok( self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info ) -> Tuple[int, Optional[List[Message]]]: - """处理非200的HTTP响应异常。""" + """ + 处理非200的HTTP响应异常。 + + 根据不同的HTTP状态码决定下一步操作: + - 4xx 客户端错误:通常不可重试,直接放弃。 + - 413 (Payload Too Large): 尝试压缩消息体后重试一次。 + - 429 (Too Many Requests) / 5xx 服务器错误:可重试。 + + Args: + e (RespNotOkException): 捕获到的响应异常。 + model_info (ModelInfo): 当前模型信息。 + api_provider (APIProvider): API提供商配置。 + remain_try (int): 剩余重试次数。 + messages_info (tuple): 包含消息列表和是否已压缩的标志。 + + Returns: + Tuple[int, Optional[List[Message]]]: (等待间隔, 新的消息列表)。 + 等待间隔为-1表示不再重试。新的消息列表用于压缩后重试。 + """ model_name = model_info.name + # 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试 if e.status_code in [400, 401, 402, 403, 404]: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。") return -1, None + # 处理请求体过大的情况 elif e.status_code == 413: messages, is_compressed = messages_info + # 如果消息存在且尚未被压缩,则尝试压缩后立即重试 if messages and not is_compressed: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试。") return 0, compress_messages(messages) + # 如果已经压缩过或没有消息体,则放弃 logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大且无法压缩,放弃请求。") return -1, None + # 处理请求频繁或服务器端错误,这些情况适合重试 elif e.status_code == 429 or e.status_code >= 500: reason = "请求过于频繁" if e.status_code == 429 else "服务器错误" return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name) + # 处理其他未知的HTTP错误 else: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}") return -1, None def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]: - """辅助函数:检查是否可以重试。""" - if remain_try > 1: # 剩余次数大于1才重试 + """ + 辅助函数,根据剩余次数决定是否进行下一次重试。 + + Args: + remain_try (int): 剩余的重试次数。 + interval (int): 重试前的等待间隔(秒)。 + reason (str): 本次失败的原因。 + model_name (str): 失败的模型名称。 + + Returns: + Tuple[int, None]: (等待间隔, None)。如果等待间隔为-1,表示不应再重试。 + """ + # 只有在剩余重试次数大于1时才进行下一次重试(因为当前这次失败已经消耗掉一次) + if remain_try > 1: logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。") return interval, None + + # 如果已无剩余重试次数,则记录错误并返回-1表示放弃 logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。") return -1, None @@ -822,16 +925,28 @@ class LLMRequest: return response.embedding, model_info.name def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): - """异步记录用量到数据库。""" + """ + 记录模型使用情况。 + + 此方法首先在内存中更新模型的累计token使用量,然后创建一个异步任务, + 将详细的用量数据(包括模型信息、token数、耗时等)写入数据库。 + + Args: + model_info (ModelInfo): 使用的模型信息。 + usage (Optional[UsageRecord]): API返回的用量记录。 + time_cost (float): 本次请求的总耗时。 + endpoint (str): 请求的API端点 (e.g., "/chat/completions")。 + """ if usage: - # 更新内存中的token计数 + # 步骤1: 更新内存中的token计数,用于负载均衡 total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty) + # 步骤2: 创建一个后台任务,将用量数据异步写入数据库 asyncio.create_task(llm_usage_recorder.record_usage_to_database( model_info=model_info, model_usage=usage, - user_id="system", + user_id="system", # 此处可根据业务需求修改 time_cost=time_cost, request_type=self.task_name, endpoint=endpoint, @@ -839,15 +954,34 @@ class LLMRequest: @staticmethod def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: - """构建工具选项列表。""" + """ + 根据输入的字典列表构建并验证 `ToolOption` 对象列表。 + + 此方法将标准化的工具定义(字典格式)转换为内部使用的 `ToolOption` 对象, + 同时会验证参数格式的正确性。 + + Args: + tools (Optional[List[Dict[str, Any]]]): 工具定义的列表。 + 每个工具是一个字典,包含 "name", "description", 和 "parameters"。 + "parameters" 是一个元组列表,每个元组包含 (name, type, desc, required, enum)。 + + Returns: + Optional[List[ToolOption]]: 构建好的 `ToolOption` 对象列表,如果输入为空则返回 None。 + """ + # 如果没有提供工具,直接返回 None if not tools: return None + tool_options: List[ToolOption] = [] + # 遍历每个工具定义 for tool in tools: try: + # 使用建造者模式创建 ToolOption builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", "")) + + # 遍历工具的参数 for param in tool.get("parameters", []): - # 参数格式验证 + # 严格验证参数格式是否为包含5个元素的元组 assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" builder.add_param( name=param[0], @@ -856,7 +990,11 @@ class LLMRequest: required=param[3], enum_values=param[4], ) + # 将构建好的 ToolOption 添加到列表中 tool_options.append(builder.build()) except (KeyError, IndexError, TypeError, AssertionError) as e: + # 如果构建过程中出现任何错误,记录日志并跳过该工具 logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}") + + # 如果列表非空则返回列表,否则返回 None return tool_options or None From 900b9af443339b565f37e1271f6f6a42e617f3c9 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 22:00:07 +0800 Subject: [PATCH 29/41] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6=E7=9A=84=E5=8F=8C=E5=90=91?= =?UTF-8?q?=E8=BF=81=E7=A7=BB=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_system/base/plugin_base.py | 208 +++++++++-------------- src/plugin_system/base/plus_plugin.py | 0 src/plugin_system/core/plugin_manager.py | 70 -------- 3 files changed, 83 insertions(+), 195 deletions(-) delete mode 100644 src/plugin_system/base/plus_plugin.py diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 9ef95182d..12797bafd 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -5,6 +5,7 @@ import toml import orjson import shutil import datetime +from pathlib import Path from src.common.logger import get_logger from src.config.config import CONFIG_DIR @@ -268,100 +269,64 @@ class PluginBase(ABC): except IOError as e: logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True) - def _get_expected_config_version(self) -> str: - """获取插件期望的配置版本号""" - # 从config_schema的plugin.config_version字段获取 - if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict): - config_version_field = self.config_schema["plugin"].get("config_version") - if isinstance(config_version_field, ConfigField): - return config_version_field.default - return "1.0.0" - - @staticmethod - def _get_current_config_version(config: Dict[str, Any]) -> str: - """从配置文件中获取当前版本号""" - if "plugin" in config and "config_version" in config["plugin"]: - return str(config["plugin"]["config_version"]) - # 如果没有config_version字段,视为最早的版本 - return "0.0.0" - def _backup_config_file(self, config_file_path: str) -> str: - """备份配置文件""" - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - backup_path = f"{config_file_path}.backup_{timestamp}" - + """备份配置文件到指定的 backup 子目录""" try: + config_path = Path(config_file_path) + backup_dir = config_path.parent / "backup" + backup_dir.mkdir(exist_ok=True) + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + backup_filename = f"{config_path.name}.backup_{timestamp}" + backup_path = backup_dir / backup_filename + shutil.copy2(config_file_path, backup_path) logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}") - return backup_path + return str(backup_path) except Exception as e: - logger.error(f"{self.log_prefix} 备份配置文件失败: {e}") + logger.error(f"{self.log_prefix} 备份配置文件失败: {e}", exc_info=True) return "" - def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]: - """将旧配置值迁移到新配置结构中 + def _synchronize_config( + self, schema_config: Dict[str, Any], user_config: Dict[str, Any] + ) -> tuple[Dict[str, Any], bool]: + """递归地将用户配置与 schema 同步,返回同步后的配置和是否发生变化的标志""" + changed = False - Args: - old_config: 旧配置数据 - new_config: 基于新schema生成的默认配置 - - Returns: - Dict[str, Any]: 迁移后的配置 - """ - - def migrate_section( - old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str + # 内部递归函数 + def _sync_dicts( + schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "" ) -> Dict[str, Any]: - """迁移单个配置节""" - result = new_section.copy() + nonlocal changed + synced_dict = schema_dict.copy() - for key, value in old_section.items(): - if key in new_section: - # 特殊处理:config_version字段总是使用新版本 - if section_name == "plugin" and key == "config_version": - # 保持新的版本号,不迁移旧值 - logger.debug( - f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})" - ) - continue + # 检查并记录用户配置中多余的、在 schema 中不存在的键 + for key in user_dict: + if key not in schema_dict: + logger.warning(f"{self.log_prefix} 发现废弃配置项 '{parent_key}{key}',将被移除。") + changed = True - # 键存在于新配置中,复制值 - if isinstance(value, dict) and isinstance(new_section[key], dict): - # 递归处理嵌套字典 - result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}") + # 以 schema 为基准进行遍历,保留用户的值,补全缺失的项 + for key, schema_value in schema_dict.items(): + full_key = f"{parent_key}{key}" + if key in user_dict: + user_value = user_dict[key] + if isinstance(schema_value, dict) and isinstance(user_value, dict): + # 递归同步嵌套的字典 + synced_dict[key] = _sync_dicts(schema_value, user_value, f"{full_key}.") else: - result[key] = value - logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}") + # 键存在,保留用户的值 + synced_dict[key] = user_value else: - # 键在新配置中不存在,记录警告 - logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除") + # 键在用户配置中缺失,补全 + logger.info(f"{self.log_prefix} 补全缺失的配置项: '{full_key}' = {schema_value}") + changed = True + # synced_dict[key] 已经包含了来自 schema_dict.copy() 的默认值 - return result + return synced_dict - migrated_config = {} - - # 迁移每个配置节 - for section_name, new_section_data in new_config.items(): - if ( - section_name in old_config - and isinstance(old_config[section_name], dict) - and isinstance(new_section_data, dict) - ): - migrated_config[section_name] = migrate_section( - old_config[section_name], new_section_data, section_name - ) - else: - # 新增的节或类型不匹配,使用默认值 - migrated_config[section_name] = new_section_data - if section_name in old_config: - logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值") - - # 检查旧配置中是否有新配置没有的节 - for section_name in old_config: - if section_name not in migrated_config: - logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除") - - return migrated_config + final_config = _sync_dicts(schema_config, user_config) + return final_config, changed def _generate_config_from_schema(self) -> Dict[str, Any]: # sourcery skip: dict-comprehension @@ -393,11 +358,7 @@ class PluginBase(ABC): toml_str = f"# {self.plugin_name} - 配置文件\n" plugin_description = self.get_manifest_info("description", "插件配置文件") - toml_str += f"# {plugin_description}\n" - - # 获取当前期望的配置版本 - expected_version = self._get_expected_config_version() - toml_str += f"# 配置版本: {expected_version}\n\n" + toml_str += f"# {plugin_description}\n\n" # 遍历每个配置节 for section, fields in self.config_schema.items(): @@ -456,77 +417,74 @@ class PluginBase(ABC): def _load_plugin_config(self): # sourcery skip: extract-method """ - 加载插件配置文件,实现集中化管理和自动迁移。 + 加载并同步插件配置文件。 处理逻辑: - 1. 确定用户配置文件路径(位于 `config/plugins/` 目录下)。 - 2. 如果用户配置文件不存在,则根据 config_schema 直接在中央目录生成一份。 - 3. 加载用户配置文件,并进行版本检查和自动迁移(如果需要)。 - 4. 最终加载的配置是用户配置文件。 + 1. 确定用户配置文件路径和插件自带的配置文件路径。 + 2. 如果用户配置文件不存在,尝试从插件目录迁移(移动)一份。 + 3. 如果迁移后(或原本)用户配置文件仍不存在,则根据 schema 生成一份。 + 4. 加载用户配置文件。 + 5. 以 schema 为基准,与用户配置进行同步,补全缺失项并移除废弃项。 + 6. 如果同步过程发现不一致,则先备份原始文件,然后将同步后的完整配置写回用户目录。 + 7. 将最终同步后的配置加载到 self.config。 """ if not self.config_file_name: logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载") return - # 1. 确定并确保用户配置文件路径存在 user_config_path = os.path.join(CONFIG_DIR, "plugins", self.plugin_name, self.config_file_name) + plugin_config_path = os.path.join(self.plugin_dir, self.config_file_name) os.makedirs(os.path.dirname(user_config_path), exist_ok=True) - # 2. 如果用户配置文件不存在,直接在中央目录生成 + # 首次加载迁移:如果用户配置不存在,但插件目录中存在,则移动过来 + if not os.path.exists(user_config_path) and os.path.exists(plugin_config_path): + try: + shutil.move(plugin_config_path, user_config_path) + logger.info(f"{self.log_prefix} 已将配置文件从 {plugin_config_path} 迁移到 {user_config_path}") + except OSError as e: + logger.error(f"{self.log_prefix} 迁移配置文件失败: {e}", exc_info=True) + + # 如果用户配置文件仍然不存在,生成默认的 if not os.path.exists(user_config_path): logger.info(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在,将生成默认配置。") self._generate_and_save_default_config(user_config_path) - # 检查最终的用户配置文件是否存在 if not os.path.exists(user_config_path): - # 如果插件没有定义config_schema,那么不创建文件是正常行为 if not self.config_schema: - logger.debug(f"{self.log_prefix} 插件未定义config_schema,使用空的配置.") + logger.debug(f"{self.log_prefix} 插件未定义 config_schema,使用空配置。") self.config = {} - return - - logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。") + else: + logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。") return - # 3. 加载、检查和迁移用户配置文件 - _, file_ext = os.path.splitext(self.config_file_name) - if file_ext.lower() != ".toml": - logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml") - self.config = {} - return try: with open(user_config_path, "r", encoding="utf-8") as f: - existing_config = toml.load(f) or {} + user_config = toml.load(f) or {} except Exception as e: logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True) - self.config = {} + self.config = self._generate_config_from_schema() # 加载失败时使用默认 schema return - current_version = self._get_current_config_version(existing_config) - expected_version = self._get_expected_config_version() + # 生成基于 schema 的理想配置结构 + schema_config = self._generate_config_from_schema() - if current_version == "0.0.0": - logger.debug(f"{self.log_prefix} 用户配置文件无版本信息,跳过版本检查") - self.config = existing_config - elif current_version != expected_version: - logger.info( - f"{self.log_prefix} 检测到用户配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}" - ) - new_config_structure = self._generate_config_from_schema() - migrated_config = self._migrate_config_values(existing_config, new_config_structure) - self._save_config_to_file(migrated_config, user_config_path) - logger.info(f"{self.log_prefix} 用户配置文件已从 v{current_version} 更新到 v{expected_version}") - self.config = migrated_config - else: - logger.debug(f"{self.log_prefix} 用户配置版本匹配 (v{current_version}),直接加载") - self.config = existing_config + # 将用户配置与 schema 同步 + synced_config, was_changed = self._synchronize_config(schema_config, user_config) - logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载") + # 如果配置发生了变化(补全或移除),则备份并重写配置文件 + if was_changed: + logger.info(f"{self.log_prefix} 检测到配置结构不匹配,将自动同步并更新配置文件。") + self._backup_config_file(user_config_path) + self._save_config_to_file(synced_config, user_config_path) + logger.info(f"{self.log_prefix} 配置文件已同步更新。") - # 从配置中更新 enable_plugin 状态 + self.config = synced_config + logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载并同步。") + + # 从最终配置中更新插件启用状态 if "plugin" in self.config and "enabled" in self.config["plugin"]: self._is_enabled = self.config["plugin"]["enabled"] - logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}") + logger.info(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}") def _check_dependencies(self) -> bool: """检查插件依赖""" diff --git a/src/plugin_system/base/plus_plugin.py b/src/plugin_system/base/plus_plugin.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 237cb6429..0367add33 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -39,76 +39,6 @@ class PluginManager: self._ensure_plugin_directories() logger.info("插件管理器初始化完成") - def _synchronize_plugin_config(self, plugin_name: str, plugin_dir: str): - """ - 同步单个插件的配置。 - - 此过程确保中央配置与插件本地配置保持同步,包含两个主要步骤: - 1. 如果中央配置不存在,则从插件目录复制默认配置到中央配置目录。 - 2. 使用中央配置覆盖插件的本地配置,以确保插件运行时使用的是最新的用户配置。 - """ - try: - plugin_path = Path(plugin_dir) - # 修正:插件的配置文件路径应为 config.toml 文件,而不是目录 - plugin_config_file = plugin_path / "config.toml" - central_config_dir = Path("config") / "plugins" / plugin_name - - # 确保中央配置目录存在 - central_config_dir.mkdir(parents=True, exist_ok=True) - - # 步骤 1: 从插件目录复制默认配置到中央目录 - self._copy_default_config_to_central(plugin_name, plugin_config_file, central_config_dir) - - # 步骤 2: 从中央目录同步配置到插件目录 - self._sync_central_config_to_plugin(plugin_name, plugin_config_file, central_config_dir) - - except OSError as e: - logger.error(f"处理插件 '{plugin_name}' 的配置时发生文件操作错误: {e}") - except Exception as e: - logger.error(f"同步插件 '{plugin_name}' 配置时发生未知错误: {e}") - - @staticmethod - def _copy_default_config_to_central(plugin_name: str, plugin_config_file: Path, central_config_dir: Path): - """ - 如果中央配置不存在,则将插件的默认 config.toml 复制到中央目录。 - """ - if not plugin_config_file.is_file(): - return # 插件没有提供默认配置文件,直接跳过 - - central_config_file = central_config_dir / plugin_config_file.name - if not central_config_file.exists(): - shutil.copy2(plugin_config_file, central_config_file) - logger.info(f"为插件 '{plugin_name}' 从模板复制了默认配置: {plugin_config_file.name}") - - def _sync_central_config_to_plugin(self, plugin_name: str, plugin_config_file: Path, central_config_dir: Path): - """ - 将中央配置同步(覆盖)到插件的本地配置。 - """ - # 遍历中央配置目录中的所有文件 - for central_file in central_config_dir.iterdir(): - if not central_file.is_file(): - continue - - # 目标文件应与中央配置文件同名,这里我们强制它为 config.toml - target_plugin_file = plugin_config_file - - # 仅在文件内容不同时才执行复制,以减少不必要的IO操作 - if not self._is_file_content_identical(central_file, target_plugin_file): - shutil.copy2(central_file, target_plugin_file) - logger.info(f"已将中央配置 '{central_file.name}' 同步到插件 '{plugin_name}'") - - @staticmethod - def _is_file_content_identical(file1: Path, file2: Path) -> bool: - """ - 通过比较 MD5 哈希值检查两个文件的内容是否相同。 - """ - if not file2.exists(): - return False # 目标文件不存在,视为不同 - - # 使用 'rb' 模式以二进制方式读取文件,确保哈希值计算的一致性 - with open(file1, "rb") as f1, open(file2, "rb") as f2: - return hashlib.md5(f1.read()).hexdigest() == hashlib.md5(f2.read()).hexdigest() - # === 插件目录管理 === def add_plugin_directory(self, directory: str) -> bool: From df2e8b1aa85e5304ebe3ebd94f03741c4ba134b5 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 26 Sep 2025 22:08:54 +0800 Subject: [PATCH 30/41] =?UTF-8?q?=E8=AF=95=E5=9B=BEmaizone=E8=A7=A3?= =?UTF-8?q?=E5=86=B3=E9=97=AE=E9=A2=98=E4=BD=86=E4=B8=8D=E4=BF=9D=E8=AF=81?= =?UTF-8?q?=E6=88=90=E5=8A=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../services/qzone_service.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 67a3669db..e9b9303a1 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -657,20 +657,30 @@ class QZoneService: end_idx = resp_text.rfind("}") + 1 if start_idx != -1 and end_idx != -1: json_str = resp_text[start_idx:end_idx] - upload_result = eval(json_str) # 与原版保持一致使用eval + try: + upload_result = orjson.loads(json_str) + except orjson.JSONDecodeError: + logger.error(f"图片上传响应JSON解析失败,原始响应: {resp_text}") + return None - logger.info(f"图片上传解析结果: {upload_result}") + logger.debug(f"图片上传解析结果: {upload_result}") if upload_result.get("ret") == 0: - # 使用原版的参数提取逻辑 - picbo, richval = _get_picbo_and_richval(upload_result) - logger.info(f"图片 {index + 1} 上传成功: picbo={picbo}") - return {"pic_bo": picbo, "richval": richval} + try: + # 使用原版的参数提取逻辑 + picbo, richval = _get_picbo_and_richval(upload_result) + logger.info(f"图片 {index + 1} 上传成功: picbo={picbo}") + return {"pic_bo": picbo, "richval": richval} + except Exception as e: + logger.error( + f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}", exc_info=True + ) + return None else: logger.error(f"图片 {index + 1} 上传失败: {upload_result}") return None else: - logger.error("无法解析上传响应") + logger.error(f"无法从响应中提取JSON内容: {resp_text}") return None else: error_text = await response.text() From f968d134c7bd0027a8e2e3109b7280c3fec9660e Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 14:05:05 +0800 Subject: [PATCH 31/41] =?UTF-8?q?fix(plugin):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E6=8F=92=E4=BB=B6=E9=85=8D=E7=BD=AE=E5=90=8C=E6=AD=A5=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E5=B9=B6=E5=A2=9E=E5=BC=BA=E9=94=99=E8=AF=AF=E6=97=A5?= =?UTF-8?q?=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在插件管理器中,移除了加载插件时对已废弃的配置同步方法的调用。 同时,为了更好地排查 `maizone` 插件发送动态失败的问题,增强了其命令的异常日志,现在会额外记录异常类型。 --- src/plugin_system/core/plugin_manager.py | 2 -- .../built_in/maizone_refactored/commands/send_feed_command.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 0367add33..e0a39ac25 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -106,8 +106,6 @@ class PluginManager: if not plugin_dir: return False, 1 - # 同步插件配置 - self._synchronize_plugin_config(plugin_name, plugin_dir) plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) if not plugin_instance: diff --git a/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py b/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py index 819655e84..631ca430d 100644 --- a/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py +++ b/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py @@ -53,6 +53,6 @@ class SendFeedCommand(PlusCommand): return False, result.get("message", "未知错误"), True except Exception as e: - logger.error(f"执行发送说说命令时发生未知异常: {e}", exc_info=True) + logger.error(f"执行发送说说命令时发生未知异常: {e},它的类型是:{type(e)}", exc_info=True) await self.send_text("呜... 发送过程中好像出了点问题。") return False, "命令执行异常", True From 93b0a6a8629de26d03a1f660db198192ed87d8c7 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 14:06:22 +0800 Subject: [PATCH 32/41] =?UTF-8?q?fix(tool):=20=E5=A2=9E=E5=BC=BA=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E6=8F=90=E5=8F=96=E5=A4=B1=E8=B4=A5=E6=97=B6=E7=9A=84?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在信息提取过程中,当大语言模型(LLM)返回的 JSON 格式不正确时,先前的日志只会记录一个通用的解析错误,而不会显示导致失败的原始响应内容,这使得调试变得困难。 此次更新通过在捕获到 JSON 解析异常时,额外记录 LLM 的原始输出内容来解决此问题。这有助于快速诊断并定位是模型输出不稳定还是提示词需要调整,从而提高了脚本的健壮性和可维护性。 此外,还对代码进行了一些格式化调整以提高可读性。 --- scripts/lpmm_learning_tool.py | 97 ++++++++++++++++++++++++----------- 1 file changed, 67 insertions(+), 30 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 941494bc0..3fe26eb93 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -38,11 +38,13 @@ file_lock = Lock() # --- 模块一:数据预处理 --- + def process_text_file(file_path): with open(file_path, "r", encoding="utf-8") as f: raw = f.read() return [p.strip() for p in raw.split("\n\n") if p.strip()] + def preprocess_raw_data(): logger.info("--- 步骤 1: 开始数据预处理 ---") os.makedirs(RAW_DATA_PATH, exist_ok=True) @@ -50,7 +52,7 @@ def preprocess_raw_data(): if not raw_files: logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件") return [] - + all_paragraphs = [] for file in raw_files: logger.info(f"正在处理文件: {file.name}") @@ -61,8 +63,10 @@ def preprocess_raw_data(): logger.info("--- 数据预处理完成 ---") return unique_paragraphs + # --- 模块二:信息提取 --- + def get_extraction_prompt(paragraph: str) -> str: return f""" 请从以下段落中提取关键信息。你需要提取两种类型的信息: @@ -81,6 +85,7 @@ def get_extraction_prompt(paragraph: str) -> str: --- """ + async def extract_info_async(pg_hash, paragraph, llm_api): temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") with file_lock: @@ -92,11 +97,13 @@ async def extract_info_async(pg_hash, paragraph, llm_api): os.remove(temp_file_path) prompt = get_extraction_prompt(paragraph) + content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) extracted_data = orjson.loads(content) doc_item = { - "idx": pg_hash, "passage": paragraph, + "idx": pg_hash, + "passage": paragraph, "extracted_entities": extracted_data.get("entities", []), "extracted_triples": extracted_data.get("triples", []), } @@ -106,27 +113,45 @@ async def extract_info_async(pg_hash, paragraph, llm_api): return doc_item, None except Exception as e: logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") + if content: + logger.error(f"导致解析失败的原始输出: {content}") return None, pg_hash + def extract_info_sync(pg_hash, paragraph, llm_api): return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) + def extract_information(paragraphs_dict, model_set): logger.info("--- 步骤 2: 开始信息提取 ---") os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) - + llm_api = LLMRequest(model_set=model_set) failed_hashes, open_ie_docs = [], [] with ThreadPoolExecutor(max_workers=5) as executor: - f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()} - with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress: + f_to_hash = { + executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items() + } + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + ) as progress: task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) for future in as_completed(f_to_hash): doc_item, failed_hash = future.result() - if failed_hash: failed_hashes.append(failed_hash) - elif doc_item: open_ie_docs.append(doc_item) + if failed_hash: + failed_hashes.append(failed_hash) + elif doc_item: + open_ie_docs.append(doc_item) progress.update(task, advance=1) if open_ie_docs: @@ -135,19 +160,22 @@ def extract_information(paragraphs_dict, model_set): avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0 avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0 openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words) - + now = datetime.datetime.now() filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json") output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) with open(output_path, "wb") as f: f.write(orjson.dumps(openie_obj._to_dict())) logger.info(f"信息提取结果已保存到: {output_path}") - - if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") + + if failed_hashes: + logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") logger.info("--- 信息提取完成 ---") + # --- 模块三:数据导入 --- + async def import_data(openie_obj: Optional[OpenIE] = None): """ 将OpenIE数据导入知识库(Embedding Store 和 KG) @@ -159,15 +187,19 @@ async def import_data(openie_obj: Optional[OpenIE] = None): """ logger.info("--- 步骤 3: 开始数据导入 ---") embed_manager, kg_manager = EmbeddingManager(), KGManager() - + logger.info("正在加载现有的 Embedding 库...") - try: embed_manager.load_from_file() - except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。") + try: + embed_manager.load_from_file() + except Exception as e: + logger.warning(f"加载 Embedding 库失败: {e}。") logger.info("正在加载现有的 KG...") - try: kg_manager.load_from_file() - except Exception as e: logger.warning(f"加载 KG 失败: {e}。") - + try: + kg_manager.load_from_file() + except Exception as e: + logger.warning(f"加载 KG 失败: {e}。") + try: if openie_obj: openie_data = openie_obj @@ -180,7 +212,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None): raw_paragraphs = openie_data.extract_raw_paragraph_dict() triple_list_data = openie_data.extract_triple_dict() - + new_raw_paragraphs, new_triple_list_data = {}, {} stored_embeds = embed_manager.stored_pg_hashes stored_kgs = kg_manager.stored_paragraph_hashes @@ -189,7 +221,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None): if p_hash not in stored_embeds and p_hash not in stored_kgs: new_raw_paragraphs[p_hash] = raw_p new_triple_list_data[p_hash] = triple_list_data.get(p_hash, []) - + if not new_raw_paragraphs: logger.info("没有新的段落需要处理。") else: @@ -207,32 +239,35 @@ async def import_data(openie_obj: Optional[OpenIE] = None): logger.info("--- 数据导入完成 ---") + def import_from_specific_file(): """从用户指定的 openie.json 文件导入数据""" file_path = input("请输入 openie.json 文件的完整路径: ").strip() - + if not os.path.exists(file_path): logger.error(f"文件路径不存在: {file_path}") return - + if not file_path.endswith(".json"): logger.error("请输入一个有效的 .json 文件路径。") return try: logger.info(f"正在从 {file_path} 加载 OpenIE 数据...") - openie_obj = OpenIE.load(filepath=file_path) + openie_obj = OpenIE.load() asyncio.run(import_data(openie_obj=openie_obj)) except Exception as e: logger.error(f"从指定文件导入数据时发生错误: {e}") + # --- 主函数 --- + def main(): # 使用 os.path.relpath 创建相对于项目根目录的友好路径 raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")) - + print("=== LPMM 知识库学习工具 ===") print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)") print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)") @@ -243,24 +278,26 @@ def main(): print("-" * 30) choice = input("请输入你的选择 (0-5): ").strip() - if choice == '1': + if choice == "1": preprocess_raw_data() - elif choice == '2': + elif choice == "2": paragraphs = preprocess_raw_data() - if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) - elif choice == '3': + if paragraphs: + extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + elif choice == "3": asyncio.run(import_data()) - elif choice == '4': + elif choice == "4": paragraphs = preprocess_raw_data() if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) asyncio.run(import_data()) - elif choice == '5': + elif choice == "5": import_from_specific_file() - elif choice == '0': + elif choice == "0": sys.exit(0) else: print("无效输入,请重新运行脚本。") + if __name__ == "__main__": - main() \ No newline at end of file + main() From ddb7ef4d93350b44b04e53986f2085c9f9e3cc57 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 14:13:09 +0800 Subject: [PATCH 33/41] =?UTF-8?q?feat(tool):=20=E4=B8=BA=E5=AD=A6=E4=B9=A0?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E6=B7=BB=E5=8A=A0=E7=BC=93=E5=AD=98=E6=B8=85?= =?UTF-8?q?=E7=90=86=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为 lpmm_learning_tool.py 脚本添加了清理缓存的选项。 用户现在可以通过菜单选项 '6' 来删除 `temp/lpmm_cache` 目录下的所有临时文件。 此功能有助于释放磁盘空间,并可以在缓存数据陈旧或损坏时进行重置,提高了工具的可维护性。 --- scripts/lpmm_learning_tool.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 3fe26eb93..9a67a82ee 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -1,5 +1,6 @@ import asyncio import os +import shutil import sys import orjson import datetime @@ -36,6 +37,21 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") file_lock = Lock() +# --- 缓存清理 --- + +def clear_cache(): + """清理 lpmm_learning_tool.py 生成的缓存文件""" + logger.info("--- 开始清理缓存 ---") + if os.path.exists(TEMP_DIR): + try: + shutil.rmtree(TEMP_DIR) + logger.info(f"成功删除缓存目录: {TEMP_DIR}") + except OSError as e: + logger.error(f"删除缓存时出错: {e}") + else: + logger.info("缓存目录不存在,无需清理。") + logger.info("--- 缓存清理完成 ---") + # --- 模块一:数据预处理 --- @@ -274,6 +290,7 @@ def main(): print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识") print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") + print("6. [清理缓存] -> 删除所有已提取信息的缓存") print("0. [退出]") print("-" * 30) choice = input("请输入你的选择 (0-5): ").strip() @@ -293,6 +310,8 @@ def main(): asyncio.run(import_data()) elif choice == "5": import_from_specific_file() + elif choice == "6": + clear_cache() elif choice == "0": sys.exit(0) else: From 866d50c6dc3b4b0230ba511bb06dd43ffa435680 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 14:19:43 +0800 Subject: [PATCH 34/41] =?UTF-8?q?=E5=8A=A0=E4=BA=86=E4=B8=80=E4=B8=AA=20?= =?UTF-8?q?=20=20=20=E5=B0=9D=E8=AF=95=E8=A7=A3=E6=9E=90JSON=E5=AD=97?= =?UTF-8?q?=E7=AC=A6=E4=B8=B2=EF=BC=8C=E5=A6=82=E6=9E=9C=E5=A4=B1=E8=B4=A5?= =?UTF-8?q?=E5=88=99=E5=B0=9D=E8=AF=95=E4=BF=AE=E5=A4=8D=E5=B9=B6=E9=87=8D?= =?UTF-8?q?=E6=96=B0=E8=A7=A3=E6=9E=90=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/lpmm_learning_tool.py | 57 ++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 9a67a82ee..f0888d552 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -8,6 +8,7 @@ from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock from typing import Optional +from json_repair import repair_json # 将项目根目录添加到 sys.path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -83,6 +84,53 @@ def preprocess_raw_data(): # --- 模块二:信息提取 --- +def _parse_and_repair_json(json_string: str) -> Optional[dict]: + """ + 尝试解析JSON字符串,如果失败则尝试修复并重新解析。 + + 该函数首先会清理字符串,去除常见的Markdown代码块标记, + 然后尝试直接解析。如果解析失败,它会调用 `repair_json` + 进行修复,并再次尝试解析。 + + Args: + json_string: 从LLM获取的、可能格式不正确的JSON字符串。 + + Returns: + 解析后的字典。如果最终无法解析,则返回 None,并记录详细错误日志。 + """ + if not isinstance(json_string, str): + logger.error(f"输入内容非字符串,无法解析: {type(json_string)}") + return None + + # 1. 预处理:去除常见的多余字符,如Markdown代码块标记 + cleaned_string = json_string.strip() + if cleaned_string.startswith("```json"): + cleaned_string = cleaned_string[7:].strip() + elif cleaned_string.startswith("```"): + cleaned_string = cleaned_string[3:].strip() + + if cleaned_string.endswith("```"): + cleaned_string = cleaned_string[:-3].strip() + + # 2. 性能优化:乐观地尝试直接解析 + try: + return orjson.loads(cleaned_string) + except orjson.JSONDecodeError: + logger.warning("直接解析JSON失败,将尝试修复...") + + # 3. 修复与最终解析 + repaired_json_str = "" + try: + repaired_json_str = repair_json(cleaned_string) + return orjson.loads(repaired_json_str) + except Exception as e: + # 4. 增强错误处理:记录详细的失败信息 + logger.error(f"修复并解析JSON后依然失败: {e}") + logger.error(f"原始字符串 (清理后): {cleaned_string}") + logger.error(f"修复后尝试解析的字符串: {repaired_json_str}") + return None + + def get_extraction_prompt(paragraph: str) -> str: return f""" 请从以下段落中提取关键信息。你需要提取两种类型的信息: @@ -116,7 +164,14 @@ async def extract_info_async(pg_hash, paragraph, llm_api): content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) - extracted_data = orjson.loads(content) + + # 改进点:调用封装好的函数处理JSON解析和修复 + extracted_data = _parse_and_repair_json(content) + + if extracted_data is None: + # 如果解析失败,抛出异常以触发统一的错误处理逻辑 + raise ValueError("无法从LLM输出中解析有效的JSON数据") + doc_item = { "idx": pg_hash, "passage": paragraph, From 3cded7220a471b3b41c437c2ed0ccb1ecefb1699 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 14:37:06 +0800 Subject: [PATCH 35/41] =?UTF-8?q?fix(chat):=20=E5=AE=8C=E5=96=84LLM?= =?UTF-8?q?=E5=88=86=E5=8F=A5=E9=80=BB=E8=BE=91=EF=BC=8C=E5=9C=A8=E6=97=A0?= =?UTF-8?q?=E5=88=86=E5=89=B2=E6=A0=87=E8=AE=B0=E6=97=B6=E5=9B=9E=E9=80=80?= =?UTF-8?q?=E8=87=B3=E6=A0=87=E7=82=B9=E5=88=86=E5=89=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当使用 "llm" 模式进行分句时,如果模型未能按预期生成 `[SPLIT]` 标记,之前的逻辑会直接返回整个未分割的文本。 这可能导致过长的句子被发送到下游模块(如TTS),影响体验。本次修改添加了回退机制,当未检测到 `[SPLIT]` 标记时,会自动切换到基于标点的传统分句方法,以提高分句的鲁棒性。 --- src/chat/utils/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 85e665328..746b13e63 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -341,9 +341,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese split_sentences = [s.strip() for s in split_sentences_raw if s.strip()] else: if split_mode == "llm": - logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。") - split_sentences = [cleaned_text] - else: # mode == "punctuation" + logger.debug("未检测到 [SPLIT] 标记,回退到基于标点的传统模式进行分割。") + split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) + else: # mode == "punctuation" logger.debug("使用基于标点的传统模式进行分割。") split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) else: From e544173c3f33a9be88903248558a26e7025a0476 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 14:45:58 +0800 Subject: [PATCH 36/41] =?UTF-8?q?Revert=20"feat(proactive):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E4=B8=BB=E5=8A=A8=E8=81=8A=E5=A4=A9=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=8A=A0=E6=90=9C=E7=B4=A2=E5=89=8D=E5=88=A4?= =?UTF-8?q?=E6=96=AD=E4=B8=8E=E5=9B=9E=E5=A4=8D=E5=89=8D=E6=A3=80=E6=9F=A5?= =?UTF-8?q?"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 34bce03f17325b321df94b4eee846b12b9d72dda. --- .../chat_loop/proactive/proactive_thinker.py | 120 ++++-------------- 1 file changed, 23 insertions(+), 97 deletions(-) diff --git a/src/chat/chat_loop/proactive/proactive_thinker.py b/src/chat/chat_loop/proactive/proactive_thinker.py index 4dea5ec99..34abf7803 100644 --- a/src/chat/chat_loop/proactive/proactive_thinker.py +++ b/src/chat/chat_loop/proactive/proactive_thinker.py @@ -162,107 +162,33 @@ class ProactiveThinker: news_block = "暂时没有获取到最新资讯。" if trigger_event.source != "reminder_system": - # 升级决策模型 - should_search_prompt = f""" -# 搜索决策 - -## 任务 -分析话题“{topic}”,判断它的展开更依赖于“外部信息”还是“内部信息”,并决定是否需要进行网络搜索。 - -## 判断原则 -- **需要搜索 (SEARCH)**:当话题的有效讨论**必须**依赖于现实世界的、客观的、可被检索的外部信息时。这包括但不限于: - - 新闻时事、公共事件 - - 专业知识、科学概念 - - 天气、股价等实时数据 - - 对具体实体(如电影、书籍、地点)的客观描述查询 - -- **无需搜索 (SKIP)**:当话题的展开主要依赖于**已有的对话上下文、个人情感、主观体验或社交互动**时。这包括但不限于: - - 延续之前的对话、追问细节 - - 表达关心、问候或个人感受 - - 分享主观看法或经历 - - 纯粹的社交性互动 - -## 你的决策 -根据以上原则,对“{topic}”这个话题进行分析,并严格输出`SEARCH`或`SKIP`。 -""" - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config - - decision_llm = LLMRequest( - model_set=model_config.model_task_config.planner, - request_type="planner" - ) - - decision, _ = await decision_llm.generate_response_async(prompt=should_search_prompt) - - if "SEARCH" in decision: - try: - if topic and topic.strip(): - web_search_tool = tool_api.get_tool_instance("web_search") - if web_search_tool: - try: - search_result_dict = await web_search_tool.execute( - function_args={"query": topic, "max_results": 10} - ) - if search_result_dict and not search_result_dict.get("error"): - news_block = search_result_dict.get("content", "未能提取有效资讯。") - elif search_result_dict: - logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}") - except Exception as e: - logger.error(f"{self.context.log_prefix} 网络搜索执行失败: {e}") - else: - logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。") - else: - logger.warning(f"{self.context.log_prefix} 主题为空,跳过网络搜索。") - except Exception as e: - logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}") - message_list = await get_raw_msg_before_timestamp_with_chat( + try: + web_search_tool = tool_api.get_tool_instance("web_search") + if web_search_tool: + try: + search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10}) + except TypeError: + try: + search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10}) + except TypeError: + logger.warning(f"{self.context.log_prefix} 网络搜索工具参数不匹配,跳过搜索") + news_block = "跳过网络搜索。" + search_result_dict = None + + if search_result_dict and not search_result_dict.get("error"): + news_block = search_result_dict.get("content", "未能提取有效资讯。") + elif search_result_dict: + logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}") + else: + logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。") + except Exception as e: + logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}") + message_list = get_raw_msg_before_timestamp_with_chat( chat_id=self.context.stream_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.3), ) - chat_context_block, _ = await build_readable_messages_with_id(messages=message_list) - - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config - - bot_name = global_config.bot.nickname - - confirmation_prompt = f"""# 主动回复二次确认 - -## 基本信息 -你的名字是{bot_name},准备主动发起关于"{topic}"的话题。 - -## 最近的聊天内容 -{chat_context_block} - -## 合理判断标准 -请检查以下条件,如果**所有条件都合理**就可以回复: - -1. **回应检查**:检查你({bot_name})发送的最后一条消息之后,是否有其他人发言。如果没有,则大概率应该保持沉默。 -2. **话题补充**:只有当你认为准备发起的话题是对上一条无人回应消息的**有价值的补充**时,才可以在上一条消息无人回应的情况下继续发言。 -3. **时间合理性**:当前时间是否在深夜(凌晨2点-6点)这种不适合主动聊天的时段? -4. **内容价值**:这个话题"{topic}"是否有意义,不是完全无关紧要的内容? -5. **重复避免**:你准备说的话题是否与你自己的上一条消息明显重复? -6. **自然性**:在当前上下文中主动提起这个话题是否自然合理? - -## 输出要求 -如果判断应该跳过(比如上一条消息无人回应、深夜时段、无意义话题、重复内容),输出:SKIP_PROACTIVE_REPLY -其他情况都应该输出:PROCEED_TO_REPLY - -请严格按照上述格式输出,不要添加任何解释。""" - - planner_llm = LLMRequest( - model_set=model_config.model_task_config.planner, - request_type="planner" - ) - - confirmation_result, _ = await planner_llm.generate_response_async(prompt=confirmation_prompt) - - if not confirmation_result or "SKIP_PROACTIVE_REPLY" in confirmation_result: - logger.info(f"{self.context.log_prefix} 决策模型二次确认决定跳过主动回复") - return - + chat_context_block, _ = await build_readable_messages_with_id(messages=message_list) bot_name = global_config.bot.nickname personality = global_config.personality identity_block = ( From 9775b9e73198b39a0eb0d69bc90f94c255f77600 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 14:59:54 +0800 Subject: [PATCH 37/41] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E7=82=B9=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/chat_loop/cycle_processor.py | 7 ------- src/chat/planner_actions/plan_filter.py | 4 ++-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py index e571b5ac5..79d4eca9d 100644 --- a/src/chat/chat_loop/cycle_processor.py +++ b/src/chat/chat_loop/cycle_processor.py @@ -206,13 +206,6 @@ class CycleProcessor: raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成") with Timer("规划器", cycle_timers): actions, _ = await self.action_planner.plan(mode=mode) - - # 在这里添加日志,清晰地显示最终选择的动作 - if actions: - chosen_actions = [a.get("action_type", "unknown") for a in actions] - logger.info(f"{self.log_prefix} LLM最终选择的动作: {chosen_actions}") - else: - logger.info(f"{self.log_prefix} LLM最终没有选择任何动作") async def execute_action(action_info): """执行单个动作的通用函数""" diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index fccda0230..6aaefba18 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -46,12 +46,12 @@ class PlanFilter: try: prompt, used_message_id_list = await self._build_prompt(plan) plan.llm_prompt = prompt - logger.debug(f"墨墨在这里加了日志 -> LLM prompt: {prompt}") + logger.info(f"规划器原始提示词: {prompt}") llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt) if llm_content: - logger.debug(f"墨墨在这里加了日志 -> LLM a原始返回: {llm_content}") + logger.info(f"规划器原始返回: {llm_content}") parsed_json = orjson.loads(repair_json(llm_content)) logger.debug(f"墨墨在这里加了日志 -> 解析后的 JSON: {parsed_json}") From b32a9343935e589c571ca26b8f848df778574f2b Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 15:06:02 +0800 Subject: [PATCH 38/41] =?UTF-8?q?fix(chat):=20=E7=A1=AE=E4=BF=9D=20planner?= =?UTF-8?q?=20=E6=8F=90=E7=A4=BA=E8=AF=8D=E6=A8=A1=E5=9D=97=E8=A2=AB?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 通过显式导入 planner_prompts 模块,确保其中的提示词在 planner 实例化之前被正确注册,避免潜在的引用问题。 --- src/chat/planner_actions/planner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 6e45b7907..93be817fd 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -10,7 +10,7 @@ from src.chat.planner_actions.plan_filter import PlanFilter from src.chat.planner_actions.plan_generator import PlanGenerator from src.common.logger import get_logger from src.plugin_system.base.component_types import ChatMode - +import src.chat.planner_actions.planner_prompts #noga # 导入提示词模块以确保其被初始化 logger = get_logger("planner") From 86c3c78259d87563902696c31b9cbfc9e53b20ee Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 15:06:44 +0800 Subject: [PATCH 39/41] =?UTF-8?q?style(chat):=20=E4=B8=BA=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E8=AF=8D=E5=AF=BC=E5=85=A5=E6=B7=BB=E5=8A=A0=20noqa?= =?UTF-8?q?=20=E4=BB=A5=E5=BF=BD=E7=95=A5=E6=9C=AA=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E8=AD=A6=E5=91=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/planner_actions/planner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 93be817fd..0e3d1afc3 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -10,7 +10,7 @@ from src.chat.planner_actions.plan_filter import PlanFilter from src.chat.planner_actions.plan_generator import PlanGenerator from src.common.logger import get_logger from src.plugin_system.base.component_types import ChatMode -import src.chat.planner_actions.planner_prompts #noga +import src.chat.planner_actions.planner_prompts #noga # noqa: F401 # 导入提示词模块以确保其被初始化 logger = get_logger("planner") From fe201c389e662b17d1df9425c5a8a1f2d61f9921 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 15:47:30 +0800 Subject: [PATCH 40/41] 1 --- src/chat/chat_loop/heartFC_chat.py | 221 +++++++++++++++-------------- 1 file changed, 111 insertions(+), 110 deletions(-) diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index fca7df847..adc868117 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -39,6 +39,7 @@ class HeartFChatting: """ self.context = HfcContext(chat_id) self.context.new_message_queue = asyncio.Queue() + self._processing_lock = asyncio.Lock() self.cycle_tracker = CycleTracker(self.context) self.response_handler = ResponseHandler(self.context) @@ -357,130 +358,130 @@ class HeartFChatting: - FOCUS模式:直接处理所有消息并检查退出条件 - NORMAL模式:检查进入FOCUS模式的条件,并通过normal_mode_handler处理消息 """ - # --- 核心状态更新 --- - await self.sleep_manager.update_sleep_state(self.wakeup_manager) - current_sleep_state = self.sleep_manager.get_current_sleep_state() - is_sleeping = current_sleep_state == SleepState.SLEEPING - is_in_insomnia = current_sleep_state == SleepState.INSOMNIA + async with self._processing_lock: + # --- 核心状态更新 --- + await self.sleep_manager.update_sleep_state(self.wakeup_manager) + current_sleep_state = self.sleep_manager.get_current_sleep_state() + is_sleeping = current_sleep_state == SleepState.SLEEPING + is_in_insomnia = current_sleep_state == SleepState.INSOMNIA - # 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收 - filter_command_flag = not (is_sleeping or is_in_insomnia) + # 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收 + filter_command_flag = not (is_sleeping or is_in_insomnia) - # 从队列中获取所有待处理的新消息 - recent_messages = [] - while not self.context.new_message_queue.empty(): - recent_messages.append(await self.context.new_message_queue.get()) + # 从队列中获取所有待处理的新消息 + recent_messages = [] + while not self.context.new_message_queue.empty(): + recent_messages.append(await self.context.new_message_queue.get()) - has_new_messages = bool(recent_messages) - new_message_count = len(recent_messages) + has_new_messages = bool(recent_messages) + new_message_count = len(recent_messages) - # 只有在有新消息时才进行思考循环处理 - if has_new_messages: - self.context.last_message_time = time.time() - self.context.last_read_time = time.time() + # 只有在有新消息时才进行思考循环处理 + if has_new_messages: + self.context.last_message_time = time.time() + self.context.last_read_time = time.time() - # --- 专注模式安静群组检查 --- - quiet_groups = global_config.chat.focus_mode_quiet_groups - if quiet_groups and self.context.chat_stream: - is_group_chat = self.context.chat_stream.group_info is not None - if is_group_chat: - try: - platform = self.context.chat_stream.platform - group_id = self.context.chat_stream.group_info.group_id - - # 兼容不同QQ适配器的平台名称 - is_qq_platform = platform in ["qq", "napcat"] - - current_chat_identifier = f"{platform}:{group_id}" - config_identifier_for_qq = f"qq:{group_id}" - - is_in_quiet_list = (current_chat_identifier in quiet_groups or - (is_qq_platform and config_identifier_for_qq in quiet_groups)) - - if is_in_quiet_list: - is_mentioned_in_batch = False - for msg in recent_messages: - if msg.get("is_mentioned"): - is_mentioned_in_batch = True - break + # --- 专注模式安静群组检查 --- + quiet_groups = global_config.chat.focus_mode_quiet_groups + if quiet_groups and self.context.chat_stream: + is_group_chat = self.context.chat_stream.group_info is not None + if is_group_chat: + try: + platform = self.context.chat_stream.platform + group_id = self.context.chat_stream.group_info.group_id - if not is_mentioned_in_batch: - logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。") - return True # 消耗消息但不做回复 + # 兼容不同QQ适配器的平台名称 + is_qq_platform = platform in ["qq", "napcat"] + + current_chat_identifier = f"{platform}:{group_id}" + config_identifier_for_qq = f"qq:{group_id}" + + is_in_quiet_list = (current_chat_identifier in quiet_groups or + (is_qq_platform and config_identifier_for_qq in quiet_groups)) + + if is_in_quiet_list: + is_mentioned_in_batch = False + for msg in recent_messages: + if msg.get("is_mentioned"): + is_mentioned_in_batch = True + break + + if not is_mentioned_in_batch: + logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。") + return True # 消耗消息但不做回复 + except Exception as e: + logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}") + + # 处理唤醒度逻辑 + if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]: + self._handle_wakeup_messages(recent_messages) + + # 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP + current_sleep_state = self.sleep_manager.get_current_sleep_state() + + if current_sleep_state == SleepState.SLEEPING: + # 只有在纯粹的 SLEEPING 状态下才跳过消息处理 + return True + + if current_sleep_state == SleepState.WOKEN_UP: + logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。") + + # 根据聊天模式处理新消息 + should_process, interest_value = await self._should_process_messages(recent_messages) + if not should_process: + # 消息数量不足或兴趣不够,等待 + await asyncio.sleep(0.5) + return True # Skip rest of the logic for this iteration + + # Messages should be processed + action_type = await self.cycle_processor.observe(interest_value=interest_value) + + # 尝试触发表达学习 + if self.context.expression_learner: + try: + await self.context.expression_learner.trigger_learning_for_chat() except Exception as e: - logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}") + logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}") - # 处理唤醒度逻辑 - if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]: - self._handle_wakeup_messages(recent_messages) + # 管理no_reply计数器 + if action_type != "no_reply": + self.recent_interest_records.clear() + self.context.no_reply_consecutive = 0 + logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器") + else: # action_type == "no_reply" + self.context.no_reply_consecutive += 1 + self._determine_form_type() - # 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP - current_sleep_state = self.sleep_manager.get_current_sleep_state() + # 在一轮动作执行完毕后,增加睡眠压力 + if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system: + if action_type not in ["no_reply", "no_action"]: + self.context.energy_manager.increase_sleep_pressure() - if current_sleep_state == SleepState.SLEEPING: - # 只有在纯粹的 SLEEPING 状态下才跳过消息处理 - return True - - if current_sleep_state == SleepState.WOKEN_UP: - logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。") - - # 根据聊天模式处理新消息 - should_process, interest_value = await self._should_process_messages(recent_messages) - if not should_process: - # 消息数量不足或兴趣不够,等待 - await asyncio.sleep(0.5) - return True # Skip rest of the logic for this iteration - - # Messages should be processed - action_type = await self.cycle_processor.observe(interest_value=interest_value) - - # 尝试触发表达学习 - if self.context.expression_learner: - try: - await self.context.expression_learner.trigger_learning_for_chat() - except Exception as e: - logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}") - - # 管理no_reply计数器 - if action_type != "no_reply": - self.recent_interest_records.clear() - self.context.no_reply_consecutive = 0 - logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器") - else: # action_type == "no_reply" - self.context.no_reply_consecutive += 1 - self._determine_form_type() - - # 在一轮动作执行完毕后,增加睡眠压力 - if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system: - if action_type not in ["no_reply", "no_action"]: - self.context.energy_manager.increase_sleep_pressure() - - # 如果成功观察,增加能量值并重置累积兴趣值 - self.context.energy_value += 1 / global_config.chat.focus_value - # 重置累积兴趣值,因为消息已经被成功处理 - self.context.breaking_accumulated_interest = 0.0 - logger.info( - f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值" - ) - - # 更新上一帧的睡眠状态 - self.context.was_sleeping = is_sleeping - - # --- 重新入睡逻辑 --- - # 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡 - if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages: - re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60 - # 使用 last_message_time 来判断空闲时间 - if time.time() - self.context.last_message_time > re_sleep_delay: + # 如果成功观察,增加能量值并重置累积兴趣值 + self.context.energy_value += 1 / global_config.chat.focus_value + # 重置累积兴趣值,因为消息已经被成功处理 + self.context.breaking_accumulated_interest = 0.0 logger.info( - f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。" + f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值" ) - self.sleep_manager.reset_sleep_state_after_wakeup() - # 保存HFC上下文状态 - self.context.save_context_state() + # 更新上一帧的睡眠状态 + self.context.was_sleeping = is_sleeping - return has_new_messages + # --- 重新入睡逻辑 --- + # 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡 + if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages: + re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60 + # 使用 last_message_time 来判断空闲时间 + if time.time() - self.context.last_message_time > re_sleep_delay: + logger.info( + f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。" + ) + self.sleep_manager.reset_sleep_state_after_wakeup() + + # 保存HFC上下文状态 + self.context.save_context_state() + return has_new_messages def _handle_wakeup_messages(self, messages): """ From f9fbfe319f48a9b9e9c291cb34bc90a4f0620d0b Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 15:57:00 +0800 Subject: [PATCH 41/41] 1 --- .../maizone_refactored/services/content_service.py | 7 ++----- .../maizone_refactored/services/cookie_service.py | 3 +-- .../maizone_refactored/services/image_service.py | 3 +-- .../maizone_refactored/services/qzone_service.py | 9 +++------ .../maizone_refactored/services/reply_tracker_service.py | 3 +-- .../maizone_refactored/services/scheduler_service.py | 8 +++----- .../built_in/maizone_refactored/utils/history_utils.py | 3 +-- 7 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index 9f7da7ccf..27f2a0ee9 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -119,12 +119,10 @@ class ContentService: logger.error(f"生成说说内容时发生异常: {e}") return "" - async def generate_comment(self, content: str, target_name: str, rt_con: str = "", images=None) -> str: + async def generate_comment(self, content: str, target_name: str, rt_con: str = "", images: list = []) -> str: """ 针对一条具体的说说内容生成评论。 """ - if images is None: - images = [] for i in range(3): # 重试3次 try: chat_manager = get_chat_manager() @@ -182,8 +180,7 @@ class ContentService: return "" return "" - @staticmethod - async def generate_comment_reply(story_content: str, comment_content: str, commenter_name: str) -> str: + async def generate_comment_reply(self, story_content: str, comment_content: str, commenter_name: str) -> str: """ 针对自己说说的评论,生成回复。 """ diff --git a/src/plugins/built_in/maizone_refactored/services/cookie_service.py b/src/plugins/built_in/maizone_refactored/services/cookie_service.py index 1c61a29fd..b4aedf322 100644 --- a/src/plugins/built_in/maizone_refactored/services/cookie_service.py +++ b/src/plugins/built_in/maizone_refactored/services/cookie_service.py @@ -50,8 +50,7 @@ class CookieService: logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}") return None - @staticmethod - async def _get_cookies_from_adapter(stream_id: Optional[str]) -> Optional[Dict[str, str]]: + async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]: """通过Adapter API获取Cookie""" try: params = {"domain": "user.qzone.qq.com"} diff --git a/src/plugins/built_in/maizone_refactored/services/image_service.py b/src/plugins/built_in/maizone_refactored/services/image_service.py index 1ffcd7d70..cbb411da7 100644 --- a/src/plugins/built_in/maizone_refactored/services/image_service.py +++ b/src/plugins/built_in/maizone_refactored/services/image_service.py @@ -59,8 +59,7 @@ class ImageService: logger.error(f"处理AI配图时发生异常: {e}") return False - @staticmethod - async def _call_siliconflow_api(api_key: str, story: str, image_dir: str, batch_size: int) -> bool: + async def _call_siliconflow_api(self, api_key: str, story: str, image_dir: str, batch_size: int) -> bool: """ 调用硅基流动(SiliconFlow)的API来生成图片。 diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index e9b9303a1..752e27dfa 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -187,8 +187,7 @@ class QZoneService: # --- Internal Helper Methods --- - @staticmethod - async def _get_intercom_context(stream_id: str) -> Optional[str]: + async def _get_intercom_context(self, stream_id: str) -> Optional[str]: """ 根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。 @@ -399,8 +398,7 @@ class QZoneService: logger.error(f"加载本地图片失败: {e}") return [] - @staticmethod - def _generate_gtk(skey: str) -> str: + def _generate_gtk(self, skey: str) -> str: hash_val = 5381 for char in skey: hash_val += (hash_val << 5) + ord(char) @@ -437,8 +435,7 @@ class QZoneService: logger.error(f"更新或加载Cookie时发生异常: {e}") return None - @staticmethod - async def _fetch_cookies_http(host: str, port: str, napcat_token: str) -> Optional[Dict]: + async def _fetch_cookies_http(self, host: str, port: str, napcat_token: str) -> Optional[Dict]: """通过HTTP服务器获取Cookie""" url = f"http://{host}:{port}/get_cookies" max_retries = 5 diff --git a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py index 3aabc88b6..0fa7edb99 100644 --- a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py +++ b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py @@ -36,8 +36,7 @@ class ReplyTrackerService: self._load_data() logger.debug(f"ReplyTrackerService initialized with data file: {self.reply_record_file}") - @staticmethod - def _validate_data(data: Any) -> bool: + def _validate_data(self, data: Any) -> bool: """验证加载的数据格式是否正确""" if not isinstance(data, dict): logger.error("加载的数据不是字典格式") diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 6124f4f06..ed32da48d 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -129,8 +129,7 @@ class SchedulerService: logger.error(f"定时任务循环中发生未知错误: {e}\n{traceback.format_exc()}") await asyncio.sleep(300) # 发生错误后,等待一段时间再重试 - @staticmethod - async def _is_processed(hour_str: str, activity: str) -> bool: + async def _is_processed(self, hour_str: str, activity: str) -> bool: """ 检查指定的任务(某个小时的某个活动)是否已经被成功处理过。 @@ -153,8 +152,7 @@ class SchedulerService: logger.error(f"检查日程处理状态时发生数据库错误: {e}") return False # 数据库异常时,默认为未处理,允许重试 - @staticmethod - async def _mark_as_processed(hour_str: str, activity: str, success: bool, content: str): + async def _mark_as_processed(self, hour_str: str, activity: str, success: bool, content: str): """ 将任务的处理状态和结果写入数据库。 @@ -187,7 +185,7 @@ class SchedulerService: send_success=success, ) session.add(new_record) - await session.commit() + session.commit() logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}") except Exception as e: logger.error(f"更新日程处理状态时发生数据库错误: {e}") diff --git a/src/plugins/built_in/maizone_refactored/utils/history_utils.py b/src/plugins/built_in/maizone_refactored/utils/history_utils.py index 171396de2..19b3e7baa 100644 --- a/src/plugins/built_in/maizone_refactored/utils/history_utils.py +++ b/src/plugins/built_in/maizone_refactored/utils/history_utils.py @@ -49,8 +49,7 @@ class _SimpleQZoneAPI: if p_skey: self.gtk2 = self._generate_gtk(p_skey) - @staticmethod - def _generate_gtk(skey: str) -> str: + def _generate_gtk(self, skey: str) -> str: hash_val = 5381 for char in skey: hash_val += (hash_val << 5) + ord(char)