feat(memory): 增强记忆构建系统并优化检索性能
- 添加记忆提取异常处理机制,提升系统稳定性 - 实现记忆内容格式化功能,增强可读性和结构化输出 - 优化LLM响应解析逻辑,避免系统标识误写入记忆 - 改进向量存储批量嵌入生成,提升处理效率 - 为记忆系统添加机器人身份上下文注入,避免自身信息记录 - 增强记忆检索接口,支持额外上下文参数传递 - 添加控制台记忆预览功能,便于人工检查 - 优化记忆融合算法,正确处理单记忆组情况 - 改进流循环管理器,支持未读消息积压强制分发机制
This commit is contained in:
@@ -17,7 +17,7 @@ from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
||||
from src.chat.memory_system.metadata_index import MetadataIndexManager
|
||||
@@ -295,6 +295,9 @@ class EnhancedMemorySystem:
|
||||
# 4. 存储记忆
|
||||
await self._store_memories(fused_chunks)
|
||||
|
||||
# 4.1 控制台预览
|
||||
self._log_memory_preview(fused_chunks)
|
||||
|
||||
# 5. 更新统计
|
||||
self.total_memories += len(fused_chunks)
|
||||
self.last_build_time = time.time()
|
||||
@@ -307,6 +310,15 @@ class EnhancedMemorySystem:
|
||||
self.status = original_status
|
||||
return fused_chunks
|
||||
|
||||
except MemoryExtractionError as e:
|
||||
if build_scope_key and build_marker_time is not None:
|
||||
recorded_time = self._last_memory_build_times.get(build_scope_key)
|
||||
if recorded_time == build_marker_time:
|
||||
self._last_memory_build_times.pop(build_scope_key, None)
|
||||
self.status = original_status
|
||||
logger.warning("记忆构建因LLM响应问题中断: %s", e)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
if build_scope_key and build_marker_time is not None:
|
||||
recorded_time = self._last_memory_build_times.get(build_scope_key)
|
||||
@@ -316,6 +328,23 @@ class EnhancedMemorySystem:
|
||||
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _log_memory_preview(self, memories: List[MemoryChunk]) -> None:
|
||||
"""在控制台输出记忆预览,便于人工检查"""
|
||||
if not memories:
|
||||
logger.info("📝 本次未生成新的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"📝 本次生成的记忆预览 ({len(memories)} 条):")
|
||||
for idx, memory in enumerate(memories, start=1):
|
||||
text = memory.text_content or ""
|
||||
if len(text) > 120:
|
||||
text = text[:117] + "..."
|
||||
|
||||
logger.info(
|
||||
f" {idx}) 类型={memory.memory_type.value} 重要性={memory.metadata.importance.name} "
|
||||
f"置信度={memory.metadata.confidence.name} | 内容={text}"
|
||||
)
|
||||
|
||||
async def process_conversation_memory(
|
||||
self,
|
||||
conversation_text: str,
|
||||
|
||||
@@ -55,12 +55,27 @@ class EnhancedMemoryHooks:
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
# 注入机器人基础人设,帮助记忆构建时避免记录自身信息
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
personality_config = getattr(global_config, "personality", None)
|
||||
bot_context = {}
|
||||
if bot_config is not None:
|
||||
bot_context["bot_name"] = getattr(bot_config, "nickname", None)
|
||||
bot_context["bot_aliases"] = list(getattr(bot_config, "alias_names", []) or [])
|
||||
bot_context["bot_account"] = getattr(bot_config, "qq_account", None)
|
||||
|
||||
if personality_config is not None:
|
||||
bot_context["bot_identity"] = getattr(personality_config, "identity", None)
|
||||
bot_context["bot_personality"] = getattr(personality_config, "personality_core", None)
|
||||
bot_context["bot_personality_side"] = getattr(personality_config, "personality_side", None)
|
||||
|
||||
# 构建上下文
|
||||
memory_context = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"timestamp": datetime.now().timestamp(),
|
||||
"message_type": "user_message",
|
||||
**bot_context,
|
||||
**(context or {})
|
||||
}
|
||||
|
||||
@@ -92,7 +107,8 @@ class EnhancedMemoryHooks:
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5
|
||||
limit: int = 5,
|
||||
extra_context: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
@@ -123,6 +139,9 @@ class EnhancedMemoryHooks:
|
||||
]
|
||||
}
|
||||
|
||||
if extra_context:
|
||||
context.update(extra_context)
|
||||
|
||||
# 获取相关记忆
|
||||
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
|
||||
query_text=query_text,
|
||||
@@ -140,7 +159,9 @@ class EnhancedMemoryHooks:
|
||||
"confidence": result.confidence,
|
||||
"importance": result.importance,
|
||||
"timestamp": result.timestamp,
|
||||
"source": result.source
|
||||
"source": result.source,
|
||||
"relevance": result.relevance_score,
|
||||
"structure": result.structure,
|
||||
}
|
||||
results.append(memory_dict)
|
||||
|
||||
|
||||
@@ -56,7 +56,8 @@ async def get_relevant_memories_for_response(
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5
|
||||
limit: int = 5,
|
||||
extra_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
@@ -65,7 +66,8 @@ async def get_relevant_memories_for_response(
|
||||
query_text: 查询文本(通常是用户的当前消息)
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回记忆数量限制
|
||||
limit: 返回记忆数量限制
|
||||
extra_context: 额外上下文信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含记忆信息的字典
|
||||
@@ -75,7 +77,8 @@ async def get_relevant_memories_for_response(
|
||||
query_text=query_text,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
limit=limit
|
||||
limit=limit,
|
||||
extra_context=extra_context
|
||||
)
|
||||
|
||||
result = {
|
||||
@@ -157,7 +160,8 @@ def get_memory_system_status() -> Dict[str, Any]:
|
||||
async def remember_message(
|
||||
message: str,
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat"
|
||||
chat_id: str = "default_chat",
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
便捷的记忆构建函数
|
||||
@@ -176,7 +180,8 @@ async def remember_message(
|
||||
message_content=message,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id
|
||||
message_id=message_id,
|
||||
context=context
|
||||
)
|
||||
|
||||
|
||||
@@ -184,7 +189,8 @@ async def recall_memories(
|
||||
query: str,
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
limit: int = 5
|
||||
limit: int = 5,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
便捷的记忆检索函数
|
||||
@@ -202,5 +208,6 @@ async def recall_memories(
|
||||
query_text=query,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
limit=limit
|
||||
limit=limit,
|
||||
extra_context=context
|
||||
)
|
||||
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from datetime import datetime
|
||||
@@ -31,6 +32,7 @@ class EnhancedMemoryResult:
|
||||
timestamp: float
|
||||
source: str = "enhanced_memory"
|
||||
relevance_score: float = 0.0
|
||||
structure: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class EnhancedMemoryManager:
|
||||
@@ -41,6 +43,14 @@ class EnhancedMemoryManager:
|
||||
self.is_initialized = False
|
||||
self.user_cache = {} # 用户记忆缓存
|
||||
|
||||
def _clean_text(self, text: Any) -> str:
|
||||
if text is None:
|
||||
return ""
|
||||
|
||||
cleaned = re.sub(r"[\s\u3000]+", " ", str(text)).strip()
|
||||
cleaned = re.sub(r"[、,,;;]+$", "", cleaned)
|
||||
return cleaned
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化增强记忆系统"""
|
||||
if self.is_initialized:
|
||||
@@ -271,14 +281,16 @@ class EnhancedMemoryManager:
|
||||
|
||||
results = []
|
||||
for memory in relevant_memories:
|
||||
formatted_content, structure = self._format_memory_chunk(memory)
|
||||
result = EnhancedMemoryResult(
|
||||
content=memory.text_content,
|
||||
content=formatted_content,
|
||||
memory_type=memory.memory_type.value,
|
||||
confidence=memory.metadata.confidence.value,
|
||||
importance=memory.metadata.importance.value,
|
||||
timestamp=memory.metadata.created_at,
|
||||
source="enhanced_memory",
|
||||
relevance_score=memory.metadata.relevance_score
|
||||
relevance_score=memory.metadata.relevance_score,
|
||||
structure=structure
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
@@ -288,6 +300,197 @@ class EnhancedMemoryManager:
|
||||
logger.error(f"get_enhanced_memory_context 失败: {e}")
|
||||
return []
|
||||
|
||||
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
|
||||
"""将记忆块转换为更易读的文本描述"""
|
||||
structure = memory.content.to_dict()
|
||||
subject = structure.get("subject")
|
||||
predicate = structure.get("predicate") or ""
|
||||
obj = structure.get("object")
|
||||
|
||||
subject_display = self._format_subject(subject, memory)
|
||||
formatted = self._apply_predicate_format(subject_display, predicate, obj)
|
||||
|
||||
if not formatted:
|
||||
predicate_display = self._format_predicate(predicate)
|
||||
object_display = self._format_object(obj)
|
||||
formatted = f"{subject_display}{predicate_display}{object_display}".strip()
|
||||
|
||||
formatted = self._clean_text(formatted)
|
||||
|
||||
return formatted, structure
|
||||
|
||||
def _format_subject(self, subject: Optional[str], memory: MemoryChunk) -> str:
|
||||
if not subject:
|
||||
return "该用户"
|
||||
|
||||
if subject == memory.metadata.user_id:
|
||||
return "该用户"
|
||||
if memory.metadata.chat_id and subject == memory.metadata.chat_id:
|
||||
return "该聊天"
|
||||
return self._clean_text(subject)
|
||||
|
||||
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> Optional[str]:
|
||||
predicate = (predicate or "").strip()
|
||||
obj_value = obj
|
||||
|
||||
if predicate == "is_named":
|
||||
name = self._extract_from_object(obj_value, ["name", "nickname"]) or self._format_object(obj_value)
|
||||
name = self._clean_text(name)
|
||||
if not name:
|
||||
return None
|
||||
name_display = name if (name.startswith("「") and name.endswith("」")) else f"「{name}」"
|
||||
return f"{subject}的昵称是{name_display}"
|
||||
if predicate == "is_age":
|
||||
age = self._extract_from_object(obj_value, ["age"]) or self._format_object(obj_value)
|
||||
age = self._clean_text(age)
|
||||
if not age:
|
||||
return None
|
||||
return f"{subject}今年{age}岁"
|
||||
if predicate == "is_profession":
|
||||
profession = self._extract_from_object(obj_value, ["profession", "job"]) or self._format_object(obj_value)
|
||||
profession = self._clean_text(profession)
|
||||
if not profession:
|
||||
return None
|
||||
return f"{subject}的职业是{profession}"
|
||||
if predicate == "lives_in":
|
||||
location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(obj_value)
|
||||
location = self._clean_text(location)
|
||||
if not location:
|
||||
return None
|
||||
return f"{subject}居住在{location}"
|
||||
if predicate == "has_phone":
|
||||
phone = self._extract_from_object(obj_value, ["phone", "number"]) or self._format_object(obj_value)
|
||||
phone = self._clean_text(phone)
|
||||
if not phone:
|
||||
return None
|
||||
return f"{subject}的电话号码是{phone}"
|
||||
if predicate == "has_email":
|
||||
email = self._extract_from_object(obj_value, ["email"]) or self._format_object(obj_value)
|
||||
email = self._clean_text(email)
|
||||
if not email:
|
||||
return None
|
||||
return f"{subject}的邮箱是{email}"
|
||||
if predicate == "likes":
|
||||
liked = self._format_object(obj_value)
|
||||
if not liked:
|
||||
return None
|
||||
return f"{subject}喜欢{liked}"
|
||||
if predicate == "likes_food":
|
||||
food = self._format_object(obj_value)
|
||||
if not food:
|
||||
return None
|
||||
return f"{subject}爱吃{food}"
|
||||
if predicate == "dislikes":
|
||||
disliked = self._format_object(obj_value)
|
||||
if not disliked:
|
||||
return None
|
||||
return f"{subject}不喜欢{disliked}"
|
||||
if predicate == "hates":
|
||||
hated = self._format_object(obj_value)
|
||||
if not hated:
|
||||
return None
|
||||
return f"{subject}讨厌{hated}"
|
||||
if predicate == "favorite_is":
|
||||
favorite = self._format_object(obj_value)
|
||||
if not favorite:
|
||||
return None
|
||||
return f"{subject}最喜欢{favorite}"
|
||||
if predicate == "mentioned_event":
|
||||
event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(obj_value)
|
||||
event_text = self._clean_text(self._truncate(event_text))
|
||||
if not event_text:
|
||||
return None
|
||||
return f"{subject}提到了计划或事件:{event_text}"
|
||||
if predicate in {"正在", "在", "正在进行"}:
|
||||
action = self._format_object(obj_value)
|
||||
if not action:
|
||||
return None
|
||||
return f"{subject}{predicate}{action}"
|
||||
if predicate in {"感到", "觉得", "表示", "提到", "说道", "说"}:
|
||||
feeling = self._format_object(obj_value)
|
||||
if not feeling:
|
||||
return None
|
||||
return f"{subject}{predicate}{feeling}"
|
||||
if predicate in {"与", "和", "跟"}:
|
||||
counterpart = self._format_object(obj_value)
|
||||
if counterpart:
|
||||
return f"{subject}{predicate}{counterpart}"
|
||||
return f"{subject}{predicate}"
|
||||
|
||||
return None
|
||||
|
||||
def _format_predicate(self, predicate: str) -> str:
|
||||
if not predicate:
|
||||
return ""
|
||||
predicate_map = {
|
||||
"is_named": "的昵称是",
|
||||
"is_profession": "的职业是",
|
||||
"lives_in": "居住在",
|
||||
"has_phone": "的电话是",
|
||||
"has_email": "的邮箱是",
|
||||
"likes": "喜欢",
|
||||
"dislikes": "不喜欢",
|
||||
"likes_food": "爱吃",
|
||||
"hates": "讨厌",
|
||||
"favorite_is": "最喜欢",
|
||||
"mentioned_event": "提到的事件",
|
||||
}
|
||||
if predicate in predicate_map:
|
||||
connector = predicate_map[predicate]
|
||||
if connector.startswith("的"):
|
||||
return connector
|
||||
return f" {connector} "
|
||||
cleaned = predicate.replace("_", " ").strip()
|
||||
if re.search(r"[\u4e00-\u9fff]", cleaned):
|
||||
return cleaned
|
||||
return f" {cleaned} "
|
||||
|
||||
def _format_object(self, obj: Any) -> str:
|
||||
if obj is None:
|
||||
return ""
|
||||
if isinstance(obj, dict):
|
||||
parts = []
|
||||
for key, value in obj.items():
|
||||
formatted_value = self._format_object(value)
|
||||
if not formatted_value:
|
||||
continue
|
||||
pretty_key = {
|
||||
"name": "名字",
|
||||
"profession": "职业",
|
||||
"location": "位置",
|
||||
"event_text": "内容",
|
||||
"timestamp": "时间",
|
||||
}.get(key, key)
|
||||
parts.append(f"{pretty_key}: {formatted_value}")
|
||||
return self._clean_text(";".join(parts))
|
||||
if isinstance(obj, list):
|
||||
formatted_items = [self._format_object(item) for item in obj]
|
||||
filtered = [item for item in formatted_items if item]
|
||||
return self._clean_text("、".join(filtered)) if filtered else ""
|
||||
if isinstance(obj, (int, float)):
|
||||
return str(obj)
|
||||
text = self._truncate(str(obj).strip())
|
||||
return self._clean_text(text)
|
||||
|
||||
def _extract_from_object(self, obj: Any, keys: List[str]) -> Optional[str]:
|
||||
if isinstance(obj, dict):
|
||||
for key in keys:
|
||||
if key in obj and obj[key]:
|
||||
value = obj[key]
|
||||
if isinstance(value, (dict, list)):
|
||||
return self._clean_text(self._format_object(value))
|
||||
return self._clean_text(value)
|
||||
if isinstance(obj, list) and obj:
|
||||
return self._clean_text(self._format_object(obj[0]))
|
||||
if isinstance(obj, (str, int, float)):
|
||||
return self._clean_text(obj)
|
||||
return None
|
||||
|
||||
def _truncate(self, text: str, max_length: int = 80) -> str:
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[: max_length - 1] + "…"
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭增强记忆系统"""
|
||||
if not self.is_initialized:
|
||||
|
||||
@@ -38,6 +38,10 @@ class ExtractionResult:
|
||||
strategy_used: ExtractionStrategy
|
||||
|
||||
|
||||
class MemoryExtractionError(Exception):
|
||||
"""记忆提取过程中发生的不可恢复错误"""
|
||||
|
||||
|
||||
class MemoryBuilder:
|
||||
"""记忆构建器"""
|
||||
|
||||
@@ -87,10 +91,14 @@ class MemoryBuilder:
|
||||
logger.info(f"✅ 成功构建 {len(validated_memories)} 条记忆,耗时 {extraction_time:.2f}秒")
|
||||
return validated_memories
|
||||
|
||||
except MemoryExtractionError as e:
|
||||
logger.error(f"❌ 记忆构建失败(响应解析错误): {e}")
|
||||
self.extraction_stats["failed_extractions"] += 1
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
|
||||
self.extraction_stats["failed_extractions"] += 1
|
||||
return []
|
||||
raise
|
||||
|
||||
def _preprocess_text(self, text: str) -> str:
|
||||
"""预处理文本"""
|
||||
@@ -147,9 +155,11 @@ class MemoryBuilder:
|
||||
|
||||
return memories
|
||||
|
||||
except MemoryExtractionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"LLM提取失败: {e}")
|
||||
return []
|
||||
raise MemoryExtractionError(str(e)) from e
|
||||
|
||||
def _extract_with_rules(
|
||||
self,
|
||||
@@ -161,16 +171,18 @@ class MemoryBuilder:
|
||||
"""使用规则提取记忆"""
|
||||
memories = []
|
||||
|
||||
subject_display = self._resolve_user_display(context, user_id)
|
||||
|
||||
# 规则1: 检测个人信息
|
||||
personal_info = self._extract_personal_info(text, user_id, timestamp, context)
|
||||
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subject_display)
|
||||
memories.extend(personal_info)
|
||||
|
||||
# 规则2: 检测偏好信息
|
||||
preferences = self._extract_preferences(text, user_id, timestamp, context)
|
||||
preferences = self._extract_preferences(text, user_id, timestamp, context, subject_display)
|
||||
memories.extend(preferences)
|
||||
|
||||
# 规则3: 检测事件信息
|
||||
events = self._extract_events(text, user_id, timestamp, context)
|
||||
events = self._extract_events(text, user_id, timestamp, context, subject_display)
|
||||
memories.extend(events)
|
||||
|
||||
return memories
|
||||
@@ -202,6 +214,45 @@ class MemoryBuilder:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
chat_id = context.get("chat_id", "unknown")
|
||||
message_type = context.get("message_type", "normal")
|
||||
target_user_id = context.get("user_id", "用户")
|
||||
target_user_id = str(target_user_id)
|
||||
|
||||
target_user_name = (
|
||||
context.get("user_display_name")
|
||||
or context.get("user_name")
|
||||
or context.get("nickname")
|
||||
or context.get("sender_name")
|
||||
)
|
||||
if isinstance(target_user_name, str):
|
||||
target_user_name = target_user_name.strip()
|
||||
else:
|
||||
target_user_name = ""
|
||||
|
||||
if not target_user_name or self._looks_like_system_identifier(target_user_name):
|
||||
target_user_name = "该用户"
|
||||
|
||||
target_user_id_display = target_user_id
|
||||
if self._looks_like_system_identifier(target_user_id_display):
|
||||
target_user_id_display = "(系统ID,勿写入记忆)"
|
||||
|
||||
bot_name = context.get("bot_name")
|
||||
bot_identity = context.get("bot_identity")
|
||||
bot_personality = context.get("bot_personality")
|
||||
bot_personality_side = context.get("bot_personality_side")
|
||||
bot_aliases = context.get("bot_aliases") or []
|
||||
if isinstance(bot_aliases, str):
|
||||
bot_aliases = [bot_aliases]
|
||||
|
||||
bot_name_display = bot_name or "机器人"
|
||||
alias_display = "、".join(a for a in bot_aliases if a) or "无"
|
||||
persona_details = []
|
||||
if bot_identity:
|
||||
persona_details.append(f"身份: {bot_identity}")
|
||||
if bot_personality:
|
||||
persona_details.append(f"核心人设: {bot_personality}")
|
||||
if bot_personality_side:
|
||||
persona_details.append(f"侧写: {bot_personality_side}")
|
||||
persona_display = ";".join(persona_details) if persona_details else "无"
|
||||
|
||||
prompt = f"""
|
||||
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
|
||||
@@ -209,6 +260,20 @@ class MemoryBuilder:
|
||||
当前时间: {current_date}
|
||||
聊天ID: {chat_id}
|
||||
消息类型: {message_type}
|
||||
目标用户ID: {target_user_id_display}
|
||||
目标用户称呼: {target_user_name}
|
||||
|
||||
## 🤖 机器人身份(仅供参考,禁止写入记忆)
|
||||
- 机器人名称: {bot_name_display}
|
||||
- 别名: {alias_display}
|
||||
- 机器人人设概述: {persona_display}
|
||||
|
||||
这些信息是机器人的固定设定,可用于帮助你理解对话。你可以在需要时记录机器人自身的状态、行为或设定,但要与用户信息清晰区分,避免误将系统ID写入记忆。
|
||||
|
||||
请务必遵守以下命名规范:
|
||||
- 当说话者是机器人时,请使用“{bot_name_display}”或其他明确称呼作为主语;
|
||||
- 如果看到系统自动生成的长ID(类似 {target_user_id}),请改用“{target_user_name}”、机器人的称呼或“该用户”描述,不要把ID写入记忆;
|
||||
- 记录关键事实时,请准确标记主体是机器人还是用户,避免混淆。
|
||||
|
||||
对话内容:
|
||||
{text}
|
||||
@@ -232,6 +297,7 @@ class MemoryBuilder:
|
||||
- 特殊经历:考试、面试、会议、搬家、购物
|
||||
- 计划安排:约会、会议、旅行、活动
|
||||
|
||||
|
||||
**判断标准:** 涉及时间地点的具体活动和经历,都应该记忆
|
||||
|
||||
### 3. **偏好** (preference) - 高优先级记忆
|
||||
@@ -364,56 +430,255 @@ class MemoryBuilder:
|
||||
context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
"""解析LLM响应"""
|
||||
memories = []
|
||||
if not response:
|
||||
raise MemoryExtractionError("LLM未返回任何响应")
|
||||
|
||||
json_payload = self._extract_json_payload(response)
|
||||
if not json_payload:
|
||||
preview = response[:200] if response else "空响应"
|
||||
raise MemoryExtractionError(f"未在LLM响应中找到有效的JSON负载,响应片段: {preview}")
|
||||
|
||||
try:
|
||||
# 提取JSON负载
|
||||
json_payload = self._extract_json_payload(response)
|
||||
if not json_payload:
|
||||
logger.error("未在响应中找到有效的JSON负载")
|
||||
return memories
|
||||
|
||||
data = orjson.loads(json_payload)
|
||||
memory_list = data.get("memories", [])
|
||||
except Exception as e:
|
||||
preview = json_payload[:200]
|
||||
raise MemoryExtractionError(
|
||||
f"LLM响应JSON解析失败: {e}, 片段: {preview}"
|
||||
) from e
|
||||
|
||||
for mem_data in memory_list:
|
||||
try:
|
||||
# 创建记忆块
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=mem_data.get("subject", user_id),
|
||||
predicate=mem_data.get("predicate", ""),
|
||||
obj=mem_data.get("object", ""),
|
||||
memory_type=MemoryType(mem_data.get("type", "contextual")),
|
||||
chat_id=context.get("chat_id"),
|
||||
source_context=mem_data.get("reasoning", ""),
|
||||
importance=ImportanceLevel(mem_data.get("importance", 2)),
|
||||
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
|
||||
)
|
||||
memory_list = data.get("memories", [])
|
||||
|
||||
# 添加关键词
|
||||
keywords = mem_data.get("keywords", [])
|
||||
for keyword in keywords:
|
||||
memory.add_keyword(keyword)
|
||||
bot_identifiers = self._collect_bot_identifiers(context)
|
||||
system_identifiers = self._collect_system_identifiers(context)
|
||||
default_subject = self._resolve_user_display(context, user_id)
|
||||
|
||||
memories.append(memory)
|
||||
bot_display = None
|
||||
if context:
|
||||
primary_bot_name = context.get("bot_name")
|
||||
if isinstance(primary_bot_name, str) and primary_bot_name.strip():
|
||||
bot_display = primary_bot_name.strip()
|
||||
if bot_display is None:
|
||||
aliases = context.get("bot_aliases")
|
||||
if isinstance(aliases, (list, tuple, set)):
|
||||
for alias in aliases:
|
||||
if isinstance(alias, str) and alias.strip():
|
||||
bot_display = alias.strip()
|
||||
break
|
||||
elif isinstance(aliases, str) and aliases.strip():
|
||||
bot_display = aliases.strip()
|
||||
if bot_display is None:
|
||||
identity = context.get("bot_identity")
|
||||
if isinstance(identity, str) and identity.strip():
|
||||
bot_display = identity.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析单个记忆失败: {e}, 数据: {mem_data}")
|
||||
if not bot_display:
|
||||
bot_display = "机器人"
|
||||
|
||||
bot_display = self._clean_subject_text(bot_display)
|
||||
|
||||
memories: List[MemoryChunk] = []
|
||||
|
||||
for mem_data in memory_list:
|
||||
try:
|
||||
subject_value = mem_data.get("subject")
|
||||
normalized_subject = self._normalize_subject(
|
||||
subject_value,
|
||||
bot_identifiers,
|
||||
system_identifiers,
|
||||
default_subject,
|
||||
bot_display
|
||||
)
|
||||
|
||||
if normalized_subject is None:
|
||||
logger.debug("跳过疑似机器人自身信息的记忆: %s", mem_data)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
preview = response[:200] if response else "空响应"
|
||||
logger.error(f"解析LLM响应失败: {e}, 响应片段: {preview}")
|
||||
# 创建记忆块
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=normalized_subject,
|
||||
predicate=mem_data.get("predicate", ""),
|
||||
obj=mem_data.get("object", ""),
|
||||
memory_type=MemoryType(mem_data.get("type", "contextual")),
|
||||
chat_id=context.get("chat_id"),
|
||||
source_context=mem_data.get("reasoning", ""),
|
||||
importance=ImportanceLevel(mem_data.get("importance", 2)),
|
||||
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
|
||||
)
|
||||
|
||||
# 添加关键词
|
||||
keywords = mem_data.get("keywords", [])
|
||||
for keyword in keywords:
|
||||
memory.add_keyword(keyword)
|
||||
|
||||
subject_text = memory.content.subject.strip() if isinstance(memory.content.subject, str) else str(memory.content.subject)
|
||||
if not subject_text:
|
||||
memory.content.subject = default_subject
|
||||
elif subject_text.lower() in system_identifiers or self._looks_like_system_identifier(subject_text):
|
||||
logger.debug("将系统标识主语替换为默认用户名称: %s", subject_text)
|
||||
memory.content.subject = default_subject
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析单个记忆失败: {e}, 数据: {mem_data}")
|
||||
continue
|
||||
|
||||
return memories
|
||||
|
||||
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
identifiers: set[str] = {"bot", "机器人", "ai助手"}
|
||||
if not context:
|
||||
return identifiers
|
||||
|
||||
for key in [
|
||||
"bot_name",
|
||||
"bot_identity",
|
||||
"bot_personality",
|
||||
"bot_personality_side",
|
||||
"bot_account",
|
||||
]:
|
||||
value = context.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
identifiers.add(value.strip().lower())
|
||||
|
||||
aliases = context.get("bot_aliases")
|
||||
if isinstance(aliases, (list, tuple, set)):
|
||||
for alias in aliases:
|
||||
if isinstance(alias, str) and alias.strip():
|
||||
identifiers.add(alias.strip().lower())
|
||||
elif isinstance(aliases, str) and aliases.strip():
|
||||
identifiers.add(aliases.strip().lower())
|
||||
|
||||
return identifiers
|
||||
|
||||
def _collect_system_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
identifiers: set[str] = set()
|
||||
if not context:
|
||||
return identifiers
|
||||
|
||||
keys = [
|
||||
"chat_id",
|
||||
"stream_id",
|
||||
"stram_id",
|
||||
"session_id",
|
||||
"conversation_id",
|
||||
"message_id",
|
||||
"topic_id",
|
||||
"thread_id",
|
||||
]
|
||||
|
||||
for key in keys:
|
||||
value = context.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
identifiers.add(value.strip().lower())
|
||||
|
||||
user_id_value = context.get("user_id")
|
||||
if isinstance(user_id_value, str) and user_id_value.strip():
|
||||
if self._looks_like_system_identifier(user_id_value):
|
||||
identifiers.add(user_id_value.strip().lower())
|
||||
|
||||
return identifiers
|
||||
|
||||
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
|
||||
candidate_keys = [
|
||||
"user_display_name",
|
||||
"user_name",
|
||||
"nickname",
|
||||
"sender_name",
|
||||
"member_name",
|
||||
"display_name",
|
||||
"from_user_name",
|
||||
"author_name",
|
||||
"speaker_name",
|
||||
]
|
||||
|
||||
if context:
|
||||
for key in candidate_keys:
|
||||
value = context.get(key)
|
||||
if isinstance(value, str):
|
||||
candidate = value.strip()
|
||||
if candidate:
|
||||
return self._clean_subject_text(candidate)
|
||||
|
||||
if user_id and not self._looks_like_system_identifier(user_id):
|
||||
return self._clean_subject_text(user_id)
|
||||
|
||||
return "该用户"
|
||||
|
||||
def _clean_subject_text(self, text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
cleaned = re.sub(r"[\s\u3000]+", " ", text).strip()
|
||||
cleaned = re.sub(r"[、,,;;]+$", "", cleaned)
|
||||
return cleaned
|
||||
|
||||
def _looks_like_system_identifier(self, value: str) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
condensed = value.replace("-", "").replace("_", "").strip()
|
||||
if len(condensed) >= 16 and re.fullmatch(r"[0-9a-fA-F]+", condensed):
|
||||
return True
|
||||
|
||||
if len(value) >= 12 and re.fullmatch(r"[0-9A-Z_:-]+", value) and any(ch.isdigit() for ch in value):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _normalize_subject(
|
||||
self,
|
||||
subject: Any,
|
||||
bot_identifiers: set[str],
|
||||
system_identifiers: set[str],
|
||||
default_subject: str,
|
||||
bot_display: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
if subject is None:
|
||||
return default_subject
|
||||
|
||||
subject_str = subject if isinstance(subject, str) else str(subject)
|
||||
cleaned = self._clean_subject_text(subject_str)
|
||||
if not cleaned:
|
||||
return default_subject
|
||||
|
||||
lowered = cleaned.lower()
|
||||
bot_primary = self._clean_subject_text(bot_display or "")
|
||||
|
||||
if lowered in bot_identifiers:
|
||||
return bot_primary or cleaned
|
||||
|
||||
if lowered in {"用户", "user", "the user", "对方", "对手"}:
|
||||
return default_subject
|
||||
|
||||
prefix_match = re.match(r"^(用户|User|user|USER|成员|member|Member|target|Target|TARGET)[\s::\-\u2014_]*?(.*)$", cleaned)
|
||||
if prefix_match:
|
||||
remainder = self._clean_subject_text(prefix_match.group(2))
|
||||
if not remainder:
|
||||
return default_subject
|
||||
remainder_lower = remainder.lower()
|
||||
if remainder_lower in bot_identifiers:
|
||||
return bot_primary or remainder
|
||||
if (
|
||||
remainder_lower in system_identifiers
|
||||
or self._looks_like_system_identifier(remainder)
|
||||
):
|
||||
return default_subject
|
||||
cleaned = remainder
|
||||
lowered = cleaned.lower()
|
||||
|
||||
if lowered in system_identifiers or self._looks_like_system_identifier(cleaned):
|
||||
return default_subject
|
||||
|
||||
return cleaned
|
||||
|
||||
def _extract_personal_info(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any]
|
||||
context: Dict[str, Any],
|
||||
subject_display: str
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取个人信息"""
|
||||
memories = []
|
||||
@@ -437,7 +702,7 @@ class MemoryBuilder:
|
||||
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=user_id,
|
||||
subject=subject_display,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
memory_type=MemoryType.PERSONAL_FACT,
|
||||
@@ -455,7 +720,8 @@ class MemoryBuilder:
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any]
|
||||
context: Dict[str, Any],
|
||||
subject_display: str
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取偏好信息"""
|
||||
memories = []
|
||||
@@ -474,7 +740,7 @@ class MemoryBuilder:
|
||||
if match:
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=user_id,
|
||||
subject=subject_display,
|
||||
predicate=predicate,
|
||||
obj=match.group(1),
|
||||
memory_type=MemoryType.PREFERENCE,
|
||||
@@ -492,7 +758,8 @@ class MemoryBuilder:
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any]
|
||||
context: Dict[str, Any],
|
||||
subject_display: str
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取事件信息"""
|
||||
memories = []
|
||||
@@ -503,7 +770,7 @@ class MemoryBuilder:
|
||||
if any(keyword in text for keyword in event_keywords):
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=user_id,
|
||||
subject=subject_display,
|
||||
predicate="mentioned_event",
|
||||
obj={"event_text": text, "timestamp": timestamp},
|
||||
memory_type=MemoryType.EVENT,
|
||||
@@ -634,26 +901,24 @@ class MemoryBuilder:
|
||||
r'明年|下一年': str(current_time.year + 1),
|
||||
}
|
||||
|
||||
# 检查并替换记忆内容中的相对时间
|
||||
memory_content = memory.content.description
|
||||
def _normalize_value(value):
|
||||
if isinstance(value, str):
|
||||
normalized = value
|
||||
for pattern, replacement in relative_time_patterns.items():
|
||||
normalized = re.sub(pattern, replacement, normalized)
|
||||
return normalized
|
||||
if isinstance(value, dict):
|
||||
return {k: _normalize_value(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_normalize_value(item) for item in value]
|
||||
return value
|
||||
|
||||
# 应用时间规范化
|
||||
for pattern, replacement in relative_time_patterns.items():
|
||||
memory_content = re.sub(pattern, replacement, memory_content)
|
||||
# 规范化主语和谓语(通常是字符串)
|
||||
memory.content.subject = _normalize_value(memory.content.subject)
|
||||
memory.content.predicate = _normalize_value(memory.content.predicate)
|
||||
|
||||
# 更新记忆内容
|
||||
memory.content.description = memory_content
|
||||
|
||||
# 如果记忆有对象信息,也进行时间规范化
|
||||
if hasattr(memory.content, 'object') and isinstance(memory.content.object, dict):
|
||||
obj_str = str(memory.content.object)
|
||||
for pattern, replacement in relative_time_patterns.items():
|
||||
obj_str = re.sub(pattern, replacement, obj_str)
|
||||
try:
|
||||
# 尝试解析回字典(如果原来是字典)
|
||||
memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str
|
||||
except Exception:
|
||||
memory.content.object = obj_str
|
||||
# 规范化宾语(可能是字符串、列表或字典)
|
||||
memory.content.object = _normalize_value(memory.content.object)
|
||||
|
||||
# 记录时间规范化操作
|
||||
logger.debug(f"记忆 {memory.memory_id} 已进行时间规范化")
|
||||
|
||||
@@ -80,6 +80,12 @@ class MemoryFusionEngine:
|
||||
new_memories, existing_memories or []
|
||||
)
|
||||
|
||||
if not duplicate_groups:
|
||||
fusion_time = time.time() - start_time
|
||||
self._update_fusion_stats(len(new_memories), 0, fusion_time)
|
||||
logger.info("✅ 记忆融合完成: %d 条记忆,移除 0 条重复", len(new_memories))
|
||||
return new_memories
|
||||
|
||||
# 2. 对每个重复组进行融合
|
||||
fused_memories = []
|
||||
removed_count = 0
|
||||
@@ -113,6 +119,7 @@ class MemoryFusionEngine:
|
||||
) -> List[DuplicateGroup]:
|
||||
"""检测重复记忆组"""
|
||||
all_memories = new_memories + existing_memories
|
||||
new_memory_ids = {memory.memory_id for memory in new_memories}
|
||||
groups = []
|
||||
processed_ids = set()
|
||||
|
||||
@@ -147,6 +154,10 @@ class MemoryFusionEngine:
|
||||
# 选择代表性记忆
|
||||
group.representative_memory = self._select_representative_memory(group)
|
||||
groups.append(group)
|
||||
else:
|
||||
# 仅包含单条记忆,只有当其来自新记忆列表时保留
|
||||
if memory1.memory_id in new_memory_ids:
|
||||
groups.append(group)
|
||||
|
||||
logger.debug(f"检测到 {len(groups)} 个重复记忆组")
|
||||
return groups
|
||||
|
||||
@@ -12,7 +12,6 @@ from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -130,10 +129,11 @@ class VectorStorageManager:
|
||||
await self.initialize_embedding_model()
|
||||
|
||||
# 批量获取嵌入向量
|
||||
embedding_tasks = []
|
||||
memory_texts = []
|
||||
|
||||
for memory in memories:
|
||||
# 预先缓存记忆,确保后续流程可访问
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
if memory.embedding is None:
|
||||
# 如果没有嵌入向量,需要生成
|
||||
text = self._prepare_embedding_text(memory)
|
||||
@@ -183,10 +183,10 @@ class VectorStorageManager:
|
||||
memory_ids = [memory_id for memory_id, _ in memory_texts]
|
||||
|
||||
# 批量生成嵌入向量
|
||||
embeddings = await self._batch_generate_embeddings(texts)
|
||||
embeddings = await self._batch_generate_embeddings(memory_ids, texts)
|
||||
|
||||
# 存储向量和记忆
|
||||
for memory_id, embedding in zip(memory_ids, embeddings):
|
||||
for memory_id, embedding in embeddings.items():
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory:
|
||||
@@ -195,76 +195,43 @@ class VectorStorageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
|
||||
async def _batch_generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
async def _batch_generate_embeddings(self, memory_ids: List[str], texts: List[str]) -> Dict[str, List[float]]:
|
||||
"""批量生成嵌入向量"""
|
||||
if not texts:
|
||||
return []
|
||||
return {}
|
||||
|
||||
results: Dict[str, List[float]] = {}
|
||||
|
||||
try:
|
||||
# 创建新的事件循环来运行异步操作
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
|
||||
|
||||
try:
|
||||
# 使用线程池并行生成嵌入向量
|
||||
with ThreadPoolExecutor(max_workers=min(4, len(texts))) as executor:
|
||||
tasks = []
|
||||
for text in texts:
|
||||
task = loop.run_in_executor(
|
||||
executor,
|
||||
self._generate_single_embedding,
|
||||
text
|
||||
)
|
||||
tasks.append(task)
|
||||
async def generate_embedding(memory_id: str, text: str) -> None:
|
||||
async with semaphore:
|
||||
try:
|
||||
embedding, _ = await self.embedding_model.get_embedding(text)
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
results[memory_id] = embedding
|
||||
else:
|
||||
logger.warning(
|
||||
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)",
|
||||
self.config.dimension,
|
||||
len(embedding) if embedding else 0,
|
||||
memory_id,
|
||||
)
|
||||
results[memory_id] = []
|
||||
except Exception as exc:
|
||||
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
|
||||
results[memory_id] = []
|
||||
|
||||
embeddings = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
valid_embeddings = []
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if isinstance(embedding, Exception):
|
||||
logger.warning(f"生成第 {i} 个文本的嵌入向量失败: {embedding}")
|
||||
valid_embeddings.append([])
|
||||
elif embedding and len(embedding) == self.config.dimension:
|
||||
valid_embeddings.append(embedding)
|
||||
else:
|
||||
logger.warning(f"第 {i} 个文本的嵌入向量格式异常")
|
||||
valid_embeddings.append([])
|
||||
|
||||
return valid_embeddings
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
return [[] for _ in texts]
|
||||
for memory_id in memory_ids:
|
||||
results.setdefault(memory_id, [])
|
||||
|
||||
def _generate_single_embedding(self, text: str) -> List[float]:
|
||||
"""生成单个文本的嵌入向量"""
|
||||
try:
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 使用模型生成嵌入向量
|
||||
embedding, _ = loop.run_until_complete(
|
||||
self.embedding_model.get_embedding(text)
|
||||
)
|
||||
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
return embedding
|
||||
else:
|
||||
logger.warning(f"嵌入向量维度不匹配: 期望 {self.config.dimension}, 实际 {len(embedding) if embedding else 0}")
|
||||
return []
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成嵌入向量失败: {e}")
|
||||
return []
|
||||
return results
|
||||
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]):
|
||||
"""添加单个记忆到向量存储"""
|
||||
|
||||
@@ -38,6 +38,14 @@ class StreamLoopManager:
|
||||
global_config.chat, "max_concurrent_distributions", 10
|
||||
)
|
||||
|
||||
# 强制分发策略
|
||||
self.force_dispatch_unread_threshold: Optional[int] = getattr(
|
||||
global_config.chat, "force_dispatch_unread_threshold", 20
|
||||
)
|
||||
self.force_dispatch_min_interval: float = getattr(
|
||||
global_config.chat, "force_dispatch_min_interval", 0.1
|
||||
)
|
||||
|
||||
# Chatter管理器
|
||||
self.chatter_manager: Optional[ChatterManager] = None
|
||||
|
||||
@@ -75,7 +83,7 @@ class StreamLoopManager:
|
||||
|
||||
logger.info("流循环管理器已停止")
|
||||
|
||||
async def start_stream_loop(self, stream_id: str) -> bool:
|
||||
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
|
||||
"""启动指定流的循环任务
|
||||
|
||||
Args:
|
||||
@@ -90,11 +98,19 @@ class StreamLoopManager:
|
||||
logger.debug(f"流 {stream_id} 循环已在运行")
|
||||
return True
|
||||
|
||||
# 判断是否需要强制分发
|
||||
force = force or self._should_force_dispatch_for_stream(stream_id)
|
||||
|
||||
# 检查是否超过最大并发限制
|
||||
if len(self.stream_loops) >= self.max_concurrent_streams:
|
||||
if len(self.stream_loops) >= self.max_concurrent_streams and not force:
|
||||
logger.warning(f"超过最大并发流数限制,无法启动流 {stream_id}")
|
||||
return False
|
||||
|
||||
if force and len(self.stream_loops) >= self.max_concurrent_streams:
|
||||
logger.warning(
|
||||
"流 %s 未读消息积压严重(>%s),突破并发限制强制启动分发", stream_id, self.force_dispatch_unread_threshold
|
||||
)
|
||||
|
||||
# 创建流循环任务
|
||||
task = asyncio.create_task(self._stream_loop(stream_id))
|
||||
self.stream_loops[stream_id] = task
|
||||
@@ -145,9 +161,16 @@ class StreamLoopManager:
|
||||
continue
|
||||
|
||||
# 2. 检查是否有消息需要处理
|
||||
has_messages = await self._has_messages_to_process(context)
|
||||
unread_count = self._get_unread_count(context)
|
||||
force_dispatch = self._needs_force_dispatch_for_context(context, unread_count)
|
||||
|
||||
has_messages = force_dispatch or await self._has_messages_to_process(context)
|
||||
|
||||
if has_messages:
|
||||
if force_dispatch:
|
||||
logger.info(
|
||||
"流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count
|
||||
)
|
||||
# 3. 激活chatter处理
|
||||
success = await self._process_stream_messages(stream_id, context)
|
||||
|
||||
@@ -162,6 +185,17 @@ class StreamLoopManager:
|
||||
# 4. 计算下次检查间隔
|
||||
interval = await self._calculate_interval(stream_id, has_messages)
|
||||
|
||||
if has_messages:
|
||||
updated_unread_count = self._get_unread_count(context)
|
||||
if self._needs_force_dispatch_for_context(context, updated_unread_count):
|
||||
interval = min(interval, max(self.force_dispatch_min_interval, 0.0))
|
||||
logger.debug(
|
||||
"流 %s 未读消息仍有 %d 条,使用加速分发间隔 %.2fs",
|
||||
stream_id,
|
||||
updated_unread_count,
|
||||
interval,
|
||||
)
|
||||
|
||||
# 5. sleep等待下次检查
|
||||
logger.info(f"流 {stream_id} 等待 {interval:.2f}s")
|
||||
await asyncio.sleep(interval)
|
||||
@@ -319,6 +353,38 @@ class StreamLoopManager:
|
||||
self.chatter_manager = chatter_manager
|
||||
logger.info(f"设置chatter管理器: {chatter_manager.__class__.__name__}")
|
||||
|
||||
def _should_force_dispatch_for_stream(self, stream_id: str) -> bool:
|
||||
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
|
||||
return False
|
||||
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
return False
|
||||
|
||||
unread = getattr(chat_stream.context_manager.context, "unread_messages", [])
|
||||
return len(unread) > self.force_dispatch_unread_threshold
|
||||
except Exception as e:
|
||||
logger.debug(f"检查流 {stream_id} 是否需要强制分发失败: {e}")
|
||||
return False
|
||||
|
||||
def _get_unread_count(self, context: Any) -> int:
|
||||
try:
|
||||
unread_messages = getattr(context, "unread_messages", None)
|
||||
if unread_messages is None:
|
||||
return 0
|
||||
return len(unread_messages)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: Optional[int] = None) -> bool:
|
||||
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
|
||||
return False
|
||||
|
||||
count = unread_count if unread_count is not None else self._get_unread_count(context)
|
||||
return count > self.force_dispatch_unread_threshold
|
||||
|
||||
def get_performance_summary(self) -> Dict[str, Any]:
|
||||
"""获取性能摘要
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import random
|
||||
import time
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING, List
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats
|
||||
@@ -86,11 +87,8 @@ class MessageManager:
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
success = await chat_stream.context_manager.add_message(message)
|
||||
if success:
|
||||
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||
else:
|
||||
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||
await self._check_and_handle_interruption(chat_stream)
|
||||
chat_stream.context_manager.context.processing_task = asyncio.create_task(chat_stream.context_manager.add_message(message))
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
@@ -280,51 +278,51 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"清理不活跃聊天流时发生错误: {e}")
|
||||
|
||||
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
|
||||
async def _check_and_handle_interruption(self, chat_stream: Optional[ChatStream] = None):
|
||||
"""检查并处理消息打断"""
|
||||
if not global_config.chat.interruption_enabled:
|
||||
return
|
||||
|
||||
# 检查是否有正在进行的处理任务
|
||||
if context.processing_task and not context.processing_task.done():
|
||||
if chat_stream.context_manager.context.processing_task and not chat_stream.context_manager.context.processing_task.done():
|
||||
# 计算打断概率
|
||||
interruption_probability = context.calculate_interruption_probability(
|
||||
interruption_probability = chat_stream.context_manager.context.calculate_interruption_probability(
|
||||
global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor
|
||||
)
|
||||
|
||||
# 检查是否已达到最大打断次数
|
||||
if context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
if chat_stream.context_manager.context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
logger.debug(
|
||||
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},跳过打断检查"
|
||||
f"聊天流 {chat_stream.stream_id} 已达到最大打断次数 {chat_stream.context_manager.context.interruption_count}/{global_config.chat.interruption_max_limit},跳过打断检查"
|
||||
)
|
||||
return
|
||||
|
||||
# 根据概率决定是否打断
|
||||
if random.random() < interruption_probability:
|
||||
logger.info(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
|
||||
logger.info(f"聊天流 {chat_stream.stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
|
||||
|
||||
# 取消现有任务
|
||||
context.processing_task.cancel()
|
||||
chat_stream.context_manager.context.processing_task.cancel()
|
||||
try:
|
||||
await context.processing_task
|
||||
await chat_stream.context_manager.context.processing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 增加打断计数并应用afc阈值降低
|
||||
context.increment_interruption_count()
|
||||
context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction)
|
||||
chat_stream.context_manager.context.increment_interruption_count()
|
||||
chat_stream.context_manager.context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction)
|
||||
|
||||
# 检查是否已达到最大次数
|
||||
if context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
if chat_stream.context_manager.context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
logger.warning(
|
||||
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},后续消息将不再打断"
|
||||
f"聊天流 {chat_stream.stream_id} 已达到最大打断次数 {chat_stream.context_manager.context.interruption_count}/{global_config.chat.interruption_max_limit},后续消息将不再打断"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {context.get_afc_threshold_adjustment()}"
|
||||
f"聊天流 {chat_stream.stream_id} 已打断,当前打断次数: {chat_stream.context_manager.context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {chat_stream.context_manager.context.get_afc_threshold_adjustment()}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
|
||||
logger.debug(f"聊天流 {chat_stream.stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
|
||||
|
||||
async def clear_all_unread_messages(self, stream_id: str):
|
||||
"""清除指定上下文中的所有未读消息,在消息处理完成后调用"""
|
||||
@@ -358,4 +356,4 @@ class MessageManager:
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
message_manager = MessageManager()
|
||||
message_manager = MessageManag
|
||||
@@ -470,20 +470,114 @@ class DefaultReplyer:
|
||||
try:
|
||||
# 使用新的增强记忆系统
|
||||
from src.chat.memory_system.enhanced_memory_integration import recall_memories, remember_message
|
||||
|
||||
|
||||
stream = self.chat_stream
|
||||
user_info_obj = getattr(stream, "user_info", None)
|
||||
group_info_obj = getattr(stream, "group_info", None)
|
||||
|
||||
memory_user_id = str(stream.stream_id)
|
||||
memory_user_display = None
|
||||
memory_aliases = []
|
||||
user_info_dict = {}
|
||||
|
||||
if user_info_obj is not None:
|
||||
raw_user_id = getattr(user_info_obj, "user_id", None)
|
||||
if raw_user_id:
|
||||
memory_user_id = str(raw_user_id)
|
||||
|
||||
if hasattr(user_info_obj, "to_dict"):
|
||||
try:
|
||||
user_info_dict = user_info_obj.to_dict() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
user_info_dict = {}
|
||||
|
||||
candidate_keys = [
|
||||
"user_cardname",
|
||||
"user_nickname",
|
||||
"nickname",
|
||||
"remark",
|
||||
"display_name",
|
||||
"user_name",
|
||||
]
|
||||
|
||||
for key in candidate_keys:
|
||||
value = user_info_dict.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
stripped = value.strip()
|
||||
if memory_user_display is None:
|
||||
memory_user_display = stripped
|
||||
elif stripped not in memory_aliases:
|
||||
memory_aliases.append(stripped)
|
||||
|
||||
attr_keys = [
|
||||
"user_cardname",
|
||||
"user_nickname",
|
||||
"nickname",
|
||||
"remark",
|
||||
"display_name",
|
||||
"name",
|
||||
]
|
||||
|
||||
for attr in attr_keys:
|
||||
value = getattr(user_info_obj, attr, None)
|
||||
if isinstance(value, str) and value.strip():
|
||||
stripped = value.strip()
|
||||
if memory_user_display is None:
|
||||
memory_user_display = stripped
|
||||
elif stripped not in memory_aliases:
|
||||
memory_aliases.append(stripped)
|
||||
|
||||
alias_values = (
|
||||
user_info_dict.get("aliases")
|
||||
or user_info_dict.get("alias_names")
|
||||
or user_info_dict.get("alias")
|
||||
)
|
||||
if isinstance(alias_values, (list, tuple, set)):
|
||||
for alias in alias_values:
|
||||
if isinstance(alias, str) and alias.strip():
|
||||
stripped = alias.strip()
|
||||
if stripped not in memory_aliases and stripped != memory_user_display:
|
||||
memory_aliases.append(stripped)
|
||||
|
||||
memory_context = {
|
||||
"user_id": memory_user_id,
|
||||
"user_display_name": memory_user_display or "",
|
||||
"user_name": memory_user_display or "",
|
||||
"nickname": memory_user_display or "",
|
||||
"sender_name": memory_user_display or "",
|
||||
"platform": getattr(stream, "platform", None),
|
||||
"chat_id": stream.stream_id,
|
||||
"stream_id": stream.stream_id,
|
||||
}
|
||||
|
||||
if memory_aliases:
|
||||
memory_context["user_aliases"] = memory_aliases
|
||||
|
||||
if group_info_obj is not None:
|
||||
group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None)
|
||||
if group_name:
|
||||
memory_context["group_name"] = str(group_name)
|
||||
group_id = getattr(group_info_obj, "group_id", None)
|
||||
if group_id:
|
||||
memory_context["group_id"] = str(group_id)
|
||||
|
||||
memory_context = {key: value for key, value in memory_context.items() if value}
|
||||
|
||||
# 检索相关记忆
|
||||
enhanced_memories = await recall_memories(
|
||||
query=target,
|
||||
user_id=str(self.chat_stream.stream_id),
|
||||
chat_id=self.chat_stream.stream_id
|
||||
user_id=memory_user_id,
|
||||
chat_id=stream.stream_id,
|
||||
context=memory_context
|
||||
)
|
||||
|
||||
# 异步存储聊天历史(非阻塞)
|
||||
asyncio.create_task(
|
||||
remember_message(
|
||||
message=chat_history,
|
||||
user_id=str(self.chat_stream.stream_id),
|
||||
chat_id=self.chat_stream.stream_id
|
||||
user_id=memory_user_id,
|
||||
chat_id=stream.stream_id,
|
||||
context=memory_context
|
||||
)
|
||||
)
|
||||
|
||||
@@ -492,17 +586,20 @@ class DefaultReplyer:
|
||||
if enhanced_memories and enhanced_memories.get("has_memories"):
|
||||
for memory in enhanced_memories.get("memories", []):
|
||||
running_memories.append({
|
||||
'content': memory.get("content", ""),
|
||||
'score': memory.get("confidence", 0.0),
|
||||
'memory_type': memory.get("type", "unknown")
|
||||
"content": memory.get("content", ""),
|
||||
"memory_type": memory.get("type", "unknown"),
|
||||
"confidence": memory.get("confidence"),
|
||||
"importance": memory.get("importance"),
|
||||
"relevance": memory.get("relevance"),
|
||||
"source": memory.get("source"),
|
||||
"structure": memory.get("structure"),
|
||||
})
|
||||
|
||||
# 构建瞬时记忆字符串
|
||||
if enhanced_memories and enhanced_memories.get("has_memories"):
|
||||
instant_memory = "\\n".join([
|
||||
f"{memory.get('content', '')} (相似度: {memory.get('confidence', 0.0):.2f})"
|
||||
for memory in enhanced_memories.get("memories", [])[:3] # 取前3条
|
||||
])
|
||||
top_memory = enhanced_memories.get("memories", [])[:1]
|
||||
if top_memory:
|
||||
instant_memory = top_memory[0].get("content", "")
|
||||
|
||||
logger.info(f"增强记忆系统检索到 {len(running_memories)} 条记忆")
|
||||
|
||||
@@ -511,6 +608,20 @@ class DefaultReplyer:
|
||||
running_memories = []
|
||||
instant_memory = ""
|
||||
|
||||
def _format_confidence_label(value: Optional[float]) -> str:
|
||||
if value is None:
|
||||
return "未知"
|
||||
mapping = {4: "已验证", 3: "高", 2: "中等", 1: "较低"}
|
||||
rounded = int(value)
|
||||
return mapping.get(rounded, f"{value:.2f}")
|
||||
|
||||
def _format_importance_label(value: Optional[float]) -> str:
|
||||
if value is None:
|
||||
return "未知"
|
||||
mapping = {4: "关键", 3: "高", 2: "一般", 1: "较低"}
|
||||
rounded = int(value)
|
||||
return mapping.get(rounded, f"{value:.2f}")
|
||||
|
||||
# 构建记忆字符串,即使某种记忆为空也要继续
|
||||
memory_str = ""
|
||||
has_any_memory = False
|
||||
@@ -520,15 +631,26 @@ class DefaultReplyer:
|
||||
if not memory_str:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memories:
|
||||
memory_str += f"- {running_memory['content']} (类型: {running_memory['memory_type']}, 相似度: {running_memory['score']:.2f})\n"
|
||||
details = []
|
||||
details.append(f"类型: {running_memory.get('memory_type', 'unknown')}")
|
||||
if running_memory.get("confidence") is not None:
|
||||
details.append(f"置信度: {_format_confidence_label(running_memory.get('confidence'))}")
|
||||
if running_memory.get("importance") is not None:
|
||||
details.append(f"重要性: {_format_importance_label(running_memory.get('importance'))}")
|
||||
if running_memory.get("relevance") is not None:
|
||||
details.append(f"相关度: {running_memory['relevance']:.2f}")
|
||||
|
||||
detail_text = f" ({','.join(details)})" if details else ""
|
||||
memory_str += f"- {running_memory['content']}{detail_text}\n"
|
||||
has_any_memory = True
|
||||
|
||||
# 添加瞬时记忆
|
||||
if instant_memory:
|
||||
if not memory_str:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
memory_str += f"- {instant_memory}\n"
|
||||
has_any_memory = True
|
||||
if not any(rm["content"] == instant_memory for rm in running_memories):
|
||||
if not memory_str:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
memory_str += f"- 最相关记忆:{instant_memory}\n"
|
||||
has_any_memory = True
|
||||
|
||||
# 只有当完全没有任何记忆时才返回空字符串
|
||||
return memory_str if has_any_memory else ""
|
||||
|
||||
@@ -417,7 +417,7 @@ class Prompt:
|
||||
context_data[key] = value
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"构建超时 ({timeout_seconds}s)")
|
||||
logger.error("构建超时")
|
||||
context_data = {}
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
@@ -580,14 +580,18 @@ class Prompt:
|
||||
|
||||
# 构建记忆块
|
||||
memory_parts = []
|
||||
existing_contents = set()
|
||||
|
||||
if running_memories:
|
||||
memory_parts.append("以下是当前在聊天中,你回忆起的记忆:")
|
||||
for memory in running_memories:
|
||||
memory_parts.append(f"- {memory['content']}")
|
||||
content = memory["content"]
|
||||
memory_parts.append(f"- {content}")
|
||||
existing_contents.add(content)
|
||||
|
||||
if instant_memory:
|
||||
memory_parts.append(f"- {instant_memory}")
|
||||
if instant_memory not in existing_contents:
|
||||
memory_parts.append(f"- 最相关记忆:{instant_memory}")
|
||||
|
||||
memory_block = "\n".join(memory_parts) if memory_parts else ""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user