refactor(chat): 重构SmartPrompt系统简化架构并移除缓存机制

- 简化SmartPromptParameters类结构,移除复杂的分层参数架构
- 统一错误处理和降级机制,增强系统稳定性
- 移除缓存相关功能,简化架构并减少复杂性
- 完全继承DefaultReplyer功能,确保功能完整性
- 优化性能和依赖管理,改进并发任务处理
- 增强跨群上下文、关系信息、记忆系统等功能的错误处理
- 统一视频分析结果注入逻辑,避免重复代码
This commit is contained in:
Windpicker-owo
2025-08-31 19:09:36 +08:00
parent 9e7483d25a
commit a6e937de6d
4 changed files with 807 additions and 795 deletions

View File

@@ -1,56 +1,37 @@
"""
智能提示词参数模块 - 优化参数结构
SmartPromptParameters拆分为多个专用参数类
简化SmartPromptParameters,减少冗余和重复
"""
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal
@dataclass
class PromptCoreParams:
"""核心参数类 - 包含构建提示词的基本参数"""
class SmartPromptParameters:
"""简化的智能提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
current_prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
def validate(self) -> List[str]:
"""验证核心参数"""
errors = []
if not isinstance(self.chat_id, str):
errors.append("chat_id必须是字符串类型")
if not isinstance(self.reply_to, str):
errors.append("reply_to必须是字符串类型")
if self.current_prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("current_prompt_mode必须是's4u''normal''minimal'")
return errors
@dataclass
class PromptFeatureParams:
"""功能参数类 - 控制各种功能的开关"""
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
enable_cache: bool = True
# 性能和缓存控制
cache_ttl: int = 300
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
@dataclass
class PromptContentParams:
"""内容参数类 - 包含已构建的内容块"""
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
@@ -60,10 +41,10 @@ class PromptContentParams:
# 已构建的内容块
expression_habits_block: str = ""
relation_info: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info: str = ""
prompt_info: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
@@ -77,203 +58,37 @@ class PromptContentParams:
mood_prompt: str = ""
action_descriptions: str = ""
def has_prebuilt_content(self) -> bool:
"""检查是否有预构建的内容"""
return any([
self.expression_habits_block,
self.relation_info,
self.memory_block,
self.tool_info,
self.prompt_info,
self.cross_context_block
])
@dataclass
class SmartPromptParameters:
"""
智能提示词参数系统 - 重构版本
组合多个专用参数类,提供统一的接口
"""
# 核心参数
core: PromptCoreParams = field(default_factory=PromptCoreParams)
# 功能参数
features: PromptFeatureParams = field(default_factory=PromptFeatureParams)
# 内容参数
content: PromptContentParams = field(default_factory=PromptContentParams)
# 兼容性属性 - 提供与旧代码的兼容性
@property
def chat_id(self) -> str:
return self.core.chat_id
@chat_id.setter
def chat_id(self, value: str):
self.core.chat_id = value
@property
def is_group_chat(self) -> bool:
return self.core.is_group_chat
@is_group_chat.setter
def is_group_chat(self, value: bool):
self.core.is_group_chat = value
@property
def sender(self) -> str:
return self.core.sender
@sender.setter
def sender(self, value: str):
self.core.sender = value
@property
def target(self) -> str:
return self.core.target
@target.setter
def target(self, value: str):
self.core.target = value
@property
def reply_to(self) -> str:
return self.core.reply_to
@reply_to.setter
def reply_to(self, value: str):
self.core.reply_to = value
@property
def extra_info(self) -> str:
return self.core.extra_info
@extra_info.setter
def extra_info(self, value: str):
self.core.extra_info = value
@property
def current_prompt_mode(self) -> str:
return self.core.current_prompt_mode
@current_prompt_mode.setter
def current_prompt_mode(self, value: str):
self.core.current_prompt_mode = value
@property
def enable_tool(self) -> bool:
return self.features.enable_tool
@enable_tool.setter
def enable_tool(self, value: bool):
self.features.enable_tool = value
@property
def enable_memory(self) -> bool:
return self.features.enable_memory
@enable_memory.setter
def enable_memory(self, value: bool):
self.features.enable_memory = value
@property
def enable_cache(self) -> bool:
return self.features.enable_cache
@enable_cache.setter
def enable_cache(self, value: bool):
self.features.enable_cache = value
@property
def cache_ttl(self) -> int:
return self.features.cache_ttl
@cache_ttl.setter
def cache_ttl(self, value: int):
self.features.cache_ttl = value
@property
def expression_habits_block(self) -> str:
return self.content.expression_habits_block
@expression_habits_block.setter
def expression_habits_block(self, value: str):
self.content.expression_habits_block = value
@property
def relation_info(self) -> str:
return self.content.relation_info
@relation_info.setter
def relation_info(self, value: str):
self.content.relation_info = value
@property
def memory_block(self) -> str:
return self.content.memory_block
@memory_block.setter
def memory_block(self, value: str):
self.content.memory_block = value
@property
def tool_info(self) -> str:
return self.content.tool_info
@tool_info.setter
def tool_info(self, value: str):
self.content.tool_info = value
@property
def prompt_info(self) -> str:
return self.content.prompt_info
@prompt_info.setter
def prompt_info(self, value: str):
self.content.prompt_info = value
@property
def cross_context_block(self) -> str:
return self.content.cross_context_block
@cross_context_block.setter
def cross_context_block(self, value: str):
self.content.cross_context_block = value
# 兼容性方法 - 支持旧代码的直接访问
def validate(self) -> List[str]:
"""参数验证"""
errors = self.core.validate()
# 验证功能参数
if self.features.cache_ttl <= 0:
errors.append("cache_ttl必须大于0")
if self.features.max_context_messages <= 0:
"""统一的参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
def get_needed_build_tasks(self) -> List[str]:
"""获取需要执行的任务列表"""
tasks = []
if self.features.enable_expression and not self.content.expression_habits_block:
if self.enable_expression and not self.expression_habits_block:
tasks.append("expression_habits")
if self.features.enable_memory and not self.content.memory_block:
if self.enable_memory and not self.memory_block:
tasks.append("memory_block")
if self.features.enable_relation and not self.content.relation_info:
if self.enable_relation and not self.relation_info_block:
tasks.append("relation_info")
if self.features.enable_tool and not self.content.tool_info:
if self.enable_tool and not self.tool_info_block:
tasks.append("tool_info")
if self.features.enable_knowledge and not self.content.prompt_info:
if self.enable_knowledge and not self.knowledge_prompt:
tasks.append("knowledge_info")
if self.features.enable_cross_context and not self.content.cross_context_block:
if self.enable_cross_context and not self.cross_context_block:
tasks.append("cross_context")
return tasks
@@ -289,44 +104,44 @@ class SmartPromptParameters:
Returns:
SmartPromptParameters: 新参数对象
"""
# 创建核心参数
core_params = PromptCoreParams(
return cls(
# 基础参数
chat_id=kwargs.get("chat_id", ""),
is_group_chat=kwargs.get("is_group_chat", False),
sender=kwargs.get("sender", ""),
target=kwargs.get("target", ""),
reply_to=kwargs.get("reply_to", ""),
extra_info=kwargs.get("extra_info", ""),
current_prompt_mode=kwargs.get("current_prompt_mode", "s4u"),
)
# 创建功能参数
feature_params = PromptFeatureParams(
prompt_mode=kwargs.get("current_prompt_mode", "s4u"),
# 功能开关
enable_tool=kwargs.get("enable_tool", True),
enable_memory=kwargs.get("enable_memory", True),
enable_expression=kwargs.get("enable_expression", True),
enable_relation=kwargs.get("enable_relation", True),
enable_cross_context=kwargs.get("enable_cross_context", True),
enable_knowledge=kwargs.get("enable_knowledge", True),
enable_cache=kwargs.get("enable_cache", True),
cache_ttl=kwargs.get("cache_ttl", 300),
# 性能控制
max_context_messages=kwargs.get("max_context_messages", 50),
debug_mode=kwargs.get("debug_mode", False),
)
# 创建内容参数
content_params = PromptContentParams(
# 聊天历史和上下文
chat_target_info=kwargs.get("chat_target_info"),
message_list_before_now_long=kwargs.get("message_list_before_now_long", []),
message_list_before_short=kwargs.get("message_list_before_short", []),
chat_talking_prompt_short=kwargs.get("chat_talking_prompt_short", ""),
target_user_info=kwargs.get("target_user_info"),
# 已构建的内容块
expression_habits_block=kwargs.get("expression_habits_block", ""),
relation_info=kwargs.get("relation_info", ""),
relation_info_block=kwargs.get("relation_info", ""),
memory_block=kwargs.get("memory_block", ""),
tool_info=kwargs.get("tool_info", ""),
prompt_info=kwargs.get("prompt_info", ""),
tool_info_block=kwargs.get("tool_info", ""),
knowledge_prompt=kwargs.get("knowledge_prompt", ""),
cross_context_block=kwargs.get("cross_context_block", ""),
# 其他内容块
keywords_reaction_prompt=kwargs.get("keywords_reaction_prompt", ""),
extra_info_block=kwargs.get("extra_info_block", ""),
time_block=kwargs.get("time_block", ""),
@@ -336,10 +151,4 @@ class SmartPromptParameters:
reply_target_block=kwargs.get("reply_target_block", ""),
mood_prompt=kwargs.get("mood_prompt", ""),
action_descriptions=kwargs.get("action_descriptions", ""),
)
return cls(
core=core_params,
features=feature_params,
content=content_params
)

View File

@@ -1,6 +1,7 @@
"""
共享提示词工具模块 - 消除重复代码
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
移除缓存相关功能
"""
import re
import time
@@ -22,7 +23,7 @@ logger = get_logger("prompt_utils")
class PromptUtils:
"""提示词工具类 - 提供共享功能"""
"""提示词工具类 - 提供共享功能,移除缓存相关功能"""
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
@@ -51,13 +52,50 @@ class PromptUtils:
return sender, target
@staticmethod
async def build_cross_context_block(
async def build_relation_info(chat_id: str, reply_to: str) -> str:
"""
构建关系信息 - 统一实现
Args:
chat_id: 聊天ID
reply_to: 回复目标字符串
Returns:
str: 关系信息字符串
"""
if not global_config.relationship.enable_relationship:
return ""
try:
from src.person_info.relationship_fetcher import relationship_fetcher_manager
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
if not reply_to:
return ""
sender, text = PromptUtils.parse_reply_target(reply_to)
if not sender or not text:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
return ""
@staticmethod
async def build_cross_context(
chat_id: str,
target_user_info: Optional[Dict[str, Any]],
current_prompt_mode: str
) -> str:
"""
构建跨群聊上下文 - 统一实现
构建跨群聊上下文 - 统一实现完全继承DefaultReplyer功能
Args:
chat_id: 当前聊天ID
@@ -75,7 +113,12 @@ class PromptUtils:
current_stream = get_chat_manager().get_stream(chat_id)
if not current_stream or not current_stream.group_info:
return ""
current_chat_raw_id = current_stream.group_info.group_id
try:
current_chat_raw_id = current_stream.group_info.group_id
except Exception as e:
logger.error(f"获取群聊ID失败: {e}")
return ""
for group in global_config.cross_context.groups:
if str(current_chat_raw_id) in group.chat_ids:
@@ -97,15 +140,19 @@ class PromptUtils:
if not stream_id:
continue
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=5, # 可配置
)
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f"[以下是来自\"{chat_name}\"的近期消息]\n{formatted_messages}")
try:
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=5, # 可配置
)
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f"[以下是来自\"{chat_name}\"的近期消息]\n{formatted_messages}")
except Exception as e:
logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}")
continue
elif current_prompt_mode == "s4u":
# s4u模式获取当前发言用户在其他群聊的消息
@@ -120,27 +167,31 @@ class PromptUtils:
if not stream_id:
continue
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=20, # 获取更多消息以供筛选
)
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][
-5:
] # 筛选并取最近5条
try:
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=20, # 获取更多消息以供筛选
)
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][
-5:
] # 筛选并取最近5条
if user_messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
user_name = (
target_user_info.get("person_name") or
target_user_info.get("user_nickname") or user_id
)
formatted_messages, _ = build_readable_messages_with_id(
user_messages, timestamp_mode="relative"
)
cross_context_messages.append(
f"[以下是\"{user_name}\"\"{chat_name}\"的近期发言]\n{formatted_messages}"
)
if user_messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
user_name = (
target_user_info.get("person_name") or
target_user_info.get("user_nickname") or user_id
)
formatted_messages, _ = build_readable_messages_with_id(
user_messages, timestamp_mode="relative"
)
cross_context_messages.append(
f"[以下是\"{user_name}\"\"{chat_name}\"的近期发言]\n{formatted_messages}"
)
except Exception as e:
logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}")
continue
if not cross_context_messages:
return ""
@@ -260,88 +311,4 @@ class DependencyChecker:
"memory": await DependencyChecker.check_memory_dependencies(),
"tool": await DependencyChecker.check_tool_dependencies(),
"knowledge": await DependencyChecker.check_knowledge_dependencies(),
}
class SmartPromptCache:
"""智能提示词缓存系统 - 分层缓存实现"""
def __init__(self):
self._l1_cache: Dict[str, Tuple[str, float]] = {} # 内存缓存: {key: (value, timestamp)}
self._l2_cache_enabled = False # 是否启用L2缓存
self._cache_ttl = 300 # 默认缓存TTL: 5分钟
def enable_l2_cache(self, enabled: bool = True):
"""启用或禁用L2缓存"""
self._l2_cache_enabled = enabled
def set_cache_ttl(self, ttl: int):
"""设置缓存TTL"""
self._cache_ttl = ttl
def _generate_key(self, chat_id: str, prompt_mode: str, reply_to: str) -> str:
"""生成缓存键"""
import hashlib
key_content = f"{chat_id}_{prompt_mode}_{reply_to}"
return hashlib.md5(key_content.encode()).hexdigest()
def get(self, chat_id: str, prompt_mode: str, reply_to: str) -> Optional[str]:
"""获取缓存值"""
cache_key = self._generate_key(chat_id, prompt_mode, reply_to)
# 检查L1缓存
if cache_key in self._l1_cache:
value, timestamp = self._l1_cache[cache_key]
if time.time() - timestamp < self._cache_ttl:
logger.debug(f"L1缓存命中: {cache_key}")
return value
else:
# 缓存过期,清理
del self._l1_cache[cache_key]
# TODO: 实现L2缓存如Redis
# if self._l2_cache_enabled:
# return self._get_from_l2_cache(cache_key)
return None
def set(self, chat_id: str, prompt_mode: str, reply_to: str, value: str):
"""设置缓存值"""
cache_key = self._generate_key(chat_id, prompt_mode, reply_to)
# 设置L1缓存
self._l1_cache[cache_key] = (value, time.time())
# TODO: 实现L2缓存
# if self._l2_cache_enabled:
# self._set_to_l2_cache(cache_key, value)
# 定期清理过期缓存
if len(self._l1_cache) > 1000: # 缓存条目过多时清理
self._clean_expired_cache()
def _clean_expired_cache(self):
"""清理过期缓存"""
current_time = time.time()
expired_keys = [
key for key, (_, timestamp) in self._l1_cache.items()
if current_time - timestamp >= self._cache_ttl
]
for key in expired_keys:
del self._l1_cache[key]
logger.debug(f"清理过期缓存: {len(expired_keys)} 个条目")
def clear(self):
"""清空所有缓存"""
self._l1_cache.clear()
# TODO: 清空L2缓存
logger.info("缓存已清空")
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return {
"l1_cache_size": len(self._l1_cache),
"l2_cache_enabled": self._l2_cache_enabled,
"cache_ttl": self._cache_ttl,
}

File diff suppressed because it is too large Load Diff