refactor(chat): 重构SmartPrompt系统简化架构并移除缓存机制
- 简化SmartPromptParameters类结构,移除复杂的分层参数架构 - 统一错误处理和降级机制,增强系统稳定性 - 移除缓存相关功能,简化架构并减少复杂性 - 完全继承DefaultReplyer功能,确保功能完整性 - 优化性能和依赖管理,改进并发任务处理 - 增强跨群上下文、关系信息、记忆系统等功能的错误处理 - 统一视频分析结果注入逻辑,避免重复代码
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
共享提示词工具模块 - 消除重复代码
|
||||
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
|
||||
移除缓存相关功能
|
||||
"""
|
||||
import re
|
||||
import time
|
||||
@@ -22,7 +23,7 @@ logger = get_logger("prompt_utils")
|
||||
|
||||
|
||||
class PromptUtils:
|
||||
"""提示词工具类 - 提供共享功能"""
|
||||
"""提示词工具类 - 提供共享功能,移除缓存相关功能"""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
@@ -51,13 +52,50 @@ class PromptUtils:
|
||||
return sender, target
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context_block(
|
||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||
"""
|
||||
构建关系信息 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 关系信息字符串
|
||||
"""
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
try:
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
if not reply_to:
|
||||
return ""
|
||||
sender, text = PromptUtils.parse_reply_target(reply_to)
|
||||
if not sender or not text:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
except Exception as e:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str,
|
||||
target_user_info: Optional[Dict[str, Any]],
|
||||
current_prompt_mode: str
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
||||
|
||||
Args:
|
||||
chat_id: 当前聊天ID
|
||||
@@ -75,7 +113,12 @@ class PromptUtils:
|
||||
current_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not current_stream or not current_stream.group_info:
|
||||
return ""
|
||||
current_chat_raw_id = current_stream.group_info.group_id
|
||||
|
||||
try:
|
||||
current_chat_raw_id = current_stream.group_info.group_id
|
||||
except Exception as e:
|
||||
logger.error(f"获取群聊ID失败: {e}")
|
||||
return ""
|
||||
|
||||
for group in global_config.cross_context.groups:
|
||||
if str(current_chat_raw_id) in group.chat_ids:
|
||||
@@ -97,15 +140,19 @@ class PromptUtils:
|
||||
if not stream_id:
|
||||
continue
|
||||
|
||||
messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=5, # 可配置
|
||||
)
|
||||
if messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f"[以下是来自\"{chat_name}\"的近期消息]\n{formatted_messages}")
|
||||
try:
|
||||
messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=5, # 可配置
|
||||
)
|
||||
if messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f"[以下是来自\"{chat_name}\"的近期消息]\n{formatted_messages}")
|
||||
except Exception as e:
|
||||
logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}")
|
||||
continue
|
||||
|
||||
elif current_prompt_mode == "s4u":
|
||||
# s4u模式:获取当前发言用户在其他群聊的消息
|
||||
@@ -120,27 +167,31 @@ class PromptUtils:
|
||||
if not stream_id:
|
||||
continue
|
||||
|
||||
messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=20, # 获取更多消息以供筛选
|
||||
)
|
||||
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][
|
||||
-5:
|
||||
] # 筛选并取最近5条
|
||||
try:
|
||||
messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=20, # 获取更多消息以供筛选
|
||||
)
|
||||
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][
|
||||
-5:
|
||||
] # 筛选并取最近5条
|
||||
|
||||
if user_messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||
user_name = (
|
||||
target_user_info.get("person_name") or
|
||||
target_user_info.get("user_nickname") or user_id
|
||||
)
|
||||
formatted_messages, _ = build_readable_messages_with_id(
|
||||
user_messages, timestamp_mode="relative"
|
||||
)
|
||||
cross_context_messages.append(
|
||||
f"[以下是\"{user_name}\"在\"{chat_name}\"的近期发言]\n{formatted_messages}"
|
||||
)
|
||||
if user_messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||
user_name = (
|
||||
target_user_info.get("person_name") or
|
||||
target_user_info.get("user_nickname") or user_id
|
||||
)
|
||||
formatted_messages, _ = build_readable_messages_with_id(
|
||||
user_messages, timestamp_mode="relative"
|
||||
)
|
||||
cross_context_messages.append(
|
||||
f"[以下是\"{user_name}\"在\"{chat_name}\"的近期发言]\n{formatted_messages}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}")
|
||||
continue
|
||||
|
||||
if not cross_context_messages:
|
||||
return ""
|
||||
@@ -260,88 +311,4 @@ class DependencyChecker:
|
||||
"memory": await DependencyChecker.check_memory_dependencies(),
|
||||
"tool": await DependencyChecker.check_tool_dependencies(),
|
||||
"knowledge": await DependencyChecker.check_knowledge_dependencies(),
|
||||
}
|
||||
|
||||
|
||||
class SmartPromptCache:
|
||||
"""智能提示词缓存系统 - 分层缓存实现"""
|
||||
|
||||
def __init__(self):
|
||||
self._l1_cache: Dict[str, Tuple[str, float]] = {} # 内存缓存: {key: (value, timestamp)}
|
||||
self._l2_cache_enabled = False # 是否启用L2缓存
|
||||
self._cache_ttl = 300 # 默认缓存TTL: 5分钟
|
||||
|
||||
def enable_l2_cache(self, enabled: bool = True):
|
||||
"""启用或禁用L2缓存"""
|
||||
self._l2_cache_enabled = enabled
|
||||
|
||||
def set_cache_ttl(self, ttl: int):
|
||||
"""设置缓存TTL(秒)"""
|
||||
self._cache_ttl = ttl
|
||||
|
||||
def _generate_key(self, chat_id: str, prompt_mode: str, reply_to: str) -> str:
|
||||
"""生成缓存键"""
|
||||
import hashlib
|
||||
key_content = f"{chat_id}_{prompt_mode}_{reply_to}"
|
||||
return hashlib.md5(key_content.encode()).hexdigest()
|
||||
|
||||
def get(self, chat_id: str, prompt_mode: str, reply_to: str) -> Optional[str]:
|
||||
"""获取缓存值"""
|
||||
cache_key = self._generate_key(chat_id, prompt_mode, reply_to)
|
||||
|
||||
# 检查L1缓存
|
||||
if cache_key in self._l1_cache:
|
||||
value, timestamp = self._l1_cache[cache_key]
|
||||
if time.time() - timestamp < self._cache_ttl:
|
||||
logger.debug(f"L1缓存命中: {cache_key}")
|
||||
return value
|
||||
else:
|
||||
# 缓存过期,清理
|
||||
del self._l1_cache[cache_key]
|
||||
|
||||
# TODO: 实现L2缓存(如Redis)
|
||||
# if self._l2_cache_enabled:
|
||||
# return self._get_from_l2_cache(cache_key)
|
||||
|
||||
return None
|
||||
|
||||
def set(self, chat_id: str, prompt_mode: str, reply_to: str, value: str):
|
||||
"""设置缓存值"""
|
||||
cache_key = self._generate_key(chat_id, prompt_mode, reply_to)
|
||||
|
||||
# 设置L1缓存
|
||||
self._l1_cache[cache_key] = (value, time.time())
|
||||
|
||||
# TODO: 实现L2缓存
|
||||
# if self._l2_cache_enabled:
|
||||
# self._set_to_l2_cache(cache_key, value)
|
||||
|
||||
# 定期清理过期缓存
|
||||
if len(self._l1_cache) > 1000: # 缓存条目过多时清理
|
||||
self._clean_expired_cache()
|
||||
|
||||
def _clean_expired_cache(self):
|
||||
"""清理过期缓存"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, (_, timestamp) in self._l1_cache.items()
|
||||
if current_time - timestamp >= self._cache_ttl
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._l1_cache[key]
|
||||
|
||||
logger.debug(f"清理过期缓存: {len(expired_keys)} 个条目")
|
||||
|
||||
def clear(self):
|
||||
"""清空所有缓存"""
|
||||
self._l1_cache.clear()
|
||||
# TODO: 清空L2缓存
|
||||
logger.info("缓存已清空")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"l1_cache_size": len(self._l1_cache),
|
||||
"l2_cache_enabled": self._l2_cache_enabled,
|
||||
"cache_ttl": self._cache_ttl,
|
||||
}
|
||||
Reference in New Issue
Block a user