From c9b712d8fa4010c8637424d04690541484cc855b Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 6 Sep 2025 00:10:54 +0800 Subject: [PATCH 1/4] =?UTF-8?q?refactor(prompt):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E5=B9=B6=E7=BB=9F=E4=B8=80=E6=8F=90=E7=A4=BA=E8=AF=8D=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 删除旧的智能提示词系统(smart_prompt.py)和相关参数模块(prompt_parameters.py) - 将 prompt_builder.py 重命名为 prompt.py 并精简功能 - 更新所有模块的导入路径从 `src.chat.utils.prompt_builder` 到 `src.chat.utils.prompt` - 统一提示词构建接口,使用新的 PromptContext 替代 SmartPromptParameters - 移除重复和冗余代码,简化系统架构 BREAKING CHANGE: 旧的 SmartPrompt 系统已被完全移除,所有相关模块需要改用新的统一 Prompt 系统 --- scripts/update_prompt_imports.py | 61 ++ src/chat/chat_loop/response_handler.py | 2 +- src/chat/express/expression_learner.py | 2 +- src/chat/express/expression_selector.py | 2 +- src/chat/memory_system/memory_activator.py | 2 +- src/chat/message_receive/bot.py | 2 +- src/chat/planner_actions/planner.py | 2 +- src/chat/replyer/default_generator.py | 31 +- src/chat/utils/prompt.py | 693 +++++++++++++ src/chat/utils/prompt_builder.py | 299 ------ src/chat/utils/prompt_parameters.py | 156 --- src/chat/utils/prompt_utils.py | 2 +- src/chat/utils/smart_prompt.py | 938 ------------------ src/mais4u/mai_think.py | 2 +- .../body_emotion_action_manager.py | 2 +- src/mais4u/mais4u_chat/s4u_mood_manager.py | 2 +- src/mais4u/mais4u_chat/s4u_prompt.py | 2 +- src/mood/mood_manager.py | 2 +- src/person_info/relationship_fetcher.py | 2 +- src/plugin_system/core/tool_use.py | 2 +- 20 files changed, 782 insertions(+), 1424 deletions(-) create mode 100644 scripts/update_prompt_imports.py create mode 100644 src/chat/utils/prompt.py delete mode 100644 src/chat/utils/prompt_builder.py delete mode 100644 src/chat/utils/prompt_parameters.py delete mode 100644 src/chat/utils/smart_prompt.py diff --git a/scripts/update_prompt_imports.py b/scripts/update_prompt_imports.py new file mode 100644 index 000000000..672659086 --- /dev/null +++ b/scripts/update_prompt_imports.py @@ -0,0 +1,61 @@ +""" +更新Prompt类导入脚本 +将旧的prompt_builder.Prompt导入更新为unified_prompt.Prompt +""" + +import os +import re +from pathlib import Path + +# 需要更新的文件列表 +files_to_update = [ + "src/person_info/relationship_fetcher.py", + "src/mood/mood_manager.py", + "src/mais4u/mais4u_chat/body_emotion_action_manager.py", + "src/chat/express/expression_learner.py", + "src/chat/planner_actions/planner.py", + "src/mais4u/mais4u_chat/s4u_prompt.py", + "src/chat/message_receive/bot.py", + "src/chat/replyer/default_generator.py", + "src/chat/express/expression_selector.py", + "src/mais4u/mai_think.py", + "src/mais4u/mais4u_chat/s4u_mood_manager.py", + "src/plugin_system/core/tool_use.py", + "src/chat/memory_system/memory_activator.py", + "src/chat/utils/smart_prompt.py" +] + +def update_prompt_imports(file_path): + """更新文件中的Prompt导入""" + if not os.path.exists(file_path): + print(f"文件不存在: {file_path}") + return False + + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # 替换导入语句 + old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager" + new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager" + + if old_import in content: + new_content = content.replace(old_import, new_import) + with open(file_path, 'w', encoding='utf-8') as f: + f.write(new_content) + print(f"已更新: {file_path}") + return True + else: + print(f"无需更新: {file_path}") + return False + +def main(): + """主函数""" + updated_count = 0 + for file_path in files_to_update: + if update_prompt_imports(file_path): + updated_count += 1 + + print(f"\n更新完成!共更新了 {updated_count} 个文件") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/chat/chat_loop/response_handler.py b/src/chat/chat_loop/response_handler.py index ecfc6addb..55dca45a3 100644 --- a/src/chat/chat_loop/response_handler.py +++ b/src/chat/chat_loop/response_handler.py @@ -12,7 +12,7 @@ from .hfc_context import HfcContext # 导入反注入系统 from src.chat.antipromptinjector import get_anti_injector from src.chat.antipromptinjector.types import ProcessResult -from src.chat.utils.prompt_builder import Prompt +from src.chat.utils.prompt import Prompt logger = get_logger("hfc") anti_injector_logger = get_logger("anti_injector") diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 7f127f0a5..1b9fcf267 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -13,7 +13,7 @@ from src.common.database.sqlalchemy_models import Expression from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index f0991c7c7..2883ec82d 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -11,7 +11,7 @@ from src.config.config import global_config, model_config from src.common.logger import get_logger from sqlalchemy import select from src.common.database.sqlalchemy_models import Expression -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.database.sqlalchemy_database_api import get_db_session logger = get_logger("expression_selector") diff --git a/src/chat/memory_system/memory_activator.py b/src/chat/memory_system/memory_activator.py index 4067363f0..33d22a5dd 100644 --- a/src/chat/memory_system/memory_activator.py +++ b/src/chat/memory_system/memory_activator.py @@ -8,7 +8,7 @@ from datetime import datetime from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.memory_system.Hippocampus import hippocampus_manager diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index e71616892..260a42170 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -12,7 +12,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream from src.chat.message_receive.message import MessageRecv, MessageRecvS4U from src.chat.message_receive.storage import MessageStorage from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.plugin_system.core import component_registry, event_manager, global_announcement_manager from src.plugin_system.base import BaseCommand, EventType from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 291c19a66..c08d52029 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -9,7 +9,7 @@ from json_repair import repair_json from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( build_readable_actions, get_actions_by_timestamp_with_chat, diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index be58c5426..973565d37 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -1,6 +1,6 @@ """ -默认回复生成器 - 集成SmartPrompt系统 -使用重构后的SmartPrompt系统替换原有的复杂提示词构建逻辑 +默认回复生成器 - 集成统一Prompt系统 +使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑 """ import traceback @@ -23,7 +23,7 @@ from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_chat_type_and_target_info -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( build_readable_messages, get_raw_msg_before_timestamp_with_chat, @@ -39,8 +39,8 @@ from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.apis import llm_api from src.schedule.schedule_manager import schedule_manager -# 导入新的智能Prompt系统 -from src.chat.utils.smart_prompt import SmartPrompt, SmartPromptParameters +# 导入新的统一Prompt系统 +from src.chat.utils.prompt import Prompt, PromptContext logger = get_logger("replyer") @@ -971,8 +971,8 @@ class DefaultReplyer: # 根据配置选择模板 current_prompt_mode = global_config.personality.prompt_mode - # 使用重构后的SmartPrompt系统 - prompt_params = SmartPromptParameters( + # 使用新的统一Prompt系统 + prompt_context = PromptContext( chat_id=chat_id, is_group_chat=is_group_chat, sender=sender, @@ -1005,12 +1005,9 @@ class DefaultReplyer: action_descriptions=action_descriptions, ) - # 使用重构后的SmartPrompt系统 - smart_prompt = SmartPrompt( - template_name=None, # 由current_prompt_mode自动选择 - parameters=prompt_params, - ) - prompt_text = await smart_prompt.build_prompt() + # 使用新的统一Prompt系统 + prompt = Prompt(template_name=None, context=prompt_context) # 由current_prompt_mode自动选择 + prompt_text = await prompt.build_prompt() return prompt_text @@ -1111,8 +1108,8 @@ class DefaultReplyer: template_name = "default_expressor_prompt" - # 使用重构后的SmartPrompt系统 - Expressor模式 - prompt_params = SmartPromptParameters( + # 使用新的统一Prompt系统 - Expressor模式 + prompt_context = PromptContext( chat_id=chat_id, is_group_chat=is_group_chat, sender=sender, @@ -1132,8 +1129,8 @@ class DefaultReplyer: relation_info_block=relation_info, ) - smart_prompt = SmartPrompt(parameters=prompt_params) - prompt_text = await smart_prompt.build_prompt() + prompt = Prompt(template_name=template_name, context=prompt_context) + prompt_text = await prompt.build_prompt() return prompt_text diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py new file mode 100644 index 000000000..1e44b72d8 --- /dev/null +++ b/src/chat/utils/prompt.py @@ -0,0 +1,693 @@ +""" +统一提示词系统 - 合并模板管理和智能构建功能 +将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类 +""" + +import re +import asyncio +import time +import contextvars +from dataclasses import dataclass, field +from typing import Dict, Any, Optional, List, Union, Literal, Tuple +from contextlib import asynccontextmanager + +from rich.traceback import install +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.utils.chat_message_builder import build_readable_messages +from src.chat.utils.prompt_utils import PromptUtils +from src.person_info.person_info import get_person_info_manager + +install(extra_lines=3) +logger = get_logger("unified_prompt") + + +@dataclass +class PromptParameters: + """统一提示词参数系统""" + + # 基础参数 + chat_id: str = "" + is_group_chat: bool = False + sender: str = "" + target: str = "" + reply_to: str = "" + extra_info: str = "" + prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" + + # 功能开关 + enable_tool: bool = True + enable_memory: bool = True + enable_expression: bool = True + enable_relation: bool = True + enable_cross_context: bool = True + enable_knowledge: bool = True + + # 性能控制 + max_context_messages: int = 50 + + # 调试选项 + debug_mode: bool = False + + # 聊天历史和上下文 + chat_target_info: Optional[Dict[str, Any]] = None + message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list) + message_list_before_short: List[Dict[str, Any]] = field(default_factory=list) + chat_talking_prompt_short: str = "" + target_user_info: Optional[Dict[str, Any]] = None + + # 已构建的内容块 + expression_habits_block: str = "" + relation_info_block: str = "" + memory_block: str = "" + tool_info_block: str = "" + knowledge_prompt: str = "" + cross_context_block: str = "" + + # 其他内容块 + keywords_reaction_prompt: str = "" + extra_info_block: str = "" + time_block: str = "" + identity_block: str = "" + schedule_block: str = "" + moderation_prompt_block: str = "" + reply_target_block: str = "" + mood_prompt: str = "" + action_descriptions: str = "" + + # 可用动作信息 + available_actions: Optional[Dict[str, Any]] = None + + def validate(self) -> List[str]: + """参数验证""" + errors = [] + if not self.chat_id: + errors.append("chat_id不能为空") + if self.prompt_mode not in ["s4u", "normal", "minimal"]: + errors.append("prompt_mode必须是's4u'、'normal'或'minimal'") + if self.max_context_messages <= 0: + errors.append("max_context_messages必须大于0") + return errors + + +class PromptContext: + """提示词上下文管理器""" + + def __init__(self): + self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} + self._current_context_var = contextvars.ContextVar("current_context", default=None) + self._context_lock = asyncio.Lock() + + @property + def _current_context(self) -> Optional[str]: + """获取当前协程的上下文ID""" + return self._current_context_var.get() + + @_current_context.setter + def _current_context(self, value: Optional[str]): + """设置当前协程的上下文ID""" + self._current_context_var.set(value) # type: ignore + + @asynccontextmanager + async def async_scope(self, context_id: Optional[str] = None): + """创建一个异步的临时提示模板作用域""" + if context_id is not None: + try: + await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0) + try: + if context_id not in self._context_prompts: + self._context_prompts[context_id] = {} + finally: + self._context_lock.release() + except asyncio.TimeoutError: + logger.warning(f"获取上下文锁超时,context_id: {context_id}") + context_id = None + + previous_context = self._current_context + token = self._current_context_var.set(context_id) if context_id else None + else: + previous_context = self._current_context + token = None + + try: + yield self + finally: + if context_id is not None and token is not None: + try: + self._current_context_var.reset(token) + except Exception as e: + logger.warning(f"恢复上下文时出错: {e}") + try: + self._current_context = previous_context + except Exception: + ... + + async def get_prompt_async(self, name: str) -> Optional["Prompt"]: + """异步获取当前作用域中的提示模板""" + async with self._context_lock: + current_context = self._current_context + logger.debug(f"获取提示词: {name} 当前上下文: {current_context}") + if ( + current_context + and current_context in self._context_prompts + and name in self._context_prompts[current_context] + ): + return self._context_prompts[current_context][name] + return None + + async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: + """异步注册提示模板到指定作用域""" + async with self._context_lock: + if target_context := context_id or self._current_context: + if prompt.name: + self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt + + +class PromptManager: + """统一提示词管理器""" + + def __init__(self): + self._prompts = {} + self._counter = 0 + self._context = PromptContext() + self._lock = asyncio.Lock() + + @asynccontextmanager + async def async_message_scope(self, message_id: Optional[str] = None): + """为消息处理创建异步临时作用域""" + async with self._context.async_scope(message_id): + yield self + + async def get_prompt_async(self, name: str) -> "Prompt": + """异步获取提示模板""" + context_prompt = await self._context.get_prompt_async(name) + if context_prompt is not None: + logger.debug(f"从上下文中获取提示词: {name} {context_prompt}") + return context_prompt + + async with self._lock: + if name not in self._prompts: + raise KeyError(f"Prompt '{name}' not found") + return self._prompts[name] + + def generate_name(self, template: str) -> str: + """为未命名的prompt生成名称""" + self._counter += 1 + return f"prompt_{self._counter}" + + def register(self, prompt: "Prompt") -> None: + """注册一个prompt""" + if not prompt.name: + prompt.name = self.generate_name(prompt.template) + self._prompts[prompt.name] = prompt + + def add_prompt(self, name: str, fstr: str) -> "Prompt": + """添加新提示模板""" + prompt = Prompt(fstr, name=name) + if prompt.name: + self._prompts[prompt.name] = prompt + return prompt + + async def format_prompt(self, name: str, **kwargs) -> str: + """格式化提示模板""" + prompt = await self.get_prompt_async(name) + result = prompt.format(**kwargs) + return result + + +# 全局单例 +global_prompt_manager = PromptManager() + + +class Prompt: + """ + 统一提示词类 - 合并模板管理和智能构建功能 + 真正的Prompt类,支持模板管理和智能上下文构建 + """ + + # 临时标记,作为类常量 + _TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__" + _TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__" + + def __init__( + self, + template: str, + name: Optional[str] = None, + parameters: Optional[PromptParameters] = None, + should_register: bool = True + ): + """ + 初始化统一提示词 + + Args: + template: 提示词模板字符串 + name: 提示词名称 + parameters: 构建参数 + should_register: 是否自动注册到全局管理器 + """ + self.template = template + self.name = name + self.parameters = parameters or PromptParameters() + self.args = self._parse_template_args(template) + self._formatted_result = "" + + # 预处理模板中的转义花括号 + self._processed_template = self._process_escaped_braces(template) + + # 自动注册 + if should_register and not global_prompt_manager._context._current_context: + global_prompt_manager.register(self) + + @staticmethod + def _process_escaped_braces(template) -> str: + """处理模板中的转义花括号""" + if isinstance(template, list): + template = "\n".join(str(item) for item in template) + elif not isinstance(template, str): + template = str(template) + + return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE) + + @staticmethod + def _restore_escaped_braces(template: str) -> str: + """将临时标记还原为实际的花括号字符""" + return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") + + def _parse_template_args(self, template: str) -> List[str]: + """解析模板参数""" + template_args = [] + processed_template = self._process_escaped_braces(template) + result = re.findall(r"\{(.*?)}", processed_template) + for expr in result: + if expr and expr not in template_args: + template_args.append(expr) + return template_args + + async def build(self) -> str: + """ + 构建完整的提示词,包含智能上下文 + + Returns: + str: 构建完成的提示词文本 + """ + # 参数验证 + errors = self.parameters.validate() + if errors: + logger.error(f"参数验证失败: {', '.join(errors)}") + raise ValueError(f"参数验证失败: {', '.join(errors)}") + + start_time = time.time() + try: + # 构建上下文数据 + context_data = await self._build_context_data() + + # 格式化模板 + result = await self._format_with_context(context_data) + + total_time = time.time() - start_time + logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s") + + self._formatted_result = result + return result + + except asyncio.TimeoutError as e: + logger.error(f"构建Prompt超时: {e}") + raise TimeoutError(f"构建Prompt超时: {e}") + except Exception as e: + logger.error(f"构建Prompt失败: {e}") + raise RuntimeError(f"构建Prompt失败: {e}") + + async def _build_context_data(self) -> Dict[str, Any]: + """构建智能上下文数据""" + # 并行执行所有构建任务 + start_time = time.time() + timing_logs = {} + + try: + # 准备构建任务 + tasks = [] + task_names = [] + + # 初始化预构建参数 + pre_built_params = {} + if self.parameters.expression_habits_block: + pre_built_params["expression_habits_block"] = self.parameters.expression_habits_block + if self.parameters.relation_info_block: + pre_built_params["relation_info_block"] = self.parameters.relation_info_block + if self.parameters.memory_block: + pre_built_params["memory_block"] = self.parameters.memory_block + if self.parameters.tool_info_block: + pre_built_params["tool_info_block"] = self.parameters.tool_info_block + if self.parameters.knowledge_prompt: + pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt + if self.parameters.cross_context_block: + pre_built_params["cross_context_block"] = self.parameters.cross_context_block + + # 根据参数确定要构建的项 + if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"): + tasks.append(self._build_expression_habits()) + task_names.append("expression_habits") + + if self.parameters.enable_memory and not pre_built_params.get("memory_block"): + tasks.append(self._build_memory_block()) + task_names.append("memory_block") + + if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"): + tasks.append(self._build_relation_info()) + task_names.append("relation_info") + + if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"): + tasks.append(self._build_tool_info()) + task_names.append("tool_info") + + if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"): + tasks.append(self._build_knowledge_info()) + task_names.append("knowledge_info") + + if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"): + tasks.append(self._build_cross_context()) + task_names.append("cross_context") + + # 性能优化 + base_timeout = 10.0 + task_timeout = 2.0 + timeout_seconds = min( + max(base_timeout, len(tasks) * task_timeout), + 30.0, + ) + + max_concurrent_tasks = 5 + if len(tasks) > max_concurrent_tasks: + results = [] + for i in range(0, len(tasks), max_concurrent_tasks): + batch_tasks = tasks[i : i + max_concurrent_tasks] + batch_names = task_names[i : i + max_concurrent_tasks] + + batch_results = await asyncio.wait_for( + asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds + ) + results.extend(batch_results) + else: + results = await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds + ) + + # 处理结果 + context_data = {} + for i, result in enumerate(results): + task_name = task_names[i] if i < len(task_names) else f"task_{i}" + + if isinstance(result, Exception): + logger.error(f"构建任务{task_name}失败: {str(result)}") + elif isinstance(result, dict): + context_data.update(result) + + # 添加预构建的参数 + for key, value in pre_built_params.items(): + if value: + context_data[key] = value + + except asyncio.TimeoutError: + logger.error(f"构建超时 ({timeout_seconds}s)") + context_data = {} + for key, value in pre_built_params.items(): + if value: + context_data[key] = value + + # 构建聊天历史 + if self.parameters.prompt_mode == "s4u": + await self._build_s4u_chat_context(context_data) + else: + await self._build_normal_chat_context(context_data) + + # 补充基础信息 + context_data.update({ + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt, + "extra_info_block": self.parameters.extra_info_block, + "time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}", + "identity": self.parameters.identity_block, + "schedule_block": self.parameters.schedule_block, + "moderation_prompt": self.parameters.moderation_prompt_block, + "reply_target_block": self.parameters.reply_target_block, + "mood_state": self.parameters.mood_prompt, + "action_descriptions": self.parameters.action_descriptions, + }) + + total_time = time.time() - start_time + logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s") + + return context_data + + async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None: + """构建S4U模式的聊天上下文""" + if not self.parameters.message_list_before_now_long: + return + + core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts( + self.parameters.message_list_before_now_long, + self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", + self.parameters.sender + ) + + context_data["core_dialogue_prompt"] = core_dialogue + context_data["background_dialogue_prompt"] = background_dialogue + + async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None: + """构建normal模式的聊天上下文""" + if not self.parameters.chat_talking_prompt_short: + return + + context_data["chat_info"] = f"""群里的聊天内容: +{self.parameters.chat_talking_prompt_short}""" + + async def _build_s4u_chat_history_prompts( + self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str + ) -> Tuple[str, str]: + """构建S4U风格的分离对话prompt""" + # 实现逻辑与原有SmartPromptBuilder相同 + core_dialogue_list = [] + bot_id = str(global_config.bot.qq_account) + + for msg_dict in message_list_before_now: + try: + msg_user_id = str(msg_dict.get("user_id")) + reply_to = msg_dict.get("reply_to", "") + platform, reply_to_user_id = PromptUtils.parse_reply_target(reply_to) + if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: + core_dialogue_list.append(msg_dict) + except Exception as e: + logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}") + + # 构建背景对话 prompt + all_dialogue_prompt = "" + if message_list_before_now: + latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :] + all_dialogue_prompt_str = build_readable_messages( + latest_25_msgs, + replace_bot_name=True, + timestamp_mode="normal", + truncate=True, + ) + all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}" + + # 构建核心对话 prompt + core_dialogue_prompt = "" + if core_dialogue_list: + latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list + has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages) + + if not has_bot_message: + core_dialogue_prompt = "" + else: + core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] + + core_dialogue_prompt_str = build_readable_messages( + core_dialogue_list, + replace_bot_name=True, + merge_messages=False, + timestamp_mode="normal_no_YMD", + read_mark=0.0, + truncate=True, + show_actions=True, + ) + core_dialogue_prompt = f"""-------------------------------- +这是你和{sender}的对话,你们正在交流中: +{core_dialogue_prompt_str} +-------------------------------- +""" + + return core_dialogue_prompt, all_dialogue_prompt + + async def _build_expression_habits(self) -> Dict[str, Any]: + """构建表达习惯""" + # 简化的实现,完整实现需要导入相关模块 + return {"expression_habits_block": ""} + + async def _build_memory_block(self) -> Dict[str, Any]: + """构建记忆块""" + # 简化的实现 + return {"memory_block": ""} + + async def _build_relation_info(self) -> Dict[str, Any]: + """构建关系信息""" + try: + relation_info = await PromptUtils.build_relation_info(self.parameters.chat_id, self.parameters.reply_to) + return {"relation_info_block": relation_info} + except Exception as e: + logger.error(f"构建关系信息失败: {e}") + return {"relation_info_block": ""} + + async def _build_tool_info(self) -> Dict[str, Any]: + """构建工具信息""" + # 简化的实现 + return {"tool_info_block": ""} + + async def _build_knowledge_info(self) -> Dict[str, Any]: + """构建知识信息""" + # 简化的实现 + return {"knowledge_prompt": ""} + + async def _build_cross_context(self) -> Dict[str, Any]: + """构建跨群上下文""" + try: + cross_context = await PromptUtils.build_cross_context( + self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info + ) + return {"cross_context_block": cross_context} + except Exception as e: + logger.error(f"构建跨群上下文失败: {e}") + return {"cross_context_block": ""} + + async def _format_with_context(self, context_data: Dict[str, Any]) -> str: + """使用上下文数据格式化模板""" + if self.parameters.prompt_mode == "s4u": + params = self._prepare_s4u_params(context_data) + elif self.parameters.prompt_mode == "normal": + params = self._prepare_normal_params(context_data) + else: + params = self._prepare_default_params(context_data) + + return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params) + + def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + """准备S4U模式的参数""" + return { + **context_data, + "expression_habits_block": context_data.get("expression_habits_block", ""), + "tool_info_block": context_data.get("tool_info_block", ""), + "knowledge_prompt": context_data.get("knowledge_prompt", ""), + "memory_block": context_data.get("memory_block", ""), + "relation_info_block": context_data.get("relation_info_block", ""), + "extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""), + "cross_context_block": context_data.get("cross_context_block", ""), + "identity": self.parameters.identity_block or context_data.get("identity", ""), + "action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""), + "sender_name": self.parameters.sender, + "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), + "background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""), + "time_block": context_data.get("time_block", ""), + "core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""), + "reply_target_block": context_data.get("reply_target_block", ""), + "reply_style": global_config.personality.reply_style, + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), + } + + def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + """准备Normal模式的参数""" + return { + **context_data, + "expression_habits_block": context_data.get("expression_habits_block", ""), + "tool_info_block": context_data.get("tool_info_block", ""), + "knowledge_prompt": context_data.get("knowledge_prompt", ""), + "memory_block": context_data.get("memory_block", ""), + "relation_info_block": context_data.get("relation_info_block", ""), + "extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""), + "cross_context_block": context_data.get("cross_context_block", ""), + "identity": self.parameters.identity_block or context_data.get("identity", ""), + "action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""), + "schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""), + "time_block": context_data.get("time_block", ""), + "chat_info": context_data.get("chat_info", ""), + "reply_target_block": context_data.get("reply_target_block", ""), + "config_expression_style": global_config.personality.reply_style, + "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), + } + + def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: + """准备默认模式的参数""" + return { + "expression_habits_block": context_data.get("expression_habits_block", ""), + "relation_info_block": context_data.get("relation_info_block", ""), + "chat_target": "", + "time_block": context_data.get("time_block", ""), + "chat_info": context_data.get("chat_info", ""), + "identity": self.parameters.identity_block or context_data.get("identity", ""), + "chat_target_2": "", + "reply_target_block": context_data.get("reply_target_block", ""), + "raw_reply": self.parameters.target, + "reason": "", + "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), + "reply_style": global_config.personality.reply_style, + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), + } + + def format(self, *args, **kwargs) -> str: + """格式化模板,支持位置参数和关键字参数""" + try: + # 先用位置参数格式化 + if args: + formatted_args = {} + for i in range(len(args)): + if i < len(self.args): + formatted_args[self.args[i]] = args[i] + processed_template = self._processed_template.format(**formatted_args) + else: + processed_template = self._processed_template + + # 再用关键字参数格式化 + if kwargs: + processed_template = processed_template.format(**kwargs) + + # 将临时标记还原为实际的花括号 + result = self._restore_escaped_braces(processed_template) + return result + except (IndexError, KeyError) as e: + raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e + + def __str__(self) -> str: + """返回格式化后的结果或原始模板""" + return self._formatted_result if self._formatted_result else self.template + + def __repr__(self) -> str: + """返回提示词的表示形式""" + return f"Prompt(template='{self.template}', name='{self.name}')" + + +# 工厂函数 +def create_prompt( + template: str, + name: Optional[str] = None, + parameters: Optional[PromptParameters] = None, + **kwargs +) -> Prompt: + """快速创建Prompt实例的工厂函数""" + if parameters is None: + parameters = PromptParameters(**kwargs) + return Prompt(template, name, parameters) + + +async def create_prompt_async( + template: str, + name: Optional[str] = None, + parameters: Optional[PromptParameters] = None, + **kwargs +) -> Prompt: + """异步创建Prompt实例""" + prompt = create_prompt(template, name, parameters, **kwargs) + if global_prompt_manager._context._current_context: + await global_prompt_manager._context.register_async(prompt) + return prompt \ No newline at end of file diff --git a/src/chat/utils/prompt_builder.py b/src/chat/utils/prompt_builder.py deleted file mode 100644 index 3585b5959..000000000 --- a/src/chat/utils/prompt_builder.py +++ /dev/null @@ -1,299 +0,0 @@ -import re -import asyncio -import contextvars - -from rich.traceback import install -from contextlib import asynccontextmanager -from typing import Dict, Any, Optional, List, Union - -from src.common.logger import get_logger - -install(extra_lines=3) - -logger = get_logger("prompt_build") - - -class PromptContext: - def __init__(self): - self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} - # 使用contextvars创建协程上下文变量 - self._current_context_var = contextvars.ContextVar("current_context", default=None) - self._context_lock = asyncio.Lock() # 保留锁用于其他操作 - - @property - def _current_context(self) -> Optional[str]: - """获取当前协程的上下文ID""" - return self._current_context_var.get() - - @_current_context.setter - def _current_context(self, value: Optional[str]): - """设置当前协程的上下文ID""" - self._current_context_var.set(value) # type: ignore - - @asynccontextmanager - async def async_scope(self, context_id: Optional[str] = None): - # sourcery skip: hoist-statement-from-if, use-contextlib-suppress - """创建一个异步的临时提示模板作用域""" - # 保存当前上下文并设置新上下文 - if context_id is not None: - try: - # 添加超时保护,避免长时间等待锁 - await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0) - try: - if context_id not in self._context_prompts: - self._context_prompts[context_id] = {} - finally: - self._context_lock.release() - except asyncio.TimeoutError: - logger.warning(f"获取上下文锁超时,context_id: {context_id}") - # 超时时直接进入,不设置上下文 - context_id = None - - # 保存当前协程的上下文值,不影响其他协程 - previous_context = self._current_context - # 设置当前协程的新上下文 - token = self._current_context_var.set(context_id) if context_id else None # type: ignore - else: - # 如果没有提供新上下文,保持当前上下文不变 - previous_context = self._current_context - token = None - - try: - yield self - finally: - # 恢复之前的上下文,添加异常保护 - if context_id is not None and token is not None: - try: - self._current_context_var.reset(token) - except Exception as e: - logger.warning(f"恢复上下文时出错: {e}") - # 如果reset失败,尝试直接设置 - try: - self._current_context = previous_context - except Exception: - ... - # 静默忽略恢复失败 - - async def get_prompt_async(self, name: str) -> Optional["Prompt"]: - """异步获取当前作用域中的提示模板""" - async with self._context_lock: - current_context = self._current_context - logger.debug(f"获取提示词: {name} 当前上下文: {current_context}") - if ( - current_context - and current_context in self._context_prompts - and name in self._context_prompts[current_context] - ): - return self._context_prompts[current_context][name] - return None - - async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: - """异步注册提示模板到指定作用域""" - async with self._context_lock: - if target_context := context_id or self._current_context: - if prompt.name: - self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt - - -class PromptManager: - def __init__(self): - self._prompts = {} - self._counter = 0 - self._context = PromptContext() - self._lock = asyncio.Lock() - - @asynccontextmanager - async def async_message_scope(self, message_id: Optional[str] = None): - """为消息处理创建异步临时作用域,支持 message_id 为 None 的情况""" - async with self._context.async_scope(message_id): - yield self - - async def get_prompt_async(self, name: str) -> "Prompt": - # 首先尝试从当前上下文获取 - context_prompt = await self._context.get_prompt_async(name) - if context_prompt is not None: - logger.debug(f"从上下文中获取提示词: {name} {context_prompt}") - return context_prompt - # 如果上下文中不存在,则使用全局提示模板 - async with self._lock: - # logger.debug(f"从全局获取提示词: {name}") - if name not in self._prompts: - raise KeyError(f"Prompt '{name}' not found") - return self._prompts[name] - - def generate_name(self, template: str) -> str: - """为未命名的prompt生成名称""" - self._counter += 1 - return f"prompt_{self._counter}" - - def register(self, prompt: "Prompt") -> None: - """注册一个prompt""" - if not prompt.name: - prompt.name = self.generate_name(prompt.template) - self._prompts[prompt.name] = prompt - - def add_prompt(self, name: str, fstr: str) -> "Prompt": - prompt = Prompt(fstr, name=name) - if prompt.name: - self._prompts[prompt.name] = prompt - return prompt - - async def format_prompt(self, name: str, **kwargs) -> str: - # 获取当前提示词 - prompt = await self.get_prompt_async(name) - # 获取基本格式化结果 - result = prompt.format(**kwargs) - return result - - -# 全局单例 -global_prompt_manager = PromptManager() - - -class Prompt(str): - template: str - name: Optional[str] - args: List[str] - _args: List[Any] - _kwargs: Dict[str, Any] - # 临时标记,作为类常量 - _TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__" - _TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__" - - @staticmethod - def _process_escaped_braces(template) -> str: - """处理模板中的转义花括号,将 \\{ 和 \\} 替换为临时标记""" # type: ignore - # 如果传入的是列表,将其转换为字符串 - if isinstance(template, list): - template = "\n".join(str(item) for item in template) - elif not isinstance(template, str): - template = str(template) - - return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE) - - @staticmethod - def _restore_escaped_braces(template: str) -> str: - """将临时标记还原为实际的花括号字符""" - return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") - - def __new__( - cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs - ): - # 如果传入的是元组,转换为列表 - if isinstance(args, tuple): - args = list(args) - should_register = kwargs.pop("_should_register", True) - - # 预处理模板中的转义花括号 - processed_fstr = cls._process_escaped_braces(fstr) - - # 解析模板 - template_args = [] - result = re.findall(r"\{(.*?)}", processed_fstr) - for expr in result: - if expr and expr not in template_args: - template_args.append(expr) - - # 如果提供了初始参数,立即格式化 - if kwargs or args: - formatted = cls._format_template(fstr, args=args, kwargs=kwargs) - obj = super().__new__(cls, formatted) - else: - obj = super().__new__(cls, "") - - obj.template = fstr - obj.name = name - obj.args = template_args - obj._args = args or [] - obj._kwargs = kwargs - - # 修改自动注册逻辑 - if should_register and not global_prompt_manager._context._current_context: - global_prompt_manager.register(obj) - return obj - - @classmethod - async def create_async( - cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs - ): - """异步创建Prompt实例""" - prompt = cls(fstr, name, args, **kwargs) - if global_prompt_manager._context._current_context: - await global_prompt_manager._context.register_async(prompt) - return prompt - - @classmethod - def _format_template( - cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None - ) -> str: - if kwargs is None: - kwargs = {} - # 预处理模板中的转义花括号 - processed_template = cls._process_escaped_braces(template) - - template_args = [] - result = re.findall(r"\{(.*?)}", processed_template) - for expr in result: - if expr and expr not in template_args: - template_args.append(expr) - formatted_args = {} - formatted_kwargs = {} - - # 处理位置参数 - if args: - # print(len(template_args), len(args), template_args, args) - for i in range(len(args)): - if i < len(template_args): - arg = args[i] - if isinstance(arg, Prompt): - formatted_args[template_args[i]] = arg.format(**kwargs) - else: - formatted_args[template_args[i]] = arg - else: - logger.error( - f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}" - ) - raise ValueError("格式化模板失败") - - # 处理关键字参数 - if kwargs: - for key, value in kwargs.items(): - if isinstance(value, Prompt): - remaining_kwargs = {k: v for k, v in kwargs.items() if k != key} - formatted_kwargs[key] = value.format(**remaining_kwargs) - else: - formatted_kwargs[key] = value - - try: - # 先用位置参数格式化 - if args: - processed_template = processed_template.format(**formatted_args) - # 再用关键字参数格式化 - if kwargs: - processed_template = processed_template.format(**formatted_kwargs) - - # 将临时标记还原为实际的花括号 - result = cls._restore_escaped_braces(processed_template) - return result - except (IndexError, KeyError) as e: - raise ValueError( - f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}" - ) from e - - def format(self, *args, **kwargs) -> "str": - """支持位置参数和关键字参数的格式化,使用""" - ret = type(self)( - self.template, - self.name, - args=list(args) if args else self._args, - _should_register=False, - **kwargs or self._kwargs, - ) - # print(f"prompt build result: {ret} name: {ret.name} ") - return str(ret) - - def __str__(self) -> str: - return super().__str__() if self._kwargs or self._args else self.template - - def __repr__(self) -> str: - return f"Prompt(template='{self.template}', name='{self.name}')" diff --git a/src/chat/utils/prompt_parameters.py b/src/chat/utils/prompt_parameters.py deleted file mode 100644 index 2558917d4..000000000 --- a/src/chat/utils/prompt_parameters.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -智能提示词参数模块 - 优化参数结构 -简化SmartPromptParameters,减少冗余和重复 -""" - -from dataclasses import dataclass, field -from typing import Dict, Any, Optional, List, Literal - - -@dataclass -class SmartPromptParameters: - """简化的智能提示词参数系统""" - - # 基础参数 - chat_id: str = "" - is_group_chat: bool = False - sender: str = "" - target: str = "" - reply_to: str = "" - extra_info: str = "" - prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" - - # 功能开关 - enable_tool: bool = True - enable_memory: bool = True - enable_expression: bool = True - enable_relation: bool = True - enable_cross_context: bool = True - enable_knowledge: bool = True - - # 性能控制 - max_context_messages: int = 50 - - # 调试选项 - debug_mode: bool = False - - # 聊天历史和上下文 - chat_target_info: Optional[Dict[str, Any]] = None - message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list) - message_list_before_short: List[Dict[str, Any]] = field(default_factory=list) - chat_talking_prompt_short: str = "" - target_user_info: Optional[Dict[str, Any]] = None - - # 已构建的内容块 - expression_habits_block: str = "" - relation_info_block: str = "" - memory_block: str = "" - tool_info_block: str = "" - knowledge_prompt: str = "" - cross_context_block: str = "" - - # 其他内容块 - keywords_reaction_prompt: str = "" - extra_info_block: str = "" - time_block: str = "" - identity_block: str = "" - schedule_block: str = "" - moderation_prompt_block: str = "" - reply_target_block: str = "" - mood_prompt: str = "" - action_descriptions: str = "" - - # 可用动作信息 - available_actions: Optional[Dict[str, Any]] = None - - def validate(self) -> List[str]: - """统一的参数验证""" - errors = [] - if not self.chat_id: - errors.append("chat_id不能为空") - if self.prompt_mode not in ["s4u", "normal", "minimal"]: - errors.append("prompt_mode必须是's4u'、'normal'或'minimal'") - if self.max_context_messages <= 0: - errors.append("max_context_messages必须大于0") - return errors - - def get_needed_build_tasks(self) -> List[str]: - """获取需要执行的任务列表""" - tasks = [] - - if self.enable_expression and not self.expression_habits_block: - tasks.append("expression_habits") - - if self.enable_memory and not self.memory_block: - tasks.append("memory_block") - - if self.enable_relation and not self.relation_info_block: - tasks.append("relation_info") - - if self.enable_tool and not self.tool_info_block: - tasks.append("tool_info") - - if self.enable_knowledge and not self.knowledge_prompt: - tasks.append("knowledge_info") - - if self.enable_cross_context and not self.cross_context_block: - tasks.append("cross_context") - - return tasks - - @classmethod - def from_legacy_params(cls, **kwargs) -> "SmartPromptParameters": - """ - 从旧版参数创建新参数对象 - - Args: - **kwargs: 旧版参数 - - Returns: - SmartPromptParameters: 新参数对象 - """ - return cls( - # 基础参数 - chat_id=kwargs.get("chat_id", ""), - is_group_chat=kwargs.get("is_group_chat", False), - sender=kwargs.get("sender", ""), - target=kwargs.get("target", ""), - reply_to=kwargs.get("reply_to", ""), - extra_info=kwargs.get("extra_info", ""), - prompt_mode=kwargs.get("current_prompt_mode", "s4u"), - # 功能开关 - enable_tool=kwargs.get("enable_tool", True), - enable_memory=kwargs.get("enable_memory", True), - enable_expression=kwargs.get("enable_expression", True), - enable_relation=kwargs.get("enable_relation", True), - enable_cross_context=kwargs.get("enable_cross_context", True), - enable_knowledge=kwargs.get("enable_knowledge", True), - # 性能控制 - max_context_messages=kwargs.get("max_context_messages", 50), - debug_mode=kwargs.get("debug_mode", False), - # 聊天历史和上下文 - chat_target_info=kwargs.get("chat_target_info"), - message_list_before_now_long=kwargs.get("message_list_before_now_long", []), - message_list_before_short=kwargs.get("message_list_before_short", []), - chat_talking_prompt_short=kwargs.get("chat_talking_prompt_short", ""), - target_user_info=kwargs.get("target_user_info"), - # 已构建的内容块 - expression_habits_block=kwargs.get("expression_habits_block", ""), - relation_info_block=kwargs.get("relation_info", ""), - memory_block=kwargs.get("memory_block", ""), - tool_info_block=kwargs.get("tool_info", ""), - knowledge_prompt=kwargs.get("knowledge_prompt", ""), - cross_context_block=kwargs.get("cross_context_block", ""), - # 其他内容块 - keywords_reaction_prompt=kwargs.get("keywords_reaction_prompt", ""), - extra_info_block=kwargs.get("extra_info_block", ""), - time_block=kwargs.get("time_block", ""), - identity_block=kwargs.get("identity_block", ""), - schedule_block=kwargs.get("schedule_block", ""), - moderation_prompt_block=kwargs.get("moderation_prompt_block", ""), - reply_target_block=kwargs.get("reply_target_block", ""), - mood_prompt=kwargs.get("mood_prompt", ""), - action_descriptions=kwargs.get("action_descriptions", ""), - # 可用动作信息 - available_actions=kwargs.get("available_actions", None), - ) diff --git a/src/chat/utils/prompt_utils.py b/src/chat/utils/prompt_utils.py index f9985be53..a5bb931dd 100644 --- a/src/chat/utils/prompt_utils.py +++ b/src/chat/utils/prompt_utils.py @@ -1,6 +1,6 @@ """ 共享提示词工具模块 - 消除重复代码 -提供统一的工具函数供DefaultReplyer和SmartPrompt使用 +提供统一的工具函数供DefaultReplyer和统一Prompt系统使用 """ import re diff --git a/src/chat/utils/smart_prompt.py b/src/chat/utils/smart_prompt.py deleted file mode 100644 index aba79f7ec..000000000 --- a/src/chat/utils/smart_prompt.py +++ /dev/null @@ -1,938 +0,0 @@ -""" -智能Prompt系统 - 完全重构版本 -基于原有DefaultReplyer的完整功能集成,使用新的参数结构 -解决实现质量不高、功能集成不完整和错误处理不足的问题 -""" - -import asyncio -import time -from datetime import datetime -from dataclasses import dataclass, field -from typing import Dict, Any, Optional, List, Tuple - -from src.chat.utils.prompt_builder import global_prompt_manager, Prompt -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.utils.chat_message_builder import ( - build_readable_messages, -) -from src.person_info.person_info import get_person_info_manager -from src.chat.utils.prompt_utils import PromptUtils -from src.chat.utils.prompt_parameters import SmartPromptParameters - -logger = get_logger("smart_prompt") - - -@dataclass -class ChatContext: - """聊天上下文信息""" - - chat_id: str = "" - platform: str = "" - is_group: bool = False - user_id: str = "" - user_nickname: str = "" - group_id: Optional[str] = None - timestamp: datetime = field(default_factory=datetime.now) - - -class SmartPromptBuilder: - """重构的智能提示词构建器 - 统一错误处理和功能集成,移除缓存机制和依赖检查""" - - def __init__(self): - # 移除缓存相关初始化 - pass - - async def build_context_data(self, params: SmartPromptParameters) -> Dict[str, Any]: - """并行构建完整的上下文数据 - 移除缓存机制和依赖检查""" - - # 并行执行所有构建任务 - start_time = time.time() - timing_logs = {} - - try: - # 准备构建任务 - tasks = [] - task_names = [] - - # 初始化预构建参数,使用新的结构 - pre_built_params = {} - if params.expression_habits_block: - pre_built_params["expression_habits_block"] = params.expression_habits_block - if params.relation_info_block: - pre_built_params["relation_info_block"] = params.relation_info_block - if params.memory_block: - pre_built_params["memory_block"] = params.memory_block - if params.tool_info_block: - pre_built_params["tool_info_block"] = params.tool_info_block - if params.knowledge_prompt: - pre_built_params["knowledge_prompt"] = params.knowledge_prompt - if params.cross_context_block: - pre_built_params["cross_context_block"] = params.cross_context_block - - # 根据新的参数结构确定要构建的项 - if params.enable_expression and not pre_built_params.get("expression_habits_block"): - tasks.append(self._build_expression_habits(params)) - task_names.append("expression_habits") - - if params.enable_memory and not pre_built_params.get("memory_block"): - tasks.append(self._build_memory_block(params)) - task_names.append("memory_block") - - if params.enable_relation and not pre_built_params.get("relation_info_block"): - tasks.append(self._build_relation_info(params)) - task_names.append("relation_info") - - # 添加mai_think上下文构建任务 - if not pre_built_params.get("mai_think"): - tasks.append(self._build_mai_think_context(params)) - task_names.append("mai_think_context") - - if params.enable_tool and not pre_built_params.get("tool_info_block"): - tasks.append(self._build_tool_info(params)) - task_names.append("tool_info") - - if params.enable_knowledge and not pre_built_params.get("knowledge_prompt"): - tasks.append(self._build_knowledge_info(params)) - task_names.append("knowledge_info") - - if params.enable_cross_context and not pre_built_params.get("cross_context_block"): - tasks.append(self._build_cross_context(params)) - task_names.append("cross_context") - - # 性能优化:根据任务数量动态调整超时时间 - base_timeout = 10.0 # 基础超时时间 - task_timeout = 2.0 # 每个任务的超时时间 - timeout_seconds = min( - max(base_timeout, len(tasks) * task_timeout), # 根据任务数量计算超时 - 30.0, # 最大超时时间 - ) - - # 性能优化:限制并发任务数量,避免资源耗尽 - max_concurrent_tasks = 5 # 最大并发任务数 - if len(tasks) > max_concurrent_tasks: - # 分批执行任务 - results = [] - for i in range(0, len(tasks), max_concurrent_tasks): - batch_tasks = tasks[i : i + max_concurrent_tasks] - batch_names = task_names[i : i + max_concurrent_tasks] - - batch_results = await asyncio.wait_for( - asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds - ) - results.extend(batch_results) - else: - # 一次性执行所有任务 - results = await asyncio.wait_for( - asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds - ) - - # 处理结果并收集性能数据 - context_data = {} - for i, result in enumerate(results): - task_name = task_names[i] if i < len(task_names) else f"task_{i}" - - if isinstance(result, Exception): - logger.error(f"构建任务{task_name}失败: {str(result)}") - elif isinstance(result, dict): - # 结果格式: {component_name: value} - context_data.update(result) - - # 记录耗时过长的任务 - if task_name in timing_logs and timing_logs[task_name] > 8.0: - logger.warning(f"构建任务{task_name}耗时过长: {timing_logs[task_name]:.2f}s") - - # 添加预构建的参数 - for key, value in pre_built_params.items(): - if value: - context_data[key] = value - - except asyncio.TimeoutError: - logger.error(f"构建超时 ({timeout_seconds}s)") - context_data = {} - - # 添加预构建的参数,即使在超时情况下 - for key, value in pre_built_params.items(): - if value: - context_data[key] = value - - # 构建聊天历史 - 根据模式不同 - if params.prompt_mode == "s4u": - await self._build_s4u_chat_context(context_data, params) - else: - await self._build_normal_chat_context(context_data, params) - - # 补充基础信息 - context_data.update( - { - "keywords_reaction_prompt": params.keywords_reaction_prompt, - "extra_info_block": params.extra_info_block, - "time_block": params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", - "identity": params.identity_block, - "schedule_block": params.schedule_block, - "moderation_prompt": params.moderation_prompt_block, - "reply_target_block": params.reply_target_block, - "mood_state": params.mood_prompt, - "action_descriptions": params.action_descriptions, - } - ) - - total_time = time.time() - start_time - if timing_logs: - timing_str = "; ".join([f"{name}: {time:.2f}s" for name, time in timing_logs.items()]) - logger.info(f"构建任务耗时: {timing_str}") - logger.debug(f"构建完成,总耗时: {total_time:.2f}s") - - return context_data - - async def _build_s4u_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None: - """构建S4U模式的聊天上下文 - 使用新参数结构""" - if not params.message_list_before_now_long: - return - - # 使用共享工具构建分离历史 - core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts( - params.message_list_before_now_long, - params.target_user_info.get("user_id") if params.target_user_info else "", - params.sender - ) - - context_data["core_dialogue_prompt"] = core_dialogue - context_data["background_dialogue_prompt"] = background_dialogue - - async def _build_normal_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None: - """构建normal模式的聊天上下文 - 使用新参数结构""" - if not params.chat_talking_prompt_short: - return - - context_data["chat_info"] = f"""群里的聊天内容: -{params.chat_talking_prompt_short}""" - - async def _build_s4u_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str - ) -> Tuple[str, str]: - """构建S4U风格的分离对话prompt - 完整实现""" - core_dialogue_list = [] - bot_id = str(global_config.bot.qq_account) - - # 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话 - for msg_dict in message_list_before_now: - try: - msg_user_id = str(msg_dict.get("user_id")) - reply_to = msg_dict.get("reply_to", "") - _platform, reply_to_user_id = self._parse_reply_target(reply_to) - if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: - # bot 和目标用户的对话 - core_dialogue_list.append(msg_dict) - except Exception as e: - logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}") - - # 构建背景对话 prompt - all_dialogue_prompt = "" - if message_list_before_now: - latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :] - all_dialogue_prompt_str = build_readable_messages( - latest_25_msgs, - replace_bot_name=True, - timestamp_mode="normal", - truncate=True, - ) - all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}" - - # 构建核心对话 prompt - core_dialogue_prompt = "" - if core_dialogue_list: - # 检查最新五条消息中是否包含bot自己说的消息 - latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list - has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages) - - # logger.info(f"最新五条消息:{latest_5_messages}") - # logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}") - - # 如果最新五条消息中不包含bot的消息,则返回空字符串 - if not has_bot_message: - core_dialogue_prompt = "" - else: - core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量 - - core_dialogue_prompt_str = build_readable_messages( - core_dialogue_list, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=True, - show_actions=True, - ) - core_dialogue_prompt = f"""-------------------------------- -这是你和{sender}的对话,你们正在交流中: -{core_dialogue_prompt_str} --------------------------------- -""" - - return core_dialogue_prompt, all_dialogue_prompt - - async def _build_mai_think_context(self, params: SmartPromptParameters) -> Any: - """构建mai_think上下文 - 完全继承DefaultReplyer功能""" - from src.mais4u.mai_think import mai_thinking_manager - - # 获取mai_think实例 - mai_think = mai_thinking_manager.get_mai_think(params.chat_id) - - # 设置mai_think的上下文信息 - mai_think.memory_block = params.memory_block or "" - mai_think.relation_info_block = params.relation_info_block or "" - mai_think.time_block = params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - - # 设置聊天目标信息 - if params.is_group_chat: - chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") - chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") - else: - chat_target_name = "对方" - if params.chat_target_info: - chat_target_name = ( - params.chat_target_info.get("person_name") or params.chat_target_info.get("user_nickname") or "对方" - ) - chat_target_1 = await global_prompt_manager.format_prompt( - "chat_target_private1", sender_name=chat_target_name - ) - chat_target_2 = await global_prompt_manager.format_prompt( - "chat_target_private2", sender_name=chat_target_name - ) - - mai_think.chat_target = chat_target_1 - mai_think.chat_target_2 = chat_target_2 - mai_think.chat_info = params.chat_talking_prompt_short or "" - mai_think.mood_state = params.mood_prompt or "" - mai_think.identity = params.identity_block or "" - mai_think.sender = params.sender - mai_think.target = params.target - - # 返回mai_think实例,以便后续使用 - return mai_think - - def _parse_reply_target_id(self, reply_to: str) -> str: - """解析回复目标中的用户ID""" - if not reply_to: - return "" - - # 复用_parse_reply_target方法的逻辑 - sender, _ = self._parse_reply_target(reply_to) - if not sender: - return "" - - # 获取用户ID - person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) - if person_id: - user_id = person_info_manager.get_value_sync(person_id, "user_id") - return str(user_id) if user_id else "" - - async def _build_expression_habits(self, params: SmartPromptParameters) -> Dict[str, Any]: - """构建表达习惯 - 使用共享工具类,完全继承DefaultReplyer功能""" - # 检查是否允许在此聊天流中使用表达 - use_expression, _, _ = global_config.expression.get_expression_config_for_chat(params.chat_id) - if not use_expression: - return {"expression_habits_block": ""} - - from src.chat.express.expression_selector import expression_selector - - style_habits = [] - grammar_habits = [] - - # 使用从处理器传来的选中表达方式 - # LLM模式:调用LLM选择5-10个,然后随机选5个 - try: - selected_expressions = await expression_selector.select_suitable_expressions_llm( - params.chat_id, params.chat_talking_prompt_short, max_num=8, min_num=2, target_message=params.target - ) - except Exception as e: - logger.error(f"选择表达方式失败: {e}") - selected_expressions = [] - - if selected_expressions: - logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式") - for expr in selected_expressions: - if isinstance(expr, dict) and "situation" in expr and "style" in expr: - expr_type = expr.get("type", "style") - if expr_type == "grammar": - grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") - else: - style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}") - else: - logger.debug("没有从处理器获得表达方式,将使用空的表达方式") - # 不再在replyer中进行随机选择,全部交给处理器处理 - - style_habits_str = "\n".join(style_habits) - grammar_habits_str = "\n".join(grammar_habits) - - # 动态构建expression habits块 - expression_habits_block = "" - expression_habits_title = "" - if style_habits_str.strip(): - expression_habits_title = ( - "你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:" - ) - expression_habits_block += f"{style_habits_str}\n" - if grammar_habits_str.strip(): - expression_habits_title = ( - "你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:" - ) - expression_habits_block += f"{grammar_habits_str}\n" - - if style_habits_str.strip() and grammar_habits_str.strip(): - expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中。" - - return {"expression_habits_block": f"{expression_habits_title}\n{expression_habits_block}"} - - async def _build_memory_block(self, params: SmartPromptParameters) -> Dict[str, Any]: - """构建记忆块 - 使用共享工具类,完全继承DefaultReplyer功能""" - if not global_config.memory.enable_memory: - return {"memory_block": ""} - - from src.chat.memory_system.memory_activator import MemoryActivator - from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2 - - instant_memory = None - - # 初始化记忆激活器 - try: - memory_activator = MemoryActivator() - - # 获取长期记忆 - running_memories = await memory_activator.activate_memory_with_chat_history( - target_message=params.target, chat_history_prompt=params.chat_talking_prompt_short - ) - except Exception as e: - logger.error(f"激活记忆失败: {e}") - running_memories = [] - - # 处理瞬时记忆 - if global_config.memory.enable_instant_memory: - # 使用异步记忆包装器(最优化的非阻塞模式) - try: - from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory - - # 获取异步记忆包装器 - async_memory = get_async_instant_memory(params.chat_id) - - # 后台存储聊天历史(完全非阻塞) - async_memory.store_memory_background(params.chat_talking_prompt_short) - - # 快速检索记忆,最大超时2秒 - instant_memory = await async_memory.get_memory_with_fallback(params.target, max_timeout=2.0) - - logger.info(f"异步瞬时记忆:{instant_memory}") - - except ImportError: - # 如果异步包装器不可用,尝试使用异步记忆管理器 - try: - from src.chat.memory_system.async_memory_optimizer import ( - retrieve_memory_nonblocking, - store_memory_nonblocking, - ) - - # 异步存储聊天历史(非阻塞) - asyncio.create_task( - store_memory_nonblocking(chat_id=params.chat_id, content=params.chat_talking_prompt_short) - ) - - # 尝试从缓存获取瞬时记忆 - instant_memory = await retrieve_memory_nonblocking(chat_id=params.chat_id, query=params.target) - - # 如果没有缓存结果,快速检索一次 - if instant_memory is None: - try: - # 使用VectorInstantMemoryV2实例 - instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1) - instant_memory = await asyncio.wait_for( - instant_memory_system.get_memory_for_context(params.target), timeout=1.5 - ) - except asyncio.TimeoutError: - logger.warning("瞬时记忆检索超时,使用空结果") - instant_memory = "" - - logger.info(f"向量瞬时记忆:{instant_memory}") - - except ImportError: - # 最后的fallback:使用原有逻辑但加上超时控制 - logger.warning("异步记忆系统不可用,使用带超时的同步方式") - - # 使用VectorInstantMemoryV2实例 - instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1) - - # 异步存储聊天历史 - asyncio.create_task(instant_memory_system.store_message(params.chat_talking_prompt_short)) - - # 带超时的记忆检索 - try: - instant_memory = await asyncio.wait_for( - instant_memory_system.get_memory_for_context(params.target), - timeout=1.0, # 最保守的1秒超时 - ) - except asyncio.TimeoutError: - logger.warning("瞬时记忆检索超时,跳过记忆获取") - instant_memory = "" - except Exception as e: - logger.error(f"瞬时记忆检索失败: {e}") - instant_memory = "" - - logger.info(f"同步瞬时记忆:{instant_memory}") - - except Exception as e: - logger.error(f"瞬时记忆系统异常: {e}") - instant_memory = "" - - # 构建记忆字符串,即使某种记忆为空也要继续 - memory_str = "" - has_any_memory = False - - # 添加长期记忆 - if running_memories: - if not memory_str: - memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" - for running_memory in running_memories: - memory_str += f"- {running_memory['content']}\n" - has_any_memory = True - - # 添加瞬时记忆 - if instant_memory: - if not memory_str: - memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" - memory_str += f"- {instant_memory}\n" - has_any_memory = True - - # 注入视频分析结果引导语 - memory_str = self._inject_video_prompt_if_needed(params.target, memory_str) - - # 只有当完全没有任何记忆时才返回空字符串 - return {"memory_block": memory_str if has_any_memory else ""} - - def _inject_video_prompt_if_needed(self, target: str, memory_str: str) -> str: - """统一视频分析结果注入逻辑""" - if target and ("[视频内容]" in target or "好的,我将根据您提供的" in target): - video_prompt_injection = ( - "\n请注意,以上内容是你刚刚观看的视频,请以第一人称分享你的观后感,而不是在分析一份报告。" - ) - return memory_str + video_prompt_injection - return memory_str - - async def _build_relation_info(self, params: SmartPromptParameters) -> Dict[str, Any]: - """构建关系信息 - 使用共享工具类""" - try: - relation_info = await PromptUtils.build_relation_info(params.chat_id, params.reply_to) - return {"relation_info_block": relation_info} - except Exception as e: - logger.error(f"构建关系信息失败: {e}") - return {"relation_info_block": ""} - - async def _build_tool_info(self, params: SmartPromptParameters) -> Dict[str, Any]: - """构建工具信息 - 使用共享工具类,完全继承DefaultReplyer功能""" - if not params.enable_tool: - return {"tool_info_block": ""} - - if not params.reply_to: - return {"tool_info_block": ""} - - sender, text = PromptUtils.parse_reply_target(params.reply_to) - - if not text: - return {"tool_info_block": ""} - - from src.plugin_system.core.tool_use import ToolExecutor - - # 使用工具执行器获取信息 - try: - tool_executor = ToolExecutor(chat_id=params.chat_id) - tool_results, _, _ = await tool_executor.execute_from_chat_message( - sender=sender, target_message=text, chat_history=params.chat_talking_prompt_short, return_details=False - ) - - if tool_results: - tool_info_str = "以下是你通过工具获取到的实时信息:\n" - for tool_result in tool_results: - tool_name = tool_result.get("tool_name", "unknown") - content = tool_result.get("content", "") - result_type = tool_result.get("type", "tool_result") - - tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n" - - tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。" - logger.info(f"获取到 {len(tool_results)} 个工具结果") - - return {"tool_info_block": tool_info_str} - else: - logger.debug("未获取到任何工具结果") - return {"tool_info_block": ""} - - except Exception as e: - logger.error(f"工具信息获取失败: {e}") - return {"tool_info_block": ""} - - async def _build_knowledge_info(self, params: SmartPromptParameters) -> Dict[str, Any]: - """构建知识信息 - 使用共享工具类,完全继承DefaultReplyer功能""" - if not params.reply_to: - logger.debug("没有回复对象,跳过获取知识库内容") - return {"knowledge_prompt": ""} - - sender, content = PromptUtils.parse_reply_target(params.reply_to) - if not content: - logger.debug("回复对象内容为空,跳过获取知识库内容") - return {"knowledge_prompt": ""} - - logger.debug( - f"获取知识库内容,元消息:{params.chat_talking_prompt_short[:30]}...,消息长度: {len(params.chat_talking_prompt_short)}" - ) - - # 从LPMM知识库获取知识 - try: - # 检查LPMM知识库是否启用 - if not global_config.lpmm_knowledge.enable: - logger.debug("LPMM知识库未启用,跳过获取知识库内容") - return {"knowledge_prompt": ""} - - from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool - from src.plugin_system.apis import llm_api - from src.config.config import model_config - - time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - bot_name = global_config.bot.nickname - - prompt = await global_prompt_manager.format_prompt( - "lpmm_get_knowledge_prompt", - bot_name=bot_name, - time_now=time_now, - chat_history=params.chat_talking_prompt_short, - sender=sender, - target_message=content, - ) - - _, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools( - prompt, - model_config=model_config.model_task_config.tool_use, - tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()], - ) - - if tool_calls: - from src.plugin_system.core.tool_use import ToolExecutor - - tool_executor = ToolExecutor(chat_id=params.chat_id) - result = await tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool()) - - if not result or not result.get("content"): - logger.debug("从LPMM知识库获取知识失败,返回空知识...") - return {"knowledge_prompt": ""} - - found_knowledge_from_lpmm = result.get("content", "") - logger.debug( - f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" - ) - - return { - "knowledge_prompt": f"你有以下这些**知识**:\n{found_knowledge_from_lpmm}\n请你**记住上面的知识**,之后可能会用到。\n" - } - else: - logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") - return {"knowledge_prompt": ""} - - except Exception as e: - logger.error(f"获取知识库内容时发生异常: {str(e)}") - return {"knowledge_prompt": ""} - - async def _build_cross_context(self, params: SmartPromptParameters) -> Dict[str, Any]: - """构建跨群上下文 - 使用共享工具类""" - try: - cross_context = await PromptUtils.build_cross_context( - params.chat_id, params.prompt_mode, params.target_user_info - ) - return {"cross_context_block": cross_context} - except Exception as e: - logger.error(f"构建跨群上下文失败: {e}") - return {"cross_context_block": ""} - - def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: - """解析回复目标消息 - 使用共享工具类""" - return PromptUtils.parse_reply_target(target_message) - - -class SmartPrompt: - """重构的智能提示词核心类 - 移除缓存机制和依赖检查,简化架构""" - - def __init__( - self, - template_name: Optional[str] = None, - parameters: Optional[SmartPromptParameters] = None, - ): - self.parameters = parameters or SmartPromptParameters() - self.template_name = template_name or self._get_default_template() - self.builder = SmartPromptBuilder() - - def _get_default_template(self) -> str: - """根据模式选择默认模板""" - if self.parameters.prompt_mode == "s4u": - return "s4u_style_prompt" - elif self.parameters.prompt_mode == "normal": - return "normal_style_prompt" - else: - return "default_expressor_prompt" - - async def build_prompt(self) -> str: - """构建最终的Prompt文本 - 移除缓存机制和依赖检查""" - # 参数验证 - errors = self.parameters.validate() - if errors: - logger.error(f"参数验证失败: {', '.join(errors)}") - raise ValueError(f"参数验证失败: {', '.join(errors)}") - - start_time = time.time() - try: - # 构建基础上下文的完整映射 - context_data = await self.builder.build_context_data(self.parameters) - - # 检查关键上下文数据 - if not context_data or not isinstance(context_data, dict): - logger.error("构建的上下文数据无效") - raise ValueError("构建的上下文数据无效") - - # 获取模板 - template = await self._get_template() - if template is None: - logger.error("无法获取模板") - raise ValueError("无法获取模板") - - # 根据模式传递不同的参数 - if self.parameters.prompt_mode == "s4u": - result = await self._build_s4u_prompt(template, context_data) - elif self.parameters.prompt_mode == "normal": - result = await self._build_normal_prompt(template, context_data) - else: - result = await self._build_default_prompt(template, context_data) - - # 记录性能数据 - total_time = time.time() - start_time - logger.debug(f"SmartPrompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s") - - return result - - except asyncio.TimeoutError as e: - logger.error(f"构建Prompt超时: {e}") - raise TimeoutError(f"构建Prompt超时: {e}") - except Exception as e: - logger.error(f"构建Prompt失败: {e}") - raise RuntimeError(f"构建Prompt失败: {e}") - - async def _get_template(self) -> Optional[Prompt]: - """获取模板""" - try: - return await global_prompt_manager.get_prompt_async(self.template_name) - except Exception as e: - logger.error(f"获取模板 {self.template_name} 失败: {e}") - raise RuntimeError(f"获取模板 {self.template_name} 失败: {e}") - - async def _build_s4u_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str: - """构建S4U模式的完整Prompt - 使用新参数结构""" - params = { - **context_data, - "expression_habits_block": context_data.get("expression_habits_block", ""), - "tool_info_block": context_data.get("tool_info_block", ""), - "knowledge_prompt": context_data.get("knowledge_prompt", ""), - "memory_block": context_data.get("memory_block", ""), - "relation_info_block": context_data.get("relation_info_block", ""), - "extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""), - "cross_context_block": context_data.get("cross_context_block", ""), - "identity": self.parameters.identity_block or context_data.get("identity", ""), - "action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""), - "sender_name": self.parameters.sender, - "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), - "background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""), - "time_block": context_data.get("time_block", ""), - "core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""), - "reply_target_block": context_data.get("reply_target_block", ""), - "reply_style": global_config.personality.reply_style, - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt - or context_data.get("keywords_reaction_prompt", ""), - "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - } - return await global_prompt_manager.format_prompt(self.template_name, **params) - - async def _build_normal_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str: - """构建Normal模式的完整Prompt - 使用新参数结构""" - params = { - **context_data, - "expression_habits_block": context_data.get("expression_habits_block", ""), - "tool_info_block": context_data.get("tool_info_block", ""), - "knowledge_prompt": context_data.get("knowledge_prompt", ""), - "memory_block": context_data.get("memory_block", ""), - "relation_info_block": context_data.get("relation_info_block", ""), - "extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""), - "cross_context_block": context_data.get("cross_context_block", ""), - "identity": self.parameters.identity_block or context_data.get("identity", ""), - "action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""), - "schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""), - "time_block": context_data.get("time_block", ""), - "chat_info": context_data.get("chat_info", ""), - "reply_target_block": context_data.get("reply_target_block", ""), - "config_expression_style": global_config.personality.reply_style, - "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt - or context_data.get("keywords_reaction_prompt", ""), - "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - } - return await global_prompt_manager.format_prompt(self.template_name, **params) - - async def _build_default_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str: - """构建默认模式的Prompt - 使用新参数结构""" - params = { - "expression_habits_block": context_data.get("expression_habits_block", ""), - "relation_info_block": context_data.get("relation_info_block", ""), - "chat_target": "", - "time_block": context_data.get("time_block", ""), - "chat_info": context_data.get("chat_info", ""), - "identity": self.parameters.identity_block or context_data.get("identity", ""), - "chat_target_2": "", - "reply_target_block": context_data.get("reply_target_block", ""), - "raw_reply": self.parameters.target, - "reason": "", - "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), - "reply_style": global_config.personality.reply_style, - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt - or context_data.get("keywords_reaction_prompt", ""), - "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - } - return await global_prompt_manager.format_prompt(self.template_name, **params) - - -# 工厂函数 - 简化创建 - 更新参数结构 -def create_smart_prompt( - chat_id: str = "", sender_name: str = "", target_message: str = "", reply_to: str = "", **kwargs -) -> SmartPrompt: - """快速创建智能Prompt实例的工厂函数 - 使用新参数结构""" - - # 使用新的参数结构 - parameters = SmartPromptParameters( - chat_id=chat_id, sender=sender_name, target=target_message, reply_to=reply_to, **kwargs - ) - - return SmartPrompt(parameters=parameters) - - -class SmartPromptHealthChecker: - """SmartPrompt健康检查器 - 移除依赖检查""" - - @staticmethod - async def check_system_health() -> Dict[str, Any]: - """检查系统健康状态 - 移除依赖检查""" - health_status = {"status": "healthy", "components": {}, "issues": []} - - try: - # 检查配置 - try: - from src.config.config import global_config - - health_status["components"]["config"] = "ok" - - # 检查关键配置项 - if not hasattr(global_config, "personality") or not hasattr(global_config.personality, "prompt_mode"): - health_status["issues"].append("缺少personality.prompt_mode配置") - health_status["status"] = "degraded" - - if not hasattr(global_config, "memory") or not hasattr(global_config.memory, "enable_memory"): - health_status["issues"].append("缺少memory.enable_memory配置") - - except Exception as e: - health_status["components"]["config"] = f"failed: {str(e)}" - health_status["issues"].append("配置加载失败") - health_status["status"] = "unhealthy" - - # 检查Prompt模板 - try: - required_templates = ["s4u_style_prompt", "normal_style_prompt", "default_expressor_prompt"] - for template_name in required_templates: - try: - await global_prompt_manager.get_prompt_async(template_name) - health_status["components"][f"template_{template_name}"] = "ok" - except Exception as e: - health_status["components"][f"template_{template_name}"] = f"failed: {str(e)}" - health_status["issues"].append(f"模板{template_name}加载失败") - health_status["status"] = "degraded" - - except Exception as e: - health_status["components"]["prompt_templates"] = f"failed: {str(e)}" - health_status["issues"].append("Prompt模板检查失败") - health_status["status"] = "unhealthy" - - return health_status - - except Exception as e: - return {"status": "unhealthy", "components": {}, "issues": [f"健康检查异常: {str(e)}"]} - - @staticmethod - async def run_performance_test() -> Dict[str, Any]: - """运行性能测试""" - test_results = {"status": "completed", "tests": {}, "summary": {}} - - try: - # 创建测试参数 - test_params = SmartPromptParameters( - chat_id="test_chat", - sender="test_user", - target="test_message", - reply_to="test_user:test_message", - prompt_mode="s4u", - ) - - # 测试不同模式下的构建性能 - modes = ["s4u", "normal", "minimal"] - for mode in modes: - test_params.prompt_mode = mode - smart_prompt = SmartPrompt(parameters=test_params) - - # 运行多次测试取平均值 - times = [] - for _ in range(3): - start_time = time.time() - try: - await smart_prompt.build_prompt() - end_time = time.time() - times.append(end_time - start_time) - except Exception as e: - times.append(float("inf")) - logger.error(f"性能测试失败 (模式: {mode}): {e}") - - # 计算统计信息 - valid_times = [t for t in times if t != float("inf")] - if valid_times: - avg_time = sum(valid_times) / len(valid_times) - min_time = min(valid_times) - max_time = max(valid_times) - - test_results["tests"][mode] = { - "avg_time": avg_time, - "min_time": min_time, - "max_time": max_time, - "success_rate": len(valid_times) / len(times), - } - else: - test_results["tests"][mode] = { - "avg_time": float("inf"), - "min_time": float("inf"), - "max_time": float("inf"), - "success_rate": 0, - } - - # 计算总体统计 - all_avg_times = [ - test["avg_time"] for test in test_results["tests"].values() if test["avg_time"] != float("inf") - ] - if all_avg_times: - test_results["summary"] = { - "overall_avg_time": sum(all_avg_times) / len(all_avg_times), - "fastest_mode": min(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0], - "slowest_mode": max(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0], - } - - return test_results - - except Exception as e: - return {"status": "failed", "tests": {}, "summary": {}, "error": str(e)} diff --git a/src/mais4u/mai_think.py b/src/mais4u/mai_think.py index 3daa5875d..4c34c4798 100644 --- a/src/mais4u/mai_think.py +++ b/src/mais4u/mai_think.py @@ -1,6 +1,6 @@ from src.chat.message_receive.chat_stream import get_chat_manager import time -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.llm_models.utils_model import LLMRequest from src.config.config import model_config from src.chat.message_receive.message import MessageRecvS4U diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index 26af9fedd..bf3640be0 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.config.config import global_config, model_config -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index 734193c91..8d1e22b8f 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.config.config import global_config, model_config -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api from src.mais4u.constant_s4u import ENABLE_S4U diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 598ee4e89..db6a6edf9 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -1,6 +1,6 @@ from src.config.config import global_config from src.common.logger import get_logger -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat import time from src.chat.utils.utils import get_recent_group_speaker diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 1fc04c9d8..95a365b41 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -6,7 +6,7 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.llm_models.utils_model import LLMRequest from src.manager.async_task_manager import AsyncTask, async_task_manager diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index ba55feca8..1c62dec1a 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -9,7 +9,7 @@ from json_repair import repair_json from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.person_info import get_person_info_manager diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index ee57e5d82..1b2618f43 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -6,7 +6,7 @@ from src.plugin_system.core.global_announcement_manager import global_announceme from src.llm_models.utils_model import LLMRequest from src.llm_models.payload_content import ToolCall from src.config.config import global_config, model_config -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager import inspect from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger From d05e2f9ee45bccab096df935295e4deebc526569 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 6 Sep 2025 01:36:00 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=80=E5=A0=86?= =?UTF-8?q?=E6=96=B0prompt=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 43 +++++--- src/chat/utils/prompt.py | 142 ++++++++++++++++++++++++-- src/chat/utils/prompt_utils.py | 132 ------------------------ 3 files changed, 164 insertions(+), 153 deletions(-) delete mode 100644 src/chat/utils/prompt_utils.py diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index bee1ad802..3cc694c1d 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -11,7 +11,6 @@ import re from typing import List, Optional, Dict, Any, Tuple from datetime import datetime -from src.chat.utils.prompt_utils import PromptUtils from src.mais4u.mai_think import mai_thinking_manager from src.common.logger import get_logger from src.config.config import global_config, model_config @@ -36,10 +35,9 @@ from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.person_info import get_person_info_manager from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.apis import llm_api -from src.schedule.schedule_manager import schedule_manager # 导入新的统一Prompt系统 -from src.chat.utils.prompt import Prompt, PromptContext +from src.chat.utils.prompt import Prompt, PromptParameters logger = get_logger("replyer") @@ -599,7 +597,8 @@ class DefaultReplyer: def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: """解析回复目标消息 - 使用共享工具""" - return PromptUtils.parse_reply_target(target_message) + from src.chat.utils.prompt import Prompt + return Prompt.parse_reply_target(target_message) async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: """构建关键词反应提示 @@ -874,7 +873,8 @@ class DefaultReplyer: target_user_info = None if sender: target_user_info = await person_info_manager.get_person_info_by_name(sender) - + + from src.chat.utils.prompt import Prompt # 并行执行六个构建任务 task_results = await asyncio.gather( self._time_and_run_task( @@ -887,7 +887,7 @@ class DefaultReplyer: ), self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"), self._time_and_run_task( - PromptUtils.build_cross_context(chat_id, target_user_info, global_config.personality.prompt_mode), + Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info), "cross_context", ), ) @@ -939,6 +939,7 @@ class DefaultReplyer: schedule_block = "" if global_config.schedule.enable: + from src.schedule.schedule_manager import schedule_manager current_activity = schedule_manager.get_current_activity() if current_activity: schedule_block = f"你当前正在:{current_activity}。" @@ -970,8 +971,8 @@ class DefaultReplyer: # 根据配置选择模板 current_prompt_mode = global_config.personality.prompt_mode - # 使用新的统一Prompt系统 - prompt_context = PromptContext( + # 使用新的统一Prompt系统 - 创建PromptParameters + prompt_parameters = PromptParameters( chat_id=chat_id, is_group_chat=is_group_chat, sender=sender, @@ -1004,9 +1005,19 @@ class DefaultReplyer: action_descriptions=action_descriptions, ) - # 使用新的统一Prompt系统 - prompt = Prompt(template_name=None, context=prompt_context) # 由current_prompt_mode自动选择 - prompt_text = await prompt.build_prompt() + # 使用新的统一Prompt系统 - 使用正确的模板名称 + template_name = None + if current_prompt_mode == "s4u": + template_name = "s4u_style_prompt" + elif current_prompt_mode == "normal": + template_name = "normal_style_prompt" + elif current_prompt_mode == "minimal": + template_name = "default_expressor_prompt" + + # 获取模板内容 + template_prompt = await global_prompt_manager.get_prompt_async(template_name) + prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) + prompt_text = await prompt.build() return prompt_text @@ -1107,8 +1118,8 @@ class DefaultReplyer: template_name = "default_expressor_prompt" - # 使用新的统一Prompt系统 - Expressor模式 - prompt_context = PromptContext( + # 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters + prompt_parameters = PromptParameters( chat_id=chat_id, is_group_chat=is_group_chat, sender=sender, @@ -1128,8 +1139,10 @@ class DefaultReplyer: relation_info_block=relation_info, ) - prompt = Prompt(template_name=template_name, context=prompt_context) - prompt_text = await prompt.build_prompt() + # 使用新的统一Prompt系统 - Expressor模式 + template_prompt = await global_prompt_manager.get_prompt_async("default_expressor_prompt") + prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) + prompt_text = await prompt.build() return prompt_text diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 1e44b72d8..b5cf140c5 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -8,14 +8,14 @@ import asyncio import time import contextvars from dataclasses import dataclass, field -from typing import Dict, Any, Optional, List, Union, Literal, Tuple +from typing import Dict, Any, Optional, List, Literal, Tuple from contextlib import asynccontextmanager from rich.traceback import install from src.common.logger import get_logger from src.config.config import global_config from src.chat.utils.chat_message_builder import build_readable_messages -from src.chat.utils.prompt_utils import PromptUtils +from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.person_info import get_person_info_manager install(extra_lines=3) @@ -472,7 +472,7 @@ class Prompt: try: msg_user_id = str(msg_dict.get("user_id")) reply_to = msg_dict.get("reply_to", "") - platform, reply_to_user_id = PromptUtils.parse_reply_target(reply_to) + platform, reply_to_user_id = Prompt.parse_reply_target(reply_to) if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: core_dialogue_list.append(msg_dict) except Exception as e: @@ -531,7 +531,7 @@ class Prompt: async def _build_relation_info(self) -> Dict[str, Any]: """构建关系信息""" try: - relation_info = await PromptUtils.build_relation_info(self.parameters.chat_id, self.parameters.reply_to) + relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to) return {"relation_info_block": relation_info} except Exception as e: logger.error(f"构建关系信息失败: {e}") @@ -550,7 +550,7 @@ class Prompt: async def _build_cross_context(self) -> Dict[str, Any]: """构建跨群上下文""" try: - cross_context = await PromptUtils.build_cross_context( + cross_context = await Prompt.build_cross_context( self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info ) return {"cross_context_block": cross_context} @@ -666,6 +666,135 @@ class Prompt: """返回提示词的表示形式""" return f"Prompt(template='{self.template}', name='{self.name}')" + # ============================================================================= + # PromptUtils功能迁移 - 静态工具方法 + # 这些方法原来在PromptUtils类中,现在作为Prompt类的静态方法 + # 解决循环导入问题 + # ============================================================================= + + @staticmethod + def parse_reply_target(target_message: str) -> Tuple[str, str]: + """ + 解析回复目标消息 - 统一实现 + + Args: + target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容" + + Returns: + Tuple[str, str]: (发送者名称, 消息内容) + """ + sender = "" + target = "" + + # 添加None检查,防止NoneType错误 + if target_message is None: + return sender, target + + if ":" in target_message or ":" in target_message: + # 使用正则表达式匹配中文或英文冒号 + parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1) + if len(parts) == 2: + sender = parts[0].strip() + target = parts[1].strip() + return sender, target + + @staticmethod + async def build_relation_info(chat_id: str, reply_to: str) -> str: + """ + 构建关系信息 - 统一实现 + + Args: + chat_id: 聊天ID + reply_to: 回复目标字符串 + + Returns: + str: 关系信息字符串 + """ + if not global_config.relationship.enable_relationship: + return "" + + from src.person_info.relationship_fetcher import relationship_fetcher_manager + + relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id) + + if not reply_to: + return "" + sender, text = Prompt.parse_reply_target(reply_to) + if not sender or not text: + return "" + + # 获取用户ID + person_info_manager = get_person_info_manager() + person_id = person_info_manager.get_person_id_by_person_name(sender) + if not person_id: + logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") + return f"你完全不认识{sender},不理解ta的相关信息。" + + return await relationship_fetcher.build_relation_info(person_id, points_num=5) + + @staticmethod + async def build_cross_context( + chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]] + ) -> str: + """ + 构建跨群聊上下文 - 统一实现 + + Args: + chat_id: 聊天ID + prompt_mode: 当前提示词模式 + target_user_info: 目标用户信息 + + Returns: + str: 跨群聊上下文字符串 + """ + if not global_config.cross_context.enable: + return "" + + from src.plugin_system.apis import cross_context_api + + other_chat_raw_ids = cross_context_api.get_context_groups(chat_id) + if not other_chat_raw_ids: + return "" + + chat_stream = get_chat_manager().get_stream(chat_id) + if not chat_stream: + return "" + + if prompt_mode == "normal": + return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids) + elif prompt_mode == "s4u": + return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info) + + return "" + + @staticmethod + def parse_reply_target_id(reply_to: str) -> str: + """ + 解析回复目标中的用户ID + + Args: + reply_to: 回复目标字符串 + + Returns: + str: 用户ID + """ + if not reply_to: + return "" + + # 复用parse_reply_target方法的逻辑 + sender, _ = Prompt.parse_reply_target(reply_to) + if not sender: + return "" + + # 获取用户ID + person_info_manager = get_person_info_manager() + person_id = person_info_manager.get_person_id_by_person_name(sender) + if person_id: + user_id = person_info_manager.get_value_sync(person_id, "user_id") + return str(user_id) if user_id else "" + + return "" + # 工厂函数 def create_prompt( @@ -690,4 +819,5 @@ async def create_prompt_async( prompt = create_prompt(template, name, parameters, **kwargs) if global_prompt_manager._context._current_context: await global_prompt_manager._context.register_async(prompt) - return prompt \ No newline at end of file + return prompt + diff --git a/src/chat/utils/prompt_utils.py b/src/chat/utils/prompt_utils.py deleted file mode 100644 index 4eed6025f..000000000 --- a/src/chat/utils/prompt_utils.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -共享提示词工具模块 - 消除重复代码 -提供统一的工具函数供DefaultReplyer和统一Prompt系统使用 -""" - -import re -from typing import Dict, Any, Optional, Tuple - -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.message_receive.chat_stream import get_chat_manager -from src.person_info.person_info import get_person_info_manager -from src.plugin_system.apis import cross_context_api - -logger = get_logger("prompt_utils") - - -class PromptUtils: - """提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查""" - - @staticmethod - def parse_reply_target(target_message: str) -> Tuple[str, str]: - """ - 解析回复目标消息 - 统一实现 - - Args: - target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容" - - Returns: - Tuple[str, str]: (发送者名称, 消息内容) - """ - sender = "" - target = "" - - # 添加None检查,防止NoneType错误 - if target_message is None: - return sender, target - - if ":" in target_message or ":" in target_message: - # 使用正则表达式匹配中文或英文冒号 - parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1) - if len(parts) == 2: - sender = parts[0].strip() - target = parts[1].strip() - return sender, target - - @staticmethod - async def build_relation_info(chat_id: str, reply_to: str) -> str: - """ - 构建关系信息 - 统一实现 - - Args: - chat_id: 聊天ID - reply_to: 回复目标字符串 - - Returns: - str: 关系信息字符串 - """ - if not global_config.relationship.enable_relationship: - return "" - - from src.person_info.relationship_fetcher import relationship_fetcher_manager - - relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id) - - if not reply_to: - return "" - sender, text = PromptUtils.parse_reply_target(reply_to) - if not sender or not text: - return "" - - # 获取用户ID - person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) - if not person_id: - logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") - return f"你完全不认识{sender},不理解ta的相关信息。" - - return await relationship_fetcher.build_relation_info(person_id, points_num=5) - - @staticmethod - async def build_cross_context( - chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str - ) -> str: - """ - 构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能 - """ - if not global_config.cross_context.enable: - return "" - - other_chat_raw_ids = cross_context_api.get_context_groups(chat_id) - if not other_chat_raw_ids: - return "" - - chat_stream = get_chat_manager().get_stream(chat_id) - if not chat_stream: - return "" - - if current_prompt_mode == "normal": - return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids) - elif current_prompt_mode == "s4u": - return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info) - - return "" - - @staticmethod - def parse_reply_target_id(reply_to: str) -> str: - """ - 解析回复目标中的用户ID - - Args: - reply_to: 回复目标字符串 - - Returns: - str: 用户ID - """ - if not reply_to: - return "" - - # 复用parse_reply_target方法的逻辑 - sender, _ = PromptUtils.parse_reply_target(reply_to) - if not sender: - return "" - - # 获取用户ID - person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) - if person_id: - user_id = person_info_manager.get_value_sync(person_id, "user_id") - return str(user_id) if user_id else "" - - return "" From 6c042cc73fc0199593213e6db4fe63cbd36c3b45 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 6 Sep 2025 03:38:43 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E8=BF=81=E7=A7=BBnapcat=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E8=87=B3built=5Fin?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/config/__init__.py | 5 -- .../src/mmc_com_layer.py | 26 -------- src/chat/chat_loop/cycle_processor.py | 2 +- src/chat/chat_loop/hfc_utils.py | 7 ++- src/plugins/built_in/core_actions/emoji.py | 5 +- .../napcat_adapter_plugin/.gitignore | 0 .../built_in}/napcat_adapter_plugin/CONSTS.py | 0 .../napcat_adapter_plugin/_manifest.json | 0 .../napcat_adapter_plugin/event_handlers.py | 29 +++++++++ .../napcat_adapter_plugin/event_types.py | 24 +++++++ .../built_in}/napcat_adapter_plugin/plugin.py | 63 +++++++++++++++++-- .../napcat_adapter_plugin/pyproject.toml | 0 .../napcat_adapter_plugin/src/__init__.py | 0 .../src/config/__init__.py | 2 + .../src/config/config.py | 0 .../src/config/config_base.py | 0 .../src/config/config_utils.py | 0 .../src/config/features_config.py | 0 .../src/config/migrate_features.py | 0 .../src/config/official_configs.py | 0 .../napcat_adapter_plugin/src/database.py | 0 .../src/message_buffer.py | 0 .../src/message_chunker.py | 12 +++- .../src/mmc_com_layer.py | 44 +++++++++++++ .../src/recv_handler/__init__.py | 0 .../src/recv_handler/message_handler.py | 31 +++++---- .../src/recv_handler/message_sending.py | 14 +++-- .../src/recv_handler/meta_event_handler.py | 0 .../src/recv_handler/notice_handler.py | 0 .../src/recv_handler/qq_emoji_list.py | 0 .../src/response_pool.py | 19 +++++- .../napcat_adapter_plugin/src/send_handler.py | 13 +++- .../napcat_adapter_plugin/src/utils.py | 0 .../src/video_handler.py | 0 .../src/websocket_manager.py | 28 +++++---- .../template/features_template.toml | 0 .../template/template_config.toml | 0 .../built_in}/napcat_adapter_plugin/todo.md | 0 38 files changed, 243 insertions(+), 81 deletions(-) delete mode 100644 plugins/napcat_adapter_plugin/src/config/__init__.py delete mode 100644 plugins/napcat_adapter_plugin/src/mmc_com_layer.py rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/.gitignore (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/CONSTS.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/_manifest.json (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/event_handlers.py (98%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/event_types.py (98%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/plugin.py (76%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/pyproject.toml (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/__init__.py (100%) create mode 100644 src/plugins/built_in/napcat_adapter_plugin/src/config/__init__.py rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/config/config.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/config/config_base.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/config/config_utils.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/config/features_config.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/config/migrate_features.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/config/official_configs.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/database.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/message_buffer.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/message_chunker.py (95%) create mode 100644 src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/recv_handler/__init__.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/recv_handler/message_handler.py (96%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/recv_handler/message_sending.py (82%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/recv_handler/notice_handler.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/response_pool.py (73%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/send_handler.py (98%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/utils.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/video_handler.py (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/src/websocket_manager.py (85%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/template/features_template.toml (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/template/template_config.toml (100%) rename {plugins => src/plugins/built_in}/napcat_adapter_plugin/todo.md (100%) diff --git a/plugins/napcat_adapter_plugin/src/config/__init__.py b/plugins/napcat_adapter_plugin/src/config/__init__.py deleted file mode 100644 index 40ba89aeb..000000000 --- a/plugins/napcat_adapter_plugin/src/config/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .config import global_config - -__all__ = [ - "global_config", -] diff --git a/plugins/napcat_adapter_plugin/src/mmc_com_layer.py b/plugins/napcat_adapter_plugin/src/mmc_com_layer.py deleted file mode 100644 index 14cddf102..000000000 --- a/plugins/napcat_adapter_plugin/src/mmc_com_layer.py +++ /dev/null @@ -1,26 +0,0 @@ -from maim_message import Router, RouteConfig, TargetConfig -from .config import global_config -from src.common.logger import get_logger -from .send_handler import send_handler - -logger = get_logger("napcat_adapter") - -route_config = RouteConfig( - route_config={ - global_config.maibot_server.platform_name: TargetConfig( - url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws", - token=None, - ) - } -) -router = Router(route_config) - - -async def mmc_start_com(): - logger.info("正在连接MaiBot") - router.register_class_handler(send_handler.handle_message) - await router.run() - - -async def mmc_stop_com(): - await router.stop() diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py index b446c697a..bb1a1a5f0 100644 --- a/src/chat/chat_loop/cycle_processor.py +++ b/src/chat/chat_loop/cycle_processor.py @@ -149,7 +149,7 @@ class CycleProcessor: logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考") if ENABLE_S4U: - await send_typing() + await send_typing(self.context.chat_stream.user_info.user_id) loop_start_time = time.time() diff --git a/src/chat/chat_loop/hfc_utils.py b/src/chat/chat_loop/hfc_utils.py index ae77b2378..32d31fd52 100644 --- a/src/chat/chat_loop/hfc_utils.py +++ b/src/chat/chat_loop/hfc_utils.py @@ -121,7 +121,7 @@ class CycleDetail: self.loop_action_info = loop_info["loop_action_info"] -async def send_typing(): +async def send_typing(user_id): """ 发送打字状态指示 @@ -139,6 +139,11 @@ async def send_typing(): group_info=group_info, ) + from plugin_system.core.event_manager import event_manager + from src.plugins.built_in.napcat_adapter_plugin.event_types import NapcatEvent + # 设置正在输入状态 + await event_manager.trigger_event(NapcatEvent.PERSONAL.SET_INPUT_STATUS,user_id=user_id,event_type=1) + await send_api.custom_to_stream( message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False ) diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index ab5b18386..25e09d8d6 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -39,8 +39,9 @@ class EmojiAction(BaseAction): llm_judge_prompt = """ 判定是否需要使用表情动作的条件: 1. 用户明确要求使用表情包 - 2. 这是一个适合表达强烈情绪的场合 - 3. 不要发送太多表情包,如果你已经发送过多个表情包则回答"否" + 2. 这是一个适合表达情绪的场合 + 3. 发表情包能使当前对话更有趣 + 4. 不要发送太多表情包,如果你已经发送过多个表情包则回答"否" 请回答"是"或"否"。 """ diff --git a/plugins/napcat_adapter_plugin/.gitignore b/src/plugins/built_in/napcat_adapter_plugin/.gitignore similarity index 100% rename from plugins/napcat_adapter_plugin/.gitignore rename to src/plugins/built_in/napcat_adapter_plugin/.gitignore diff --git a/plugins/napcat_adapter_plugin/CONSTS.py b/src/plugins/built_in/napcat_adapter_plugin/CONSTS.py similarity index 100% rename from plugins/napcat_adapter_plugin/CONSTS.py rename to src/plugins/built_in/napcat_adapter_plugin/CONSTS.py diff --git a/plugins/napcat_adapter_plugin/_manifest.json b/src/plugins/built_in/napcat_adapter_plugin/_manifest.json similarity index 100% rename from plugins/napcat_adapter_plugin/_manifest.json rename to src/plugins/built_in/napcat_adapter_plugin/_manifest.json diff --git a/plugins/napcat_adapter_plugin/event_handlers.py b/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py similarity index 98% rename from plugins/napcat_adapter_plugin/event_handlers.py rename to src/plugins/built_in/napcat_adapter_plugin/event_handlers.py index 521bc77f4..1e5fbd531 100644 --- a/plugins/napcat_adapter_plugin/event_handlers.py +++ b/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py @@ -1746,3 +1746,32 @@ class SetGroupSignHandler(BaseEventHandler): else: logger.error("事件 napcat_set_group_sign 请求失败!") return HandlerResult(False, False, {"status": "error"}) + +# ===PERSONAL=== +class SetInputStatusHandler(BaseEventHandler): + handler_name: str = "napcat_set_input_status_handler" + handler_description: str = "设置输入状态" + weight: int = 100 + intercept_message: bool = False + init_subscribe = [NapcatEvent.PERSONAL.SET_INPUT_STATUS] + + async def execute(self, params: dict): + raw = params.get("raw", {}) + user_id = params.get("user_id", "") + event_type = params.get("event_type", 0) + + if params.get("raw", ""): + user_id = raw.get("user_id", "") + event_type = raw.get("event_type", 0) + + if not user_id or event_type is None: + logger.error("事件 napcat_set_input_status 缺少必要参数: user_id 或 event_type") + return HandlerResult(False, False, {"status": "error"}) + + payload = {"user_id": str(user_id), "event_type": int(event_type)} + response = await send_handler.send_message_to_napcat(action="set_input_status", params=payload) + if response.get("status", "") == "ok": + return HandlerResult(True, True, response) + else: + logger.error("事件 napcat_set_input_status 请求失败!") + return HandlerResult(False, False, {"status": "error"}) diff --git a/plugins/napcat_adapter_plugin/event_types.py b/src/plugins/built_in/napcat_adapter_plugin/event_types.py similarity index 98% rename from plugins/napcat_adapter_plugin/event_types.py rename to src/plugins/built_in/napcat_adapter_plugin/event_types.py index ee318834d..af417f37a 100644 --- a/plugins/napcat_adapter_plugin/event_types.py +++ b/src/plugins/built_in/napcat_adapter_plugin/event_types.py @@ -1816,3 +1816,27 @@ class NapcatEvent: """ class FILE(Enum): ... + + class PERSONAL(Enum): + SET_INPUT_STATUS = "napcat_set_input_status" + """ + 设置输入状态 + + Args: + user_id (Optional[str|int]): 用户id(必需) + event_type (Optional[int]): 输入状态id(必需) + raw (Optional[dict]): 原始请求体 + + Returns: + dict: { + "status": "ok", + "retcode": 0, + "data": { + "result": 0, + "errMsg": "string" + }, + "message": "string", + "wording": "string", + "echo": "string" + } + """ diff --git a/plugins/napcat_adapter_plugin/plugin.py b/src/plugins/built_in/napcat_adapter_plugin/plugin.py similarity index 76% rename from plugins/napcat_adapter_plugin/plugin.py rename to src/plugins/built_in/napcat_adapter_plugin/plugin.py index 48ae8603d..0067ba964 100644 --- a/plugins/napcat_adapter_plugin/plugin.py +++ b/src/plugins/built_in/napcat_adapter_plugin/plugin.py @@ -8,6 +8,7 @@ from typing import List from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField from src.plugin_system.core.event_manager import event_manager +from src.plugin_system.apis import config_api from src.common.logger import get_logger @@ -17,7 +18,6 @@ from .src.recv_handler.meta_event_handler import meta_event_handler from .src.recv_handler.notice_handler import notice_handler from .src.recv_handler.message_sending import message_send_instance from .src.send_handler import send_handler -from .src.config import global_config from .src.config.features_config import features_manager from .src.config.migrate_features import auto_migrate_features from .src.mmc_com_layer import mmc_start_com, router, mmc_stop_com @@ -134,13 +134,14 @@ async def message_process(): logger.debug(f"清理消息队列时出错: {e}") -async def napcat_server(): +async def napcat_server(plugin_config: dict): """启动 Napcat WebSocket 连接(支持正向和反向连接)""" - mode = global_config.napcat_server.mode + # 使用插件系统配置API获取配置 + mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode") logger.info(f"正在启动 adapter,连接模式: {mode}") try: - await websocket_manager.start_connection(message_recv) + await websocket_manager.start_connection(message_recv, plugin_config) except Exception as e: logger.error(f"启动 WebSocket 连接失败: {e}") raise @@ -240,9 +241,18 @@ class LauchNapcatAdapterHandler(BaseEventHandler): logger.info("功能管理器初始化完成") logger.info("开始启动Napcat Adapter") message_send_instance.maibot_router = router + # 设置插件配置 + message_send_instance.set_plugin_config(self.plugin_config) + # 设置chunker的插件配置 + chunker.set_plugin_config(self.plugin_config) + # 设置response_pool的插件配置 + from .src.response_pool import set_plugin_config as set_response_pool_config + set_response_pool_config(self.plugin_config) + # 设置send_handler的插件配置 + send_handler.set_plugin_config(self.plugin_config) # 创建单独的异步任务,防止阻塞主线程 - asyncio.create_task(napcat_server()) - asyncio.create_task(mmc_start_com()) + asyncio.create_task(napcat_server(self.plugin_config)) + asyncio.create_task(mmc_start_com(self.plugin_config)) asyncio.create_task(message_process()) asyncio.create_task(check_timeout_response()) @@ -278,9 +288,50 @@ class NapcatAdapterPlugin(BasePlugin): "name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"), "version": ConfigField(type=str, default="1.0.0", description="插件版本"), "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), + }, + "inner": { + "version": ConfigField(type=str, default="0.2.1", description="配置版本号,请勿修改"), + }, + "nickname": { + "nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"), + }, + "napcat_server": { + "mode": ConfigField(type=str, default="reverse", description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]), + "host": ConfigField(type=str, default="localhost", description="主机地址"), + "port": ConfigField(type=int, default=8095, description="端口号"), + "url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)"), + "access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"), + "heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"), + }, + "maibot_server": { + "host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"), + "port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口,即PORT字段"), + "platform_name": ConfigField(type=str, default="napcat", description="平台名称,用于消息路由"), + }, + "voice": { + "use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"), + }, + "slicing": { + "max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"), + "delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"), + }, + "debug": { + "level": ConfigField(type=str, default="INFO", description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), } } + # 配置节描述 + config_section_descriptions = { + "plugin": "插件基本信息", + "inner": "内部配置信息(请勿修改)", + "nickname": "昵称配置(目前未使用)", + "napcat_server": "Napcat连接的ws服务设置", + "maibot_server": "连接麦麦的ws服务设置", + "voice": "发送语音设置", + "slicing": "WebSocket消息切片设置", + "debug": "调试设置" + } + def register_events(self): # 注册事件 for e in event_types.NapcatEvent.ON_RECEIVED: diff --git a/plugins/napcat_adapter_plugin/pyproject.toml b/src/plugins/built_in/napcat_adapter_plugin/pyproject.toml similarity index 100% rename from plugins/napcat_adapter_plugin/pyproject.toml rename to src/plugins/built_in/napcat_adapter_plugin/pyproject.toml diff --git a/plugins/napcat_adapter_plugin/src/__init__.py b/src/plugins/built_in/napcat_adapter_plugin/src/__init__.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/__init__.py rename to src/plugins/built_in/napcat_adapter_plugin/src/__init__.py diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/config/__init__.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/__init__.py new file mode 100644 index 000000000..99c6f490c --- /dev/null +++ b/src/plugins/built_in/napcat_adapter_plugin/src/config/__init__.py @@ -0,0 +1,2 @@ +# 配置已迁移到插件系统,此文件不再需要 +# 所有配置访问应通过插件系统的 config_api 进行 diff --git a/plugins/napcat_adapter_plugin/src/config/config.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/config.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/config/config.py rename to src/plugins/built_in/napcat_adapter_plugin/src/config/config.py diff --git a/plugins/napcat_adapter_plugin/src/config/config_base.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/config_base.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/config/config_base.py rename to src/plugins/built_in/napcat_adapter_plugin/src/config/config_base.py diff --git a/plugins/napcat_adapter_plugin/src/config/config_utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/config_utils.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/config/config_utils.py rename to src/plugins/built_in/napcat_adapter_plugin/src/config/config_utils.py diff --git a/plugins/napcat_adapter_plugin/src/config/features_config.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/features_config.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/config/features_config.py rename to src/plugins/built_in/napcat_adapter_plugin/src/config/features_config.py diff --git a/plugins/napcat_adapter_plugin/src/config/migrate_features.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/migrate_features.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/config/migrate_features.py rename to src/plugins/built_in/napcat_adapter_plugin/src/config/migrate_features.py diff --git a/plugins/napcat_adapter_plugin/src/config/official_configs.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/official_configs.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/config/official_configs.py rename to src/plugins/built_in/napcat_adapter_plugin/src/config/official_configs.py diff --git a/plugins/napcat_adapter_plugin/src/database.py b/src/plugins/built_in/napcat_adapter_plugin/src/database.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/database.py rename to src/plugins/built_in/napcat_adapter_plugin/src/database.py diff --git a/plugins/napcat_adapter_plugin/src/message_buffer.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/message_buffer.py rename to src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py diff --git a/plugins/napcat_adapter_plugin/src/message_chunker.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py similarity index 95% rename from plugins/napcat_adapter_plugin/src/message_chunker.py rename to src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py index f4e150711..0f25bd62e 100644 --- a/plugins/napcat_adapter_plugin/src/message_chunker.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py @@ -9,7 +9,7 @@ import uuid import asyncio import time from typing import List, Dict, Any, Optional, Union -from .config import global_config +from src.plugin_system.apis import config_api from src.common.logger import get_logger @@ -20,7 +20,15 @@ class MessageChunker: """消息切片器,用于处理大消息的分片发送""" def __init__(self): - self.max_chunk_size = global_config.slicing.max_frame_size * 1024 + self.max_chunk_size = 64 * 1024 # 默认值,将在设置配置时更新 + self.plugin_config = None + + def set_plugin_config(self, plugin_config: dict): + """设置插件配置""" + self.plugin_config = plugin_config + if plugin_config: + max_frame_size = config_api.get_plugin_config(plugin_config, "slicing.max_frame_size", 64) + self.max_chunk_size = max_frame_size * 1024 def should_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool: """判断消息是否需要切片""" diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py new file mode 100644 index 000000000..c735d63cf --- /dev/null +++ b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py @@ -0,0 +1,44 @@ +from maim_message import Router, RouteConfig, TargetConfig +from src.common.logger import get_logger +from .send_handler import send_handler +from src.plugin_system.apis import config_api + +logger = get_logger("napcat_adapter") + +router = None + + +def create_router(plugin_config: dict): + """创建路由器实例""" + global router + platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "napcat") + host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost") + port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000) + + route_config = RouteConfig( + route_config={ + platform_name: TargetConfig( + url=f"ws://{host}:{port}/ws", + token=None, + ) + } + ) + router = Router(route_config) + return router + + +async def mmc_start_com(plugin_config: dict = None): + """启动MaiBot连接""" + logger.info("正在连接MaiBot") + if plugin_config: + create_router(plugin_config) + + if router: + router.register_class_handler(send_handler.handle_message) + await router.run() + + +async def mmc_stop_com(): + """停止MaiBot连接""" + if router: + await router.stop() diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/__init__.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/recv_handler/__init__.py rename to src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py similarity index 96% rename from plugins/napcat_adapter_plugin/src/recv_handler/message_handler.py rename to src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index 1bd34ceac..f5edbb6c5 100644 --- a/plugins/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -5,8 +5,7 @@ from ...CONSTS import PLUGIN_NAME logger = get_logger("napcat_adapter") -from ..config import global_config -from ..config.features_config import features_manager +from src.plugin_system.core.config_manager import config_api from ..message_buffer import SimpleMessageBuffer from ..utils import ( get_group_info, @@ -90,21 +89,21 @@ class MessageHandler: # 使用新的权限管理器检查权限 if group_id: - if not features_manager.is_group_allowed(group_id): + if not config_api.get_plugin_config(PLUGIN_NAME, f"features.group_allowed.{group_id}", True): logger.warning("群聊不在聊天权限范围内,消息被丢弃") return False else: - if not features_manager.is_private_allowed(user_id): + if not config_api.get_plugin_config(PLUGIN_NAME, f"features.private_allowed.{user_id}", True): logger.warning("私聊不在聊天权限范围内,消息被丢弃") return False # 检查全局禁止名单 - if not ignore_global_list and features_manager.is_user_banned(user_id): + if not ignore_global_list and config_api.get_plugin_config(PLUGIN_NAME, f"features.user_banned.{user_id}", False): logger.warning("用户在全局黑名单中,消息被丢弃") return False # 检查QQ官方机器人 - if features_manager.is_qq_bot_banned() and group_id and not ignore_bot: + if config_api.get_plugin_config(PLUGIN_NAME, "features.qq_bot_banned", False) and group_id and not ignore_bot: logger.debug("开始判断是否为机器人") member_info = await get_member_info(self.get_server_connection(), group_id, user_id) if member_info: @@ -149,7 +148,7 @@ class MessageHandler: # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), user_id=sender_info.get("user_id"), user_nickname=sender_info.get("nickname"), user_cardname=sender_info.get("card"), @@ -175,7 +174,7 @@ class MessageHandler: nickname = fetched_member_info.get("nickname") if fetched_member_info else None # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), user_id=sender_info.get("user_id"), user_nickname=nickname, user_cardname=None, @@ -192,7 +191,7 @@ class MessageHandler: group_name = fetched_group_info.get("group_name") group_info: GroupInfo = GroupInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), group_id=raw_message.get("group_id"), group_name=group_name, ) @@ -210,7 +209,7 @@ class MessageHandler: # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), user_id=sender_info.get("user_id"), user_nickname=sender_info.get("nickname"), user_cardname=sender_info.get("card"), @@ -223,7 +222,7 @@ class MessageHandler: group_name = fetched_group_info.get("group_name") group_info: GroupInfo = GroupInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), group_id=raw_message.get("group_id"), group_name=group_name, ) @@ -233,12 +232,12 @@ class MessageHandler: return None additional_config: dict = {} - if global_config.voice.use_tts: + if config_api.get_plugin_config(PLUGIN_NAME, "voice.use_tts"): additional_config["allow_tts"] = True # 消息信息 message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), message_id=message_id, time=message_time, user_info=user_info, @@ -260,14 +259,14 @@ class MessageHandler: return None # 检查是否需要使用消息缓冲 - if features_manager.is_message_buffer_enabled(): + if config_api.get_plugin_config(PLUGIN_NAME, "features.message_buffer_enabled", False): # 检查消息类型是否启用缓冲 message_type = raw_message.get("message_type") should_use_buffer = False - if message_type == "group" and features_manager.is_message_buffer_group_enabled(): + if message_type == "group" and config_api.get_plugin_config(PLUGIN_NAME, "features.message_buffer_group_enabled", False): should_use_buffer = True - elif message_type == "private" and features_manager.is_message_buffer_private_enabled(): + elif message_type == "private" and config_api.get_plugin_config(PLUGIN_NAME, "features.message_buffer_private_enabled", False): should_use_buffer = True if should_use_buffer: diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/message_sending.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py similarity index 82% rename from plugins/napcat_adapter_plugin/src/recv_handler/message_sending.py rename to src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py index 653fe5444..3372aa262 100644 --- a/plugins/napcat_adapter_plugin/src/recv_handler/message_sending.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py @@ -2,7 +2,7 @@ import asyncio from src.common.logger import get_logger from ..message_chunker import chunker -from ..config import global_config +from src.plugin_system.apis import config_api logger = get_logger("napcat_adapter") from maim_message import MessageBase, Router @@ -14,10 +14,15 @@ class MessageSending: """ maibot_router: Router = None + plugin_config = None def __init__(self): pass + def set_plugin_config(self, plugin_config: dict): + """设置插件配置""" + self.plugin_config = plugin_config + async def message_send(self, message_base: MessageBase) -> bool: """ 发送消息(Ada -> MMC 方向,需要实现切片) @@ -52,9 +57,10 @@ class MessageSending: return False # 使用配置中的延迟时间 - if i < len(chunks) - 1: - delay_seconds = global_config.slicing.delay_ms / 1000.0 - logger.debug(f"切片发送延迟: {global_config.slicing.delay_ms}毫秒") + if i < len(chunks) - 1 and self.plugin_config: + delay_ms = config_api.get_plugin_config(self.plugin_config, "slicing.delay_ms", 10) + delay_seconds = delay_ms / 1000.0 + logger.debug(f"切片发送延迟: {delay_ms}毫秒") await asyncio.sleep(delay_seconds) logger.debug("所有切片发送完成") diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py rename to src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/recv_handler/notice_handler.py rename to src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py diff --git a/plugins/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py rename to src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py diff --git a/plugins/napcat_adapter_plugin/src/response_pool.py b/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py similarity index 73% rename from plugins/napcat_adapter_plugin/src/response_pool.py rename to src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py index 998b316dc..0c5072fa5 100644 --- a/plugins/napcat_adapter_plugin/src/response_pool.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py @@ -1,13 +1,20 @@ import asyncio import time from typing import Dict -from .config import global_config from src.common.logger import get_logger +from src.plugin_system.apis import config_api logger = get_logger("napcat_adapter") response_dict: Dict = {} response_time_dict: Dict = {} +plugin_config = None + + +def set_plugin_config(config: dict): + """设置插件配置""" + global plugin_config + plugin_config = config async def get_response(request_id: str, timeout: int = 10) -> dict: @@ -38,11 +45,17 @@ async def check_timeout_response() -> None: while True: cleaned_message_count: int = 0 now_time = time.time() + + # 获取心跳间隔配置 + heartbeat_interval = 30 # 默认值 + if plugin_config: + heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30) + for echo_id, response_time in list(response_time_dict.items()): - if now_time - response_time > global_config.napcat_server.heartbeat_interval: + if now_time - response_time > heartbeat_interval: cleaned_message_count += 1 response_dict.pop(echo_id) response_time_dict.pop(echo_id) logger.warning(f"响应消息 {echo_id} 超时,已删除") logger.info(f"已删除 {cleaned_message_count} 条超时响应消息") - await asyncio.sleep(global_config.napcat_server.heartbeat_interval) + await asyncio.sleep(heartbeat_interval) diff --git a/plugins/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py similarity index 98% rename from plugins/napcat_adapter_plugin/src/send_handler.py rename to src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index b4fb19471..a6eda3b00 100644 --- a/plugins/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -12,9 +12,9 @@ from maim_message import ( MessageBase, ) from typing import Dict, Any, Tuple, Optional +from src.plugin_system.apis import config_api from . import CommandType -from .config import global_config from .response_pool import get_response from src.common.logger import get_logger @@ -28,6 +28,11 @@ from .config.features_config import features_manager class SendHandler: def __init__(self): self.server_connection: Optional[Server.ServerConnection] = None + self.plugin_config = None + + def set_plugin_config(self, plugin_config: dict): + """设置插件配置""" + self.plugin_config = plugin_config async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: """设置Napcat连接""" @@ -354,7 +359,11 @@ class SendHandler: def handle_voice_message(self, encoded_voice: str) -> dict: """处理语音消息""" - if not global_config.voice.use_tts: + use_tts = False + if self.plugin_config: + use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False) + + if not use_tts: logger.warning("未启用语音消息处理") return {} if not encoded_voice: diff --git a/plugins/napcat_adapter_plugin/src/utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/utils.py rename to src/plugins/built_in/napcat_adapter_plugin/src/utils.py diff --git a/plugins/napcat_adapter_plugin/src/video_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/video_handler.py similarity index 100% rename from plugins/napcat_adapter_plugin/src/video_handler.py rename to src/plugins/built_in/napcat_adapter_plugin/src/video_handler.py diff --git a/plugins/napcat_adapter_plugin/src/websocket_manager.py b/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py similarity index 85% rename from plugins/napcat_adapter_plugin/src/websocket_manager.py rename to src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py index 1b156451c..484b9b59e 100644 --- a/plugins/napcat_adapter_plugin/src/websocket_manager.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py @@ -2,9 +2,9 @@ import asyncio import websockets as Server from typing import Optional, Callable, Any from src.common.logger import get_logger +from src.plugin_system.apis import config_api logger = get_logger("napcat_adapter") -from .config import global_config class WebSocketManager: @@ -16,10 +16,12 @@ class WebSocketManager: self.is_running = False self.reconnect_interval = 5 # 重连间隔(秒) self.max_reconnect_attempts = 10 # 最大重连次数 + self.plugin_config = None - async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None: + async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None: """根据配置启动 WebSocket 连接""" - mode = global_config.napcat_server.mode + self.plugin_config = plugin_config + mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode") if mode == "reverse": await self._start_reverse_connection(message_handler) @@ -30,8 +32,8 @@ class WebSocketManager: async def _start_reverse_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None: """启动反向连接(作为服务器)""" - host = global_config.napcat_server.host - port = global_config.napcat_server.port + host = config_api.get_plugin_config(self.plugin_config, "napcat_server.host") + port = config_api.get_plugin_config(self.plugin_config, "napcat_server.port") logger.info(f"正在启动反向连接模式,监听地址: ws://{host}:{port}") @@ -68,9 +70,10 @@ class WebSocketManager: connect_kwargs = {"max_size": 2**26} # 如果配置了访问令牌,添加到请求头 - if global_config.napcat_server.access_token: + access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token") + if access_token: connect_kwargs["additional_headers"] = { - "Authorization": f"Bearer {global_config.napcat_server.access_token}" + "Authorization": f"Bearer {access_token}" } logger.info("已添加访问令牌到连接请求头") @@ -112,15 +115,14 @@ class WebSocketManager: def _get_forward_url(self) -> str: """获取正向连接的 URL""" - config = global_config.napcat_server - # 如果配置了完整的 URL,直接使用 - if config.url: - return config.url + url = config_api.get_plugin_config(self.plugin_config, "napcat_server.url") + if url: + return url # 否则根据 host 和 port 构建 URL - host = config.host - port = config.port + host = config_api.get_plugin_config(self.plugin_config, "napcat_server.host") + port = config_api.get_plugin_config(self.plugin_config, "napcat_server.port") return f"ws://{host}:{port}" async def stop_connection(self) -> None: diff --git a/plugins/napcat_adapter_plugin/template/features_template.toml b/src/plugins/built_in/napcat_adapter_plugin/template/features_template.toml similarity index 100% rename from plugins/napcat_adapter_plugin/template/features_template.toml rename to src/plugins/built_in/napcat_adapter_plugin/template/features_template.toml diff --git a/plugins/napcat_adapter_plugin/template/template_config.toml b/src/plugins/built_in/napcat_adapter_plugin/template/template_config.toml similarity index 100% rename from plugins/napcat_adapter_plugin/template/template_config.toml rename to src/plugins/built_in/napcat_adapter_plugin/template/template_config.toml diff --git a/plugins/napcat_adapter_plugin/todo.md b/src/plugins/built_in/napcat_adapter_plugin/todo.md similarity index 100% rename from plugins/napcat_adapter_plugin/todo.md rename to src/plugins/built_in/napcat_adapter_plugin/todo.md From 41dc58d4fbd448a96fb09b18ab2dd89555482802 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 6 Sep 2025 05:45:00 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=20=20=20=E7=BB=A7=E7=BB=AD=E5=B0=9D?= =?UTF-8?q?=E8=AF=95=E8=BF=81=E7=A7=BB=EF=BC=8C=E4=BD=86=E6=98=AF=E7=BB=84?= =?UTF-8?q?=E4=BB=B6=E8=8E=B7=E5=8F=96=E6=8F=92=E4=BB=B6=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E5=AD=98=E5=9C=A8=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_system/base/base_events_handler.py | 17 +- src/plugin_system/core/component_registry.py | 1 + src/plugin_system/core/event_manager.py | 10 +- .../built_in/napcat_adapter_plugin/plugin.py | 21 +- .../src/config/config.py | 151 -------- .../src/config/features_config.py | 359 ------------------ .../src/message_buffer.py | 28 +- .../src/recv_handler/message_handler.py | 38 +- .../src/recv_handler/meta_event_handler.py | 11 +- .../src/recv_handler/notice_handler.py | 34 +- .../napcat_adapter_plugin/src/send_handler.py | 8 +- 11 files changed, 94 insertions(+), 584 deletions(-) delete mode 100644 src/plugins/built_in/napcat_adapter_plugin/src/config/config.py delete mode 100644 src/plugins/built_in/napcat_adapter_plugin/src/config/features_config.py diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index 999126a02..07dd9a7af 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -23,17 +23,20 @@ class BaseEventHandler(ABC): """是否拦截消息,默认为否""" init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN] """初始化时订阅的事件名称""" + plugin_name = None def __init__(self): self.log_prefix = "[EventHandler]" """对应插件名""" - self.plugin_config: Optional[Dict] = None - """插件配置字典""" + self.subscribed_events = [] """订阅的事件列表""" if EventType.UNKNOWN in self.init_subscribe: raise NotImplementedError("事件处理器必须指定 event_type") + from src.plugin_system.core.component_registry import component_registry + self.plugin_config = component_registry.get_plugin_config(self.plugin_name) + @abstractmethod async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]: """执行事件处理的抽象方法,子类必须实现 @@ -89,15 +92,7 @@ class BaseEventHandler(ABC): weight=cls.weight, intercept_message=cls.intercept_message, ) - - def set_plugin_config(self, plugin_config: Dict) -> None: - """设置插件配置 - - Args: - plugin_config (dict): 插件配置字典 - """ - self.plugin_config = plugin_config - + def set_plugin_name(self, plugin_name: str) -> None: """设置插件名称 diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 9f4385fd3..7dfba5bd3 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -248,6 +248,7 @@ class ComponentRegistry: logger.error(f"注册失败: {handler_name} 不是有效的EventHandler") return False + handler_class.plugin_name = handler_info.plugin_name self._event_handler_registry[handler_name] = handler_class if not handler_info.enabled: diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index 7f92b1632..4e950fd76 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -145,11 +145,12 @@ class EventManager: logger.info(f"事件 {event_name} 已禁用") return True - def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: + def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool: """注册事件处理器 Args: handler_class (Type[BaseEventHandler]): 事件处理器类 + plugin_config (Optional[dict]): 插件配置字典,默认为None Returns: bool: 注册成功返回True,已存在返回False @@ -163,7 +164,12 @@ class EventManager: logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册") return False - self._event_handlers[handler_name] = handler_class() + # 创建事件处理器实例,传递插件配置 + handler_instance = handler_class() + if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'): + handler_instance.set_plugin_config(plugin_config) + + self._event_handlers[handler_name] = handler_instance # 处理init_subscribe,缓存失败的订阅 if self._event_handlers[handler_name].init_subscribe: diff --git a/src/plugins/built_in/napcat_adapter_plugin/plugin.py b/src/plugins/built_in/napcat_adapter_plugin/plugin.py index 0067ba964..c3dc3b23b 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/plugin.py +++ b/src/plugins/built_in/napcat_adapter_plugin/plugin.py @@ -18,7 +18,6 @@ from .src.recv_handler.meta_event_handler import meta_event_handler from .src.recv_handler.notice_handler import notice_handler from .src.recv_handler.message_sending import message_send_instance from .src.send_handler import send_handler -from .src.config.features_config import features_manager from .src.config.migrate_features import auto_migrate_features from .src.mmc_com_layer import mmc_start_com, router, mmc_stop_com from .src.response_pool import put_response, check_timeout_response @@ -158,11 +157,7 @@ async def graceful_shutdown(): except Exception as e: logger.warning(f"停止消息重组器清理任务时出错: {e}") - # 停止功能管理器文件监控 - try: - await features_manager.stop_file_watcher() - except Exception as e: - logger.warning(f"停止功能管理器文件监控时出错: {e}") + # 停止功能管理器文件监控(已迁移到插件系统配置,无需操作) # 关闭消息处理器(包括消息缓冲器) try: @@ -234,11 +229,8 @@ class LauchNapcatAdapterHandler(BaseEventHandler): logger.info("启动消息重组器...") await reassembler.start_cleanup_task() - # 初始化功能管理器 - logger.info("正在初始化功能管理器...") - features_manager.load_config() - await features_manager.start_file_watcher(check_interval=2.0) - logger.info("功能管理器初始化完成") + # 功能管理器已迁移到插件系统配置 + logger.info("功能配置已迁移到插件系统") logger.info("开始启动Napcat Adapter") message_send_instance.maibot_router = router # 设置插件配置 @@ -250,6 +242,12 @@ class LauchNapcatAdapterHandler(BaseEventHandler): set_response_pool_config(self.plugin_config) # 设置send_handler的插件配置 send_handler.set_plugin_config(self.plugin_config) + # 设置message_handler的插件配置 + message_handler.set_plugin_config(self.plugin_config) + # 设置notice_handler的插件配置 + notice_handler.set_plugin_config(self.plugin_config) + # 设置meta_event_handler的插件配置 + meta_event_handler.set_plugin_config(self.plugin_config) # 创建单独的异步任务,防止阻塞主线程 asyncio.create_task(napcat_server(self.plugin_config)) asyncio.create_task(mmc_start_com(self.plugin_config)) @@ -287,6 +285,7 @@ class NapcatAdapterPlugin(BasePlugin): "plugin": { "name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"), "version": ConfigField(type=str, default="1.0.0", description="插件版本"), + "config_version": ConfigField(type=str, default="1.2.0", description="配置文件版本"), "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), }, "inner": { diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/config/config.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/config.py deleted file mode 100644 index 97b3f57e3..000000000 --- a/src/plugins/built_in/napcat_adapter_plugin/src/config/config.py +++ /dev/null @@ -1,151 +0,0 @@ -import os -from dataclasses import dataclass -from datetime import datetime - -import tomlkit -import shutil - -from tomlkit import TOMLDocument -from tomlkit.items import Table -from src.common.logger import get_logger - -logger = get_logger("napcat_adapter") -from rich.traceback import install - -from .config_base import ConfigBase -from .official_configs import ( - DebugConfig, - MaiBotServerConfig, - NapcatServerConfig, - NicknameConfig, - SlicingConfig, - VoiceConfig, -) - -install(extra_lines=3) - -TEMPLATE_DIR = "plugins/napcat_adapter_plugin/template" -CONFIG_DIR = "plugins/napcat_adapter_plugin/config" -OLD_CONFIG_DIR = "plugins/napcat_adapter_plugin/config/old" - - -def ensure_config_directories(): - """确保配置目录存在""" - os.makedirs(CONFIG_DIR, exist_ok=True) - os.makedirs(OLD_CONFIG_DIR, exist_ok=True) - - -def update_config(): - """更新配置文件,统一使用 config/old 目录进行备份""" - # 确保目录存在 - ensure_config_directories() - - # 定义文件路径 - template_path = f"{TEMPLATE_DIR}/template_config.toml" - config_path = f"{CONFIG_DIR}/config.toml" - - # 检查配置文件是否存在 - if not os.path.exists(config_path): - logger.info("主配置文件不存在,从模板创建新配置") - shutil.copy2(template_path, config_path) - logger.info(f"已创建新配置文件: {config_path}") - logger.info("程序将退出,请检查配置文件后重启") - - # 读取配置文件和模板文件 - with open(config_path, "r", encoding="utf-8") as f: - old_config = tomlkit.load(f) - with open(template_path, "r", encoding="utf-8") as f: - new_config = tomlkit.load(f) - - # 检查version是否相同 - if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") - new_version = new_config["inner"].get("version") - if old_version and new_version and old_version == new_version: - logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") - return - else: - logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") - else: - logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") - - # 创建备份文件 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - backup_path = os.path.join(OLD_CONFIG_DIR, f"config.toml.bak.{timestamp}") - - # 备份旧配置文件 - shutil.copy2(config_path, backup_path) - logger.info(f"已备份旧配置文件到: {backup_path}") - - # 复制模板文件到配置目录 - shutil.copy2(template_path, config_path) - logger.info(f"已创建新配置文件: {config_path}") - - def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): - """将source字典的值更新到target字典中(如果target中存在相同的键)""" - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, Table)): - update_dict(target[key], value) - else: - try: - # 对数组类型进行特殊处理 - if isinstance(value, list): - # 如果是空数组,确保它保持为空数组 - target[key] = tomlkit.array(str(value)) if value else tomlkit.array() - else: - # 其他类型使用item方法创建新值 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - - # 将旧配置的值更新到新配置中 - logger.info("开始合并新旧配置...") - update_dict(new_config, old_config) - - # 保存更新后的配置(保留注释和格式) - with open(config_path, "w", encoding="utf-8") as f: - f.write(tomlkit.dumps(new_config)) - logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") - - -@dataclass -class Config(ConfigBase): - """总配置类""" - - nickname: NicknameConfig - napcat_server: NapcatServerConfig - maibot_server: MaiBotServerConfig - voice: VoiceConfig - slicing: SlicingConfig - debug: DebugConfig - - -def load_config(config_path: str) -> Config: - """ - 加载配置文件 - :param config_path: 配置文件路径 - :return: Config对象 - """ - # 读取配置文件 - with open(config_path, "r", encoding="utf-8") as f: - config_data = tomlkit.load(f) - - # 创建Config对象 - try: - return Config.from_dict(config_data) - except Exception as e: - logger.critical("配置文件解析失败") - raise e - - -# 更新配置 -update_config() - -logger.info("正在品鉴配置文件...") -global_config = load_config(config_path=f"{CONFIG_DIR}/config.toml") -logger.info("非常的新鲜,非常的美味!") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/config/features_config.py b/src/plugins/built_in/napcat_adapter_plugin/src/config/features_config.py deleted file mode 100644 index 08c9df079..000000000 --- a/src/plugins/built_in/napcat_adapter_plugin/src/config/features_config.py +++ /dev/null @@ -1,359 +0,0 @@ -import asyncio -from dataclasses import dataclass, field -from typing import Literal, Optional -from pathlib import Path -import tomlkit -from src.common.logger import get_logger - -logger = get_logger("napcat_adapter") -from .config_base import ConfigBase -from .config_utils import create_config_from_template, create_default_config_dict - - -@dataclass -class FeaturesConfig(ConfigBase): - """功能配置类""" - - group_list_type: Literal["whitelist", "blacklist"] = "whitelist" - """群聊列表类型 白名单/黑名单""" - - group_list: list[int] = field(default_factory=list) - """群聊列表""" - - private_list_type: Literal["whitelist", "blacklist"] = "whitelist" - """私聊列表类型 白名单/黑名单""" - - private_list: list[int] = field(default_factory=list) - """私聊列表""" - - ban_user_id: list[int] = field(default_factory=list) - """被封禁的用户ID列表,封禁后将无法与其进行交互""" - - ban_qq_bot: bool = False - """是否屏蔽QQ官方机器人,若为True,则所有QQ官方机器人将无法与MaiMCore进行交互""" - - enable_poke: bool = True - """是否启用戳一戳功能""" - - ignore_non_self_poke: bool = False - """是否无视不是针对自己的戳一戳""" - - poke_debounce_seconds: int = 3 - """戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略""" - - enable_reply_at: bool = True - """是否启用引用回复时艾特用户的功能""" - - reply_at_rate: float = 0.5 - """引用回复时艾特用户的几率 (0.0 ~ 1.0)""" - - enable_video_analysis: bool = True - """是否启用视频识别功能""" - - max_video_size_mb: int = 100 - """视频文件最大大小限制(MB)""" - - download_timeout: int = 60 - """视频下载超时时间(秒)""" - - supported_formats: list[str] = field(default_factory=lambda: ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"]) - """支持的视频格式""" - - # 消息缓冲配置 - enable_message_buffer: bool = True - """是否启用消息缓冲合并功能""" - - message_buffer_enable_group: bool = True - """是否启用群消息缓冲合并""" - - message_buffer_enable_private: bool = True - """是否启用私聊消息缓冲合并""" - - message_buffer_interval: float = 3.0 - """消息合并间隔时间(秒),在此时间内的连续消息将被合并""" - - message_buffer_initial_delay: float = 0.5 - """消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并""" - - message_buffer_max_components: int = 50 - """单个会话最大缓冲消息组件数量,超过此数量将强制合并""" - - message_buffer_block_prefixes: list[str] = field(default_factory=lambda: ["/", "!", "!", ".", "。", "#", "%"]) - """消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲""" - - -class FeaturesManager: - """功能管理器,支持热重载""" - - def __init__(self, config_path: str = "plugins/napcat_adapter_plugin/config/features.toml"): - self.config_path = Path(config_path) - self.config: Optional[FeaturesConfig] = None - self._file_watcher_task: Optional[asyncio.Task] = None - self._last_modified: Optional[float] = None - self._callbacks: list = [] - - def add_reload_callback(self, callback): - """添加配置重载回调函数""" - self._callbacks.append(callback) - - def remove_reload_callback(self, callback): - """移除配置重载回调函数""" - if callback in self._callbacks: - self._callbacks.remove(callback) - - async def _notify_callbacks(self): - """通知所有回调函数配置已重载""" - for callback in self._callbacks: - try: - if asyncio.iscoroutinefunction(callback): - await callback(self.config) - else: - callback(self.config) - except Exception as e: - logger.error(f"配置重载回调执行失败: {e}") - - def load_config(self) -> FeaturesConfig: - """加载功能配置文件""" - try: - # 检查配置文件是否存在,如果不存在则创建并退出程序 - if not self.config_path.exists(): - logger.info(f"功能配置文件不存在: {self.config_path}") - self._create_default_config() - # 配置文件创建后程序应该退出,让用户检查配置 - logger.info("程序将退出,请检查功能配置文件后重启") - quit(0) - - with open(self.config_path, "r", encoding="utf-8") as f: - config_data = tomlkit.load(f) - - self.config = FeaturesConfig.from_dict(config_data) - self._last_modified = self.config_path.stat().st_mtime - logger.info(f"功能配置加载成功: {self.config_path}") - return self.config - - except Exception as e: - logger.error(f"功能配置加载失败: {e}") - logger.critical("无法加载功能配置文件,程序退出") - quit(1) - - def _create_default_config(self): - """创建默认功能配置文件""" - template_path = "template/features_template.toml" - - # 尝试从模板创建配置文件 - if create_config_from_template( - str(self.config_path), - template_path, - "功能配置文件", - should_exit=False, # 不在这里退出,由调用方决定 - ): - return - - # 如果模板文件不存在,创建基本配置 - logger.info("模板文件不存在,创建基本功能配置") - default_config = { - "group_list_type": "whitelist", - "group_list": [], - "private_list_type": "whitelist", - "private_list": [], - "ban_user_id": [], - "ban_qq_bot": False, - "enable_poke": True, - "ignore_non_self_poke": False, - "poke_debounce_seconds": 3, - "enable_reply_at": True, - "reply_at_rate": 0.5, - "enable_video_analysis": True, - "max_video_size_mb": 100, - "download_timeout": 60, - "supported_formats": ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], - # 消息缓冲配置 - "enable_message_buffer": True, - "message_buffer_enable_group": True, - "message_buffer_enable_private": True, - "message_buffer_interval": 3.0, - "message_buffer_initial_delay": 0.5, - "message_buffer_max_components": 50, - "message_buffer_block_prefixes": ["/", "!", "!", ".", "。", "#", "%"], - } - - if not create_default_config_dict(default_config, str(self.config_path), "功能配置文件"): - logger.critical("无法创建功能配置文件") - quit(1) - - async def reload_config(self) -> bool: - """重新加载配置文件""" - try: - if not self.config_path.exists(): - logger.warning(f"功能配置文件不存在,无法重载: {self.config_path}") - return False - - current_modified = self.config_path.stat().st_mtime - if self._last_modified and current_modified <= self._last_modified: - return False # 文件未修改 - - old_config = self.config - new_config = self.load_config() - - # 检查配置是否真的发生了变化 - if old_config and self._configs_equal(old_config, new_config): - return False - - logger.info("功能配置已重载") - await self._notify_callbacks() - return True - - except Exception as e: - logger.error(f"功能配置重载失败: {e}") - return False - - def _configs_equal(self, config1: FeaturesConfig, config2: FeaturesConfig) -> bool: - """比较两个配置是否相等""" - return ( - config1.group_list_type == config2.group_list_type - and set(config1.group_list) == set(config2.group_list) - and config1.private_list_type == config2.private_list_type - and set(config1.private_list) == set(config2.private_list) - and set(config1.ban_user_id) == set(config2.ban_user_id) - and config1.ban_qq_bot == config2.ban_qq_bot - and config1.enable_poke == config2.enable_poke - and config1.ignore_non_self_poke == config2.ignore_non_self_poke - and config1.poke_debounce_seconds == config2.poke_debounce_seconds - and config1.enable_reply_at == config2.enable_reply_at - and config1.reply_at_rate == config2.reply_at_rate - and config1.enable_video_analysis == config2.enable_video_analysis - and config1.max_video_size_mb == config2.max_video_size_mb - and config1.download_timeout == config2.download_timeout - and set(config1.supported_formats) == set(config2.supported_formats) - and - # 消息缓冲配置比较 - config1.enable_message_buffer == config2.enable_message_buffer - and config1.message_buffer_enable_group == config2.message_buffer_enable_group - and config1.message_buffer_enable_private == config2.message_buffer_enable_private - and config1.message_buffer_interval == config2.message_buffer_interval - and config1.message_buffer_initial_delay == config2.message_buffer_initial_delay - and config1.message_buffer_max_components == config2.message_buffer_max_components - and set(config1.message_buffer_block_prefixes) == set(config2.message_buffer_block_prefixes) - ) - - async def start_file_watcher(self, check_interval: float = 1.0): - """启动文件监控,定期检查配置文件变化""" - if self._file_watcher_task and not self._file_watcher_task.done(): - logger.warning("文件监控已在运行") - return - - self._file_watcher_task = asyncio.create_task(self._file_watcher_loop(check_interval)) - logger.info(f"功能配置文件监控已启动,检查间隔: {check_interval}秒") - - async def stop_file_watcher(self): - """停止文件监控""" - if self._file_watcher_task and not self._file_watcher_task.done(): - self._file_watcher_task.cancel() - try: - await self._file_watcher_task - except asyncio.CancelledError: - pass - logger.info("功能配置文件监控已停止") - - async def _file_watcher_loop(self, check_interval: float): - """文件监控循环""" - while True: - try: - await asyncio.sleep(check_interval) - await self.reload_config() - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"文件监控循环出错: {e}") - await asyncio.sleep(check_interval) - - def get_config(self) -> FeaturesConfig: - """获取当前功能配置""" - if self.config is None: - return self.load_config() - return self.config - - def is_group_allowed(self, group_id: int) -> bool: - """检查群聊是否被允许""" - config = self.get_config() - if config.group_list_type == "whitelist": - return group_id in config.group_list - else: # blacklist - return group_id not in config.group_list - - def is_private_allowed(self, user_id: int) -> bool: - """检查私聊是否被允许""" - config = self.get_config() - if config.private_list_type == "whitelist": - return user_id in config.private_list - else: # blacklist - return user_id not in config.private_list - - def is_user_banned(self, user_id: int) -> bool: - """检查用户是否被全局禁止""" - config = self.get_config() - return user_id in config.ban_user_id - - def is_qq_bot_banned(self) -> bool: - """检查是否禁止QQ官方机器人""" - config = self.get_config() - return config.ban_qq_bot - - def is_poke_enabled(self) -> bool: - """检查戳一戳功能是否启用""" - config = self.get_config() - return config.enable_poke - - def is_non_self_poke_ignored(self) -> bool: - """检查是否忽略非自己戳一戳""" - config = self.get_config() - return config.ignore_non_self_poke - - def is_message_buffer_enabled(self) -> bool: - """检查消息缓冲功能是否启用""" - config = self.get_config() - return config.enable_message_buffer - - def is_message_buffer_group_enabled(self) -> bool: - """检查群消息缓冲是否启用""" - config = self.get_config() - return config.message_buffer_enable_group - - def is_message_buffer_private_enabled(self) -> bool: - """检查私聊消息缓冲是否启用""" - config = self.get_config() - return config.message_buffer_enable_private - - def get_message_buffer_interval(self) -> float: - """获取消息缓冲间隔时间""" - config = self.get_config() - return config.message_buffer_interval - - def get_message_buffer_initial_delay(self) -> float: - """获取消息缓冲初始延迟""" - config = self.get_config() - return config.message_buffer_initial_delay - - def get_message_buffer_max_components(self) -> int: - """获取消息缓冲最大组件数量""" - config = self.get_config() - return config.message_buffer_max_components - - def is_message_buffer_group_enabled(self) -> bool: - """检查是否启用群聊消息缓冲""" - config = self.get_config() - return config.message_buffer_enable_group - - def is_message_buffer_private_enabled(self) -> bool: - """检查是否启用私聊消息缓冲""" - config = self.get_config() - return config.message_buffer_enable_private - - def get_message_buffer_block_prefixes(self) -> list[str]: - """获取消息缓冲屏蔽前缀列表""" - config = self.get_config() - return config.message_buffer_block_prefixes - - -# 全局功能管理器实例 -features_manager = FeaturesManager() diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py index 0dccb31a8..1988e6c40 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py @@ -7,7 +7,7 @@ from src.common.logger import get_logger logger = get_logger("napcat_adapter") -from .config.features_config import features_manager +from src.plugin_system.apis import config_api from .recv_handler import RealMessageType @@ -43,6 +43,11 @@ class SimpleMessageBuffer: self.lock = asyncio.Lock() self.merge_callback = merge_callback self._shutdown = False + self.plugin_config = None + + def set_plugin_config(self, plugin_config: dict): + """设置插件配置""" + self.plugin_config = plugin_config def get_session_id(self, event_data: Dict[str, Any]) -> str: """根据事件数据生成会话ID""" @@ -97,8 +102,7 @@ class SimpleMessageBuffer: return True # 检查屏蔽前缀 - config = features_manager.get_config() - block_prefixes = tuple(config.message_buffer_block_prefixes) + block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])) text = text.strip() if text.startswith(block_prefixes): @@ -124,15 +128,15 @@ class SimpleMessageBuffer: if self._shutdown: return False - config = features_manager.get_config() - if not config.enable_message_buffer: + # 检查是否启用消息缓冲 + if not config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", False): return False # 检查是否启用对应类型的缓冲 message_type = event_data.get("message_type", "") - if message_type == "group" and not config.message_buffer_enable_group: + if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False): return False - elif message_type == "private" and not config.message_buffer_enable_private: + elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False): return False # 提取文本 @@ -154,7 +158,7 @@ class SimpleMessageBuffer: session = self.buffer_pool[session_id] # 检查是否超过最大组件数量 - if len(session.messages) >= config.message_buffer_max_components: + if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5): logger.info(f"会话 {session_id} 消息数量达到上限,强制合并") asyncio.create_task(self._force_merge_session(session_id)) self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event) @@ -187,8 +191,8 @@ class SimpleMessageBuffer: async def _wait_and_start_merge(self, session_id: str): """等待初始延迟后开始合并定时器""" - config = features_manager.get_config() - await asyncio.sleep(config.message_buffer_initial_delay) + initial_delay = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_initial_delay", 0.5) + await asyncio.sleep(initial_delay) async with self.lock: session = self.buffer_pool.get(session_id) @@ -206,8 +210,8 @@ class SimpleMessageBuffer: async def _wait_and_merge(self, session_id: str): """等待合并间隔后执行合并""" - config = features_manager.get_config() - await asyncio.sleep(config.message_buffer_interval) + interval = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_interval", 2.0) + await asyncio.sleep(interval) await self._merge_session(session_id) async def _force_merge_session(self, session_id: str): diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index f5edbb6c5..f761dc33f 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -5,7 +5,7 @@ from ...CONSTS import PLUGIN_NAME logger = get_logger("napcat_adapter") -from src.plugin_system.core.config_manager import config_api +from src.plugin_system.apis import config_api from ..message_buffer import SimpleMessageBuffer from ..utils import ( get_group_info, @@ -47,9 +47,17 @@ class MessageHandler: def __init__(self): self.server_connection: Server.ServerConnection = None self.bot_id_list: Dict[int, bool] = {} + self.plugin_config = None # 初始化简化消息缓冲器,传入回调函数 self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message) + def set_plugin_config(self, plugin_config: dict): + """设置插件配置""" + self.plugin_config = plugin_config + # 将配置传递给消息缓冲器 + if self.message_buffer: + self.message_buffer.set_plugin_config(plugin_config) + async def shutdown(self): """关闭消息处理器,清理资源""" if self.message_buffer: @@ -89,21 +97,21 @@ class MessageHandler: # 使用新的权限管理器检查权限 if group_id: - if not config_api.get_plugin_config(PLUGIN_NAME, f"features.group_allowed.{group_id}", True): + if not config_api.get_plugin_config(self.plugin_config, f"features.group_allowed.{group_id}", True): logger.warning("群聊不在聊天权限范围内,消息被丢弃") return False else: - if not config_api.get_plugin_config(PLUGIN_NAME, f"features.private_allowed.{user_id}", True): + if not config_api.get_plugin_config(self.plugin_config, f"features.private_allowed.{user_id}", True): logger.warning("私聊不在聊天权限范围内,消息被丢弃") return False # 检查全局禁止名单 - if not ignore_global_list and config_api.get_plugin_config(PLUGIN_NAME, f"features.user_banned.{user_id}", False): + if not ignore_global_list and config_api.get_plugin_config(self.plugin_config, f"features.user_banned.{user_id}", False): logger.warning("用户在全局黑名单中,消息被丢弃") return False # 检查QQ官方机器人 - if config_api.get_plugin_config(PLUGIN_NAME, "features.qq_bot_banned", False) and group_id and not ignore_bot: + if config_api.get_plugin_config(self.plugin_config, "features.qq_bot_banned", False) and group_id and not ignore_bot: logger.debug("开始判断是否为机器人") member_info = await get_member_info(self.get_server_connection(), group_id, user_id) if member_info: @@ -148,7 +156,7 @@ class MessageHandler: # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"), user_id=sender_info.get("user_id"), user_nickname=sender_info.get("nickname"), user_cardname=sender_info.get("card"), @@ -174,7 +182,7 @@ class MessageHandler: nickname = fetched_member_info.get("nickname") if fetched_member_info else None # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"), user_id=sender_info.get("user_id"), user_nickname=nickname, user_cardname=None, @@ -191,7 +199,7 @@ class MessageHandler: group_name = fetched_group_info.get("group_name") group_info: GroupInfo = GroupInfo( - platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"), group_id=raw_message.get("group_id"), group_name=group_name, ) @@ -209,7 +217,7 @@ class MessageHandler: # 发送者用户信息 user_info: UserInfo = UserInfo( - platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"), user_id=sender_info.get("user_id"), user_nickname=sender_info.get("nickname"), user_cardname=sender_info.get("card"), @@ -222,7 +230,7 @@ class MessageHandler: group_name = fetched_group_info.get("group_name") group_info: GroupInfo = GroupInfo( - platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"), group_id=raw_message.get("group_id"), group_name=group_name, ) @@ -232,12 +240,12 @@ class MessageHandler: return None additional_config: dict = {} - if config_api.get_plugin_config(PLUGIN_NAME, "voice.use_tts"): + if config_api.get_plugin_config(self.plugin_config, "voice.use_tts"): additional_config["allow_tts"] = True # 消息信息 message_info: BaseMessageInfo = BaseMessageInfo( - platform=config_api.get_plugin_config(PLUGIN_NAME, "maibot_server.platform_name"), + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"), message_id=message_id, time=message_time, user_info=user_info, @@ -259,14 +267,14 @@ class MessageHandler: return None # 检查是否需要使用消息缓冲 - if config_api.get_plugin_config(PLUGIN_NAME, "features.message_buffer_enabled", False): + if config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enabled", False): # 检查消息类型是否启用缓冲 message_type = raw_message.get("message_type") should_use_buffer = False - if message_type == "group" and config_api.get_plugin_config(PLUGIN_NAME, "features.message_buffer_group_enabled", False): + if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_group_enabled", False): should_use_buffer = True - elif message_type == "private" and config_api.get_plugin_config(PLUGIN_NAME, "features.message_buffer_private_enabled", False): + elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_private_enabled", False): should_use_buffer = True if should_use_buffer: diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py index eae6fd01a..217347c36 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py @@ -1,7 +1,7 @@ from src.common.logger import get_logger logger = get_logger("napcat_adapter") -from ..config import global_config +from src.plugin_system.apis import config_api import time import asyncio @@ -14,8 +14,15 @@ class MetaEventHandler: """ def __init__(self): - self.interval = global_config.napcat_server.heartbeat_interval + self.interval = 5.0 # 默认值,稍后通过set_plugin_config设置 self._interval_checking = False + self.plugin_config = None + + def set_plugin_config(self, plugin_config: dict): + """设置插件配置""" + self.plugin_config = plugin_config + # 更新interval值 + self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000 async def handle_meta_event(self, message: dict) -> None: event_type = message.get("meta_event_type") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 2f4fddda2..0efdcd352 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -8,8 +8,7 @@ from src.common.logger import get_logger logger = get_logger("napcat_adapter") -from ..config import global_config -from ..config.features_config import features_manager +from src.plugin_system.apis import config_api from ..database import BanUser, db_manager, is_identical from . import NoticeType, ACCEPT_FORMAT from .message_sending import message_send_instance @@ -38,6 +37,11 @@ class NoticeHandler: def __init__(self): self.server_connection: Server.ServerConnection | None = None self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间 + self.plugin_config = None + + def set_plugin_config(self, plugin_config: dict): + """设置插件配置""" + self.plugin_config = plugin_config async def set_server_connection(self, server_connection: Server.ServerConnection) -> None: """设置Napcat连接""" @@ -112,7 +116,7 @@ class NoticeHandler: sub_type = raw_message.get("sub_type") match sub_type: case NoticeType.Notify.poke: - if features_manager.is_poke_enabled() and await message_handler.check_allow_to_chat( + if config_api.get_plugin_config(self.plugin_config, "features.poke_enabled", True) and await message_handler.check_allow_to_chat( user_id, group_id, False, False ): logger.info("处理戳一戳消息") @@ -159,13 +163,13 @@ class NoticeHandler: else: logger.warning("无法获取notice消息所在群的名称") group_info = GroupInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), group_id=group_id, group_name=group_name, ) message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), message_id="notice", time=message_time, user_info=user_info, @@ -206,7 +210,7 @@ class NoticeHandler: # 防抖检查:如果是针对机器人的戳一戳,检查防抖时间 if self_id == target_id: current_time = time.time() - debounce_seconds = features_manager.get_config().poke_debounce_seconds + debounce_seconds = config_api.get_plugin_config(self.plugin_config, "features.poke_debounce_seconds", 2.0) if self.last_poke_time > 0: time_diff = current_time - self.last_poke_time @@ -243,7 +247,7 @@ class NoticeHandler: else: # 如果配置为忽略不是针对自己的戳一戳,则直接返回None - if features_manager.is_non_self_poke_ignored(): + if config_api.get_plugin_config(self.plugin_config, "features.non_self_poke_ignored", False): logger.info("忽略不是针对自己的戳一戳消息") return None, None @@ -268,7 +272,7 @@ class NoticeHandler: logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本") user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=user_id, user_nickname=user_name, user_cardname=user_cardname, @@ -299,7 +303,7 @@ class NoticeHandler: operator_nickname = "QQ用户" operator_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=operator_id, user_nickname=operator_nickname, user_cardname=operator_cardname, @@ -328,7 +332,7 @@ class NoticeHandler: user_nickname = fetched_member_info.get("nickname") user_cardname = fetched_member_info.get("card") banned_user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname, @@ -367,7 +371,7 @@ class NoticeHandler: operator_nickname = "QQ用户" operator_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=operator_id, user_nickname=operator_nickname, user_cardname=operator_cardname, @@ -393,7 +397,7 @@ class NoticeHandler: else: logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效") lifted_user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname, @@ -436,13 +440,13 @@ class NoticeHandler: else: logger.warning("无法获取notice消息所在群的名称") group_info = GroupInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), group_id=group_id, group_name=group_name, ) message_info: BaseMessageInfo = BaseMessageInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), message_id="notice", time=time.time(), user_info=None, # 自然解除禁言没有操作者 @@ -493,7 +497,7 @@ class NoticeHandler: user_cardname = fetched_member_info.get("card") lifted_user_info: UserInfo = UserInfo( - platform=global_config.maibot_server.platform_name, + platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname, diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index a6eda3b00..5d6d91467 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -22,7 +22,6 @@ logger = get_logger("napcat_adapter") from .utils import get_image_format, convert_image_to_gif from .recv_handler.message_sending import message_send_instance from .websocket_manager import websocket_manager -from .config.features_config import features_manager class SendHandler: @@ -292,11 +291,8 @@ class SendHandler: """处理回复消息""" reply_seg = {"type": "reply", "data": {"id": id}} - # 获取功能配置 - ft_config = features_manager.get_config() - # 检查是否启用引用艾特功能 - if not ft_config.enable_reply_at: + if not config_api.get_plugin_config(self.plugin_config, "features.enable_reply_at", False): return reply_seg try: @@ -315,7 +311,7 @@ class SendHandler: return reply_seg # 根据概率决定是否艾特用户 - if random.random() < ft_config.reply_at_rate: + if random.random() < config_api.get_plugin_config(self.plugin_config, "features.reply_at_rate", 0.5): at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}} # 在艾特后面添加一个空格 text_seg = {"type": "text", "data": {"text": " "}}