diff --git a/.github/workflows/create-prerelease.yml b/.github/workflows/create-prerelease.yml deleted file mode 100644 index ea0cedcdf..000000000 --- a/.github/workflows/create-prerelease.yml +++ /dev/null @@ -1,34 +0,0 @@ -# 当代码推送到 master 分支时,自动创建一个 pre-release - -name: Create Pre-release - -on: - push: - branches: - - master - -jobs: - create-prerelease: - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - # 获取所有提交历史,以便生成 release notes - fetch-depth: 0 - - - name: Generate tag name - id: generate_tag - run: echo "TAG_NAME=MoFox-prerelease-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT - - - name: Create Pre-release - env: - # 使用仓库自带的 GITHUB_TOKEN 进行认证 - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - gh release create ${{ steps.generate_tag.outputs.TAG_NAME }} \ - --title "Pre-release ${{ steps.generate_tag.outputs.TAG_NAME }}" \ - --prerelease \ - --generate-notes \ No newline at end of file diff --git a/EULA.md b/EULA.md deleted file mode 100644 index bebfedd91..000000000 --- a/EULA.md +++ /dev/null @@ -1,94 +0,0 @@ -# **欢迎使用 MoFox_Bot (第三方修改版)!** - -**版本:V2.1** -**更新日期:2025年8月30日** - ---- - -你好!感谢你选择 MoFox_Bot。在开始之前,请花几分钟时间阅读这份协议。我们用问答的形式,帮助你快速了解使用这个**第三方修改版**软件时,你的权利和责任。 - -**简单来说,你需要同意这份协议才能使用我们的软件。** 如果你是未成年人,请确保你的监护人也阅读并同意了哦。 - ---- - -### **1. 这个软件和原版有什么关系?** - -这是一个非常重要的问题! - -* **第三方修改版**:首先,你需要清楚地知道,MoFox_Bot 是一个基于[MaiCore](https://mai-mai.org/)开源项目的**第三方修改版**。我们(MoFox_Bot 团队)与原始项目的开发者**没有任何关联**。 -* **独立维护**:我们独立负责这个修改版的维护、开发和更新。因此,原始项目的开发者**不会**为 MoFox_Bot 提供任何技术支持,也**不会**对因使用本修改版产生的任何问题负责。 -* **责任划分**:我们只对我们修改、添加的功能负责。对于原始代码中可能存在的任何问题,我们不承担责任。 - ---- - -### **2. 这个软件是免费和开源的吗?** - -**是的,核心代码是开源的!** - -* **遵循开源协议**:本项目继承了原始项目的 **GPLv3 开源协议**。这意味着你可以自由地使用、复制、研究、修改和重新分发它。 -* **你的义务**:当你修改或分发这个软件时,你同样需要遵守 GPLv3 协议的要求,确保你的衍生作品也是开源的。你可以在项目根目录找到 `LICENSE` 文件来了解更多细节。 -* **包含第三方代码**:请注意,项目中可能还包含了其他第三方库或组件,它们各自有独立的开源许可证。你在使用时需要同时遵守这些许可证。 - ---- - -### **3. 我的个人数据是如何被处理的?** - -你的隐私对我们很重要。了解数据如何流动,能帮助你更好地保护自己。 - -* **工作流程**:当你与机器人互动时,你的**输入内容**(比如文字、指令)、**配置信息**以及机器人**生成的回复**,会被发送给第三方的 API 服务(例如 OpenAI、Google 等大语言模型提供商)以获得智能回复。 -* **你的明确授权**:一旦你开始使用,即表示你授权我们利用你的数据进行以下操作: - 1. **调用外部 API**:这是机器人能与你对话的核心。 - 2. **建立本地知识库与记忆**:为了让机器人更个性化、更懂你,软件会在**你自己的设备上**创建和存储知识库、记忆库和对话日志。**这些数据存储在本地,我们无法访问。** - 3. **记录本地日志**:为了方便排查可能出现的技术问题,软件会在你的设备上记录运行日志。 -* **第三方服务的风险**:我们无法控制第三方 API 提供商的服务质量、数据处理政策、稳定性或安全性。使用这些服务时,你同样受到该第三方服务条款和隐私政策的约束。我们建议你自行了解这些条款。 - ---- - -### **4. 关于强大的插件系统,我需要了解什么?** - -MoFox_Bot 通过插件系统实现功能扩展,但这需要你承担相应的责任。 - -* **谁开发的插件?**:绝大多数插件是由**社区里的第三方开发者**创建和维护的,他们并**不是 MoFox_Bot 核心团队的成员**。 -* **责任完全自负**:插件的功能、质量、安全性和合法性**完全由其各自的开发者负责**。我们只提供了一个能让插件运行的技术平台,但**不对任何第三方插件的内容、行为或造成的后果承担任何责任**。 -* **你的使用风险**:使用任何第三方插件的风险**完全由你自行承担**。在安装和使用插件前,我们强烈建议你: - * 仔细阅读并理解插件开发者提供的许可协议和说明文档。 - * **只从你完全信任的来源获取和安装插件**。 - * 自行评估插件的安全性、合法性及其对你数据隐私的影响。 - ---- - -### **5. 我在使用时,有哪些行为准则?** - -请务必合法、合规地使用本软件。 - -* **禁止内容**:严禁输入、处理或传播任何违反你所在地法律法规的内容,包括但不限于:涉及国家秘密、商业机密、侵犯他人知识产权、个人隐私的内容,以及任何形式的非法、骚扰、诽谤、淫秽信息。 -* **合法用途**:你承诺不会将本项目用于任何非法目的或活动,例如网络攻击、诈骗等。 -* **数据安全**:你对自己存储在本地知识库、记忆库和日志中的所有内容的合法性负全部责任。 -* **插件规范**:不要使用任何已知包含恶意代码、安全漏洞或违法内容的插件。 - -**你将对自己使用本项目(包括所有第三方插件)的全部行为及其产生的一切后果,承担完全的法律责任。** - ---- - -### **6. 免责声明(非常重要!)** - -* **“按原样”提供**:本项目是“按原样”提供的,我们**不提供任何形式的明示或暗示的担保**,包括但不限于对适销性、特定用途适用性和不侵权的保证。 -* **AI 回复的立场**:机器人的所有回复均由第三方大语言模型生成,其观点和信息**不代表 MoFox_Bot 团队的立场**。我们不对其准确性、完整性或可靠性负责。 -* **无责任声明**:在任何情况下,MoFox_Bot 团队均不对因使用或无法使用本项目(特别是第三方插件)而导致的任何直接、间接、偶然、特殊或后果性的损害(包括但不限于数据丢失、利润损失、业务中断)承担责任。 -* **插件支持**:所有第三方插件的技术支持、功能更新和 bug 修复,都应**直接联系相应的插件开发者**。 - ---- - -### **7. 其他条款** - -* **协议的修改**:我们保留随时修改本协议的权利。修改后的协议将在新版本发布时生效。我们建议你定期检查以获取最新版本。继续使用本项目即表示你接受修订后的协议。 -* **最终解释权**:在法律允许的范围内,MoFox_Bot 团队保留对本协议的最终解释权。 -* **适用法律**:本协议的订立、执行和解释及争议的解决均应适用中国法律。 - ---- - -### **风险提示(请再次确认你已理解!)** - -* **隐私风险**:你的对话数据会被发送到不受我们控制的第三方 API。请**绝对不要**在对话中包含任何个人身份信息、财务信息、密码或其他敏感数据。 -* **精神健康风险**:AI 机器人只是一个程序,无法提供真正的情感支持或专业的心理建议。如果遇到任何心理困扰,请务必寻求专业人士的帮助(例如,全国心理援助热线:12355)。 -* **插件风险**:这是最大的风险之一。第三方插件可能带来严重的安全漏洞、系统不稳定、性能下降甚至隐私数据泄露的风险。请务必谨慎选择和使用,并为自己的选择承担全部后果。 diff --git a/bot.py b/bot.py index f382df1e1..472ee5f08 100644 --- a/bot.py +++ b/bot.py @@ -229,10 +229,10 @@ if __name__ == "__main__": asyncio.set_event_loop(loop) try: - # 执行初始化和任务调度 - loop.run_until_complete(main_system.initialize()) # 异步初始化数据库表结构 loop.run_until_complete(maibot.initialize_database_async()) + # 执行初始化和任务调度 + loop.run_until_complete(main_system.initialize()) initialize_lpmm_knowledge() # Schedule tasks returns a future that runs forever. # We can run console_input_loop concurrently. diff --git a/plugins/set_emoji_like/plugin.py b/plugins/set_emoji_like/plugin.py index 9e569cbb2..810f0639e 100644 --- a/plugins/set_emoji_like/plugin.py +++ b/plugins/set_emoji_like/plugin.py @@ -10,7 +10,6 @@ from src.plugin_system import ( ConfigField, ) from src.common.logger import get_logger -from src.plugin_system.apis import send_api from .qq_emoji_list import qq_face from src.plugin_system.base.component_types import ChatType @@ -125,31 +124,25 @@ class SetEmojiLikeAction(BaseAction): try: # 使用适配器API发送贴表情命令 - response = await send_api.adapter_command_to_stream( - action="set_msg_emoji_like", - params={"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, - stream_id=self.chat_stream.stream_id if self.chat_stream else None, - timeout=30.0, - storage_message=False, + success = await self.send_command( + command_name="set_emoji_like", args={"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, storage_message=False ) - - if response["status"] == "ok": - logger.info(f"设置表情回应成功: {response}") + if success: + logger.info("设置表情回应成功") await self.store_action_info( action_build_into_prompt=True, action_prompt_display=f"执行了set_emoji_like动作,{emoji_input},设置表情回应: {emoji_id}, 是否设置: {set_like}", action_done=True, ) - return True, f"成功设置表情回应: {response.get('message', '成功')}" + return True, "成功设置表情回应" else: - error_msg = response.get("message", "未知错误") - logger.error(f"设置表情回应失败: {error_msg}") + logger.error("设置表情回应失败") await self.store_action_info( action_build_into_prompt=True, - action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: {error_msg}", + action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败", action_done=False, ) - return False, f"设置表情回应失败: {error_msg}" + return False, "设置表情回应失败" except Exception as e: logger.error(f"设置表情回应失败: {e}") diff --git a/rust_image/Cargo.toml b/rust_image/Cargo.toml deleted file mode 100644 index e69de29bb..000000000 diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 5a61eeebc..f0888d552 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -1,13 +1,14 @@ import asyncio import os +import shutil import sys -import glob import orjson import datetime from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock from typing import Optional +from json_repair import repair_json # 将项目根目录添加到 sys.path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -37,13 +38,30 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") file_lock = Lock() +# --- 缓存清理 --- + +def clear_cache(): + """清理 lpmm_learning_tool.py 生成的缓存文件""" + logger.info("--- 开始清理缓存 ---") + if os.path.exists(TEMP_DIR): + try: + shutil.rmtree(TEMP_DIR) + logger.info(f"成功删除缓存目录: {TEMP_DIR}") + except OSError as e: + logger.error(f"删除缓存时出错: {e}") + else: + logger.info("缓存目录不存在,无需清理。") + logger.info("--- 缓存清理完成 ---") + # --- 模块一:数据预处理 --- + def process_text_file(file_path): with open(file_path, "r", encoding="utf-8") as f: raw = f.read() return [p.strip() for p in raw.split("\n\n") if p.strip()] + def preprocess_raw_data(): logger.info("--- 步骤 1: 开始数据预处理 ---") os.makedirs(RAW_DATA_PATH, exist_ok=True) @@ -51,7 +69,7 @@ def preprocess_raw_data(): if not raw_files: logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件") return [] - + all_paragraphs = [] for file in raw_files: logger.info(f"正在处理文件: {file.name}") @@ -62,8 +80,57 @@ def preprocess_raw_data(): logger.info("--- 数据预处理完成 ---") return unique_paragraphs + # --- 模块二:信息提取 --- + +def _parse_and_repair_json(json_string: str) -> Optional[dict]: + """ + 尝试解析JSON字符串,如果失败则尝试修复并重新解析。 + + 该函数首先会清理字符串,去除常见的Markdown代码块标记, + 然后尝试直接解析。如果解析失败,它会调用 `repair_json` + 进行修复,并再次尝试解析。 + + Args: + json_string: 从LLM获取的、可能格式不正确的JSON字符串。 + + Returns: + 解析后的字典。如果最终无法解析,则返回 None,并记录详细错误日志。 + """ + if not isinstance(json_string, str): + logger.error(f"输入内容非字符串,无法解析: {type(json_string)}") + return None + + # 1. 预处理:去除常见的多余字符,如Markdown代码块标记 + cleaned_string = json_string.strip() + if cleaned_string.startswith("```json"): + cleaned_string = cleaned_string[7:].strip() + elif cleaned_string.startswith("```"): + cleaned_string = cleaned_string[3:].strip() + + if cleaned_string.endswith("```"): + cleaned_string = cleaned_string[:-3].strip() + + # 2. 性能优化:乐观地尝试直接解析 + try: + return orjson.loads(cleaned_string) + except orjson.JSONDecodeError: + logger.warning("直接解析JSON失败,将尝试修复...") + + # 3. 修复与最终解析 + repaired_json_str = "" + try: + repaired_json_str = repair_json(cleaned_string) + return orjson.loads(repaired_json_str) + except Exception as e: + # 4. 增强错误处理:记录详细的失败信息 + logger.error(f"修复并解析JSON后依然失败: {e}") + logger.error(f"原始字符串 (清理后): {cleaned_string}") + logger.error(f"修复后尝试解析的字符串: {repaired_json_str}") + return None + + def get_extraction_prompt(paragraph: str) -> str: return f""" 请从以下段落中提取关键信息。你需要提取两种类型的信息: @@ -82,6 +149,7 @@ def get_extraction_prompt(paragraph: str) -> str: --- """ + async def extract_info_async(pg_hash, paragraph, llm_api): temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") with file_lock: @@ -93,11 +161,20 @@ async def extract_info_async(pg_hash, paragraph, llm_api): os.remove(temp_file_path) prompt = get_extraction_prompt(paragraph) + content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) - extracted_data = orjson.loads(content) + + # 改进点:调用封装好的函数处理JSON解析和修复 + extracted_data = _parse_and_repair_json(content) + + if extracted_data is None: + # 如果解析失败,抛出异常以触发统一的错误处理逻辑 + raise ValueError("无法从LLM输出中解析有效的JSON数据") + doc_item = { - "idx": pg_hash, "passage": paragraph, + "idx": pg_hash, + "passage": paragraph, "extracted_entities": extracted_data.get("entities", []), "extracted_triples": extracted_data.get("triples", []), } @@ -107,27 +184,45 @@ async def extract_info_async(pg_hash, paragraph, llm_api): return doc_item, None except Exception as e: logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") + if content: + logger.error(f"导致解析失败的原始输出: {content}") return None, pg_hash + def extract_info_sync(pg_hash, paragraph, llm_api): return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) + def extract_information(paragraphs_dict, model_set): logger.info("--- 步骤 2: 开始信息提取 ---") os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) - + llm_api = LLMRequest(model_set=model_set) failed_hashes, open_ie_docs = [], [] with ThreadPoolExecutor(max_workers=5) as executor: - f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()} - with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress: + f_to_hash = { + executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items() + } + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + ) as progress: task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) for future in as_completed(f_to_hash): doc_item, failed_hash = future.result() - if failed_hash: failed_hashes.append(failed_hash) - elif doc_item: open_ie_docs.append(doc_item) + if failed_hash: + failed_hashes.append(failed_hash) + elif doc_item: + open_ie_docs.append(doc_item) progress.update(task, advance=1) if open_ie_docs: @@ -136,19 +231,22 @@ def extract_information(paragraphs_dict, model_set): avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0 avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0 openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words) - + now = datetime.datetime.now() filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json") output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) with open(output_path, "wb") as f: f.write(orjson.dumps(openie_obj._to_dict())) logger.info(f"信息提取结果已保存到: {output_path}") - - if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") + + if failed_hashes: + logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") logger.info("--- 信息提取完成 ---") + # --- 模块三:数据导入 --- + async def import_data(openie_obj: Optional[OpenIE] = None): """ 将OpenIE数据导入知识库(Embedding Store 和 KG) @@ -160,15 +258,19 @@ async def import_data(openie_obj: Optional[OpenIE] = None): """ logger.info("--- 步骤 3: 开始数据导入 ---") embed_manager, kg_manager = EmbeddingManager(), KGManager() - + logger.info("正在加载现有的 Embedding 库...") - try: embed_manager.load_from_file() - except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。") + try: + embed_manager.load_from_file() + except Exception as e: + logger.warning(f"加载 Embedding 库失败: {e}。") logger.info("正在加载现有的 KG...") - try: kg_manager.load_from_file() - except Exception as e: logger.warning(f"加载 KG 失败: {e}。") - + try: + kg_manager.load_from_file() + except Exception as e: + logger.warning(f"加载 KG 失败: {e}。") + try: if openie_obj: openie_data = openie_obj @@ -181,7 +283,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None): raw_paragraphs = openie_data.extract_raw_paragraph_dict() triple_list_data = openie_data.extract_triple_dict() - + new_raw_paragraphs, new_triple_list_data = {}, {} stored_embeds = embed_manager.stored_pg_hashes stored_kgs = kg_manager.stored_paragraph_hashes @@ -190,7 +292,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None): if p_hash not in stored_embeds and p_hash not in stored_kgs: new_raw_paragraphs[p_hash] = raw_p new_triple_list_data[p_hash] = triple_list_data.get(p_hash, []) - + if not new_raw_paragraphs: logger.info("没有新的段落需要处理。") else: @@ -208,60 +310,68 @@ async def import_data(openie_obj: Optional[OpenIE] = None): logger.info("--- 数据导入完成 ---") + def import_from_specific_file(): """从用户指定的 openie.json 文件导入数据""" file_path = input("请输入 openie.json 文件的完整路径: ").strip() - + if not os.path.exists(file_path): logger.error(f"文件路径不存在: {file_path}") return - + if not file_path.endswith(".json"): logger.error("请输入一个有效的 .json 文件路径。") return try: logger.info(f"正在从 {file_path} 加载 OpenIE 数据...") - openie_obj = OpenIE.load(filepath=file_path) + openie_obj = OpenIE.load() asyncio.run(import_data(openie_obj=openie_obj)) except Exception as e: logger.error(f"从指定文件导入数据时发生错误: {e}") + # --- 主函数 --- + def main(): # 使用 os.path.relpath 创建相对于项目根目录的友好路径 raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")) - + print("=== LPMM 知识库学习工具 ===") print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)") print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)") print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识") print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") + print("6. [清理缓存] -> 删除所有已提取信息的缓存") print("0. [退出]") print("-" * 30) choice = input("请输入你的选择 (0-5): ").strip() - if choice == '1': + if choice == "1": preprocess_raw_data() - elif choice == '2': + elif choice == "2": paragraphs = preprocess_raw_data() - if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) - elif choice == '3': + if paragraphs: + extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + elif choice == "3": asyncio.run(import_data()) - elif choice == '4': + elif choice == "4": paragraphs = preprocess_raw_data() if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) asyncio.run(import_data()) - elif choice == '5': + elif choice == "5": import_from_specific_file() - elif choice == '0': + elif choice == "6": + clear_cache() + elif choice == "0": sys.exit(0) else: print("无效输入,请重新运行脚本。") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/chat/antipromptinjector/core/shield.py b/src/chat/antipromptinjector/core/shield.py index c4ab8afa8..c7a2e78bc 100644 --- a/src/chat/antipromptinjector/core/shield.py +++ b/src/chat/antipromptinjector/core/shield.py @@ -233,6 +233,5 @@ class MessageShield: def create_default_shield() -> MessageShield: """创建默认的消息加盾器""" - from .config import default_config return MessageShield() diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py index 441d14b1f..79d4eca9d 100644 --- a/src/chat/chat_loop/cycle_processor.py +++ b/src/chat/chat_loop/cycle_processor.py @@ -206,13 +206,6 @@ class CycleProcessor: raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成") with Timer("规划器", cycle_timers): actions, _ = await self.action_planner.plan(mode=mode) - - # 在这里添加日志,清晰地显示最终选择的动作 - if actions: - chosen_actions = [a.get("action_type", "unknown") for a in actions] - logger.info(f"{self.log_prefix} LLM最终选择的动作: {chosen_actions}") - else: - logger.info(f"{self.log_prefix} LLM最终没有选择任何动作") async def execute_action(action_info): """执行单个动作的通用函数""" @@ -267,6 +260,7 @@ class CycleProcessor: enable_tool=global_config.tool.enable_tool, request_type="chat.replyer", from_plugin=False, + read_mark=action_info.get("action_message", {}).get("time", 0.0), ) if not success or not response_set: logger.info( diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index bf282da5e..adc868117 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -10,7 +10,6 @@ from src.config.config import global_config from src.person_info.relationship_builder_manager import relationship_builder_manager from src.chat.express.expression_learner import expression_learner_manager from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager, SleepState -from src.plugin_system.apis import message_api from .hfc_context import HfcContext from .energy_manager import EnergyManager @@ -40,6 +39,7 @@ class HeartFChatting: """ self.context = HfcContext(chat_id) self.context.new_message_queue = asyncio.Queue() + self._processing_lock = asyncio.Lock() self.cycle_tracker = CycleTracker(self.context) self.response_handler = ResponseHandler(self.context) @@ -358,130 +358,130 @@ class HeartFChatting: - FOCUS模式:直接处理所有消息并检查退出条件 - NORMAL模式:检查进入FOCUS模式的条件,并通过normal_mode_handler处理消息 """ - # --- 核心状态更新 --- - await self.sleep_manager.update_sleep_state(self.wakeup_manager) - current_sleep_state = self.sleep_manager.get_current_sleep_state() - is_sleeping = current_sleep_state == SleepState.SLEEPING - is_in_insomnia = current_sleep_state == SleepState.INSOMNIA + async with self._processing_lock: + # --- 核心状态更新 --- + await self.sleep_manager.update_sleep_state(self.wakeup_manager) + current_sleep_state = self.sleep_manager.get_current_sleep_state() + is_sleeping = current_sleep_state == SleepState.SLEEPING + is_in_insomnia = current_sleep_state == SleepState.INSOMNIA - # 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收 - filter_command_flag = not (is_sleeping or is_in_insomnia) + # 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收 + filter_command_flag = not (is_sleeping or is_in_insomnia) - # 从队列中获取所有待处理的新消息 - recent_messages = [] - while not self.context.new_message_queue.empty(): - recent_messages.append(await self.context.new_message_queue.get()) + # 从队列中获取所有待处理的新消息 + recent_messages = [] + while not self.context.new_message_queue.empty(): + recent_messages.append(await self.context.new_message_queue.get()) - has_new_messages = bool(recent_messages) - new_message_count = len(recent_messages) + has_new_messages = bool(recent_messages) + new_message_count = len(recent_messages) - # 只有在有新消息时才进行思考循环处理 - if has_new_messages: - self.context.last_message_time = time.time() - self.context.last_read_time = time.time() + # 只有在有新消息时才进行思考循环处理 + if has_new_messages: + self.context.last_message_time = time.time() + self.context.last_read_time = time.time() - # --- 专注模式安静群组检查 --- - quiet_groups = global_config.chat.focus_mode_quiet_groups - if quiet_groups and self.context.chat_stream: - is_group_chat = self.context.chat_stream.group_info is not None - if is_group_chat: - try: - platform = self.context.chat_stream.platform - group_id = self.context.chat_stream.group_info.group_id - - # 兼容不同QQ适配器的平台名称 - is_qq_platform = platform in ["qq", "napcat"] - - current_chat_identifier = f"{platform}:{group_id}" - config_identifier_for_qq = f"qq:{group_id}" - - is_in_quiet_list = (current_chat_identifier in quiet_groups or - (is_qq_platform and config_identifier_for_qq in quiet_groups)) - - if is_in_quiet_list: - is_mentioned_in_batch = False - for msg in recent_messages: - if msg.get("is_mentioned"): - is_mentioned_in_batch = True - break + # --- 专注模式安静群组检查 --- + quiet_groups = global_config.chat.focus_mode_quiet_groups + if quiet_groups and self.context.chat_stream: + is_group_chat = self.context.chat_stream.group_info is not None + if is_group_chat: + try: + platform = self.context.chat_stream.platform + group_id = self.context.chat_stream.group_info.group_id - if not is_mentioned_in_batch: - logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。") - return True # 消耗消息但不做回复 + # 兼容不同QQ适配器的平台名称 + is_qq_platform = platform in ["qq", "napcat"] + + current_chat_identifier = f"{platform}:{group_id}" + config_identifier_for_qq = f"qq:{group_id}" + + is_in_quiet_list = (current_chat_identifier in quiet_groups or + (is_qq_platform and config_identifier_for_qq in quiet_groups)) + + if is_in_quiet_list: + is_mentioned_in_batch = False + for msg in recent_messages: + if msg.get("is_mentioned"): + is_mentioned_in_batch = True + break + + if not is_mentioned_in_batch: + logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。") + return True # 消耗消息但不做回复 + except Exception as e: + logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}") + + # 处理唤醒度逻辑 + if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]: + self._handle_wakeup_messages(recent_messages) + + # 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP + current_sleep_state = self.sleep_manager.get_current_sleep_state() + + if current_sleep_state == SleepState.SLEEPING: + # 只有在纯粹的 SLEEPING 状态下才跳过消息处理 + return True + + if current_sleep_state == SleepState.WOKEN_UP: + logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。") + + # 根据聊天模式处理新消息 + should_process, interest_value = await self._should_process_messages(recent_messages) + if not should_process: + # 消息数量不足或兴趣不够,等待 + await asyncio.sleep(0.5) + return True # Skip rest of the logic for this iteration + + # Messages should be processed + action_type = await self.cycle_processor.observe(interest_value=interest_value) + + # 尝试触发表达学习 + if self.context.expression_learner: + try: + await self.context.expression_learner.trigger_learning_for_chat() except Exception as e: - logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}") + logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}") - # 处理唤醒度逻辑 - if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]: - self._handle_wakeup_messages(recent_messages) + # 管理no_reply计数器 + if action_type != "no_reply": + self.recent_interest_records.clear() + self.context.no_reply_consecutive = 0 + logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器") + else: # action_type == "no_reply" + self.context.no_reply_consecutive += 1 + self._determine_form_type() - # 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP - current_sleep_state = self.sleep_manager.get_current_sleep_state() + # 在一轮动作执行完毕后,增加睡眠压力 + if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system: + if action_type not in ["no_reply", "no_action"]: + self.context.energy_manager.increase_sleep_pressure() - if current_sleep_state == SleepState.SLEEPING: - # 只有在纯粹的 SLEEPING 状态下才跳过消息处理 - return True - - if current_sleep_state == SleepState.WOKEN_UP: - logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。") - - # 根据聊天模式处理新消息 - should_process, interest_value = await self._should_process_messages(recent_messages) - if not should_process: - # 消息数量不足或兴趣不够,等待 - await asyncio.sleep(0.5) - return True # Skip rest of the logic for this iteration - - # Messages should be processed - action_type = await self.cycle_processor.observe(interest_value=interest_value) - - # 尝试触发表达学习 - if self.context.expression_learner: - try: - await self.context.expression_learner.trigger_learning_for_chat() - except Exception as e: - logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}") - - # 管理no_reply计数器 - if action_type != "no_reply": - self.recent_interest_records.clear() - self.context.no_reply_consecutive = 0 - logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器") - else: # action_type == "no_reply" - self.context.no_reply_consecutive += 1 - self._determine_form_type() - - # 在一轮动作执行完毕后,增加睡眠压力 - if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system: - if action_type not in ["no_reply", "no_action"]: - self.context.energy_manager.increase_sleep_pressure() - - # 如果成功观察,增加能量值并重置累积兴趣值 - self.context.energy_value += 1 / global_config.chat.focus_value - # 重置累积兴趣值,因为消息已经被成功处理 - self.context.breaking_accumulated_interest = 0.0 - logger.info( - f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值" - ) - - # 更新上一帧的睡眠状态 - self.context.was_sleeping = is_sleeping - - # --- 重新入睡逻辑 --- - # 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡 - if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages: - re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60 - # 使用 last_message_time 来判断空闲时间 - if time.time() - self.context.last_message_time > re_sleep_delay: + # 如果成功观察,增加能量值并重置累积兴趣值 + self.context.energy_value += 1 / global_config.chat.focus_value + # 重置累积兴趣值,因为消息已经被成功处理 + self.context.breaking_accumulated_interest = 0.0 logger.info( - f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。" + f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值" ) - self.sleep_manager.reset_sleep_state_after_wakeup() - # 保存HFC上下文状态 - self.context.save_context_state() + # 更新上一帧的睡眠状态 + self.context.was_sleeping = is_sleeping - return has_new_messages + # --- 重新入睡逻辑 --- + # 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡 + if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages: + re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60 + # 使用 last_message_time 来判断空闲时间 + if time.time() - self.context.last_message_time > re_sleep_delay: + logger.info( + f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。" + ) + self.sleep_manager.reset_sleep_state_after_wakeup() + + # 保存HFC上下文状态 + self.context.save_context_state() + return has_new_messages def _handle_wakeup_messages(self, messages): """ diff --git a/src/chat/chat_loop/proactive/proactive_thinker.py b/src/chat/chat_loop/proactive/proactive_thinker.py index 4dea5ec99..34abf7803 100644 --- a/src/chat/chat_loop/proactive/proactive_thinker.py +++ b/src/chat/chat_loop/proactive/proactive_thinker.py @@ -162,107 +162,33 @@ class ProactiveThinker: news_block = "暂时没有获取到最新资讯。" if trigger_event.source != "reminder_system": - # 升级决策模型 - should_search_prompt = f""" -# 搜索决策 - -## 任务 -分析话题“{topic}”,判断它的展开更依赖于“外部信息”还是“内部信息”,并决定是否需要进行网络搜索。 - -## 判断原则 -- **需要搜索 (SEARCH)**:当话题的有效讨论**必须**依赖于现实世界的、客观的、可被检索的外部信息时。这包括但不限于: - - 新闻时事、公共事件 - - 专业知识、科学概念 - - 天气、股价等实时数据 - - 对具体实体(如电影、书籍、地点)的客观描述查询 - -- **无需搜索 (SKIP)**:当话题的展开主要依赖于**已有的对话上下文、个人情感、主观体验或社交互动**时。这包括但不限于: - - 延续之前的对话、追问细节 - - 表达关心、问候或个人感受 - - 分享主观看法或经历 - - 纯粹的社交性互动 - -## 你的决策 -根据以上原则,对“{topic}”这个话题进行分析,并严格输出`SEARCH`或`SKIP`。 -""" - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config - - decision_llm = LLMRequest( - model_set=model_config.model_task_config.planner, - request_type="planner" - ) - - decision, _ = await decision_llm.generate_response_async(prompt=should_search_prompt) - - if "SEARCH" in decision: - try: - if topic and topic.strip(): - web_search_tool = tool_api.get_tool_instance("web_search") - if web_search_tool: - try: - search_result_dict = await web_search_tool.execute( - function_args={"query": topic, "max_results": 10} - ) - if search_result_dict and not search_result_dict.get("error"): - news_block = search_result_dict.get("content", "未能提取有效资讯。") - elif search_result_dict: - logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}") - except Exception as e: - logger.error(f"{self.context.log_prefix} 网络搜索执行失败: {e}") - else: - logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。") - else: - logger.warning(f"{self.context.log_prefix} 主题为空,跳过网络搜索。") - except Exception as e: - logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}") - message_list = await get_raw_msg_before_timestamp_with_chat( + try: + web_search_tool = tool_api.get_tool_instance("web_search") + if web_search_tool: + try: + search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10}) + except TypeError: + try: + search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10}) + except TypeError: + logger.warning(f"{self.context.log_prefix} 网络搜索工具参数不匹配,跳过搜索") + news_block = "跳过网络搜索。" + search_result_dict = None + + if search_result_dict and not search_result_dict.get("error"): + news_block = search_result_dict.get("content", "未能提取有效资讯。") + elif search_result_dict: + logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}") + else: + logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。") + except Exception as e: + logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}") + message_list = get_raw_msg_before_timestamp_with_chat( chat_id=self.context.stream_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.3), ) - chat_context_block, _ = await build_readable_messages_with_id(messages=message_list) - - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config - - bot_name = global_config.bot.nickname - - confirmation_prompt = f"""# 主动回复二次确认 - -## 基本信息 -你的名字是{bot_name},准备主动发起关于"{topic}"的话题。 - -## 最近的聊天内容 -{chat_context_block} - -## 合理判断标准 -请检查以下条件,如果**所有条件都合理**就可以回复: - -1. **回应检查**:检查你({bot_name})发送的最后一条消息之后,是否有其他人发言。如果没有,则大概率应该保持沉默。 -2. **话题补充**:只有当你认为准备发起的话题是对上一条无人回应消息的**有价值的补充**时,才可以在上一条消息无人回应的情况下继续发言。 -3. **时间合理性**:当前时间是否在深夜(凌晨2点-6点)这种不适合主动聊天的时段? -4. **内容价值**:这个话题"{topic}"是否有意义,不是完全无关紧要的内容? -5. **重复避免**:你准备说的话题是否与你自己的上一条消息明显重复? -6. **自然性**:在当前上下文中主动提起这个话题是否自然合理? - -## 输出要求 -如果判断应该跳过(比如上一条消息无人回应、深夜时段、无意义话题、重复内容),输出:SKIP_PROACTIVE_REPLY -其他情况都应该输出:PROCEED_TO_REPLY - -请严格按照上述格式输出,不要添加任何解释。""" - - planner_llm = LLMRequest( - model_set=model_config.model_task_config.planner, - request_type="planner" - ) - - confirmation_result, _ = await planner_llm.generate_response_async(prompt=confirmation_prompt) - - if not confirmation_result or "SKIP_PROACTIVE_REPLY" in confirmation_result: - logger.info(f"{self.context.log_prefix} 决策模型二次确认决定跳过主动回复") - return - + chat_context_block, _ = await build_readable_messages_with_id(messages=message_list) bot_name = global_config.bot.nickname personality = global_config.personality identity_block = ( diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index fb22a4115..bb663a1ad 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -4,7 +4,7 @@ import orjson import os from datetime import datetime -from typing import List, Dict, Optional, Any, Tuple, Coroutine +from typing import List, Dict, Optional, Any, Tuple from src.common.logger import get_logger from src.common.database.sqlalchemy_database_api import get_db_session diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index 275a25a57..136b1cb41 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -24,7 +24,7 @@ class SubHeartflow: self.subheartflow_id = subheartflow_id self.chat_id = subheartflow_id - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) + self.is_group_chat, self.chat_target_info = (None, None) self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id # focus模式退出冷却时间管理 @@ -38,4 +38,5 @@ class SubHeartflow: async def initialize(self): """异步初始化方法,创建兴趣流并确定聊天类型""" + self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_id) await self.heart_fc_instance.start() diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 3b68190a7..53cb00345 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -415,7 +415,6 @@ class ChatBot: return get_chat_manager().register_message(message) - chat = await get_chat_manager().get_or_create_stream( platform=message.message_info.platform, # type: ignore user_info=user_info, # type: ignore @@ -427,11 +426,11 @@ class ChatBot: # 处理消息内容,生成纯文本 await message.process() - # 过滤检查 - if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore - message.raw_message, # type: ignore - chat, - user_info, # type: ignore + # 过滤检查 (在消息处理之后进行) + if _check_ban_words( + message.processed_plain_text, chat, user_info # type: ignore + ) or _check_ban_regex( + message.processed_plain_text, chat, user_info # type: ignore ): return diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index c42654aa3..de2fb62e9 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -254,7 +254,7 @@ class ChatManager: model_instance = await _db_find_stream_async(stream_id) if model_instance: - # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 + # 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式 user_info_data = { "platform": model_instance.user_platform, "user_id": model_instance.user_id, @@ -382,7 +382,7 @@ class ChatManager: await _db_save_stream_async(stream_data_dict) stream.saved = True except Exception as e: - logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) + logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}", exc_info=True) async def _save_all_streams(self): """保存所有聊天流""" @@ -435,7 +435,7 @@ class ChatManager: if stream.stream_id in self.last_messages: stream.set_context(self.last_messages[stream.stream_id]) except Exception as e: - logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) + logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) chat_manager = None diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 22e57edf0..22c3e3776 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -2,7 +2,7 @@ import base64 import time from abc import abstractmethod, ABCMeta from dataclasses import dataclass -from typing import Optional, Any, TYPE_CHECKING +from typing import Optional, Any import urllib3 from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 154fe62a7..bcd01934d 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -72,7 +72,7 @@ class ActionModifier: from src.chat.utils.utils import get_chat_type_and_target_info # 获取聊天类型 - is_group_chat, _ = get_chat_type_and_target_info(self.chat_id) + is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id) all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION) chat_type_removals = [] diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 4ef8de2d8..6aaefba18 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List, Optional from json_repair import repair_json -from . import planner_prompts from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.utils.chat_message_builder import ( build_readable_actions, @@ -47,12 +46,12 @@ class PlanFilter: try: prompt, used_message_id_list = await self._build_prompt(plan) plan.llm_prompt = prompt - logger.debug(f"墨墨在这里加了日志 -> LLM prompt: {prompt}") + logger.info(f"规划器原始提示词: {prompt}") llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt) if llm_content: - logger.debug(f"墨墨在这里加了日志 -> LLM a原始返回: {llm_content}") + logger.info(f"规划器原始返回: {llm_content}") parsed_json = orjson.loads(repair_json(llm_content)) logger.debug(f"墨墨在这里加了日志 -> 解析后的 JSON: {parsed_json}") diff --git a/src/chat/planner_actions/plan_generator.py b/src/chat/planner_actions/plan_generator.py index 5d1ab9c38..ec0a11691 100644 --- a/src/chat/planner_actions/plan_generator.py +++ b/src/chat/planner_actions/plan_generator.py @@ -51,7 +51,7 @@ class PlanGenerator: Returns: Plan: 一个填充了初始上下文信息的 Plan 对象。 """ - _is_group_chat, chat_target_info_dict = get_chat_type_and_target_info(self.chat_id) + _is_group_chat, chat_target_info_dict = await get_chat_type_and_target_info(self.chat_id) target_info = None if chat_target_info_dict: diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 6e45b7907..0e3d1afc3 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -10,7 +10,7 @@ from src.chat.planner_actions.plan_filter import PlanFilter from src.chat.planner_actions.plan_generator import PlanGenerator from src.common.logger import get_logger from src.plugin_system.base.component_types import ChatMode - +import src.chat.planner_actions.planner_prompts #noga # noqa: F401 # 导入提示词模块以确保其被初始化 logger = get_logger("planner") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 127779e1e..76221ac1c 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -83,12 +83,12 @@ def init_prompt(): - {schedule_block} ## 历史记录 -### 当前群聊中的所有人的聊天记录: +### {chat_context_type}中的所有人的聊天记录: {background_dialogue_prompt} {cross_context_block} -### 当前群聊中正在与你对话的聊天记录 +### {chat_context_type}中正在与你对话的聊天记录 {core_dialogue_prompt} ## 表达方式 @@ -110,12 +110,11 @@ def init_prompt(): ## 任务 -*你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。* +*你正在一个{chat_context_type}里聊天,你需要理解整个{chat_context_type}的聊天动态和话题走向,并做出自然的回应。* ### 核心任务 -- 你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与聊天,你可以参考他们的回复内容,但是你现在想回复{sender_name}的发言。 - -- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。 +- 你现在的主要任务是和 {sender_name} 聊天。 +- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。 ## 规则 {safety_guidelines_block} @@ -203,7 +202,9 @@ class DefaultReplyer: ): self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) + self.is_group_chat: Optional[bool] = None + self.chat_target_info: Optional[Dict[str, Any]] = None + self._initialized = False self.heart_fc_sender = HeartFCSender() self.memory_activator = MemoryActivator() @@ -236,6 +237,7 @@ class DefaultReplyer: from_plugin: bool = True, stream_id: Optional[str] = None, reply_message: Optional[Dict[str, Any]] = None, + read_mark: float = 0.0, ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: # sourcery skip: merge-nested-ifs """ @@ -268,6 +270,7 @@ class DefaultReplyer: available_actions=available_actions, enable_tool=enable_tool, reply_message=reply_message, + read_mark=read_mark, ) if not prompt: @@ -723,10 +726,8 @@ class DefaultReplyer: truncate=True, show_actions=True, ) - core_dialogue_prompt = f"""-------------------------------- -这是你和{sender}的对话,你们正在交流中: + core_dialogue_prompt = f""" {core_dialogue_prompt_str} --------------------------------- """ return core_dialogue_prompt, all_dialogue_prompt @@ -776,6 +777,12 @@ class DefaultReplyer: mai_think.target = target return mai_think + async def _async_init(self): + if self._initialized: + return + self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_stream.stream_id) + self._initialized = True + async def build_prompt_reply_context( self, reply_to: str, @@ -783,6 +790,7 @@ class DefaultReplyer: available_actions: Optional[Dict[str, ActionInfo]] = None, enable_tool: bool = True, reply_message: Optional[Dict[str, Any]] = None, + read_mark: float = 0.0, ) -> str: """ 构建回复器上下文 @@ -800,10 +808,11 @@ class DefaultReplyer: """ if available_actions is None: available_actions = {} + await self._async_init() chat_stream = self.chat_stream chat_id = chat_stream.stream_id person_info_manager = get_person_info_manager() - is_group_chat = bool(chat_stream.group_info) + is_group_chat = self.is_group_chat if global_config.mood.enable_mood: chat_mood = mood_manager.get_mood_by_chat_id(chat_id) @@ -859,7 +868,7 @@ class DefaultReplyer: target = "(无消息内容)" person_info_manager = get_person_info_manager() - person_id = await person_info_manager.get_person_id_by_person_name(sender) + person_id = person_info_manager.get_person_id(platform, reply_message.get("user_id")) if reply_message else None platform = chat_stream.platform target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) @@ -891,7 +900,7 @@ class DefaultReplyer: replace_bot_name=True, merge_messages=False, timestamp_mode="relative", - read_mark=0.0, + read_mark=read_mark, show_actions=True, ) # 获取目标用户信息,用于s4u模式 @@ -1081,6 +1090,7 @@ class DefaultReplyer: reply_target_block=reply_target_block, mood_prompt=mood_prompt, action_descriptions=action_descriptions, + read_mark=read_mark, ) # 使用新的统一Prompt系统 - 使用正确的模板名称 @@ -1127,9 +1137,10 @@ class DefaultReplyer: reply_to: str, reply_message: Optional[Dict[str, Any]] = None, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if + await self._async_init() chat_stream = self.chat_stream chat_id = chat_stream.stream_id - is_group_chat = bool(chat_stream.group_info) + is_group_chat = self.is_group_chat if reply_message: sender = reply_message.get("sender") @@ -1167,7 +1178,7 @@ class DefaultReplyer: replace_bot_name=True, merge_messages=False, timestamp_mode="relative", - read_mark=0.0, + read_mark=read_mark, show_actions=True, ) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index b3110a8e6..e2d0a4fb9 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -520,6 +520,7 @@ async def _build_readable_messages_internal( pic_counter: int = 1, show_pic: bool = True, message_id_list: Optional[List[Dict[str, Any]]] = None, + read_mark: float = 0.0, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -642,6 +643,10 @@ async def _build_readable_messages_internal( else: person_name = "某人" + # 在用户名后面添加 QQ 号, 但机器人本体不用 + if user_id != global_config.bot.qq_account: + person_name = f"{person_name}({user_id})" + # 使用独立函数处理用户引用格式 content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name) @@ -726,11 +731,10 @@ async def _build_readable_messages_internal( "is_action": is_action, } continue - # 如果是同一个人发送的连续消息且时间间隔小于等于60秒 if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60): current_merge["content"].append(content) - current_merge["end_time"] = timestamp # 更新最后消息时间 + current_merge["end_time"] = timestamp else: # 保存上一个合并块 merged_messages.append(current_merge) @@ -758,8 +762,14 @@ async def _build_readable_messages_internal( # 4 & 5: 格式化为字符串 output_lines = [] + read_mark_inserted = False for _i, merged in enumerate(merged_messages): + # 检查是否需要插入已读标记 + if read_mark > 0 and not read_mark_inserted and merged["start_time"] >= read_mark: + output_lines.append("\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n") + read_mark_inserted = True + # 使用指定的 timestamp_mode 格式化时间 readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode) diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index db31acfa5..3d97b622e 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -78,6 +78,7 @@ class PromptParameters: # 可用动作信息 available_actions: Optional[Dict[str, Any]] = None + read_mark: float = 0.0 def validate(self) -> List[str]: """参数验证""" @@ -449,7 +450,8 @@ class Prompt: core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts( self.parameters.message_list_before_now_long, self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", - self.parameters.sender + self.parameters.sender, + read_mark=self.parameters.read_mark, ) context_data["core_dialogue_prompt"] = core_dialogue @@ -465,7 +467,7 @@ class Prompt: @staticmethod async def _build_s4u_chat_history_prompts( - message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str + message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, read_mark: float = 0.0 ) -> Tuple[str, str]: """构建S4U风格的分离对话prompt""" # 实现逻辑与原有SmartPromptBuilder相同 @@ -491,6 +493,7 @@ class Prompt: replace_bot_name=True, timestamp_mode="normal", truncate=True, + read_mark=read_mark, ) all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}" @@ -510,7 +513,7 @@ class Prompt: replace_bot_name=True, merge_messages=False, timestamp_mode="normal_no_YMD", - read_mark=0.0, + read_mark=read_mark, truncate=True, show_actions=True, ) @@ -764,6 +767,7 @@ class Prompt: "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), + "chat_context_type": "群聊" if self.parameters.is_group_chat else "私聊", } def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 5eb4cc991..746b13e63 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -341,9 +341,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese split_sentences = [s.strip() for s in split_sentences_raw if s.strip()] else: if split_mode == "llm": - logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。") - split_sentences = [cleaned_text] - else: # mode == "punctuation" + logger.debug("未检测到 [SPLIT] 标记,回退到基于标点的传统模式进行分割。") + split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) + else: # mode == "punctuation" logger.debug("使用基于标点的传统模式进行分割。") split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) else: @@ -619,7 +619,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal" return time.strftime("%H:%M:%S", time.localtime(timestamp)) -def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: +async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: """ 获取聊天类型(是否群聊)和私聊对象信息。 @@ -663,7 +663,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: if person_id: # get_value is async, so await it directly person_info_manager = get_person_info_manager() - person_name = person_info_manager.get_value(person_id, "person_name") + person_data = await person_info_manager.get_values(person_id, ["person_name"]) + person_name = person_data.get("person_name") target_info["person_id"] = person_id target_info["person_name"] = person_name diff --git a/src/common/database/database.py b/src/common/database/database.py index 6a34d900e..1815a98ff 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,6 +1,4 @@ import os -from pymongo import MongoClient -from pymongo.database import Database from rich.traceback import install from src.common.logger import get_logger @@ -10,8 +8,6 @@ from src.common.database.sqlalchemy_models import get_engine, get_db_session install(extra_lines=3) -_client = None -_db = None _sql_engine = None logger = get_logger("database") @@ -64,43 +60,6 @@ class SQLAlchemyTransaction: db = DatabaseProxy() -def __create_database_instance(): - uri = os.getenv("MONGODB_URI") - host = os.getenv("MONGODB_HOST", "127.0.0.1") - port = int(os.getenv("MONGODB_PORT", "27017")) - # db_name 变量在创建连接时不需要,在获取数据库实例时才使用 - username = os.getenv("MONGODB_USERNAME") - password = os.getenv("MONGODB_PASSWORD") - auth_source = os.getenv("MONGODB_AUTH_SOURCE") - - if uri: - # 支持标准mongodb://和mongodb+srv://连接字符串 - if uri.startswith(("mongodb://", "mongodb+srv://")): - return MongoClient(uri) - else: - raise ValueError( - "Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. " - "For MongoDB Atlas, use 'mongodb+srv://' format. " - "See: https://www.mongodb.com/docs/manual/reference/connection-string/" - ) - - if username and password: - # 如果有用户名和密码,使用认证连接 - return MongoClient(host, port, username=username, password=password, authSource=auth_source) - - # 否则使用无认证连接 - return MongoClient(host, port) - - -def get_db(): - """获取MongoDB连接实例,延迟初始化。""" - global _client, _db - if _client is None: - _client = __create_database_instance() - _db = _client[os.getenv("DATABASE_NAME", "MegBot")] - return _db - - async def initialize_sql_database(database_config): """ 根据配置初始化SQL数据库连接(SQLAlchemy版本) @@ -141,17 +100,3 @@ async def initialize_sql_database(database_config): except Exception as e: logger.error(f"初始化SQL数据库失败: {e}") return None - - -class DBWrapper: - """数据库代理类,保持接口兼容性同时实现懒加载。""" - - def __getattr__(self, name): - return getattr(get_db(), name) - - def __getitem__(self, key): - return get_db()[key] # type: ignore - - -# 全局MongoDB数据库访问点 -memory_db: Database = DBWrapper() # type: ignore diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index aedff3676..085c277a3 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -1,7 +1,6 @@ # mmc/src/common/database/db_migration.py from sqlalchemy import inspect -from sqlalchemy.schema import CreateIndex from sqlalchemy.sql import text from src.common.database.sqlalchemy_models import Base, get_engine @@ -70,24 +69,32 @@ async def check_and_migrate_database(): def add_columns_sync(conn): dialect = conn.dialect + compiler = dialect.ddl_compiler(dialect, None) + for column_name in missing_columns: column = table.c[column_name] - - # 使用DDLCompiler为特定方言编译列 - compiler = dialect.ddl_compiler(dialect, None) - - # 编译列的数据类型 column_type = compiler.get_column_specification(column) - - # 构建原生SQL sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}" - - # 添加默认值(如果存在) + if column.default: - default_value = compiler.render_literal_value(column.default.arg, column.type) + # 手动处理不同方言的默认值 + default_arg = column.default.arg + if dialect.name == "sqlite" and isinstance(default_arg, bool): + # SQLite 将布尔值存储为 0 或 1 + default_value = "1" if default_arg else "0" + elif hasattr(compiler, 'render_literal_value'): + try: + # 尝试使用 render_literal_value + default_value = compiler.render_literal_value(default_arg, column.type) + except AttributeError: + # 如果失败,则回退到简单的字符串转换 + default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + else: + # 对于没有 render_literal_value 的旧版或特定方言 + default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + sql += f" DEFAULT {default_value}" - - # 添加非空约束(如果存在) + if not column.nullable: sql += " NOT NULL" @@ -109,12 +116,11 @@ async def check_and_migrate_database(): logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}") def add_indexes_sync(conn): - with conn.begin(): - for index_name in missing_indexes: - index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) - if index_obj is not None: - conn.execute(CreateIndex(index_obj)) - logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。") + for index_name in missing_indexes: + index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) + if index_obj is not None: + index_obj.create(conn) + logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。") await connection.run_sync(add_indexes_sync) else: diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 5e5e035dd..0b1984a3c 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,5 +1,5 @@ from typing import List, Dict, Any, Literal, Union -from pydantic import Field, field_validator +from pydantic import Field from threading import Lock from src.config.config_base import ValidatedConfigBase diff --git a/src/config/config.py b/src/config/config.py index 3fbd7e9e6..ac6204689 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -8,7 +8,7 @@ from tomlkit import TOMLDocument from tomlkit.items import Table, KeyType from rich.traceback import install from typing import List, Optional -from pydantic import Field, field_validator +from pydantic import Field from src.common.logger import get_logger from src.config.config_base import ValidatedConfigBase @@ -164,6 +164,18 @@ def _version_tuple(v): return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) +def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDocument | dict | Table): + """ + 递归地从目标字典中移除所有不存在于参考字典中的键。 + """ + # 使用 list() 创建键的副本,以便在迭代期间安全地修改字典 + for key in list(target.keys()): + if key not in reference: + del target[key] + elif isinstance(target.get(key), (dict, Table)) and isinstance(reference.get(key), (dict, Table)): + _remove_obsolete_keys(target[key], reference[key]) + + def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): """ 将source字典的值更新到target字典中 @@ -334,6 +346,13 @@ def _update_config_generic(config_name: str, template_name: str): logger.info(f"开始合并{config_name}新旧配置...") _update_dict(new_config, old_config) + # 移除在新模板中已不存在的旧配置项 + logger.info(f"开始移除{config_name}中已废弃的配置项...") + with open(template_path, "r", encoding="utf-8") as f: + template_doc = tomlkit.load(f) + _remove_obsolete_keys(new_config, template_doc) + logger.info(f"已移除{config_name}中已废弃的配置项") + # 保存更新后的配置(保留注释和格式) with open(new_config_path, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(new_config)) diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 4ab0af5f7..7b997b680 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -20,6 +20,26 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("AioHTTP-Gemini客户端") +# gemini_thinking参数(默认范围) +# 不同模型的思考预算范围配置 +THINKING_BUDGET_LIMITS = { + "gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True}, + "gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True}, + "gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False}, +} +# 思维预算特殊值 +THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定 +THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用) + +gemini_safe_settings = [ + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, +] + + def _format_to_mime_type(image_format: str) -> str: """ 将图片格式转换为正确的MIME类型 @@ -130,7 +150,11 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]: def _build_generation_config( - max_tokens: int, temperature: float, response_format: RespFormat | None = None, extra_params: dict | None = None + max_tokens: int, + temperature: float, + thinking_budget: int, + response_format: RespFormat | None = None, + extra_params: dict | None = None, ) -> dict: """构建生成配置""" config = { @@ -138,6 +162,8 @@ def _build_generation_config( "temperature": temperature, "topK": 1, "topP": 1, + "safetySettings": gemini_safe_settings, + "thinkingConfig": {"includeThoughts": True, "thinkingBudget": thinking_budget}, } # 处理响应格式 @@ -150,7 +176,11 @@ def _build_generation_config( # 合并额外参数 if extra_params: - config.update(extra_params) + # 拷贝一份以防修改原始字典 + safe_extra_params = extra_params.copy() + # 移除已单独处理的 thinking_budget + safe_extra_params.pop("thinking_budget", None) + config.update(safe_extra_params) return config @@ -317,6 +347,41 @@ class AiohttpGeminiClient(BaseClient): if api_provider.base_url: self.base_url = api_provider.base_url.rstrip("/") + @staticmethod + def clamp_thinking_budget(tb: int, model_id: str) -> int: + """ + 按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本) + """ + limits = None + + # 优先尝试精确匹配 + if model_id in THINKING_BUDGET_LIMITS: + limits = THINKING_BUDGET_LIMITS[model_id] + else: + # 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先 + sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True) + for key in sorted_keys: + # 必须满足:完全等于 或者 前缀匹配(带 "-" 边界) + if model_id == key or model_id.startswith(f"{key}-"): + limits = THINKING_BUDGET_LIMITS[key] + break + + # 特殊值处理 + if tb == THINKING_BUDGET_AUTO: + return THINKING_BUDGET_AUTO + if tb == THINKING_BUDGET_DISABLED: + if limits and limits.get("can_disable", False): + return THINKING_BUDGET_DISABLED + return limits["min"] if limits else THINKING_BUDGET_AUTO + + # 已知模型裁剪到范围 + if limits: + return max(limits["min"], min(tb, limits["max"])) + + # 未知模型,返回动态模式 + logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。") + return tb + # 移除全局 session,全部请求都用 with aiohttp.ClientSession() as session: async def _make_request( @@ -376,10 +441,21 @@ class AiohttpGeminiClient(BaseClient): # 转换消息格式 contents, system_instructions = _convert_messages(message_list) + # 处理思考预算 + tb = THINKING_BUDGET_AUTO + if extra_params and "thinking_budget" in extra_params: + try: + tb = int(extra_params["thinking_budget"]) + except (ValueError, TypeError): + logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}") + tb = self.clamp_thinking_budget(tb, model_info.model_identifier) + # 构建请求体 request_data = { "contents": contents, - "generationConfig": _build_generation_config(max_tokens, temperature, response_format, extra_params), + "generationConfig": _build_generation_config( + max_tokens, temperature, tb, response_format, extra_params + ), } # 添加系统指令 @@ -475,7 +551,7 @@ class AiohttpGeminiClient(BaseClient): request_data = { "contents": contents, - "generationConfig": _build_generation_config(2048, 0.1, None, extra_params), + "generationConfig": _build_generation_config(2048, 0.1, THINKING_BUDGET_AUTO, None, extra_params), } try: diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 146e5eb46..3efa9cd2d 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,7 +1,29 @@ +# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- +""" +@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。 +它被设计为一个高度容错和可扩展的系统,包含以下主要组件: + +- **模型选择器 (_ModelSelector)**: + 实现了基于负载均衡和失败惩罚的动态模型选择策略,确保在高并发或部分模型失效时系统的稳定性。 + +- **提示处理器 (_PromptProcessor)**: + 负责对输入模型的提示词进行预处理(如内容混淆、反截断指令注入)和对模型输出进行后处理(如提取思考过程、检查截断)。 + +- **请求执行器 (_RequestExecutor)**: + 封装了底层的API请求逻辑,包括自动重试、异常分类处理和消息体压缩等功能。 + +- **请求策略 (_RequestStrategy)**: + 实现了高阶请求策略,如模型间的故障转移(Failover),确保单个模型的失败不会导致整个请求失败。 + +- **LLMRequest (主接口)**: + 作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。 +""" import re import asyncio import time import random +import string from enum import Enum from rich.traceback import install @@ -13,7 +35,7 @@ from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message from .payload_content.resp_format import RespFormat from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType -from .model_client.base_client import BaseClient, APIResponse, client_registry +from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord from .utils import compress_messages, llm_usage_recorder from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException @@ -21,18 +43,9 @@ install(extra_lines=3) logger = get_logger("model_utils") -# 常见Error Code Mapping -error_code_mapping = { - 400: "参数不正确", - 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", - 402: "账号余额不足", - 403: "需要实名,或余额不足", - 404: "Not Found", - 429: "请求过于频繁,请稍后再试", - 500: "服务器内部故障", - 503: "服务器负载过高", -} - +# ============================================================================== +# Standalone Utility Functions +# ============================================================================== def _normalize_image_format(image_format: str) -> str: """ @@ -45,35 +58,17 @@ def _normalize_image_format(image_format: str) -> str: str: 标准化后的图片格式 """ format_mapping = { - "jpg": "jpeg", - "JPG": "jpeg", - "JPEG": "jpeg", - "jpeg": "jpeg", - "png": "png", - "PNG": "png", - "webp": "webp", - "WEBP": "webp", - "gif": "gif", - "GIF": "gif", - "heic": "heic", - "HEIC": "heic", - "heif": "heif", - "HEIF": "heif", + "jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg", + "png": "png", "PNG": "png", + "webp": "webp", "WEBP": "webp", + "gif": "gif", "GIF": "gif", + "heic": "heic", "HEIC": "heic", + "heif": "heif", "HEIF": "heif", } - normalized = format_mapping.get(image_format, image_format.lower()) logger.debug(f"图片格式标准化: {image_format} -> {normalized}") return normalized - -class RequestType(Enum): - """请求类型枚举""" - - RESPONSE = "response" - EMBEDDING = "embedding" - AUDIO = "audio" - - async def execute_concurrently( coro_callable: Callable[..., Coroutine[Any, Any, Any]], concurrency_count: int, @@ -97,7 +92,6 @@ async def execute_concurrently( """ logger.info(f"启用并发请求模式,并发数: {concurrency_count}") tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)] - results = await asyncio.gather(*tasks, return_exceptions=True) successful_results = [res for res in results if not isinstance(res, Exception)] @@ -110,34 +104,149 @@ async def execute_concurrently( for i, res in enumerate(results): if isinstance(res, Exception): logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}") - + first_exception = next((res for res in results if isinstance(res, Exception)), None) if first_exception: raise first_exception - raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息") +class RequestType(Enum): + """请求类型枚举""" + RESPONSE = "response" + EMBEDDING = "embedding" + AUDIO = "audio" -class LLMRequest: - """LLM请求类""" +# ============================================================================== +# Helper Classes for LLMRequest Refactoring +# ============================================================================== - def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: - self.task_name = request_type - self.model_for_task = model_set - self.request_type = request_type - self.model_usage: Dict[str, Tuple[int, int, int]] = { - model: (0, 0, 0) for model in self.model_for_task.model_list +class _ModelSelector: + """负责模型选择、负载均衡和动态故障切换的策略。""" + + CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 + DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 + + def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]): + """ + 初始化模型选择器。 + + Args: + model_list (List[str]): 可用模型名称列表。 + model_usage (Dict[str, Tuple[int, int, int]]): 模型的初始使用情况, + 格式为 {model_name: (total_tokens, penalty, usage_penalty)}。 + """ + self.model_list = model_list + self.model_usage = model_usage + + def select_best_available_model( + self, failed_models_in_this_request: set, request_type: str + ) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: + """ + 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。 + + Args: + failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。 + request_type (str): 请求类型,用于确定是否强制创建新客户端。 + + Returns: + Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: 选定的模型详细信息,如果无可用模型则返回 None。 + """ + candidate_models_usage = { + model_name: usage_data + for model_name, usage_data in self.model_usage.items() + if model_name not in failed_models_in_this_request } - """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" - # 内容混淆过滤指令 + if not candidate_models_usage: + logger.warning("没有可用的模型供当前请求选择。") + return None + + # 核心负载均衡算法:选择一个综合得分最低的模型。 + # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + # 设计思路: + # - `total_tokens`: 基础成本,优先使用累计token少的模型,实现长期均衡。 + # - `penalty * 300`: 失败惩罚项。每次失败会增加penalty,使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。 + # - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。 + least_used_model_name = min( + candidate_models_usage, + key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, + ) + + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + # 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。 + # 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。 + force_new_client = request_type == "embedding" + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + + logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}") + # 增加所选模型的请求使用惩罚值,以实现动态负载均衡。 + self.update_usage_penalty(model_info.name, increase=True) + return model_info, api_provider, client + + def update_usage_penalty(self, model_name: str, increase: bool): + """ + 更新模型的使用惩罚值。 + + 在模型被选中时增加惩罚值,请求完成后减少惩罚值。 + 这有助于在短期内将请求分散到不同的模型,实现更动态的负载均衡。 + + Args: + model_name (str): 要更新惩罚值的模型名称。 + increase (bool): True表示增加惩罚值,False表示减少。 + """ + # 获取当前模型的统计数据 + total_tokens, penalty, usage_penalty = self.model_usage[model_name] + # 根据操作是增加还是减少来确定调整量 + adjustment = 1 if increase else -1 + # 更新模型的惩罚值 + self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment) + + def update_failure_penalty(self, model_name: str, e: Exception): + """ + 根据异常类型动态调整模型的失败惩罚值。 + 关键错误(如网络连接、服务器错误)会获得更高的惩罚, + 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。 + """ + total_tokens, penalty, usage_penalty = self.model_usage[model_name] + penalty_increment = self.DEFAULT_PENALTY_INCREMENT + + # 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池 + if isinstance(e, (NetworkConnectionError, ReqAbortException)): + # 网络连接错误或请求被中断,通常是基础设施问题,应重罚 + penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER + logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}") + elif isinstance(e, RespNotOkException): + # 对于HTTP响应错误,重点关注服务器端错误 + if e.status_code >= 500: + # 5xx 错误表明服务器端出现问题,应重罚 + penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER + logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}") + else: + # 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚 + logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") + else: + # 其他未知异常,给予基础惩罚 + logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") + + self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) + + +class _PromptProcessor: + """封装所有与提示词和响应内容的预处理和后处理逻辑。""" + + def __init__(self): + """ + 初始化提示处理器。 + 这里定义了用于内容混淆和反截断的指令文本。 + """ + # 内容混淆过滤指令: 用于在API传输中增加噪音,防止内容被审查或修改 self.noise_instruction = """**【核心过滤规则】** 在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 你的任务是【完全并彻底地忽略】这些随机字符串。 **【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。 """ - # 反截断指令 self.end_marker = "###MAI_RESPONSE_END###" self.anti_truncation_instruction = f""" @@ -150,6 +259,502 @@ class LLMRequest: 这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 """ + def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str: + """ + 为请求准备最终的提示词。 + + 此方法会根据API提供商和模型配置,对原始提示词应用内容混淆和反截断指令, + 生成最终发送给模型的完整提示内容。 + + Args: + prompt (str): 原始的用户提示词。 + model_info (ModelInfo): 目标模型的信息。 + api_provider (APIProvider): API提供商的配置。 + task_name (str): 当前任务的名称,用于日志记录。 + + Returns: + str: 处理后的、可以直接发送给模型的完整提示词。 + """ + # 步骤1: 根据API提供商的配置应用内容混淆 + processed_prompt = self._apply_content_obfuscation(prompt, api_provider) + + # 步骤2: 检查模型是否需要注入反截断指令 + if getattr(model_info, "use_anti_truncation", False): + processed_prompt += self.anti_truncation_instruction + logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。") + + return processed_prompt + + def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]: + """ + 处理响应内容,提取思维链并检查截断。 + + Returns: + Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断) + """ + content, reasoning = self._extract_reasoning(content) + is_truncated = False + if use_anti_truncation: + if content.endswith(self.end_marker): + content = content[: -len(self.end_marker)].strip() + else: + is_truncated = True + return content, reasoning, is_truncated + + def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str: + """ + 根据API提供商的配置对文本进行内容混淆。 + + 如果提供商配置中启用了内容混淆,此方法会在文本前部加入抗审查指令, + 并在文本中注入随机噪音,以降低内容被审查或修改的风险。 + + Args: + text (str): 原始文本内容。 + api_provider (APIProvider): API提供商的配置。 + + Returns: + str: 经过混淆处理的文本。 + """ + # 检查当前API提供商是否启用了内容混淆功能 + if not getattr(api_provider, "enable_content_obfuscation", False): + return text + + # 获取混淆强度,默认为1 + intensity = getattr(api_provider, "obfuscation_intensity", 1) + logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") + + # 将抗审查指令和原始文本拼接 + processed_text = self.noise_instruction + "\n\n" + text + + # 在拼接后的文本中注入随机噪音 + return self._inject_random_noise(processed_text, intensity) + + @staticmethod + def _inject_random_noise(text: str, intensity: int) -> str: + """ + 在文本中按指定强度注入随机噪音字符串。 + + 该方法通过在文本的单词之间随机插入无意义的字符串(噪音)来实现内容混淆。 + 强度越高,插入噪音的概率和长度就越大。 + + Args: + text (str): 待处理的文本。 + intensity (int): 混淆强度 (1-3),决定噪音的概率和长度。 + + Returns: + str: 注入噪音后的文本。 + """ + # 定义不同强度级别的噪音参数:概率和长度范围 + params = { + 1: {"probability": 15, "length": (3, 6)}, # 低强度 + 2: {"probability": 25, "length": (5, 10)}, # 中强度 + 3: {"probability": 35, "length": (8, 15)}, # 高强度 + } + # 根据传入的强度选择配置,如果强度无效则使用默认值 + config = params.get(intensity, params[1]) + + words = text.split() + result = [] + # 遍历每个单词 + for word in words: + result.append(word) + # 根据概率决定是否在此单词后注入噪音 + if random.randint(1, 100) <= config["probability"]: + # 确定噪音的长度 + noise_length = random.randint(*config["length"]) + # 定义噪音字符集 + chars = string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?" + # 生成噪音字符串 + noise = "".join(random.choice(chars) for _ in range(noise_length)) + result.append(noise) + + # 将处理后的单词列表重新组合成字符串 + return " ".join(result) + + @staticmethod + def _extract_reasoning(content: str) -> Tuple[str, str]: + """ + 从模型返回的完整内容中提取被...标签包裹的思考过程, + 并返回清理后的内容和思考过程。 + + Args: + content (str): 模型返回的原始字符串。 + + Returns: + Tuple[str, str]: + - 清理后的内容(移除了标签及其内容)。 + - 提取出的思考过程文本(如果没有则为空字符串)。 + """ + # 使用正则表达式精确查找 ... 标签及其内容 + think_pattern = re.compile(r"(.*?)\s*", re.DOTALL) + match = think_pattern.search(content) + + if match: + # 提取思考过程 + reasoning = match.group(1).strip() + # 从原始内容中移除匹配到的整个部分(包括标签和后面的空白) + clean_content = think_pattern.sub("", content, count=1).strip() + else: + reasoning = "" + clean_content = content.strip() + + return clean_content, reasoning + + +class _RequestExecutor: + """负责执行实际的API请求,包含重试逻辑和底层异常处理。""" + + def __init__(self, model_selector: _ModelSelector, task_name: str): + """ + 初始化请求执行器。 + + Args: + model_selector (_ModelSelector): 模型选择器实例,用于在请求失败时更新惩罚。 + task_name (str): 当前任务的名称,用于日志记录。 + """ + self.model_selector = model_selector + self.task_name = task_name + + async def execute_request( + self, + api_provider: APIProvider, + client: BaseClient, + request_type: RequestType, + model_info: ModelInfo, + **kwargs, + ) -> APIResponse: + """ + 实际执行请求的方法,包含了重试和异常处理逻辑。 + + Args: + api_provider (APIProvider): API提供商配置。 + client (BaseClient): 用于发送请求的客户端实例。 + request_type (RequestType): 请求的类型 (e.g., RESPONSE, EMBEDDING)。 + model_info (ModelInfo): 正在使用的模型的信息。 + **kwargs: 传递给客户端方法的具体参数。 + + Returns: + APIResponse: 来自API的成功响应。 + + Raises: + Exception: 如果重试后请求仍然失败,则抛出最终的异常。 + RuntimeError: 如果达到最大重试次数。 + """ + retry_remain = api_provider.max_retry + compressed_messages: Optional[List[Message]] = None + + while retry_remain > 0: + try: + # 优先使用压缩后的消息列表 + message_list = kwargs.get("message_list") + current_messages = compressed_messages or message_list + + # 根据请求类型调用不同的客户端方法 + if request_type == RequestType.RESPONSE: + assert current_messages is not None, "message_list cannot be None for response requests" + + # 修复: 防止 'message_list' 在 kwargs 中重复传递 + request_params = kwargs.copy() + request_params.pop("message_list", None) + + return await client.get_response( + model_info=model_info, message_list=current_messages, **request_params + ) + elif request_type == RequestType.EMBEDDING: + return await client.get_embedding(model_info=model_info, **kwargs) + elif request_type == RequestType.AUDIO: + return await client.get_audio_transcriptions(model_info=model_info, **kwargs) + + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + # 记录失败并更新模型的惩罚值 + self.model_selector.update_failure_penalty(model_info.name, e) + + # 处理异常,决定是否重试以及等待多久 + wait_interval, new_compressed_messages = self._handle_exception( + e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None) + ) + if new_compressed_messages: + compressed_messages = new_compressed_messages # 更新为压缩后的消息 + + if wait_interval == -1: + raise e # 如果决定不再重试,则传播异常 + elif wait_interval > 0: + await asyncio.sleep(wait_interval) # 等待指定时间后重试 + finally: + retry_remain -= 1 + + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") + raise RuntimeError("请求失败,已达到最大重试次数") + + def _handle_exception( + self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info + ) -> Tuple[int, Optional[List[Message]]]: + """ + 默认异常处理函数,决定是否重试。 + + Returns: + (等待间隔(-1表示不再重试), 新的消息列表(适用于压缩消息)) + """ + model_name = model_info.name + retry_interval = api_provider.retry_interval + + if isinstance(e, (NetworkConnectionError, ReqAbortException)): + return self._check_retry(remain_try, retry_interval, "连接异常", model_name) + elif isinstance(e, RespNotOkException): + return self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) + elif isinstance(e, RespParseException): + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}") + return -1, None + else: + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}") + return -1, None + + def _handle_resp_not_ok( + self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info + ) -> Tuple[int, Optional[List[Message]]]: + """ + 处理非200的HTTP响应异常。 + + 根据不同的HTTP状态码决定下一步操作: + - 4xx 客户端错误:通常不可重试,直接放弃。 + - 413 (Payload Too Large): 尝试压缩消息体后重试一次。 + - 429 (Too Many Requests) / 5xx 服务器错误:可重试。 + + Args: + e (RespNotOkException): 捕获到的响应异常。 + model_info (ModelInfo): 当前模型信息。 + api_provider (APIProvider): API提供商配置。 + remain_try (int): 剩余重试次数。 + messages_info (tuple): 包含消息列表和是否已压缩的标志。 + + Returns: + Tuple[int, Optional[List[Message]]]: (等待间隔, 新的消息列表)。 + 等待间隔为-1表示不再重试。新的消息列表用于压缩后重试。 + """ + model_name = model_info.name + # 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试 + if e.status_code in [400, 401, 402, 403, 404]: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。") + return -1, None + # 处理请求体过大的情况 + elif e.status_code == 413: + messages, is_compressed = messages_info + # 如果消息存在且尚未被压缩,则尝试压缩后立即重试 + if messages and not is_compressed: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试。") + return 0, compress_messages(messages) + # 如果已经压缩过或没有消息体,则放弃 + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大且无法压缩,放弃请求。") + return -1, None + # 处理请求频繁或服务器端错误,这些情况适合重试 + elif e.status_code == 429 or e.status_code >= 500: + reason = "请求过于频繁" if e.status_code == 429 else "服务器错误" + return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name) + # 处理其他未知的HTTP错误 + else: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}") + return -1, None + + def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]: + """ + 辅助函数,根据剩余次数决定是否进行下一次重试。 + + Args: + remain_try (int): 剩余的重试次数。 + interval (int): 重试前的等待间隔(秒)。 + reason (str): 本次失败的原因。 + model_name (str): 失败的模型名称。 + + Returns: + Tuple[int, None]: (等待间隔, None)。如果等待间隔为-1,表示不应再重试。 + """ + # 只有在剩余重试次数大于1时才进行下一次重试(因为当前这次失败已经消耗掉一次) + if remain_try > 1: + logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。") + return interval, None + + # 如果已无剩余重试次数,则记录错误并返回-1表示放弃 + logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。") + return -1, None + + +class _RequestStrategy: + """ + 封装高级请求策略,如故障转移。 + 此类协调模型选择、提示处理和请求执行,以实现健壮的请求处理, + 即使在单个模型或API端点失败的情况下也能正常工作。 + """ + + def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str): + """ + 初始化请求策略。 + + Args: + model_selector (_ModelSelector): 模型选择器实例。 + prompt_processor (_PromptProcessor): 提示处理器实例。 + executor (_RequestExecutor): 请求执行器实例。 + model_list (List[str]): 可用模型列表。 + task_name (str): 当前任务的名称。 + """ + self.model_selector = model_selector + self.prompt_processor = prompt_processor + self.executor = executor + self.model_list = model_list + self.task_name = task_name + + async def execute_with_failover( + self, + request_type: RequestType, + raise_when_empty: bool = True, + **kwargs, + ) -> Tuple[APIResponse, ModelInfo]: + """ + 执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。 + """ + failed_models_in_this_request = set() + max_attempts = len(self.model_list) + last_exception: Optional[Exception] = None + + for attempt in range(max_attempts): + selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value)) + if selection_result is None: + logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。") + break + + model_info, api_provider, client = selection_result + logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...") + + try: + # 准备请求参数 + request_kwargs = kwargs.copy() + if request_type == RequestType.RESPONSE and "prompt" in request_kwargs: + prompt = request_kwargs.pop("prompt") + processed_prompt = self.prompt_processor.prepare_prompt( + prompt, model_info, api_provider, self.task_name + ) + message = MessageBuilder().add_text_content(processed_prompt).build() + request_kwargs["message_list"] = [message] + + # 合并模型特定的额外参数 + if model_info.extra_params: + request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})} + + response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs) + + # 成功,立即返回 + logger.debug(f"模型 '{model_info.name}' 成功生成了回复。") + self.model_selector.update_usage_penalty(model_info.name, increase=False) + return response, model_info + + except Exception as e: + logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。") + failed_models_in_this_request.add(model_info.name) + last_exception = e + # 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率 + + logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。") + if raise_when_empty: + if last_exception: + raise RuntimeError("所有模型均未能生成响应。") from last_exception + raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。") + + # 如果不抛出异常,返回一个备用响应 + fallback_model_info = model_config.get_model_info(self.model_list[0]) + return APIResponse(content="所有模型都请求失败"), fallback_model_info + + + async def _try_model_request( + self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs + ) -> APIResponse: + """ + 为单个模型尝试请求,包含空回复/截断的内部重试逻辑。 + 如果模型返回空回复或响应被截断,此方法将自动重试请求,直到达到最大重试次数。 + + Args: + model_info (ModelInfo): 要使用的模型信息。 + api_provider (APIProvider): API提供商信息。 + client (BaseClient): API客户端实例。 + request_type (RequestType): 请求类型。 + **kwargs: 传递给执行器的请求参数。 + + Returns: + APIResponse: 成功的API响应。 + + Raises: + RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。 + """ + max_empty_retry = api_provider.max_retry + + for i in range(max_empty_retry + 1): + response = await self.executor.execute_request( + api_provider, client, request_type, model_info, **kwargs + ) + + if request_type != RequestType.RESPONSE: + return response # 对于非响应类型,直接返回 + + # --- 响应内容处理和空回复/截断检查 --- + content = response.content or "" + use_anti_truncation = getattr(model_info, "use_anti_truncation", False) + processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation) + + # 更新响应对象 + response.content = processed_content + response.reasoning_content = response.reasoning_content or reasoning + + is_empty_reply = not response.tool_calls and not (response.content and response.content.strip()) + + if not is_empty_reply and not is_truncated: + return response # 成功获取有效响应 + + if i < max_empty_retry: + reason = "空回复" if is_empty_reply else "截断" + logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...") + if api_provider.retry_interval > 0: + await asyncio.sleep(api_provider.retry_interval) + else: + reason = "空回复" if is_empty_reply else "截断" + logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。") + raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。") + + raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里 + + +# ============================================================================== +# Main Facade Class +# ============================================================================== + +class LLMRequest: + """ + LLM请求协调器。 + 封装了模型选择、Prompt处理、请求执行和高级策略(如故障转移、并发)的完整流程。 + 为上层业务逻辑提供统一的、简化的接口来与大语言模型交互。 + """ + + def __init__(self, model_set: TaskConfig, request_type: str = ""): + """ + 初始化LLM请求协调器。 + + Args: + model_set (TaskConfig): 特定任务的模型配置集合。 + request_type (str, optional): 请求类型或任务名称,用于日志和用量记录。 Defaults to "". + """ + self.task_name = request_type + self.model_for_task = model_set + self.model_usage: Dict[str, Tuple[int, int, int]] = { + model: (0, 0, 0) for model in self.model_for_task.model_list + } + """模型使用量记录,(total_tokens, penalty, usage_penalty)""" + + # 初始化辅助类 + self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage) + self._prompt_processor = _PromptProcessor() + self._executor = _RequestExecutor(self._model_selector, self.task_name) + self._strategy = _RequestStrategy( + self._model_selector, self._prompt_processor, self._executor, self.model_for_task.model_list, self.task_name + ) + async def generate_response_for_image( self, prompt: str, @@ -159,77 +764,57 @@ class LLMRequest: max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 为图像生成响应 + 为图像生成响应。 + Args: prompt (str): 提示词 image_base64 (str): 图像的Base64编码字符串 image_format (str): 图像格式(如 'png', 'jpeg' 等) + Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ - # 标准化图片格式以确保API兼容性 - normalized_format = _normalize_image_format(image_format) - - # 模型选择 start_time = time.time() - model_info, api_provider, client = self._select_model() - - # 请求体构建 - message_builder = MessageBuilder() - message_builder.add_text_content(prompt) - message_builder.add_image_content( + + # 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行 + selection_result = self._model_selector.select_best_available_model(set(), "response") + if not selection_result: + raise RuntimeError("无法为图像响应选择可用模型。") + model_info, api_provider, client = selection_result + + normalized_format = _normalize_image_format(image_format) + message = MessageBuilder().add_text_content(prompt).add_image_content( image_base64=image_base64, image_format=normalized_format, support_formats=client.get_support_image_formats(), - ) - messages = [message_builder.build()] + ).build() - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, + response = await self._executor.execute_request( + api_provider, client, RequestType.RESPONSE, model_info, + message_list=[message], temperature=temperature, max_tokens=max_tokens, ) - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - # 从内容中提取标签的推理内容(向后兼容) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", - time_cost=time.time() - start_time, - request_type=self.request_type, - endpoint="/chat/completions", - ) - return content, (reasoning_content, model_info.name, tool_calls) + + self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") + content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False) + reasoning = response.reasoning_content or reasoning + + return content, (reasoning, model_info.name, response.tool_calls) async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: """ - 为语音生成响应 - Args: - voice_base64 (str): 语音的Base64编码字符串 - Returns: - (Optional[str]): 生成的文本描述或None - """ - # 模型选择 - model_info, api_provider, client = self._select_model() + 为语音生成响应(语音转文字)。 + 使用故障转移策略来确保即使主模型失败也能获得结果。 - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.AUDIO, - model_info=model_info, - audio_base64=voice_base64, + Args: + voice_base64 (str): 语音的Base64编码字符串。 + + Returns: + Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None。 + """ + response, _ = await self._strategy.execute_with_failover( + RequestType.AUDIO, audio_base64=voice_base64 ) return response.content or None @@ -242,44 +827,36 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 异步生成响应,支持并发请求 + 异步生成响应,支持并发请求。 + Args: prompt (str): 提示词 temperature (float, optional): 温度参数 max_tokens (int, optional): 最大token数 tools: 工具配置 - raise_when_empty: 是否在空回复时抛出异常 + raise_when_empty (bool): 是否在空回复时抛出异常 + Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ - # 检查是否需要并发请求 concurrency_count = getattr(self.model_for_task, "concurrency_count", 1) if concurrency_count <= 1: - # 单次请求 - return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty) - - # 并发请求 + return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty) + try: - # 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False, - # 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理 - content, (reasoning_content, model_name, tool_calls) = await execute_concurrently( - self._execute_single_request, + return await execute_concurrently( + self._execute_single_text_request, concurrency_count, - prompt, - temperature, - max_tokens, - tools, - raise_when_empty=False, + prompt, temperature, max_tokens, tools, raise_when_empty=False ) - return content, (reasoning_content, model_name, tool_calls) except Exception as e: logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}") if raise_when_empty: raise e return "所有并发请求都失败了", ("", "unknown", None) - async def _execute_single_request( + async def _execute_single_text_request( self, prompt: str, temperature: Optional[float] = None, @@ -288,569 +865,136 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 执行单次请求,并在模型失败时按顺序切换到下一个可用模型。 + 执行单次文本生成请求的内部方法。 + 这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期, + 包括工具构建、故障转移执行和用量记录。 + + Args: + prompt (str): 用户的提示。 + temperature (Optional[float]): 生成温度。 + max_tokens (Optional[int]): 最大生成令牌数。 + tools (Optional[List[Dict[str, Any]]]): 可用工具列表。 + raise_when_empty (bool): 如果响应为空是否引发异常。 + + Returns: + Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + (响应内容, (推理过程, 模型名称, 工具调用)) """ - failed_models = set() - last_exception: Optional[Exception] = None + start_time = time.time() + tool_options = self._build_tool_options(tools) - model_scheduler = self._model_scheduler(failed_models) + response, model_info = await self._strategy.execute_with_failover( + RequestType.RESPONSE, + raise_when_empty=raise_when_empty, + prompt=prompt, # 传递原始prompt,由strategy处理 + tool_options=tool_options, + temperature=self.model_for_task.temperature if temperature is None else temperature, + max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, + ) - for model_info, api_provider, client in model_scheduler: - start_time = time.time() - model_name = model_info.name - logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏 + self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") - try: - # 检查是否启用反截断 - # 检查是否为该模型启用反截断 - use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_prompt = prompt - if use_anti_truncation: - processed_prompt += self.anti_truncation_instruction - logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") + if not response.content and not response.tool_calls: + if raise_when_empty: + raise RuntimeError("所选模型生成了空回复。") + response.content = "生成的响应为空" - processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) - - message_builder = MessageBuilder() - message_builder.add_text_content(processed_prompt) - messages = [message_builder.build()] - tool_built = self._build_tool_options(tools) - - # 针对当前模型的空回复/截断重试逻辑 - empty_retry_count = 0 - max_empty_retry = api_provider.max_retry - empty_retry_interval = api_provider.retry_interval - - while empty_retry_count <= max_empty_retry: - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - tool_options=tool_built, - temperature=temperature, - max_tokens=max_tokens, - ) - - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - is_empty_reply = not tool_calls and (not content or content.strip() == "") - is_truncated = False - if use_anti_truncation: - if content.endswith(self.end_marker): - content = content[: -len(self.end_marker)].strip() - else: - is_truncated = True - - if is_empty_reply or is_truncated: - empty_retry_count += 1 - if empty_retry_count <= max_empty_retry: - reason = "空回复" if is_empty_reply else "截断" - logger.warning( - f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..." - ) - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) - continue # 继续使用当前模型重试 - else: - # 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型 - reason = "空回复" if is_empty_reply else "截断" - logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") - raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") - - # 成功获取响应 - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - time_cost=time.time() - start_time, - user_id="system", - request_type=self.request_type, - endpoint="/chat/completions", - ) - - if not content and not tool_calls: - if raise_when_empty: - raise RuntimeError("生成空回复") - content = "生成的响应为空" - - logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏 - return content, (reasoning_content, model_name, tool_calls) - - except RespNotOkException as e: - if e.status_code in [401, 403]: - logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 - else: - logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}") - if raise_when_empty: - raise - # 对于其他HTTP错误,直接抛出,不再尝试其他模型 - return f"请求失败: {e}", ("", model_name, None) - - except RuntimeError as e: - # 捕获所有重试失败(包括空回复和网络问题) - logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 - - except Exception as e: - logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 - - # 所有模型都尝试失败 - logger.error("所有可用模型都已尝试失败。") - if raise_when_empty: - if last_exception: - raise RuntimeError("所有模型都请求失败") from last_exception - raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") - - return "所有模型都请求失败", ("", "unknown", None) + return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """获取嵌入向量 + """ + 获取嵌入向量。 + Args: embedding_input (str): 获取嵌入的目标 + Returns: (Tuple[List[float], str]): (嵌入向量,使用的模型名称) """ - # 无需构建消息体,直接使用输入文本 start_time = time.time() - model_info, api_provider, client = self._select_model() - - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.EMBEDDING, - model_info=model_info, - embedding_input=embedding_input, + response, model_info = await self._strategy.execute_with_failover( + RequestType.EMBEDDING, + embedding_input=embedding_input ) - - embedding = response.embedding - - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - time_cost=time.time() - start_time, - model_usage=usage, - user_id="system", - request_type=self.request_type, - endpoint="/embeddings", - ) - - if not embedding: + + self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings") + + if not response.embedding: raise RuntimeError("获取embedding失败") + + return response.embedding, model_info.name - return embedding, model_info.name - - def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: + def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str): """ - 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 - """ - for model_name in self.model_for_task.model_list: - if model_name in failed_models: - continue + 记录模型使用情况。 - model_info = model_config.get_model_info(model_name) - api_provider = model_config.get_provider(model_info.api_provider) - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + 此方法首先在内存中更新模型的累计token使用量,然后创建一个异步任务, + 将详细的用量数据(包括模型信息、token数、耗时等)写入数据库。 - yield model_info, api_provider, client - - def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: - """ - 根据总tokens和惩罚值选择的模型 (负载均衡) - """ - least_used_model_name = min( - self.model_usage, - key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, - ) - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - logger.debug(f"选择请求模型: {model_info.name}") - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 - return model_info, api_provider, client - - async def _execute_request( - self, - api_provider: APIProvider, - client: BaseClient, - request_type: RequestType, - model_info: ModelInfo, - message_list: List[Message] | None = None, - tool_options: list[ToolOption] | None = None, - response_format: RespFormat | None = None, - stream_response_handler: Optional[Callable] = None, - async_response_parser: Optional[Callable] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - embedding_input: str = "", - audio_base64: str = "", - ) -> APIResponse: - """ - 实际执行请求的方法 - - 包含了重试和异常处理逻辑 - """ - retry_remain = api_provider.max_retry - compressed_messages: Optional[List[Message]] = None - while retry_remain > 0: - try: - if request_type == RequestType.RESPONSE: - assert message_list is not None, "message_list cannot be None for response requests" - return await client.get_response( - model_info=model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, - temperature=self.model_for_task.temperature if temperature is None else temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.EMBEDDING: - assert embedding_input, "embedding_input cannot be empty for embedding requests" - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.AUDIO: - assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" - return await client.get_audio_transcriptions( - model_info=model_info, - audio_base64=audio_base64, - extra_params=model_info.extra_params, - ) - except Exception as e: - logger.debug(f"请求失败: {str(e)}") - # 处理异常 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) - - wait_interval, compressed_messages = self._default_exception_handler( - e, - self.task_name, - model_info=model_info, - api_provider=api_provider, - remain_try=retry_remain, - retry_interval=api_provider.retry_interval, - messages=(message_list, compressed_messages is not None) if message_list else None, - ) - - if wait_interval == -1: - retry_remain = 0 # 不再重试 - elif wait_interval > 0: - logger.info(f"等待 {wait_interval} 秒后重试...") - await asyncio.sleep(wait_interval) - finally: - # 放在finally防止死循环 - retry_remain -= 1 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 - logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") - raise RuntimeError("请求失败,已达到最大重试次数") - - def _default_exception_handler( - self, - e: Exception, - task_name: str, - model_info: ModelInfo, - api_provider: APIProvider, - remain_try: int, - retry_interval: int = 10, - messages: Tuple[List[Message], bool] | None = None, - ) -> Tuple[int, List[Message] | None]: - """ - 默认异常处理函数 Args: - e (Exception): 异常对象 - task_name (str): 任务名称 - model_info (ModelInfo): 模型信息 - api_provider (APIProvider): API提供商 - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) - Returns: - (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + model_info (ModelInfo): 使用的模型信息。 + usage (Optional[UsageRecord]): API返回的用量记录。 + time_cost (float): 本次请求的总耗时。 + endpoint (str): 请求的API端点 (e.g., "/chat/completions")。 """ - model_name = model_info.name if model_info else "unknown" - - if isinstance(e, NetworkConnectionError): # 网络连接错误 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确", - ) - elif isinstance(e, ReqAbortException): - logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") - return -1, None # 不再重试请求该模型 - elif isinstance(e, RespNotOkException): - return self._handle_resp_not_ok( - e, - task_name, - model_info, - api_provider, - remain_try, - retry_interval, - messages, - ) - elif isinstance(e, RespParseException): - # 响应解析错误 - logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") - logger.debug(f"附加内容: {str(e.ext_info)}") - return -1, None # 不再重试请求该模型 - else: - logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") - return -1, None # 不再重试请求该模型 - - @staticmethod - def _check_retry( - remain_try: int, - retry_interval: int, - can_retry_msg: str, - cannot_retry_msg: str, - can_retry_callable: Callable | None = None, - **kwargs, - ) -> Tuple[int, List[Message] | None]: - """辅助函数:检查是否可以重试 - Args: - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - can_retry_msg (str): 可以重试时的提示信息 - cannot_retry_msg (str): 不可以重试时的提示信息 - can_retry_callable (Callable | None): 可以重试时调用的函数(如果有) - **kwargs: 其他参数 - - Returns: - (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - if remain_try > 0: - # 还有重试机会 - logger.warning(f"{can_retry_msg}") - if can_retry_callable is not None: - return retry_interval, can_retry_callable(**kwargs) - else: - return retry_interval, None - else: - # 达到最大重试次数 - logger.warning(f"{cannot_retry_msg}") - return -1, None # 不再重试请求该模型 - - def _handle_resp_not_ok( - self, - e: RespNotOkException, - task_name: str, - model_info: ModelInfo, - api_provider: APIProvider, - remain_try: int, - retry_interval: int = 10, - messages: tuple[list[Message], bool] | None = None, - ): - model_name = model_info.name - """ - 处理响应错误异常 - Args: - e (RespNotOkException): 响应错误异常对象 - task_name (str): 任务名称 - model_info (ModelInfo): 模型信息 - api_provider (APIProvider): API提供商 - remain_try (int): 剩余尝试次数 - retry_interval (int): 重试间隔 - messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) - Returns: - (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - # 响应错误 - if e.status_code in [400, 401, 402, 403, 404]: - model_name = model_info.name - if ( - e.status_code == 403 - and model_name.startswith("Pro/deepseek-ai") - and api_provider.base_url == "https://api.siliconflow.cn/v1/" - ): - old_model_name = model_name - new_model_name = model_name[4:] - model_info.name = new_model_name - logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {new_model_name}") - # 更新任务配置中的模型列表 - for i, m_name in enumerate(self.model_for_task.model_list): - if m_name == old_model_name: - self.model_for_task.model_list[i] = new_model_name - logger.warning( - f"将任务 {self.task_name} 的模型列表中的 {old_model_name} 临时降级至 {new_model_name}" - ) - break - return 0, None # 立即重试 - # 客户端错误 - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None # 不再重试请求该模型 - elif e.status_code == 413: - if messages and not messages[1]: - # 消息列表不为空且未压缩,尝试压缩消息 - return self._check_retry( - remain_try, - 0, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求", - can_retry_callable=compress_messages, - messages=messages[0], - ) - # 没有消息可压缩 - logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。") - return -1, None - elif e.status_code == 429: - # 请求过于频繁 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求", - ) - elif e.status_code >= 500: - # 服务器错误 - return self._check_retry( - remain_try, - retry_interval, - can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试", - cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试", - ) - else: - # 未知错误 - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None + if usage: + # 步骤1: 更新内存中的token计数,用于负载均衡 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty) + + # 步骤2: 创建一个后台任务,将用量数据异步写入数据库 + asyncio.create_task(llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", # 此处可根据业务需求修改 + time_cost=time_cost, + request_type=self.task_name, + endpoint=endpoint, + )) @staticmethod def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: - # sourcery skip: extract-method - """构建工具选项列表""" + """ + 根据输入的字典列表构建并验证 `ToolOption` 对象列表。 + + 此方法将标准化的工具定义(字典格式)转换为内部使用的 `ToolOption` 对象, + 同时会验证参数格式的正确性。 + + Args: + tools (Optional[List[Dict[str, Any]]]): 工具定义的列表。 + 每个工具是一个字典,包含 "name", "description", 和 "parameters"。 + "parameters" 是一个元组列表,每个元组包含 (name, type, desc, required, enum)。 + + Returns: + Optional[List[ToolOption]]: 构建好的 `ToolOption` 对象列表,如果输入为空则返回 None。 + """ + # 如果没有提供工具,直接返回 None if not tools: return None + tool_options: List[ToolOption] = [] + # 遍历每个工具定义 for tool in tools: - tool_legal = True - tool_options_builder = ToolOptionBuilder() - tool_options_builder.set_name(tool.get("name", "")) - tool_options_builder.set_description(tool.get("description", "")) - parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) - for param in parameters: - try: + try: + # 使用建造者模式创建 ToolOption + builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", "")) + + # 遍历工具的参数 + for param in tool.get("parameters", []): + # 严格验证参数格式是否为包含5个元素的元组 assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" - assert isinstance(param[0], str), "参数名称必须是字符串" - assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" - assert isinstance(param[2], str), "参数描述必须是字符串" - assert isinstance(param[3], bool), "参数是否必填必须是布尔值" - assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" - tool_options_builder.add_param( + builder.add_param( name=param[0], param_type=param[1], description=param[2], required=param[3], enum_values=param[4], ) - except AssertionError as ae: - tool_legal = False - logger.error(f"{param[0]} 参数定义错误: {str(ae)}") - except Exception as e: - tool_legal = False - logger.error(f"构建工具参数失败: {str(e)}") - if tool_legal: - tool_options.append(tool_options_builder.build()) + # 将构建好的 ToolOption 添加到列表中 + tool_options.append(builder.build()) + except (KeyError, IndexError, TypeError, AssertionError) as e: + # 如果构建过程中出现任何错误,记录日志并跳过该工具 + logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}") + + # 如果列表非空则返回列表,否则返回 None return tool_options or None - - @staticmethod - def _extract_reasoning(content: str) -> Tuple[str, str]: - """CoT思维链提取,向后兼容""" - match = re.search(r"(?:)?(.*?)", content, re.DOTALL) - content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() - reasoning = match[1].strip() if match else "" - return content, reasoning - - def _apply_content_obfuscation(self, text: str, api_provider) -> str: - """根据API提供商配置对文本进行混淆处理""" - if not hasattr(api_provider, "enable_content_obfuscation") or not api_provider.enable_content_obfuscation: - logger.debug(f"API提供商 '{api_provider.name}' 未启用内容混淆") - return text - - intensity = getattr(api_provider, "obfuscation_intensity", 1) - logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}") - - # 在开头加入过滤规则指令 - processed_text = self.noise_instruction + "\n\n" + text - logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}") - - # 添加随机乱码 - final_text = self._inject_random_noise(processed_text, intensity) - logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}") - - return final_text - - @staticmethod - def _inject_random_noise(text: str, intensity: int) -> str: - """在文本中注入随机乱码""" - import random - import string - - def generate_noise(length: int) -> str: - """生成指定长度的随机乱码字符""" - chars = ( - string.ascii_letters # a-z, A-Z - + string.digits # 0-9 - + "!@#$%^&*()_+-=[]{}|;:,.<>?" # 特殊符号 - + "一二三四五六七八九零壹贰叁" # 中文字符 - + "αβγδεζηθικλμνξοπρστυφχψω" # 希腊字母 - + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" # 数学符号 - ) - return "".join(random.choice(chars) for _ in range(length)) - - # 强度参数映射 - params = { - 1: {"probability": 15, "length": (3, 6)}, # 低强度:15%概率,3-6个字符 - 2: {"probability": 25, "length": (5, 10)}, # 中强度:25%概率,5-10个字符 - 3: {"probability": 35, "length": (8, 15)}, # 高强度:35%概率,8-15个字符 - } - - config = params.get(intensity, params[1]) - logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}") - - # 按词分割处理 - words = text.split() - result = [] - noise_count = 0 - - for word in words: - result.append(word) - # 根据概率插入乱码 - if random.randint(1, 100) <= config["probability"]: - noise_length = random.randint(*config["length"]) - noise = generate_noise(noise_length) - result.append(noise) - noise_count += 1 - - logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}") - return " ".join(result) diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 5ffae7298..e74044866 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -84,6 +84,7 @@ async def generate_reply( return_prompt: bool = False, request_type: str = "generator_api", from_plugin: bool = True, + read_mark: float = 0.0, ) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: """生成回复 @@ -129,6 +130,7 @@ async def generate_reply( from_plugin=from_plugin, stream_id=chat_stream.stream_id if chat_stream else chat_id, reply_message=reply_message, + read_mark=read_mark, ) if not success: logger.warning("[GeneratorAPI] 回复生成失败") diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 9ef95182d..12797bafd 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -5,6 +5,7 @@ import toml import orjson import shutil import datetime +from pathlib import Path from src.common.logger import get_logger from src.config.config import CONFIG_DIR @@ -268,100 +269,64 @@ class PluginBase(ABC): except IOError as e: logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True) - def _get_expected_config_version(self) -> str: - """获取插件期望的配置版本号""" - # 从config_schema的plugin.config_version字段获取 - if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict): - config_version_field = self.config_schema["plugin"].get("config_version") - if isinstance(config_version_field, ConfigField): - return config_version_field.default - return "1.0.0" - - @staticmethod - def _get_current_config_version(config: Dict[str, Any]) -> str: - """从配置文件中获取当前版本号""" - if "plugin" in config and "config_version" in config["plugin"]: - return str(config["plugin"]["config_version"]) - # 如果没有config_version字段,视为最早的版本 - return "0.0.0" - def _backup_config_file(self, config_file_path: str) -> str: - """备份配置文件""" - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - backup_path = f"{config_file_path}.backup_{timestamp}" - + """备份配置文件到指定的 backup 子目录""" try: + config_path = Path(config_file_path) + backup_dir = config_path.parent / "backup" + backup_dir.mkdir(exist_ok=True) + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + backup_filename = f"{config_path.name}.backup_{timestamp}" + backup_path = backup_dir / backup_filename + shutil.copy2(config_file_path, backup_path) logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}") - return backup_path + return str(backup_path) except Exception as e: - logger.error(f"{self.log_prefix} 备份配置文件失败: {e}") + logger.error(f"{self.log_prefix} 备份配置文件失败: {e}", exc_info=True) return "" - def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]: - """将旧配置值迁移到新配置结构中 + def _synchronize_config( + self, schema_config: Dict[str, Any], user_config: Dict[str, Any] + ) -> tuple[Dict[str, Any], bool]: + """递归地将用户配置与 schema 同步,返回同步后的配置和是否发生变化的标志""" + changed = False - Args: - old_config: 旧配置数据 - new_config: 基于新schema生成的默认配置 - - Returns: - Dict[str, Any]: 迁移后的配置 - """ - - def migrate_section( - old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str + # 内部递归函数 + def _sync_dicts( + schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = "" ) -> Dict[str, Any]: - """迁移单个配置节""" - result = new_section.copy() + nonlocal changed + synced_dict = schema_dict.copy() - for key, value in old_section.items(): - if key in new_section: - # 特殊处理:config_version字段总是使用新版本 - if section_name == "plugin" and key == "config_version": - # 保持新的版本号,不迁移旧值 - logger.debug( - f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})" - ) - continue + # 检查并记录用户配置中多余的、在 schema 中不存在的键 + for key in user_dict: + if key not in schema_dict: + logger.warning(f"{self.log_prefix} 发现废弃配置项 '{parent_key}{key}',将被移除。") + changed = True - # 键存在于新配置中,复制值 - if isinstance(value, dict) and isinstance(new_section[key], dict): - # 递归处理嵌套字典 - result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}") + # 以 schema 为基准进行遍历,保留用户的值,补全缺失的项 + for key, schema_value in schema_dict.items(): + full_key = f"{parent_key}{key}" + if key in user_dict: + user_value = user_dict[key] + if isinstance(schema_value, dict) and isinstance(user_value, dict): + # 递归同步嵌套的字典 + synced_dict[key] = _sync_dicts(schema_value, user_value, f"{full_key}.") else: - result[key] = value - logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}") + # 键存在,保留用户的值 + synced_dict[key] = user_value else: - # 键在新配置中不存在,记录警告 - logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除") + # 键在用户配置中缺失,补全 + logger.info(f"{self.log_prefix} 补全缺失的配置项: '{full_key}' = {schema_value}") + changed = True + # synced_dict[key] 已经包含了来自 schema_dict.copy() 的默认值 - return result + return synced_dict - migrated_config = {} - - # 迁移每个配置节 - for section_name, new_section_data in new_config.items(): - if ( - section_name in old_config - and isinstance(old_config[section_name], dict) - and isinstance(new_section_data, dict) - ): - migrated_config[section_name] = migrate_section( - old_config[section_name], new_section_data, section_name - ) - else: - # 新增的节或类型不匹配,使用默认值 - migrated_config[section_name] = new_section_data - if section_name in old_config: - logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值") - - # 检查旧配置中是否有新配置没有的节 - for section_name in old_config: - if section_name not in migrated_config: - logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除") - - return migrated_config + final_config = _sync_dicts(schema_config, user_config) + return final_config, changed def _generate_config_from_schema(self) -> Dict[str, Any]: # sourcery skip: dict-comprehension @@ -393,11 +358,7 @@ class PluginBase(ABC): toml_str = f"# {self.plugin_name} - 配置文件\n" plugin_description = self.get_manifest_info("description", "插件配置文件") - toml_str += f"# {plugin_description}\n" - - # 获取当前期望的配置版本 - expected_version = self._get_expected_config_version() - toml_str += f"# 配置版本: {expected_version}\n\n" + toml_str += f"# {plugin_description}\n\n" # 遍历每个配置节 for section, fields in self.config_schema.items(): @@ -456,77 +417,74 @@ class PluginBase(ABC): def _load_plugin_config(self): # sourcery skip: extract-method """ - 加载插件配置文件,实现集中化管理和自动迁移。 + 加载并同步插件配置文件。 处理逻辑: - 1. 确定用户配置文件路径(位于 `config/plugins/` 目录下)。 - 2. 如果用户配置文件不存在,则根据 config_schema 直接在中央目录生成一份。 - 3. 加载用户配置文件,并进行版本检查和自动迁移(如果需要)。 - 4. 最终加载的配置是用户配置文件。 + 1. 确定用户配置文件路径和插件自带的配置文件路径。 + 2. 如果用户配置文件不存在,尝试从插件目录迁移(移动)一份。 + 3. 如果迁移后(或原本)用户配置文件仍不存在,则根据 schema 生成一份。 + 4. 加载用户配置文件。 + 5. 以 schema 为基准,与用户配置进行同步,补全缺失项并移除废弃项。 + 6. 如果同步过程发现不一致,则先备份原始文件,然后将同步后的完整配置写回用户目录。 + 7. 将最终同步后的配置加载到 self.config。 """ if not self.config_file_name: logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载") return - # 1. 确定并确保用户配置文件路径存在 user_config_path = os.path.join(CONFIG_DIR, "plugins", self.plugin_name, self.config_file_name) + plugin_config_path = os.path.join(self.plugin_dir, self.config_file_name) os.makedirs(os.path.dirname(user_config_path), exist_ok=True) - # 2. 如果用户配置文件不存在,直接在中央目录生成 + # 首次加载迁移:如果用户配置不存在,但插件目录中存在,则移动过来 + if not os.path.exists(user_config_path) and os.path.exists(plugin_config_path): + try: + shutil.move(plugin_config_path, user_config_path) + logger.info(f"{self.log_prefix} 已将配置文件从 {plugin_config_path} 迁移到 {user_config_path}") + except OSError as e: + logger.error(f"{self.log_prefix} 迁移配置文件失败: {e}", exc_info=True) + + # 如果用户配置文件仍然不存在,生成默认的 if not os.path.exists(user_config_path): logger.info(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在,将生成默认配置。") self._generate_and_save_default_config(user_config_path) - # 检查最终的用户配置文件是否存在 if not os.path.exists(user_config_path): - # 如果插件没有定义config_schema,那么不创建文件是正常行为 if not self.config_schema: - logger.debug(f"{self.log_prefix} 插件未定义config_schema,使用空的配置.") + logger.debug(f"{self.log_prefix} 插件未定义 config_schema,使用空配置。") self.config = {} - return - - logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。") + else: + logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。") return - # 3. 加载、检查和迁移用户配置文件 - _, file_ext = os.path.splitext(self.config_file_name) - if file_ext.lower() != ".toml": - logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml") - self.config = {} - return try: with open(user_config_path, "r", encoding="utf-8") as f: - existing_config = toml.load(f) or {} + user_config = toml.load(f) or {} except Exception as e: logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True) - self.config = {} + self.config = self._generate_config_from_schema() # 加载失败时使用默认 schema return - current_version = self._get_current_config_version(existing_config) - expected_version = self._get_expected_config_version() + # 生成基于 schema 的理想配置结构 + schema_config = self._generate_config_from_schema() - if current_version == "0.0.0": - logger.debug(f"{self.log_prefix} 用户配置文件无版本信息,跳过版本检查") - self.config = existing_config - elif current_version != expected_version: - logger.info( - f"{self.log_prefix} 检测到用户配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}" - ) - new_config_structure = self._generate_config_from_schema() - migrated_config = self._migrate_config_values(existing_config, new_config_structure) - self._save_config_to_file(migrated_config, user_config_path) - logger.info(f"{self.log_prefix} 用户配置文件已从 v{current_version} 更新到 v{expected_version}") - self.config = migrated_config - else: - logger.debug(f"{self.log_prefix} 用户配置版本匹配 (v{current_version}),直接加载") - self.config = existing_config + # 将用户配置与 schema 同步 + synced_config, was_changed = self._synchronize_config(schema_config, user_config) - logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载") + # 如果配置发生了变化(补全或移除),则备份并重写配置文件 + if was_changed: + logger.info(f"{self.log_prefix} 检测到配置结构不匹配,将自动同步并更新配置文件。") + self._backup_config_file(user_config_path) + self._save_config_to_file(synced_config, user_config_path) + logger.info(f"{self.log_prefix} 配置文件已同步更新。") - # 从配置中更新 enable_plugin 状态 + self.config = synced_config + logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载并同步。") + + # 从最终配置中更新插件启用状态 if "plugin" in self.config and "enabled" in self.config["plugin"]: self._is_enabled = self.config["plugin"]["enabled"] - logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}") + logger.info(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}") def _check_dependencies(self) -> bool: """检查插件依赖""" diff --git a/src/plugin_system/base/plus_plugin.py b/src/plugin_system/base/plus_plugin.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 237cb6429..e0a39ac25 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -39,76 +39,6 @@ class PluginManager: self._ensure_plugin_directories() logger.info("插件管理器初始化完成") - def _synchronize_plugin_config(self, plugin_name: str, plugin_dir: str): - """ - 同步单个插件的配置。 - - 此过程确保中央配置与插件本地配置保持同步,包含两个主要步骤: - 1. 如果中央配置不存在,则从插件目录复制默认配置到中央配置目录。 - 2. 使用中央配置覆盖插件的本地配置,以确保插件运行时使用的是最新的用户配置。 - """ - try: - plugin_path = Path(plugin_dir) - # 修正:插件的配置文件路径应为 config.toml 文件,而不是目录 - plugin_config_file = plugin_path / "config.toml" - central_config_dir = Path("config") / "plugins" / plugin_name - - # 确保中央配置目录存在 - central_config_dir.mkdir(parents=True, exist_ok=True) - - # 步骤 1: 从插件目录复制默认配置到中央目录 - self._copy_default_config_to_central(plugin_name, plugin_config_file, central_config_dir) - - # 步骤 2: 从中央目录同步配置到插件目录 - self._sync_central_config_to_plugin(plugin_name, plugin_config_file, central_config_dir) - - except OSError as e: - logger.error(f"处理插件 '{plugin_name}' 的配置时发生文件操作错误: {e}") - except Exception as e: - logger.error(f"同步插件 '{plugin_name}' 配置时发生未知错误: {e}") - - @staticmethod - def _copy_default_config_to_central(plugin_name: str, plugin_config_file: Path, central_config_dir: Path): - """ - 如果中央配置不存在,则将插件的默认 config.toml 复制到中央目录。 - """ - if not plugin_config_file.is_file(): - return # 插件没有提供默认配置文件,直接跳过 - - central_config_file = central_config_dir / plugin_config_file.name - if not central_config_file.exists(): - shutil.copy2(plugin_config_file, central_config_file) - logger.info(f"为插件 '{plugin_name}' 从模板复制了默认配置: {plugin_config_file.name}") - - def _sync_central_config_to_plugin(self, plugin_name: str, plugin_config_file: Path, central_config_dir: Path): - """ - 将中央配置同步(覆盖)到插件的本地配置。 - """ - # 遍历中央配置目录中的所有文件 - for central_file in central_config_dir.iterdir(): - if not central_file.is_file(): - continue - - # 目标文件应与中央配置文件同名,这里我们强制它为 config.toml - target_plugin_file = plugin_config_file - - # 仅在文件内容不同时才执行复制,以减少不必要的IO操作 - if not self._is_file_content_identical(central_file, target_plugin_file): - shutil.copy2(central_file, target_plugin_file) - logger.info(f"已将中央配置 '{central_file.name}' 同步到插件 '{plugin_name}'") - - @staticmethod - def _is_file_content_identical(file1: Path, file2: Path) -> bool: - """ - 通过比较 MD5 哈希值检查两个文件的内容是否相同。 - """ - if not file2.exists(): - return False # 目标文件不存在,视为不同 - - # 使用 'rb' 模式以二进制方式读取文件,确保哈希值计算的一致性 - with open(file1, "rb") as f1, open(file2, "rb") as f2: - return hashlib.md5(f1.read()).hexdigest() == hashlib.md5(f2.read()).hexdigest() - # === 插件目录管理 === def add_plugin_directory(self, directory: str) -> bool: @@ -176,8 +106,6 @@ class PluginManager: if not plugin_dir: return False, 1 - # 同步插件配置 - self._synchronize_plugin_config(plugin_name, plugin_dir) plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) if not plugin_instance: diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 45357b4b0..990f1c91c 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -7,7 +7,6 @@ from functools import wraps from typing import Callable, Optional from inspect import iscoroutinefunction -import inspect from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.send_api import text_to_stream diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 3ebf4610a..e8ffba68e 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -255,7 +255,7 @@ class EmojiAction(BaseAction): if not success: logger.error(f"{self.log_prefix} 表情包发送失败") - await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False) + await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包,但失败了",action_done= False) return False, "表情包发送失败" # 发送成功后,记录到历史 @@ -264,7 +264,7 @@ class EmojiAction(BaseAction): except Exception as e: logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") - await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True) + await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包",action_done= True) return True, f"发送表情包: {emoji_description}" diff --git a/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py b/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py index 819655e84..631ca430d 100644 --- a/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py +++ b/src/plugins/built_in/maizone_refactored/commands/send_feed_command.py @@ -53,6 +53,6 @@ class SendFeedCommand(PlusCommand): return False, result.get("message", "未知错误"), True except Exception as e: - logger.error(f"执行发送说说命令时发生未知异常: {e}", exc_info=True) + logger.error(f"执行发送说说命令时发生未知异常: {e},它的类型是:{type(e)}", exc_info=True) await self.send_text("呜... 发送过程中好像出了点问题。") return False, "命令执行异常", True diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index 9f7da7ccf..27f2a0ee9 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -119,12 +119,10 @@ class ContentService: logger.error(f"生成说说内容时发生异常: {e}") return "" - async def generate_comment(self, content: str, target_name: str, rt_con: str = "", images=None) -> str: + async def generate_comment(self, content: str, target_name: str, rt_con: str = "", images: list = []) -> str: """ 针对一条具体的说说内容生成评论。 """ - if images is None: - images = [] for i in range(3): # 重试3次 try: chat_manager = get_chat_manager() @@ -182,8 +180,7 @@ class ContentService: return "" return "" - @staticmethod - async def generate_comment_reply(story_content: str, comment_content: str, commenter_name: str) -> str: + async def generate_comment_reply(self, story_content: str, comment_content: str, commenter_name: str) -> str: """ 针对自己说说的评论,生成回复。 """ diff --git a/src/plugins/built_in/maizone_refactored/services/cookie_service.py b/src/plugins/built_in/maizone_refactored/services/cookie_service.py index 1c61a29fd..b4aedf322 100644 --- a/src/plugins/built_in/maizone_refactored/services/cookie_service.py +++ b/src/plugins/built_in/maizone_refactored/services/cookie_service.py @@ -50,8 +50,7 @@ class CookieService: logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}") return None - @staticmethod - async def _get_cookies_from_adapter(stream_id: Optional[str]) -> Optional[Dict[str, str]]: + async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]: """通过Adapter API获取Cookie""" try: params = {"domain": "user.qzone.qq.com"} diff --git a/src/plugins/built_in/maizone_refactored/services/image_service.py b/src/plugins/built_in/maizone_refactored/services/image_service.py index 1ffcd7d70..cbb411da7 100644 --- a/src/plugins/built_in/maizone_refactored/services/image_service.py +++ b/src/plugins/built_in/maizone_refactored/services/image_service.py @@ -59,8 +59,7 @@ class ImageService: logger.error(f"处理AI配图时发生异常: {e}") return False - @staticmethod - async def _call_siliconflow_api(api_key: str, story: str, image_dir: str, batch_size: int) -> bool: + async def _call_siliconflow_api(self, api_key: str, story: str, image_dir: str, batch_size: int) -> bool: """ 调用硅基流动(SiliconFlow)的API来生成图片。 diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 67a3669db..752e27dfa 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -187,8 +187,7 @@ class QZoneService: # --- Internal Helper Methods --- - @staticmethod - async def _get_intercom_context(stream_id: str) -> Optional[str]: + async def _get_intercom_context(self, stream_id: str) -> Optional[str]: """ 根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。 @@ -399,8 +398,7 @@ class QZoneService: logger.error(f"加载本地图片失败: {e}") return [] - @staticmethod - def _generate_gtk(skey: str) -> str: + def _generate_gtk(self, skey: str) -> str: hash_val = 5381 for char in skey: hash_val += (hash_val << 5) + ord(char) @@ -437,8 +435,7 @@ class QZoneService: logger.error(f"更新或加载Cookie时发生异常: {e}") return None - @staticmethod - async def _fetch_cookies_http(host: str, port: str, napcat_token: str) -> Optional[Dict]: + async def _fetch_cookies_http(self, host: str, port: str, napcat_token: str) -> Optional[Dict]: """通过HTTP服务器获取Cookie""" url = f"http://{host}:{port}/get_cookies" max_retries = 5 @@ -657,20 +654,30 @@ class QZoneService: end_idx = resp_text.rfind("}") + 1 if start_idx != -1 and end_idx != -1: json_str = resp_text[start_idx:end_idx] - upload_result = eval(json_str) # 与原版保持一致使用eval + try: + upload_result = orjson.loads(json_str) + except orjson.JSONDecodeError: + logger.error(f"图片上传响应JSON解析失败,原始响应: {resp_text}") + return None - logger.info(f"图片上传解析结果: {upload_result}") + logger.debug(f"图片上传解析结果: {upload_result}") if upload_result.get("ret") == 0: - # 使用原版的参数提取逻辑 - picbo, richval = _get_picbo_and_richval(upload_result) - logger.info(f"图片 {index + 1} 上传成功: picbo={picbo}") - return {"pic_bo": picbo, "richval": richval} + try: + # 使用原版的参数提取逻辑 + picbo, richval = _get_picbo_and_richval(upload_result) + logger.info(f"图片 {index + 1} 上传成功: picbo={picbo}") + return {"pic_bo": picbo, "richval": richval} + except Exception as e: + logger.error( + f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}", exc_info=True + ) + return None else: logger.error(f"图片 {index + 1} 上传失败: {upload_result}") return None else: - logger.error("无法解析上传响应") + logger.error(f"无法从响应中提取JSON内容: {resp_text}") return None else: error_text = await response.text() diff --git a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py index 3aabc88b6..0fa7edb99 100644 --- a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py +++ b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py @@ -36,8 +36,7 @@ class ReplyTrackerService: self._load_data() logger.debug(f"ReplyTrackerService initialized with data file: {self.reply_record_file}") - @staticmethod - def _validate_data(data: Any) -> bool: + def _validate_data(self, data: Any) -> bool: """验证加载的数据格式是否正确""" if not isinstance(data, dict): logger.error("加载的数据不是字典格式") diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 6124f4f06..ed32da48d 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -129,8 +129,7 @@ class SchedulerService: logger.error(f"定时任务循环中发生未知错误: {e}\n{traceback.format_exc()}") await asyncio.sleep(300) # 发生错误后,等待一段时间再重试 - @staticmethod - async def _is_processed(hour_str: str, activity: str) -> bool: + async def _is_processed(self, hour_str: str, activity: str) -> bool: """ 检查指定的任务(某个小时的某个活动)是否已经被成功处理过。 @@ -153,8 +152,7 @@ class SchedulerService: logger.error(f"检查日程处理状态时发生数据库错误: {e}") return False # 数据库异常时,默认为未处理,允许重试 - @staticmethod - async def _mark_as_processed(hour_str: str, activity: str, success: bool, content: str): + async def _mark_as_processed(self, hour_str: str, activity: str, success: bool, content: str): """ 将任务的处理状态和结果写入数据库。 @@ -187,7 +185,7 @@ class SchedulerService: send_success=success, ) session.add(new_record) - await session.commit() + session.commit() logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}") except Exception as e: logger.error(f"更新日程处理状态时发生数据库错误: {e}") diff --git a/src/plugins/built_in/maizone_refactored/utils/history_utils.py b/src/plugins/built_in/maizone_refactored/utils/history_utils.py index 171396de2..19b3e7baa 100644 --- a/src/plugins/built_in/maizone_refactored/utils/history_utils.py +++ b/src/plugins/built_in/maizone_refactored/utils/history_utils.py @@ -49,8 +49,7 @@ class _SimpleQZoneAPI: if p_skey: self.gtk2 = self._generate_gtk(p_skey) - @staticmethod - def _generate_gtk(skey: str) -> str: + def _generate_gtk(self, skey: str) -> str: hash_val = 5381 for char in skey: hash_val += (hash_val << 5) + ord(char) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index c50f17e7b..a19ca85e5 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -26,7 +26,7 @@ import json import websockets as Server import base64 from pathlib import Path -from typing import List, Tuple, Optional, Dict, Any, Coroutine +from typing import List, Tuple, Optional, Dict, Any import uuid from maim_message import ( diff --git a/template/model_config_template.toml b/template/model_config_template.toml index ea200accb..8c9763c2f 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -53,14 +53,14 @@ price_out = 8.0 # 输出价格(用于API调用统计,单 #use_anti_truncation = true # [可选] 启用反截断功能。当模型输出不完整时,系统会自动重试。建议只为有需要的模型(如Gemini)开启。 [[models]] -model_identifier = "Pro/deepseek-ai/DeepSeek-V3" +model_identifier = "deepseek-ai/DeepSeek-V3" name = "siliconflow-deepseek-v3" api_provider = "SiliconFlow" price_in = 2.0 price_out = 8.0 [[models]] -model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +model_identifier = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" name = "deepseek-r1-distill-qwen-32b" api_provider = "SiliconFlow" price_in = 4.0