修复代码格式和文件名大小写问题
This commit is contained in:
@@ -2,16 +2,14 @@
|
||||
共享提示词工具模块 - 消除重复代码
|
||||
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
)
|
||||
@@ -23,25 +21,25 @@ logger = get_logger("prompt_utils")
|
||||
|
||||
class PromptUtils:
|
||||
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
|
||||
Args:
|
||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (发送者名称, 消息内容)
|
||||
"""
|
||||
sender = ""
|
||||
target = ""
|
||||
|
||||
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
|
||||
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
@@ -49,16 +47,16 @@ class PromptUtils:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||
"""
|
||||
构建关系信息 - 统一实现
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
|
||||
Returns:
|
||||
str: 关系信息字符串
|
||||
"""
|
||||
@@ -66,8 +64,9 @@ class PromptUtils:
|
||||
return ""
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
|
||||
if not reply_to:
|
||||
return ""
|
||||
sender, text = PromptUtils.parse_reply_target(reply_to)
|
||||
@@ -82,21 +81,19 @@ class PromptUtils:
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str,
|
||||
target_user_info: Optional[Dict[str, Any]],
|
||||
current_prompt_mode: str
|
||||
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 当前聊天ID
|
||||
target_user_info: 目标用户信息
|
||||
current_prompt_mode: 当前提示模式
|
||||
|
||||
|
||||
Returns:
|
||||
str: 跨群上下文块
|
||||
"""
|
||||
@@ -108,7 +105,7 @@ class PromptUtils:
|
||||
current_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not current_stream or not current_stream.group_info:
|
||||
return ""
|
||||
|
||||
|
||||
try:
|
||||
current_chat_raw_id = current_stream.group_info.group_id
|
||||
except Exception as e:
|
||||
@@ -144,7 +141,7 @@ class PromptUtils:
|
||||
if messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f"[以下是来自\"{chat_name}\"的近期消息]\n{formatted_messages}")
|
||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||
except Exception as e:
|
||||
logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}")
|
||||
continue
|
||||
@@ -175,14 +172,15 @@ class PromptUtils:
|
||||
if user_messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||
user_name = (
|
||||
target_user_info.get("person_name") or
|
||||
target_user_info.get("user_nickname") or user_id
|
||||
target_user_info.get("person_name")
|
||||
or target_user_info.get("user_nickname")
|
||||
or user_id
|
||||
)
|
||||
formatted_messages, _ = build_readable_messages_with_id(
|
||||
user_messages, timestamp_mode="relative"
|
||||
)
|
||||
cross_context_messages.append(
|
||||
f"[以下是\"{user_name}\"在\"{chat_name}\"的近期发言]\n{formatted_messages}"
|
||||
f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}")
|
||||
@@ -192,31 +190,31 @@ class PromptUtils:
|
||||
return ""
|
||||
|
||||
return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target_id(reply_to: str) -> str:
|
||||
"""
|
||||
解析回复目标中的用户ID
|
||||
|
||||
|
||||
Args:
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
|
||||
Returns:
|
||||
str: 用户ID
|
||||
"""
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
|
||||
# 复用parse_reply_target方法的逻辑
|
||||
sender, _ = PromptUtils.parse_reply_target(reply_to)
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user