修复一堆新prompt的bug
This commit is contained in:
@@ -11,7 +11,6 @@ import re
|
|||||||
|
|
||||||
from typing import List, Optional, Dict, Any, Tuple
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from src.chat.utils.prompt_utils import PromptUtils
|
|
||||||
from src.mais4u.mai_think import mai_thinking_manager
|
from src.mais4u.mai_think import mai_thinking_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
@@ -36,10 +35,9 @@ from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
|||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||||
from src.plugin_system.apis import llm_api
|
from src.plugin_system.apis import llm_api
|
||||||
from src.schedule.schedule_manager import schedule_manager
|
|
||||||
|
|
||||||
# 导入新的统一Prompt系统
|
# 导入新的统一Prompt系统
|
||||||
from src.chat.utils.prompt import Prompt, PromptContext
|
from src.chat.utils.prompt import Prompt, PromptParameters
|
||||||
|
|
||||||
logger = get_logger("replyer")
|
logger = get_logger("replyer")
|
||||||
|
|
||||||
@@ -599,7 +597,8 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||||
"""解析回复目标消息 - 使用共享工具"""
|
"""解析回复目标消息 - 使用共享工具"""
|
||||||
return PromptUtils.parse_reply_target(target_message)
|
from src.chat.utils.prompt import Prompt
|
||||||
|
return Prompt.parse_reply_target(target_message)
|
||||||
|
|
||||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||||
"""构建关键词反应提示
|
"""构建关键词反应提示
|
||||||
@@ -874,7 +873,8 @@ class DefaultReplyer:
|
|||||||
target_user_info = None
|
target_user_info = None
|
||||||
if sender:
|
if sender:
|
||||||
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
||||||
|
|
||||||
|
from src.chat.utils.prompt import Prompt
|
||||||
# 并行执行六个构建任务
|
# 并行执行六个构建任务
|
||||||
task_results = await asyncio.gather(
|
task_results = await asyncio.gather(
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
@@ -887,7 +887,7 @@ class DefaultReplyer:
|
|||||||
),
|
),
|
||||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
PromptUtils.build_cross_context(chat_id, target_user_info, global_config.personality.prompt_mode),
|
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
|
||||||
"cross_context",
|
"cross_context",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -939,6 +939,7 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
schedule_block = ""
|
schedule_block = ""
|
||||||
if global_config.schedule.enable:
|
if global_config.schedule.enable:
|
||||||
|
from src.schedule.schedule_manager import schedule_manager
|
||||||
current_activity = schedule_manager.get_current_activity()
|
current_activity = schedule_manager.get_current_activity()
|
||||||
if current_activity:
|
if current_activity:
|
||||||
schedule_block = f"你当前正在:{current_activity}。"
|
schedule_block = f"你当前正在:{current_activity}。"
|
||||||
@@ -970,8 +971,8 @@ class DefaultReplyer:
|
|||||||
# 根据配置选择模板
|
# 根据配置选择模板
|
||||||
current_prompt_mode = global_config.personality.prompt_mode
|
current_prompt_mode = global_config.personality.prompt_mode
|
||||||
|
|
||||||
# 使用新的统一Prompt系统
|
# 使用新的统一Prompt系统 - 创建PromptParameters
|
||||||
prompt_context = PromptContext(
|
prompt_parameters = PromptParameters(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
is_group_chat=is_group_chat,
|
is_group_chat=is_group_chat,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
@@ -1004,9 +1005,19 @@ class DefaultReplyer:
|
|||||||
action_descriptions=action_descriptions,
|
action_descriptions=action_descriptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用新的统一Prompt系统
|
# 使用新的统一Prompt系统 - 使用正确的模板名称
|
||||||
prompt = Prompt(template_name=None, context=prompt_context) # 由current_prompt_mode自动选择
|
template_name = None
|
||||||
prompt_text = await prompt.build_prompt()
|
if current_prompt_mode == "s4u":
|
||||||
|
template_name = "s4u_style_prompt"
|
||||||
|
elif current_prompt_mode == "normal":
|
||||||
|
template_name = "normal_style_prompt"
|
||||||
|
elif current_prompt_mode == "minimal":
|
||||||
|
template_name = "default_expressor_prompt"
|
||||||
|
|
||||||
|
# 获取模板内容
|
||||||
|
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
|
||||||
|
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||||
|
prompt_text = await prompt.build()
|
||||||
|
|
||||||
return prompt_text
|
return prompt_text
|
||||||
|
|
||||||
@@ -1107,8 +1118,8 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
template_name = "default_expressor_prompt"
|
template_name = "default_expressor_prompt"
|
||||||
|
|
||||||
# 使用新的统一Prompt系统 - Expressor模式
|
# 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters
|
||||||
prompt_context = PromptContext(
|
prompt_parameters = PromptParameters(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
is_group_chat=is_group_chat,
|
is_group_chat=is_group_chat,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
@@ -1128,8 +1139,10 @@ class DefaultReplyer:
|
|||||||
relation_info_block=relation_info,
|
relation_info_block=relation_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = Prompt(template_name=template_name, context=prompt_context)
|
# 使用新的统一Prompt系统 - Expressor模式
|
||||||
prompt_text = await prompt.build_prompt()
|
template_prompt = await global_prompt_manager.get_prompt_async("default_expressor_prompt")
|
||||||
|
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||||
|
prompt_text = await prompt.build()
|
||||||
|
|
||||||
return prompt_text
|
return prompt_text
|
||||||
|
|
||||||
|
|||||||
@@ -8,14 +8,14 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import contextvars
|
import contextvars
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Any, Optional, List, Union, Literal, Tuple
|
from typing import Dict, Any, Optional, List, Literal, Tuple
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
from src.chat.utils.prompt_utils import PromptUtils
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -472,7 +472,7 @@ class Prompt:
|
|||||||
try:
|
try:
|
||||||
msg_user_id = str(msg_dict.get("user_id"))
|
msg_user_id = str(msg_dict.get("user_id"))
|
||||||
reply_to = msg_dict.get("reply_to", "")
|
reply_to = msg_dict.get("reply_to", "")
|
||||||
platform, reply_to_user_id = PromptUtils.parse_reply_target(reply_to)
|
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
|
||||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -531,7 +531,7 @@ class Prompt:
|
|||||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||||
"""构建关系信息"""
|
"""构建关系信息"""
|
||||||
try:
|
try:
|
||||||
relation_info = await PromptUtils.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||||
return {"relation_info_block": relation_info}
|
return {"relation_info_block": relation_info}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"构建关系信息失败: {e}")
|
logger.error(f"构建关系信息失败: {e}")
|
||||||
@@ -550,7 +550,7 @@ class Prompt:
|
|||||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||||
"""构建跨群上下文"""
|
"""构建跨群上下文"""
|
||||||
try:
|
try:
|
||||||
cross_context = await PromptUtils.build_cross_context(
|
cross_context = await Prompt.build_cross_context(
|
||||||
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
|
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
|
||||||
)
|
)
|
||||||
return {"cross_context_block": cross_context}
|
return {"cross_context_block": cross_context}
|
||||||
@@ -666,6 +666,135 @@ class Prompt:
|
|||||||
"""返回提示词的表示形式"""
|
"""返回提示词的表示形式"""
|
||||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PromptUtils功能迁移 - 静态工具方法
|
||||||
|
# 这些方法原来在PromptUtils类中,现在作为Prompt类的静态方法
|
||||||
|
# 解决循环导入问题
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
解析回复目标消息 - 统一实现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, str]: (发送者名称, 消息内容)
|
||||||
|
"""
|
||||||
|
sender = ""
|
||||||
|
target = ""
|
||||||
|
|
||||||
|
# 添加None检查,防止NoneType错误
|
||||||
|
if target_message is None:
|
||||||
|
return sender, target
|
||||||
|
|
||||||
|
if ":" in target_message or ":" in target_message:
|
||||||
|
# 使用正则表达式匹配中文或英文冒号
|
||||||
|
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
sender = parts[0].strip()
|
||||||
|
target = parts[1].strip()
|
||||||
|
return sender, target
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||||
|
"""
|
||||||
|
构建关系信息 - 统一实现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
reply_to: 回复目标字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 关系信息字符串
|
||||||
|
"""
|
||||||
|
if not global_config.relationship.enable_relationship:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||||
|
|
||||||
|
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||||
|
|
||||||
|
if not reply_to:
|
||||||
|
return ""
|
||||||
|
sender, text = Prompt.parse_reply_target(reply_to)
|
||||||
|
if not sender or not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取用户ID
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
|
if not person_id:
|
||||||
|
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||||
|
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||||
|
|
||||||
|
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def build_cross_context(
|
||||||
|
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
构建跨群聊上下文 - 统一实现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
prompt_mode: 当前提示词模式
|
||||||
|
target_user_info: 目标用户信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 跨群聊上下文字符串
|
||||||
|
"""
|
||||||
|
if not global_config.cross_context.enable:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
from src.plugin_system.apis import cross_context_api
|
||||||
|
|
||||||
|
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||||
|
if not other_chat_raw_ids:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||||
|
if not chat_stream:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if prompt_mode == "normal":
|
||||||
|
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
||||||
|
elif prompt_mode == "s4u":
|
||||||
|
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_reply_target_id(reply_to: str) -> str:
|
||||||
|
"""
|
||||||
|
解析回复目标中的用户ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reply_to: 回复目标字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 用户ID
|
||||||
|
"""
|
||||||
|
if not reply_to:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 复用parse_reply_target方法的逻辑
|
||||||
|
sender, _ = Prompt.parse_reply_target(reply_to)
|
||||||
|
if not sender:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取用户ID
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
|
if person_id:
|
||||||
|
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||||
|
return str(user_id) if user_id else ""
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
# 工厂函数
|
# 工厂函数
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
@@ -690,4 +819,5 @@ async def create_prompt_async(
|
|||||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||||
if global_prompt_manager._context._current_context:
|
if global_prompt_manager._context._current_context:
|
||||||
await global_prompt_manager._context.register_async(prompt)
|
await global_prompt_manager._context.register_async(prompt)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|||||||
@@ -1,132 +0,0 @@
|
|||||||
"""
|
|
||||||
共享提示词工具模块 - 消除重复代码
|
|
||||||
提供统一的工具函数供DefaultReplyer和统一Prompt系统使用
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Dict, Any, Optional, Tuple
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
from src.plugin_system.apis import cross_context_api
|
|
||||||
|
|
||||||
logger = get_logger("prompt_utils")
|
|
||||||
|
|
||||||
|
|
||||||
class PromptUtils:
|
|
||||||
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
|
||||||
"""
|
|
||||||
解析回复目标消息 - 统一实现
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, str]: (发送者名称, 消息内容)
|
|
||||||
"""
|
|
||||||
sender = ""
|
|
||||||
target = ""
|
|
||||||
|
|
||||||
# 添加None检查,防止NoneType错误
|
|
||||||
if target_message is None:
|
|
||||||
return sender, target
|
|
||||||
|
|
||||||
if ":" in target_message or ":" in target_message:
|
|
||||||
# 使用正则表达式匹配中文或英文冒号
|
|
||||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
sender = parts[0].strip()
|
|
||||||
target = parts[1].strip()
|
|
||||||
return sender, target
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
|
||||||
"""
|
|
||||||
构建关系信息 - 统一实现
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
reply_to: 回复目标字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 关系信息字符串
|
|
||||||
"""
|
|
||||||
if not global_config.relationship.enable_relationship:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
|
||||||
|
|
||||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
|
||||||
|
|
||||||
if not reply_to:
|
|
||||||
return ""
|
|
||||||
sender, text = PromptUtils.parse_reply_target(reply_to)
|
|
||||||
if not sender or not text:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 获取用户ID
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
|
||||||
if not person_id:
|
|
||||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
|
||||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
|
||||||
|
|
||||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def build_cross_context(
|
|
||||||
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
|
||||||
"""
|
|
||||||
if not global_config.cross_context.enable:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
|
||||||
if not other_chat_raw_ids:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
|
||||||
if not chat_stream:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if current_prompt_mode == "normal":
|
|
||||||
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
|
||||||
elif current_prompt_mode == "s4u":
|
|
||||||
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_reply_target_id(reply_to: str) -> str:
|
|
||||||
"""
|
|
||||||
解析回复目标中的用户ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
reply_to: 回复目标字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 用户ID
|
|
||||||
"""
|
|
||||||
if not reply_to:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 复用parse_reply_target方法的逻辑
|
|
||||||
sender, _ = PromptUtils.parse_reply_target(reply_to)
|
|
||||||
if not sender:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 获取用户ID
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
|
||||||
if person_id:
|
|
||||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
|
||||||
return str(user_id) if user_id else ""
|
|
||||||
|
|
||||||
return ""
|
|
||||||
Reference in New Issue
Block a user