修复一堆新prompt的bug
This commit is contained in:
@@ -11,7 +11,6 @@ import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.chat.utils.prompt_utils import PromptUtils
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -36,10 +35,9 @@ from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import Prompt, PromptContext
|
||||
from src.chat.utils.prompt import Prompt, PromptParameters
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
@@ -599,7 +597,8 @@ class DefaultReplyer:
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具"""
|
||||
return PromptUtils.parse_reply_target(target_message)
|
||||
from src.chat.utils.prompt import Prompt
|
||||
return Prompt.parse_reply_target(target_message)
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
"""构建关键词反应提示
|
||||
@@ -874,7 +873,8 @@ class DefaultReplyer:
|
||||
target_user_info = None
|
||||
if sender:
|
||||
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
||||
|
||||
|
||||
from src.chat.utils.prompt import Prompt
|
||||
# 并行执行六个构建任务
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
@@ -887,7 +887,7 @@ class DefaultReplyer:
|
||||
),
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(
|
||||
PromptUtils.build_cross_context(chat_id, target_user_info, global_config.personality.prompt_mode),
|
||||
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
|
||||
"cross_context",
|
||||
),
|
||||
)
|
||||
@@ -939,6 +939,7 @@ class DefaultReplyer:
|
||||
|
||||
schedule_block = ""
|
||||
if global_config.schedule.enable:
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
current_activity = schedule_manager.get_current_activity()
|
||||
if current_activity:
|
||||
schedule_block = f"你当前正在:{current_activity}。"
|
||||
@@ -970,8 +971,8 @@ class DefaultReplyer:
|
||||
# 根据配置选择模板
|
||||
current_prompt_mode = global_config.personality.prompt_mode
|
||||
|
||||
# 使用新的统一Prompt系统
|
||||
prompt_context = PromptContext(
|
||||
# 使用新的统一Prompt系统 - 创建PromptParameters
|
||||
prompt_parameters = PromptParameters(
|
||||
chat_id=chat_id,
|
||||
is_group_chat=is_group_chat,
|
||||
sender=sender,
|
||||
@@ -1004,9 +1005,19 @@ class DefaultReplyer:
|
||||
action_descriptions=action_descriptions,
|
||||
)
|
||||
|
||||
# 使用新的统一Prompt系统
|
||||
prompt = Prompt(template_name=None, context=prompt_context) # 由current_prompt_mode自动选择
|
||||
prompt_text = await prompt.build_prompt()
|
||||
# 使用新的统一Prompt系统 - 使用正确的模板名称
|
||||
template_name = None
|
||||
if current_prompt_mode == "s4u":
|
||||
template_name = "s4u_style_prompt"
|
||||
elif current_prompt_mode == "normal":
|
||||
template_name = "normal_style_prompt"
|
||||
elif current_prompt_mode == "minimal":
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
# 获取模板内容
|
||||
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
|
||||
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||
prompt_text = await prompt.build()
|
||||
|
||||
return prompt_text
|
||||
|
||||
@@ -1107,8 +1118,8 @@ class DefaultReplyer:
|
||||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
# 使用新的统一Prompt系统 - Expressor模式
|
||||
prompt_context = PromptContext(
|
||||
# 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters
|
||||
prompt_parameters = PromptParameters(
|
||||
chat_id=chat_id,
|
||||
is_group_chat=is_group_chat,
|
||||
sender=sender,
|
||||
@@ -1128,8 +1139,10 @@ class DefaultReplyer:
|
||||
relation_info_block=relation_info,
|
||||
)
|
||||
|
||||
prompt = Prompt(template_name=template_name, context=prompt_context)
|
||||
prompt_text = await prompt.build_prompt()
|
||||
# 使用新的统一Prompt系统 - Expressor模式
|
||||
template_prompt = await global_prompt_manager.get_prompt_async("default_expressor_prompt")
|
||||
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||
prompt_text = await prompt.build()
|
||||
|
||||
return prompt_text
|
||||
|
||||
|
||||
@@ -8,14 +8,14 @@ import asyncio
|
||||
import time
|
||||
import contextvars
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Union, Literal, Tuple
|
||||
from typing import Dict, Any, Optional, List, Literal, Tuple
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.prompt_utils import PromptUtils
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -472,7 +472,7 @@ class Prompt:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
platform, reply_to_user_id = PromptUtils.parse_reply_target(reply_to)
|
||||
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
@@ -531,7 +531,7 @@ class Prompt:
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
relation_info = await PromptUtils.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
return {"relation_info_block": relation_info}
|
||||
except Exception as e:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
@@ -550,7 +550,7 @@ class Prompt:
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
cross_context = await PromptUtils.build_cross_context(
|
||||
cross_context = await Prompt.build_cross_context(
|
||||
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
|
||||
)
|
||||
return {"cross_context_block": cross_context}
|
||||
@@ -666,6 +666,135 @@ class Prompt:
|
||||
"""返回提示词的表示形式"""
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
|
||||
# =============================================================================
|
||||
# PromptUtils功能迁移 - 静态工具方法
|
||||
# 这些方法原来在PromptUtils类中,现在作为Prompt类的静态方法
|
||||
# 解决循环导入问题
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
Args:
|
||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (发送者名称, 消息内容)
|
||||
"""
|
||||
sender = ""
|
||||
target = ""
|
||||
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||
"""
|
||||
构建关系信息 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 关系信息字符串
|
||||
"""
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
if not reply_to:
|
||||
return ""
|
||||
sender, text = Prompt.parse_reply_target(reply_to)
|
||||
if not sender or not text:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
prompt_mode: 当前提示词模式
|
||||
target_user_info: 目标用户信息
|
||||
|
||||
Returns:
|
||||
str: 跨群聊上下文字符串
|
||||
"""
|
||||
if not global_config.cross_context.enable:
|
||||
return ""
|
||||
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
|
||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||
if not other_chat_raw_ids:
|
||||
return ""
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
return ""
|
||||
|
||||
if prompt_mode == "normal":
|
||||
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
||||
elif prompt_mode == "s4u":
|
||||
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target_id(reply_to: str) -> str:
|
||||
"""
|
||||
解析回复目标中的用户ID
|
||||
|
||||
Args:
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 用户ID
|
||||
"""
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
# 复用parse_reply_target方法的逻辑
|
||||
sender, _ = Prompt.parse_reply_target(reply_to)
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
@@ -690,4 +819,5 @@ async def create_prompt_async(
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
共享提示词工具模块 - 消除重复代码
|
||||
提供统一的工具函数供DefaultReplyer和统一Prompt系统使用
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
|
||||
logger = get_logger("prompt_utils")
|
||||
|
||||
|
||||
class PromptUtils:
|
||||
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
Args:
|
||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (发送者名称, 消息内容)
|
||||
"""
|
||||
sender = ""
|
||||
target = ""
|
||||
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||
"""
|
||||
构建关系信息 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 关系信息字符串
|
||||
"""
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
if not reply_to:
|
||||
return ""
|
||||
sender, text = PromptUtils.parse_reply_target(reply_to)
|
||||
if not sender or not text:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
||||
"""
|
||||
if not global_config.cross_context.enable:
|
||||
return ""
|
||||
|
||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||
if not other_chat_raw_ids:
|
||||
return ""
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
return ""
|
||||
|
||||
if current_prompt_mode == "normal":
|
||||
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
||||
elif current_prompt_mode == "s4u":
|
||||
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target_id(reply_to: str) -> str:
|
||||
"""
|
||||
解析回复目标中的用户ID
|
||||
|
||||
Args:
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 用户ID
|
||||
"""
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
# 复用parse_reply_target方法的逻辑
|
||||
sender, _ = PromptUtils.parse_reply_target(reply_to)
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
Reference in New Issue
Block a user