From 2bfb3a151c8d881f36ce6c36f4f992e0bb4af9be Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 6 Sep 2025 01:36:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=80=E5=A0=86=E6=96=B0pr?= =?UTF-8?q?ompt=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 41 +++++--- src/chat/utils/prompt.py | 142 ++++++++++++++++++++++++-- src/chat/utils/prompt_utils.py | 132 ------------------------ 3 files changed, 163 insertions(+), 152 deletions(-) delete mode 100644 src/chat/utils/prompt_utils.py diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 0e01e3c7f..62f15feb3 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -11,7 +11,6 @@ import re from typing import List, Optional, Dict, Any, Tuple from datetime import datetime -from src.chat.utils.prompt_utils import PromptUtils from src.mais4u.mai_think import mai_thinking_manager from src.common.logger import get_logger from src.config.config import global_config, model_config @@ -37,7 +36,7 @@ from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.apis import llm_api # 导入新的统一Prompt系统 -from src.chat.utils.prompt import Prompt, PromptContext +from src.chat.utils.prompt import Prompt, PromptParameters logger = get_logger("replyer") @@ -598,7 +597,8 @@ class DefaultReplyer: def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: """解析回复目标消息 - 使用共享工具""" - return PromptUtils.parse_reply_target(target_message) + from src.chat.utils.prompt import Prompt + return Prompt.parse_reply_target(target_message) async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: """构建关键词反应提示 @@ -910,7 +910,8 @@ class DefaultReplyer: target_user_info = None if sender: target_user_info = await person_info_manager.get_person_info_by_name(sender) - + + from src.chat.utils.prompt import Prompt # 并行执行六个构建任务 task_results = await asyncio.gather( self._time_and_run_task( @@ -923,7 +924,7 @@ class DefaultReplyer: ), self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"), self._time_and_run_task( - PromptUtils.build_cross_context(chat_id, target_user_info, global_config.personality.prompt_mode), + Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info), "cross_context", ), ) @@ -998,8 +999,8 @@ class DefaultReplyer: # 根据配置选择模板 current_prompt_mode = global_config.personality.prompt_mode - # 使用新的统一Prompt系统 - prompt_context = PromptContext( + # 使用新的统一Prompt系统 - 创建PromptParameters + prompt_parameters = PromptParameters( chat_id=chat_id, is_group_chat=is_group_chat, sender=sender, @@ -1032,9 +1033,19 @@ class DefaultReplyer: action_descriptions=action_descriptions, ) - # 使用新的统一Prompt系统 - prompt = Prompt(template_name=None, context=prompt_context) # 由current_prompt_mode自动选择 - prompt_text = await prompt.build_prompt() + # 使用新的统一Prompt系统 - 使用正确的模板名称 + template_name = None + if current_prompt_mode == "s4u": + template_name = "s4u_style_prompt" + elif current_prompt_mode == "normal": + template_name = "normal_style_prompt" + elif current_prompt_mode == "minimal": + template_name = "default_expressor_prompt" + + # 获取模板内容 + template_prompt = await global_prompt_manager.get_prompt_async(template_name) + prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) + prompt_text = await prompt.build() return prompt_text @@ -1136,8 +1147,8 @@ class DefaultReplyer: template_name = "default_expressor_prompt" - # 使用新的统一Prompt系统 - Expressor模式 - prompt_context = PromptContext( + # 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters + prompt_parameters = PromptParameters( chat_id=chat_id, is_group_chat=is_group_chat, sender=sender, @@ -1157,8 +1168,10 @@ class DefaultReplyer: relation_info_block=relation_info, ) - prompt = Prompt(template_name=template_name, context=prompt_context) - prompt_text = await prompt.build_prompt() + # 使用新的统一Prompt系统 - Expressor模式 + template_prompt = await global_prompt_manager.get_prompt_async("default_expressor_prompt") + prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) + prompt_text = await prompt.build() return prompt_text diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 1e44b72d8..b5cf140c5 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -8,14 +8,14 @@ import asyncio import time import contextvars from dataclasses import dataclass, field -from typing import Dict, Any, Optional, List, Union, Literal, Tuple +from typing import Dict, Any, Optional, List, Literal, Tuple from contextlib import asynccontextmanager from rich.traceback import install from src.common.logger import get_logger from src.config.config import global_config from src.chat.utils.chat_message_builder import build_readable_messages -from src.chat.utils.prompt_utils import PromptUtils +from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.person_info import get_person_info_manager install(extra_lines=3) @@ -472,7 +472,7 @@ class Prompt: try: msg_user_id = str(msg_dict.get("user_id")) reply_to = msg_dict.get("reply_to", "") - platform, reply_to_user_id = PromptUtils.parse_reply_target(reply_to) + platform, reply_to_user_id = Prompt.parse_reply_target(reply_to) if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: core_dialogue_list.append(msg_dict) except Exception as e: @@ -531,7 +531,7 @@ class Prompt: async def _build_relation_info(self) -> Dict[str, Any]: """构建关系信息""" try: - relation_info = await PromptUtils.build_relation_info(self.parameters.chat_id, self.parameters.reply_to) + relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to) return {"relation_info_block": relation_info} except Exception as e: logger.error(f"构建关系信息失败: {e}") @@ -550,7 +550,7 @@ class Prompt: async def _build_cross_context(self) -> Dict[str, Any]: """构建跨群上下文""" try: - cross_context = await PromptUtils.build_cross_context( + cross_context = await Prompt.build_cross_context( self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info ) return {"cross_context_block": cross_context} @@ -666,6 +666,135 @@ class Prompt: """返回提示词的表示形式""" return f"Prompt(template='{self.template}', name='{self.name}')" + # ============================================================================= + # PromptUtils功能迁移 - 静态工具方法 + # 这些方法原来在PromptUtils类中,现在作为Prompt类的静态方法 + # 解决循环导入问题 + # ============================================================================= + + @staticmethod + def parse_reply_target(target_message: str) -> Tuple[str, str]: + """ + 解析回复目标消息 - 统一实现 + + Args: + target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容" + + Returns: + Tuple[str, str]: (发送者名称, 消息内容) + """ + sender = "" + target = "" + + # 添加None检查,防止NoneType错误 + if target_message is None: + return sender, target + + if ":" in target_message or ":" in target_message: + # 使用正则表达式匹配中文或英文冒号 + parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1) + if len(parts) == 2: + sender = parts[0].strip() + target = parts[1].strip() + return sender, target + + @staticmethod + async def build_relation_info(chat_id: str, reply_to: str) -> str: + """ + 构建关系信息 - 统一实现 + + Args: + chat_id: 聊天ID + reply_to: 回复目标字符串 + + Returns: + str: 关系信息字符串 + """ + if not global_config.relationship.enable_relationship: + return "" + + from src.person_info.relationship_fetcher import relationship_fetcher_manager + + relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id) + + if not reply_to: + return "" + sender, text = Prompt.parse_reply_target(reply_to) + if not sender or not text: + return "" + + # 获取用户ID + person_info_manager = get_person_info_manager() + person_id = person_info_manager.get_person_id_by_person_name(sender) + if not person_id: + logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") + return f"你完全不认识{sender},不理解ta的相关信息。" + + return await relationship_fetcher.build_relation_info(person_id, points_num=5) + + @staticmethod + async def build_cross_context( + chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]] + ) -> str: + """ + 构建跨群聊上下文 - 统一实现 + + Args: + chat_id: 聊天ID + prompt_mode: 当前提示词模式 + target_user_info: 目标用户信息 + + Returns: + str: 跨群聊上下文字符串 + """ + if not global_config.cross_context.enable: + return "" + + from src.plugin_system.apis import cross_context_api + + other_chat_raw_ids = cross_context_api.get_context_groups(chat_id) + if not other_chat_raw_ids: + return "" + + chat_stream = get_chat_manager().get_stream(chat_id) + if not chat_stream: + return "" + + if prompt_mode == "normal": + return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids) + elif prompt_mode == "s4u": + return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info) + + return "" + + @staticmethod + def parse_reply_target_id(reply_to: str) -> str: + """ + 解析回复目标中的用户ID + + Args: + reply_to: 回复目标字符串 + + Returns: + str: 用户ID + """ + if not reply_to: + return "" + + # 复用parse_reply_target方法的逻辑 + sender, _ = Prompt.parse_reply_target(reply_to) + if not sender: + return "" + + # 获取用户ID + person_info_manager = get_person_info_manager() + person_id = person_info_manager.get_person_id_by_person_name(sender) + if person_id: + user_id = person_info_manager.get_value_sync(person_id, "user_id") + return str(user_id) if user_id else "" + + return "" + # 工厂函数 def create_prompt( @@ -690,4 +819,5 @@ async def create_prompt_async( prompt = create_prompt(template, name, parameters, **kwargs) if global_prompt_manager._context._current_context: await global_prompt_manager._context.register_async(prompt) - return prompt \ No newline at end of file + return prompt + diff --git a/src/chat/utils/prompt_utils.py b/src/chat/utils/prompt_utils.py deleted file mode 100644 index 4eed6025f..000000000 --- a/src/chat/utils/prompt_utils.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -共享提示词工具模块 - 消除重复代码 -提供统一的工具函数供DefaultReplyer和统一Prompt系统使用 -""" - -import re -from typing import Dict, Any, Optional, Tuple - -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.message_receive.chat_stream import get_chat_manager -from src.person_info.person_info import get_person_info_manager -from src.plugin_system.apis import cross_context_api - -logger = get_logger("prompt_utils") - - -class PromptUtils: - """提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查""" - - @staticmethod - def parse_reply_target(target_message: str) -> Tuple[str, str]: - """ - 解析回复目标消息 - 统一实现 - - Args: - target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容" - - Returns: - Tuple[str, str]: (发送者名称, 消息内容) - """ - sender = "" - target = "" - - # 添加None检查,防止NoneType错误 - if target_message is None: - return sender, target - - if ":" in target_message or ":" in target_message: - # 使用正则表达式匹配中文或英文冒号 - parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1) - if len(parts) == 2: - sender = parts[0].strip() - target = parts[1].strip() - return sender, target - - @staticmethod - async def build_relation_info(chat_id: str, reply_to: str) -> str: - """ - 构建关系信息 - 统一实现 - - Args: - chat_id: 聊天ID - reply_to: 回复目标字符串 - - Returns: - str: 关系信息字符串 - """ - if not global_config.relationship.enable_relationship: - return "" - - from src.person_info.relationship_fetcher import relationship_fetcher_manager - - relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id) - - if not reply_to: - return "" - sender, text = PromptUtils.parse_reply_target(reply_to) - if not sender or not text: - return "" - - # 获取用户ID - person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) - if not person_id: - logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") - return f"你完全不认识{sender},不理解ta的相关信息。" - - return await relationship_fetcher.build_relation_info(person_id, points_num=5) - - @staticmethod - async def build_cross_context( - chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str - ) -> str: - """ - 构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能 - """ - if not global_config.cross_context.enable: - return "" - - other_chat_raw_ids = cross_context_api.get_context_groups(chat_id) - if not other_chat_raw_ids: - return "" - - chat_stream = get_chat_manager().get_stream(chat_id) - if not chat_stream: - return "" - - if current_prompt_mode == "normal": - return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids) - elif current_prompt_mode == "s4u": - return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info) - - return "" - - @staticmethod - def parse_reply_target_id(reply_to: str) -> str: - """ - 解析回复目标中的用户ID - - Args: - reply_to: 回复目标字符串 - - Returns: - str: 用户ID - """ - if not reply_to: - return "" - - # 复用parse_reply_target方法的逻辑 - sender, _ = PromptUtils.parse_reply_target(reply_to) - if not sender: - return "" - - # 获取用户ID - person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) - if person_id: - user_id = person_info_manager.get_value_sync(person_id, "user_id") - return str(user_id) if user_id else "" - - return ""