修复一堆新prompt的bug
This commit is contained in:
@@ -8,14 +8,14 @@ import asyncio
|
||||
import time
|
||||
import contextvars
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Union, Literal, Tuple
|
||||
from typing import Dict, Any, Optional, List, Literal, Tuple
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.prompt_utils import PromptUtils
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -472,7 +472,7 @@ class Prompt:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
platform, reply_to_user_id = PromptUtils.parse_reply_target(reply_to)
|
||||
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
@@ -531,7 +531,7 @@ class Prompt:
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
relation_info = await PromptUtils.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
return {"relation_info_block": relation_info}
|
||||
except Exception as e:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
@@ -550,7 +550,7 @@ class Prompt:
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
cross_context = await PromptUtils.build_cross_context(
|
||||
cross_context = await Prompt.build_cross_context(
|
||||
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
|
||||
)
|
||||
return {"cross_context_block": cross_context}
|
||||
@@ -666,6 +666,135 @@ class Prompt:
|
||||
"""返回提示词的表示形式"""
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
|
||||
# =============================================================================
|
||||
# PromptUtils功能迁移 - 静态工具方法
|
||||
# 这些方法原来在PromptUtils类中,现在作为Prompt类的静态方法
|
||||
# 解决循环导入问题
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
Args:
|
||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (发送者名称, 消息内容)
|
||||
"""
|
||||
sender = ""
|
||||
target = ""
|
||||
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||
"""
|
||||
构建关系信息 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 关系信息字符串
|
||||
"""
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
if not reply_to:
|
||||
return ""
|
||||
sender, text = Prompt.parse_reply_target(reply_to)
|
||||
if not sender or not text:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
prompt_mode: 当前提示词模式
|
||||
target_user_info: 目标用户信息
|
||||
|
||||
Returns:
|
||||
str: 跨群聊上下文字符串
|
||||
"""
|
||||
if not global_config.cross_context.enable:
|
||||
return ""
|
||||
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
|
||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||
if not other_chat_raw_ids:
|
||||
return ""
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
return ""
|
||||
|
||||
if prompt_mode == "normal":
|
||||
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
||||
elif prompt_mode == "s4u":
|
||||
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target_id(reply_to: str) -> str:
|
||||
"""
|
||||
解析回复目标中的用户ID
|
||||
|
||||
Args:
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 用户ID
|
||||
"""
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
# 复用parse_reply_target方法的逻辑
|
||||
sender, _ = Prompt.parse_reply_target(reply_to)
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
@@ -690,4 +819,5 @@ async def create_prompt_async(
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
return prompt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user