feat(memory): 重构记忆系统并移除插件热重载

重构记忆系统核心模块,引入全局记忆作用域、记忆指纹去重机制和查询规划器,优化多阶段检索性能。移除插件热重载系统及其相关依赖。

主要变更:
- 引入全局记忆作用域,简化记忆管理
- 实现记忆指纹去重,避免重复记忆存储
- 新增查询规划器,支持语义查询规划和记忆类型过滤
- 优化多阶段检索,增加语义重排和权重配置
- 改进向量存储,支持嵌入维度自动解析和查询向量生成
- 增强元数据索引,支持主体索引和更新操作
- 记忆构建器支持多主体和自然语言展示
- 移除watchdog依赖和插件热重载模块
- 更新配置模板,简化记忆配置项

BREAKING CHANGE: 移除插件热重载系统,相关API和命令不再可用。记忆系统接口有较大调整,使用该系统的模块需要适配新接口。
This commit is contained in:
Windpicker-owo
2025-10-01 04:56:32 +08:00
parent ac73994847
commit 3fcf8e9add
29 changed files with 1643 additions and 925 deletions

View File

@@ -12,6 +12,7 @@ from sqlalchemy import select
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
logger = get_logger("bot_interest_manager")
@@ -28,7 +29,9 @@ class BotInterestManager:
# Embedding客户端配置
self.embedding_request = None
self.embedding_config = None
self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度
configured_dim = resolve_embedding_dimension()
self.embedding_dimension = int(configured_dim) if configured_dim else 0
self._detected_embedding_dimension: Optional[int] = None
@property
def is_initialized(self) -> bool:
@@ -82,8 +85,11 @@ class BotInterestManager:
logger.info("📋 找到embedding模型配置")
self.embedding_config = model_config.model_task_config.embedding
self.embedding_dimension = 1024 # BGE-M3的维度
logger.info(f"📐 使用模型维度: {self.embedding_dimension}")
if self.embedding_dimension:
logger.info(f"📐 配置的embedding维度: {self.embedding_dimension}")
else:
logger.info("📐 未在配置中检测到embedding维度将根据首次返回的向量自动识别")
# 创建LLMRequest实例用于embedding
self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding")
@@ -350,7 +356,27 @@ class BotInterestManager:
if embedding and len(embedding) > 0:
self.embedding_cache[text] = embedding
logger.debug(f"✅ Embedding获取成功维度: {len(embedding)}, 模型: {model_name}")
current_dim = len(embedding)
if self._detected_embedding_dimension is None:
self._detected_embedding_dimension = current_dim
if self.embedding_dimension and self.embedding_dimension != current_dim:
logger.warning(
"⚠️ 实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新",
current_dim,
self.embedding_dimension,
)
else:
self.embedding_dimension = current_dim
logger.info(f"📏 检测到embedding维度: {current_dim}")
elif current_dim != self.embedding_dimension:
logger.warning(
"⚠️ 收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。",
self.embedding_dimension,
current_dim,
)
logger.debug(f"✅ Embedding获取成功维度: {current_dim}, 模型: {model_name}")
return embedding
else:
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")

View File

@@ -26,6 +26,7 @@ from rich.progress import (
TextColumn,
)
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
install(extra_lines=3)
@@ -504,7 +505,10 @@ class EmbeddingStore:
# L2归一化
faiss.normalize_L2(embeddings)
# 构建索引
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
if not embedding_dim:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:

View File

@@ -11,12 +11,27 @@ from dataclasses import dataclass
from src.common.logger import get_logger
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
MEMORY_TYPE_LABELS = {
MemoryType.PERSONAL_FACT: "个人事实",
MemoryType.EVENT: "事件",
MemoryType.PREFERENCE: "偏好",
MemoryType.OPINION: "观点",
MemoryType.RELATIONSHIP: "关系",
MemoryType.EMOTION: "情感",
MemoryType.KNOWLEDGE: "知识",
MemoryType.SKILL: "技能",
MemoryType.GOAL: "目标",
MemoryType.EXPERIENCE: "经验",
MemoryType.CONTEXTUAL: "上下文",
}
@dataclass
class AdapterConfig:
"""适配器配置"""
@@ -85,12 +100,9 @@ class EnhancedMemoryAdapter:
async def process_conversation_memory(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""处理对话记忆"""
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
@@ -98,10 +110,30 @@ class EnhancedMemoryAdapter:
self.adapter_stats["total_processed"] += 1
try:
payload_context: Dict[str, Any] = dict(context or {})
conversation_text = payload_context.get("conversation_text")
if not conversation_text:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
conversation_text = str(conversation_candidate)
payload_context["conversation_text"] = conversation_text
else:
conversation_text = ""
else:
conversation_text = str(conversation_text)
if "timestamp" not in payload_context:
payload_context["timestamp"] = time.time()
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
# 使用集成层处理对话
result = await self.integration_layer.process_conversation(
conversation_text, context, user_id, timestamp
)
result = await self.integration_layer.process_conversation(payload_context)
# 更新统计
processing_time = time.time() - start_time
@@ -132,7 +164,7 @@ class EnhancedMemoryAdapter:
try:
limit = limit or self.config.max_retrieval_results
memories = await self.integration_layer.retrieve_relevant_memories(
query, user_id, context, limit
query, None, context, limit
)
self.adapter_stats["memories_retrieved"] += len(memories)
@@ -157,12 +189,15 @@ class EnhancedMemoryAdapter:
if not memories:
return ""
# 格式化记忆为提示词友好的格式
memory_context_parts = []
for memory in memories:
memory_context_parts.append(f"- {memory.text_content}")
# 格式化记忆为提示词友好的Markdown结构
lines: List[str] = ["### 🧠 相关记忆 (Relevant Memories)", ""]
return "\n".join(memory_context_parts)
for memory in memories:
type_label = MEMORY_TYPE_LABELS.get(memory.memory_type, memory.memory_type.value)
display_text = memory.display or memory.text_content
lines.append(f"- **[{type_label}]** {display_text}")
return "\n".join(lines)
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
"""获取增强记忆系统摘要"""
@@ -270,13 +305,10 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
async def process_conversation_with_enhanced_memory(
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None,
llm_model: Optional[LLMRequest] = None
) -> Dict[str, Any]:
"""使用增强记忆系统处理对话"""
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
@@ -284,7 +316,18 @@ async def process_conversation_with_enhanced_memory(
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp)
payload_context = dict(context or {})
if "conversation_text" not in payload_context:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
payload_context["conversation_text"] = str(conversation_candidate)
return await adapter.process_conversation_memory(payload_context)
except Exception as e:
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}

View File

@@ -1,13 +1,15 @@
# -*- coding: utf-8 -*-
"""
增强型精准记忆系统核心模块
基于文档设计的高效记忆构建、存储与召回优化系统
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
2. 内置 LLM 查询规划器与嵌入维度自动解析机制,直接从模型配置推断向量存储参数。
"""
import asyncio
import time
import orjson
import re
import hashlib
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
@@ -22,12 +24,16 @@ from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
from src.chat.memory_system.metadata_index import MetadataIndexManager
from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger(__name__)
# 全局记忆作用域(共享记忆库)
GLOBAL_MEMORY_SCOPE = "global"
class MemorySystemStatus(Enum):
"""记忆系统状态"""
@@ -47,14 +53,20 @@ class MemorySystemConfig:
memory_value_threshold: float = 0.7
min_build_interval_seconds: float = 300.0
# 向量存储配置
vector_dimension: int = 768
# 向量存储配置(嵌入维度自动来自模型配置)
vector_dimension: int = 1024
similarity_threshold: float = 0.8
# 召回配置
coarse_recall_limit: int = 50
fine_recall_limit: int = 10
semantic_rerank_limit: int = 20
final_recall_limit: int = 5
semantic_similarity_threshold: float = 0.6
vector_weight: float = 0.4
semantic_weight: float = 0.3
context_weight: float = 0.2
recency_weight: float = 0.1
# 融合配置
fusion_similarity_threshold: float = 0.85
@@ -64,6 +76,23 @@ class MemorySystemConfig:
def from_global_config(cls):
"""从全局配置创建配置实例"""
embedding_dimension = None
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
embedding_dimension = getattr(embedding_task, "embedding_dimension", None)
except Exception:
embedding_dimension = None
if not embedding_dimension:
try:
embedding_dimension = getattr(global_config.lpmm_knowledge, "embedding_dimension", None)
except Exception:
embedding_dimension = None
if not embedding_dimension:
embedding_dimension = 1024
return cls(
# 记忆构建配置
min_memory_length=global_config.memory.min_memory_length,
@@ -72,13 +101,19 @@ class MemorySystemConfig:
min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0),
# 向量存储配置
vector_dimension=global_config.memory.vector_dimension,
vector_dimension=int(embedding_dimension),
similarity_threshold=global_config.memory.vector_similarity_threshold,
# 召回配置
coarse_recall_limit=global_config.memory.metadata_filter_limit,
fine_recall_limit=global_config.memory.final_result_limit,
fine_recall_limit=global_config.memory.vector_search_limit,
semantic_rerank_limit=global_config.memory.semantic_rerank_limit,
final_recall_limit=global_config.memory.final_result_limit,
semantic_similarity_threshold=getattr(global_config.memory, "semantic_similarity_threshold", 0.6),
vector_weight=global_config.memory.vector_weight,
semantic_weight=global_config.memory.semantic_weight,
context_weight=global_config.memory.context_weight,
recency_weight=global_config.memory.recency_weight,
# 融合配置
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
@@ -104,6 +139,7 @@ class EnhancedMemorySystem:
self.vector_storage: VectorStorageManager = None
self.metadata_index: MetadataIndexManager = None
self.retrieval_system: MultiStageRetrieval = None
self.query_planner: MemoryQueryPlanner = None
# LLM模型
self.value_assessment_model: LLMRequest = None
@@ -117,6 +153,9 @@ class EnhancedMemorySystem:
# 构建节流记录
self._last_memory_build_times: Dict[str, float] = {}
# 记忆指纹缓存,用于快速检测重复记忆
self._memory_fingerprints: Dict[str, str] = {}
logger.info("EnhancedMemorySystem 初始化开始")
async def initialize(self):
@@ -125,19 +164,29 @@ class EnhancedMemorySystem:
logger.info("正在初始化增强型记忆系统...")
# 初始化LLM模型
task_config = (
self.llm_model.model_for_task
if self.llm_model is not None
else model_config.model_task_config.utils_small
)
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
value_task_config = getattr(model_config.model_task_config, "utils_small", None)
extraction_task_config = getattr(model_config.model_task_config, "utils", None)
if value_task_config is None:
logger.warning("未找到 utils_small 模型配置,回退到 utils 或外部提供的模型配置。")
value_task_config = extraction_task_config or fallback_task
if extraction_task_config is None:
logger.warning("未找到 utils 模型配置,回退到 utils_small 或外部提供的模型配置。")
extraction_task_config = value_task_config or fallback_task
if value_task_config is None or extraction_task_config is None:
raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。")
self.value_assessment_model = LLMRequest(
model_set=task_config,
model_set=value_task_config,
request_type="memory.value_assessment"
)
self.memory_extraction_model = LLMRequest(
model_set=task_config,
model_set=extraction_task_config,
request_type="memory.extraction"
)
@@ -155,13 +204,36 @@ class EnhancedMemorySystem:
retrieval_config = RetrievalConfig(
metadata_filter_limit=self.config.coarse_recall_limit,
vector_search_limit=self.config.fine_recall_limit,
final_result_limit=self.config.final_recall_limit
semantic_rerank_limit=self.config.semantic_rerank_limit,
final_result_limit=self.config.final_recall_limit,
vector_similarity_threshold=self.config.similarity_threshold,
semantic_similarity_threshold=self.config.semantic_similarity_threshold,
vector_weight=self.config.vector_weight,
semantic_weight=self.config.semantic_weight,
context_weight=self.config.context_weight,
recency_weight=self.config.recency_weight,
)
self.retrieval_system = MultiStageRetrieval(retrieval_config)
planner_task_config = getattr(model_config.model_task_config, "planner", None)
planner_model: Optional[LLMRequest] = None
try:
planner_model = LLMRequest(
model_set=planner_task_config,
request_type="memory.query_planner"
)
except Exception as planner_exc:
logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True)
self.query_planner = MemoryQueryPlanner(
planner_model,
default_limit=self.config.final_recall_limit
)
# 加载持久化数据
await self.vector_storage.load_storage()
await self.metadata_index.load_index()
self._populate_memory_fingerprints()
self.status = MemorySystemStatus.READY
logger.info("✅ 增强型记忆系统初始化完成")
@@ -174,7 +246,7 @@ class EnhancedMemorySystem:
async def retrieve_memories_for_building(
self,
query_text: str,
user_id: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
) -> List[MemoryChunk]:
@@ -182,7 +254,6 @@ class EnhancedMemorySystem:
Args:
query_text: 查询文本
user_id: 用户ID
context: 上下文信息
limit: 返回结果数量限制
@@ -201,7 +272,6 @@ class EnhancedMemorySystem:
# 执行检索
memories = await self.vector_storage.search_similar_memories(
query_text=query_text,
user_id=user_id,
limit=limit
)
@@ -218,23 +288,18 @@ class EnhancedMemorySystem:
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
) -> List[MemoryChunk]:
"""从对话中构建记忆
Args:
conversation_text: 对话文本
context: 上下文信息(包括用户信息、群组信息等)
user_id: 用户ID
context: 上下文信息
timestamp: 时间戳,默认为当前时间
Returns:
构建的记忆块列表
"""
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
raise RuntimeError("记忆系统未就绪")
original_status = self.status
self.status = MemorySystemStatus.BUILDING
start_time = time.time()
@@ -243,9 +308,9 @@ class EnhancedMemorySystem:
build_marker_time: Optional[float] = None
try:
normalized_context = self._normalize_context(context, user_id, timestamp)
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
build_scope_key = self._get_build_scope_key(normalized_context, user_id)
build_scope_key = self._get_build_scope_key(normalized_context, GLOBAL_MEMORY_SCOPE)
min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0))
current_time = time.time()
@@ -266,7 +331,7 @@ class EnhancedMemorySystem:
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
logger.debug("开始构建记忆,文本长度: %d", len(conversation_text))
# 1. 信息价值评估
value_score = await self._assess_information_value(conversation_text, normalized_context)
@@ -280,7 +345,7 @@ class EnhancedMemorySystem:
memory_chunks = await self.memory_builder.build_memories(
conversation_text,
normalized_context,
user_id,
GLOBAL_MEMORY_SCOPE,
timestamp or time.time()
)
@@ -293,19 +358,24 @@ class EnhancedMemorySystem:
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks)
# 4. 存储记忆
await self._store_memories(fused_chunks)
stored_count = await self._store_memories(fused_chunks)
# 4.1 控制台预览
self._log_memory_preview(fused_chunks)
# 5. 更新统计
self.total_memories += len(fused_chunks)
self.total_memories += stored_count
self.last_build_time = time.time()
if build_scope_key:
self._last_memory_build_times[build_scope_key] = self.last_build_time
build_time = time.time() - start_time
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}")
logger.info(
"✅ 生成 %d 条记忆,成功入库 %d 条,耗时 %.2f",
len(fused_chunks),
stored_count,
build_time,
)
self.status = original_status
return fused_chunks
@@ -347,21 +417,34 @@ class EnhancedMemorySystem:
async def process_conversation_memory(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Dict[str, Any]
) -> Dict[str, Any]:
"""对外暴露的对话记忆处理接口,兼容旧调用方式"""
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, timestamp)
context = dict(context or {})
conversation_candidate = (
context.get("conversation_text")
or context.get("message_content")
or context.get("latest_message")
or context.get("raw_text")
or ""
)
conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
timestamp = context.get("timestamp")
if timestamp is None:
timestamp = time.time()
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
normalized_context.setdefault("conversation_text", conversation_text)
memories = await self.build_memory_from_conversation(
conversation_text=conversation_text,
context=normalized_context,
user_id=user_id,
timestamp=timestamp
)
@@ -395,52 +478,77 @@ class EnhancedMemorySystem:
**kwargs
) -> List[MemoryChunk]:
"""检索相关记忆,兼容 query/query_text 参数形式"""
if self.status != MemorySystemStatus.READY:
raise RuntimeError("记忆系统未就绪")
query_text = query_text or kwargs.get("query")
if not query_text:
raw_query = query_text or kwargs.get("query")
if not raw_query:
raise ValueError("query_text 或 query 参数不能为空")
context = context or {}
user_id = user_id or kwargs.get("user_id")
resolved_user_id = GLOBAL_MEMORY_SCOPE
if self.retrieval_system is None or self.metadata_index is None:
raise RuntimeError("检索组件未初始化")
all_memories_cache = self.vector_storage.memory_cache
if not all_memories_cache:
logger.debug("记忆缓存为空,返回空结果")
self.last_retrieval_time = time.time()
self.status = MemorySystemStatus.READY
return []
self.status = MemorySystemStatus.RETRIEVING
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, None)
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
candidate_memories = list(self.vector_storage.memory_cache.values())
if user_id:
candidate_memories = [m for m in candidate_memories if m.user_id == user_id]
effective_limit = limit or self.config.final_recall_limit
query_plan = None
planner_ran = False
resolved_query_text = raw_query
if self.query_planner:
try:
planner_ran = True
query_plan = await self.query_planner.plan_query(raw_query, normalized_context)
normalized_context["query_plan"] = query_plan
effective_limit = min(effective_limit, query_plan.limit or effective_limit)
if getattr(query_plan, "semantic_query", None):
resolved_query_text = query_plan.semantic_query
logger.debug(
"查询规划: semantic='%s', types=%s, subjects=%s, limit=%d",
query_plan.semantic_query,
[mt.value for mt in query_plan.memory_types],
query_plan.subject_includes,
query_plan.limit,
)
except Exception as plan_exc:
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
if not candidate_memories:
self.status = MemorySystemStatus.READY
self.last_retrieval_time = time.time()
logger.debug(f"未找到用户 {user_id} 的候选记忆")
return []
effective_limit = effective_limit or self.config.final_recall_limit
effective_limit = max(1, min(effective_limit, self.config.final_recall_limit))
normalized_context["resolved_query_text"] = resolved_query_text
scored_memories = []
for memory in candidate_memories:
score = self._compute_memory_score(query_text, memory, normalized_context)
if score > 0:
scored_memories.append((memory, score))
if not scored_memories:
# 如果所有分数为0返回最近的记忆作为降级策略
if normalized_context.get("__memory_building__"):
logger.debug("当前处于记忆构建流程,跳过查询规划并进行降级检索")
self.status = MemorySystemStatus.BUILDING
final_memories = []
candidate_memories = list(all_memories_cache.values())
candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True)
scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]]
final_memories = candidate_memories[:effective_limit]
else:
scored_memories.sort(key=lambda item: item[1], reverse=True)
retrieval_result = await self.retrieval_system.retrieve_memories(
query=resolved_query_text,
user_id=resolved_user_id,
context=normalized_context,
metadata_index=self.metadata_index,
vector_storage=self.vector_storage,
all_memories_cache=all_memories_cache,
limit=effective_limit,
)
top_memories = [memory for memory, _ in scored_memories[:limit]]
final_memories = retrieval_result.final_memories
# 更新访问信息和缓存
for memory, score in scored_memories[:limit]:
for memory in final_memories:
memory.update_access()
memory.update_relevance(score)
cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id)
if cache_entry is not None:
cache_entry["last_accessed"] = memory.metadata.last_accessed
@@ -448,14 +556,34 @@ class EnhancedMemorySystem:
cache_entry["relevance_score"] = memory.metadata.relevance_score
retrieval_time = time.time() - start_time
logger.info(
f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}"
plan_summary = ""
if planner_ran and query_plan:
plan_types = ",".join(mt.value for mt in query_plan.memory_types) or "-"
plan_subjects = ",".join(query_plan.subject_includes) or "-"
plan_summary = (
f" | planner.semantic='{query_plan.semantic_query}'"
f" | planner.limit={query_plan.limit}"
f" | planner.types={plan_types}"
f" | planner.subjects={plan_subjects}"
)
log_message = (
"✅ 记忆检索完成"
f" | user={resolved_user_id}"
f" | count={len(final_memories)}"
f" | duration={retrieval_time:.3f}s"
f" | applied_limit={effective_limit}"
f" | raw_query='{raw_query}'"
f" | semantic_query='{resolved_query_text}'"
f"{plan_summary}"
)
logger.info(log_message)
self.last_retrieval_time = time.time()
self.status = MemorySystemStatus.READY
return top_memories
return final_memories
except Exception as e:
self.status = MemorySystemStatus.ERROR
@@ -499,8 +627,8 @@ class EnhancedMemorySystem:
except Exception:
context = dict(raw_context or {})
# 基础字段
context["user_id"] = context.get("user_id") or user_id or "unknown"
# 基础字段(统一使用全局作用域)
context["user_id"] = GLOBAL_MEMORY_SCOPE
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
context["message_type"] = context.get("message_type") or "normal"
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
@@ -523,8 +651,8 @@ class EnhancedMemorySystem:
if stream_id:
context["stream_id"] = stream_id
# chat_id 兜底
context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}"
# 全局记忆无需聊天隔离
context["chat_id"] = context.get("chat_id") or "global_chat"
# 历史窗口配置
window_candidate = (
@@ -616,18 +744,7 @@ class EnhancedMemorySystem:
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
"""确定用于节流控制的记忆构建作用域"""
stream_id = context.get("stream_id")
if stream_id:
return f"stream::{stream_id}"
chat_id = context.get("chat_id")
if chat_id:
return f"chat::{chat_id}"
if user_id:
return f"user::{user_id}"
return None
return "global_scope"
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
"""确定历史消息获取数量限制在30-50之间"""
@@ -789,24 +906,134 @@ class EnhancedMemorySystem:
logger.error(f"信息价值评估失败: {e}", exc_info=True)
return 0.5 # 默认中等价值
async def _store_memories(self, memory_chunks: List[MemoryChunk]):
"""存储记忆块到各个存储系统"""
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
"""存储记忆块到各个存储系统,返回成功入库数量"""
if not memory_chunks:
return
return 0
unique_memories: List[MemoryChunk] = []
skipped_duplicates = 0
for memory in memory_chunks:
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
existing_id = self._memory_fingerprints.get(key)
if existing_id:
existing = self.vector_storage.memory_cache.get(existing_id)
if existing:
self._merge_existing_memory(existing, memory)
await self.metadata_index.update_memory_entry(existing)
skipped_duplicates += 1
logger.debug(
"检测到重复记忆,已合并到现有记录 | memory_id=%s",
existing.memory_id,
)
continue
else:
# 指纹存在但缓存缺失,视为新记忆并覆盖旧映射
logger.debug("检测到过期指纹映射,重写现有条目")
unique_memories.append(memory)
if not unique_memories:
if skipped_duplicates:
logger.info("本次记忆全部与现有内容重复,跳过入库")
return 0
# 并行存储到向量数据库和元数据索引
storage_tasks = []
storage_tasks = [
self.vector_storage.store_memories(unique_memories),
self.metadata_index.index_memories(unique_memories),
]
# 向量存储
storage_tasks.append(self.vector_storage.store_memories(memory_chunks))
# 元数据索引
storage_tasks.append(self.metadata_index.index_memories(memory_chunks))
# 等待所有存储任务完成
await asyncio.gather(*storage_tasks, return_exceptions=True)
logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统")
self._register_memory_fingerprints(unique_memories)
logger.debug(
"成功存储 %d 条记忆(跳过重复 %d 条)",
len(unique_memories),
skipped_duplicates,
)
return len(unique_memories)
def _merge_existing_memory(self, existing: MemoryChunk, incoming: MemoryChunk) -> None:
"""将新记忆的信息合并到已存在的记忆中"""
updated = False
for keyword in incoming.keywords:
if keyword not in existing.keywords:
existing.add_keyword(keyword)
updated = True
for tag in incoming.tags:
if tag not in existing.tags:
existing.add_tag(tag)
updated = True
for category in incoming.categories:
if category not in existing.categories:
existing.add_category(category)
updated = True
if incoming.metadata.source_context:
existing.metadata.source_context = incoming.metadata.source_context
if incoming.metadata.importance.value > existing.metadata.importance.value:
existing.metadata.importance = incoming.metadata.importance
updated = True
if incoming.metadata.confidence.value > existing.metadata.confidence.value:
existing.metadata.confidence = incoming.metadata.confidence
updated = True
if incoming.metadata.relevance_score > existing.metadata.relevance_score:
existing.metadata.relevance_score = incoming.metadata.relevance_score
updated = True
if updated:
existing.metadata.last_modified = time.time()
def _populate_memory_fingerprints(self) -> None:
"""基于当前缓存构建记忆指纹映射"""
self._memory_fingerprints.clear()
for memory in self.vector_storage.memory_cache.values():
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
self._memory_fingerprints[key] = memory.memory_id
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
for memory in memories:
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
self._memory_fingerprints[key] = memory.memory_id
def _build_memory_fingerprint(self, memory: MemoryChunk) -> str:
subjects = memory.subjects or []
subject_part = "|".join(sorted(s.strip() for s in subjects if s))
predicate_part = (memory.content.predicate or "").strip()
obj = memory.content.object
if isinstance(obj, (dict, list)):
obj_part = orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8")
else:
obj_part = str(obj).strip()
base = "|".join([
str(memory.user_id or "unknown"),
memory.memory_type.value,
subject_part,
predicate_part,
obj_part,
])
return hashlib.sha256(base.encode("utf-8")).hexdigest()
@staticmethod
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
return f"{str(user_id)}:{fingerprint}"
def get_system_stats(self) -> Dict[str, Any]:
"""获取系统统计信息"""

View File

@@ -241,12 +241,12 @@ class EnhancedMemoryManager:
return []
try:
result = await self.enhanced_system.process_conversation_memory(
conversation_text=conversation_text,
context=context,
user_id=user_id,
timestamp=timestamp
)
payload_context = dict(context or {})
payload_context.setdefault("conversation_text", conversation_text)
if timestamp is not None:
payload_context.setdefault("timestamp", timestamp)
result = await self.enhanced_system.process_conversation_memory(payload_context)
# 从结果中提取记忆块
memory_chunks = []
@@ -274,7 +274,7 @@ class EnhancedMemoryManager:
try:
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
query=query_text,
user_id=user_id,
user_id=None,
context=context or {},
limit=limit
)
@@ -303,6 +303,9 @@ class EnhancedMemoryManager:
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
"""将记忆块转换为更易读的文本描述"""
structure = memory.content.to_dict()
if memory.display:
return self._clean_text(memory.display), structure
subject = structure.get("subject")
predicate = structure.get("predicate") or ""
obj = structure.get("object")

View File

@@ -114,12 +114,9 @@ class MemoryIntegrationLayer:
async def process_conversation(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Dict[str, Any]
) -> Dict[str, Any]:
"""处理对话记忆"""
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
@@ -128,13 +125,12 @@ class MemoryIntegrationLayer:
self.integration_stats["enhanced_queries"] += 1
try:
payload_context = dict(context or {})
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
# 直接使用增强记忆系统处理
result = await self.enhanced_memory.process_conversation_memory(
conversation_text=conversation_text,
context=context,
user_id=user_id,
timestamp=timestamp
)
result = await self.enhanced_memory.process_conversation_memory(payload_context)
# 更新统计
processing_time = time.time() - start_time
@@ -156,7 +152,7 @@ class MemoryIntegrationLayer:
async def retrieve_relevant_memories(
self,
query: str,
user_id: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None
) -> List[MemoryChunk]:
@@ -168,7 +164,7 @@ class MemoryIntegrationLayer:
limit = limit or self.config.max_retrieval_results
memories = await self.enhanced_memory.retrieve_relevant_memories(
query=query,
user_id=user_id,
user_id=None,
context=context or {},
limit=limit
)

View File

@@ -2,21 +2,48 @@
"""
记忆构建模块
从对话流中提取高质量、结构化记忆单元
输出格式要求:
{{
"memories": [
{{
"type": "记忆类型",
"display": "用于直接展示和检索的自然语言描述",
"subject": ["主体1", "主体2"],
"predicate": "谓语(动作/状态)",
"object": "宾语(对象/属性或结构体)",
"keywords": ["关键词1", "关键词2"],
"importance": "重要性等级(1-4)",
"confidence": "置信度(1-4)",
"reasoning": "提取理由"
}}
]
}}
注意:
1. `subject` 可包含多个主体,请用数组表示;若主体不明确,请根据上下文给出最合理的称呼
2. `display` 必须是一句完整流畅的中文描述,可直接用于用户展示和向量搜索
3. 只提取确实值得记忆的信息,不要提取琐碎内容
4. 确保信息准确、具体、有价值
5. 重要性: 1=低, 2=一般, 3=高, 4=关键;置信度: 1=低, 2=中等, 3=高, 4=已验证
"""
import re
import time
import orjson
from typing import Dict, List, Optional, Any
from datetime import datetime
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union, Type
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
create_memory_chunk
MemoryChunk,
MemoryType,
ConfidenceLevel,
ImportanceLevel,
create_memory_chunk,
)
logger = get_logger(__name__)
@@ -24,6 +51,7 @@ logger = get_logger(__name__)
class ExtractionStrategy(Enum):
"""提取策略"""
LLM_BASED = "llm_based" # 基于LLM的智能提取
RULE_BASED = "rule_based" # 基于规则的提取
HYBRID = "hybrid" # 混合策略
@@ -171,18 +199,18 @@ class MemoryBuilder:
"""使用规则提取记忆"""
memories = []
subject_display = self._resolve_user_display(context, user_id)
subjects = self._resolve_conversation_participants(context, user_id)
# 规则1: 检测个人信息
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subject_display)
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects)
memories.extend(personal_info)
# 规则2: 检测偏好信息
preferences = self._extract_preferences(text, user_id, timestamp, context, subject_display)
preferences = self._extract_preferences(text, user_id, timestamp, context, subjects)
memories.extend(preferences)
# 规则3: 检测事件信息
events = self._extract_events(text, user_id, timestamp, context, subject_display)
events = self._extract_events(text, user_id, timestamp, context, subjects)
memories.extend(events)
return memories
@@ -258,10 +286,7 @@ class MemoryBuilder:
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
当前时间: {current_date}
聊天ID: {chat_id}
消息类型: {message_type}
目标用户ID: {target_user_id_display}
目标用户称呼: {target_user_name}
## 🤖 机器人身份(仅供参考,禁止写入记忆)
- 机器人名称: {bot_name_display}
@@ -272,7 +297,6 @@ class MemoryBuilder:
请务必遵守以下命名规范:
- 当说话者是机器人时,请使用“{bot_name_display}”或其他明确称呼作为主语;
- 如果看到系统自动生成的长ID类似 {target_user_id}),请改用“{target_user_name}”、机器人的称呼或“该用户”描述不要把ID写入记忆
- 记录关键事实时,请准确标记主体是机器人还是用户,避免混淆。
对话内容:
@@ -450,7 +474,7 @@ class MemoryBuilder:
bot_identifiers = self._collect_bot_identifiers(context)
system_identifiers = self._collect_system_identifiers(context)
default_subject = self._resolve_user_display(context, user_id)
default_subjects = self._resolve_conversation_participants(context, user_id)
bot_display = None
if context:
@@ -481,19 +505,33 @@ class MemoryBuilder:
for mem_data in memory_list:
try:
subject_value = mem_data.get("subject")
normalized_subject = self._normalize_subject(
normalized_subject = self._normalize_subjects(
subject_value,
bot_identifiers,
system_identifiers,
default_subject,
default_subjects,
bot_display
)
if normalized_subject is None:
if not normalized_subject:
logger.debug("跳过疑似机器人自身信息的记忆: %s", mem_data)
continue
# 创建记忆块
importance_level = self._parse_enum_value(
ImportanceLevel,
mem_data.get("importance"),
ImportanceLevel.NORMAL,
"importance"
)
confidence_level = self._parse_enum_value(
ConfidenceLevel,
mem_data.get("confidence"),
ConfidenceLevel.MEDIUM,
"confidence"
)
memory = create_memory_chunk(
user_id=user_id,
subject=normalized_subject,
@@ -502,8 +540,9 @@ class MemoryBuilder:
memory_type=MemoryType(mem_data.get("type", "contextual")),
chat_id=context.get("chat_id"),
source_context=mem_data.get("reasoning", ""),
importance=ImportanceLevel(mem_data.get("importance", 2)),
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
importance=importance_level,
confidence=confidence_level,
display=mem_data.get("display")
)
# 添加关键词
@@ -511,13 +550,6 @@ class MemoryBuilder:
for keyword in keywords:
memory.add_keyword(keyword)
subject_text = memory.content.subject.strip() if isinstance(memory.content.subject, str) else str(memory.content.subject)
if not subject_text:
memory.content.subject = default_subject
elif subject_text.lower() in system_identifiers or self._looks_like_system_identifier(subject_text):
logger.debug("将系统标识主语替换为默认用户名称: %s", subject_text)
memory.content.subject = default_subject
memories.append(memory)
except Exception as e:
@@ -526,6 +558,64 @@ class MemoryBuilder:
return memories
def _parse_enum_value(
self,
enum_cls: Type[Enum],
raw_value: Any,
default: Enum,
field_name: str
) -> Enum:
"""解析枚举值,兼容数字/字符串表示"""
if isinstance(raw_value, enum_cls):
return raw_value
if raw_value is None:
return default
# 直接尝试整数转换
if isinstance(raw_value, (int, float)):
int_value = int(raw_value)
try:
return enum_cls(int_value)
except ValueError:
logger.debug("%s=%s 无法解析为 %s", field_name, raw_value, enum_cls.__name__)
return default
if isinstance(raw_value, str):
value_str = raw_value.strip()
if not value_str:
return default
if value_str.isdigit():
try:
return enum_cls(int(value_str))
except ValueError:
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
else:
normalized = value_str.replace("-", "_").replace(" ", "_").upper()
for member in enum_cls:
if member.name == normalized:
return member
for member in enum_cls:
if str(member.value).lower() == value_str.lower():
return member
try:
return enum_cls(value_str)
except ValueError:
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
try:
return enum_cls(raw_value)
except Exception:
logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
field_name,
raw_value,
type(raw_value).__name__,
enum_cls.__name__,
default.name)
return default
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
identifiers: set[str] = {"bot", "机器人", "ai助手"}
if not context:
@@ -580,6 +670,58 @@ class MemoryBuilder:
return identifiers
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
participants: List[str] = []
if context:
candidate_keys = [
"participants",
"participant_names",
"speaker_names",
"members",
"member_names",
"mention_users",
"audiences"
]
for key in candidate_keys:
value = context.get(key)
if isinstance(value, (list, tuple, set)):
for item in value:
if isinstance(item, str):
cleaned = self._clean_subject_text(item)
if cleaned:
participants.append(cleaned)
elif isinstance(value, str):
for part in self._split_subject_string(value):
if part:
participants.append(part)
fallback = self._resolve_user_display(context, user_id)
if fallback:
participants.append(fallback)
if context:
bot_name = context.get("bot_name") or context.get("bot_identity")
if isinstance(bot_name, str):
cleaned = self._clean_subject_text(bot_name)
if cleaned:
participants.append(cleaned)
if not participants:
participants = ["对话参与者"]
deduplicated: List[str] = []
seen = set()
for name in participants:
key = name.lower()
if key in seen:
continue
seen.add(key)
deduplicated.append(name)
return deduplicated
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
candidate_keys = [
"user_display_name",
@@ -626,51 +768,160 @@ class MemoryBuilder:
return False
def _normalize_subject(
def _split_subject_string(self, value: str) -> List[str]:
if not value:
return []
replaced = re.sub(r"\band\b", "", value, flags=re.IGNORECASE)
replaced = replaced.replace("", "").replace("", "").replace("", "")
replaced = replaced.replace("&", "").replace("/", "").replace("+", "")
tokens = [self._clean_subject_text(token) for token in re.split(r"[、,;]+", replaced)]
return [token for token in tokens if token]
def _normalize_subjects(
self,
subject: Any,
bot_identifiers: set[str],
system_identifiers: set[str],
default_subject: str,
default_subjects: List[str],
bot_display: Optional[str] = None
) -> Optional[str]:
if subject is None:
return default_subject
) -> List[str]:
defaults = default_subjects or ["对话参与者"]
subject_str = subject if isinstance(subject, str) else str(subject)
cleaned = self._clean_subject_text(subject_str)
if not cleaned:
return default_subject
raw_candidates: List[str] = []
if isinstance(subject, list):
for item in subject:
if isinstance(item, str):
raw_candidates.extend(self._split_subject_string(item))
elif item is not None:
raw_candidates.extend(self._split_subject_string(str(item)))
elif isinstance(subject, str):
raw_candidates.extend(self._split_subject_string(subject))
elif subject is not None:
raw_candidates.extend(self._split_subject_string(str(subject)))
lowered = cleaned.lower()
normalized: List[str] = []
bot_primary = self._clean_subject_text(bot_display or "")
if lowered in bot_identifiers:
return bot_primary or cleaned
for candidate in raw_candidates:
if not candidate:
continue
if lowered in {"用户", "user", "the user", "对方", "对手"}:
return default_subject
lowered = candidate.lower()
if lowered in bot_identifiers:
normalized.append(bot_primary or candidate)
continue
prefix_match = re.match(r"^(用户|User|user|USER|成员|member|Member|target|Target|TARGET)[\s:\-\u2014_]*?(.*)$", cleaned)
if prefix_match:
remainder = self._clean_subject_text(prefix_match.group(2))
if not remainder:
return default_subject
remainder_lower = remainder.lower()
if remainder_lower in bot_identifiers:
return bot_primary or remainder
if (
remainder_lower in system_identifiers
or self._looks_like_system_identifier(remainder)
):
return default_subject
cleaned = remainder
lowered = cleaned.lower()
if lowered in {"用户", "user", "the user", "对方", "对手"}:
normalized.extend(defaults)
continue
if lowered in system_identifiers or self._looks_like_system_identifier(cleaned):
return default_subject
if lowered in system_identifiers or self._looks_like_system_identifier(candidate):
continue
return cleaned
normalized.append(candidate)
if not normalized:
normalized = list(defaults)
deduplicated: List[str] = []
seen = set()
for name in normalized:
key = name.lower()
if key in seen:
continue
seen.add(key)
deduplicated.append(name)
return deduplicated
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
if isinstance(obj, dict):
for key in keys:
value = obj.get(key)
if value is None:
continue
if isinstance(value, list):
compact = "".join(str(item) for item in value[:3])
if compact:
return compact
else:
value_str = str(value).strip()
if value_str:
return value_str
elif isinstance(obj, list):
compact = "".join(str(item) for item in obj[:3])
return compact or None
elif isinstance(obj, str):
return obj.strip() or None
return None
def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str:
subject_phrase = "".join(subjects) if subjects else "对话参与者"
predicate = (predicate or "").strip()
if predicate == "is_named":
name = self._extract_value_from_object(obj, ["name", "nickname"]) or ""
name = self._clean_subject_text(name)
if name:
quoted = name if (name.startswith("") and name.endswith("")) else f"{name}"
return f"{subject_phrase}的昵称是{quoted}"
elif predicate == "is_age":
age = self._extract_value_from_object(obj, ["age"]) or ""
age = self._clean_subject_text(age)
if age:
return f"{subject_phrase}今年{age}"
elif predicate == "is_profession":
profession = self._extract_value_from_object(obj, ["profession", "job"]) or ""
profession = self._clean_subject_text(profession)
if profession:
return f"{subject_phrase}的职业是{profession}"
elif predicate == "lives_in":
location = self._extract_value_from_object(obj, ["location", "city", "place"]) or ""
location = self._clean_subject_text(location)
if location:
return f"{subject_phrase}居住在{location}"
elif predicate == "has_phone":
phone = self._extract_value_from_object(obj, ["phone", "number"]) or ""
phone = self._clean_subject_text(phone)
if phone:
return f"{subject_phrase}的电话号码是{phone}"
elif predicate == "has_email":
email = self._extract_value_from_object(obj, ["email"]) or ""
email = self._clean_subject_text(email)
if email:
return f"{subject_phrase}的邮箱是{email}"
elif predicate in {"likes", "likes_food", "favorite_is"}:
liked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
liked = self._clean_subject_text(liked)
if liked:
verb = "喜欢" if predicate != "likes_food" else "爱吃"
if predicate == "favorite_is":
verb = "最喜欢"
return f"{subject_phrase}{verb}{liked}"
elif predicate in {"dislikes", "hates"}:
disliked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
disliked = self._clean_subject_text(disliked)
if disliked:
verb = "不喜欢" if predicate == "dislikes" else "讨厌"
return f"{subject_phrase}{verb}{disliked}"
elif predicate == "mentioned_event":
description = self._extract_value_from_object(obj, ["event_text", "description"]) or ""
description = self._clean_subject_text(description)
if description:
return f"{subject_phrase}提到了:{description}"
obj_text = self._extract_value_from_object(obj, ["value", "detail", "content"]) or ""
obj_text = self._clean_subject_text(obj_text)
if predicate and obj_text:
return f"{subject_phrase}{predicate}{obj_text}".strip()
if obj_text:
return f"{subject_phrase}{obj_text}".strip()
if predicate:
return f"{subject_phrase}{predicate}".strip()
return subject_phrase
def _extract_personal_info(
self,
@@ -678,7 +929,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取个人信息"""
memories = []
@@ -702,13 +953,14 @@ class MemoryBuilder:
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate=predicate,
obj=obj,
memory_type=MemoryType.PERSONAL_FACT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.HIGH,
confidence=ConfidenceLevel.HIGH
confidence=ConfidenceLevel.HIGH,
display=self._compose_display_text(subjects, predicate, obj)
)
memories.append(memory)
@@ -721,7 +973,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取偏好信息"""
memories = []
@@ -740,13 +992,14 @@ class MemoryBuilder:
if match:
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate=predicate,
obj=match.group(1),
memory_type=MemoryType.PREFERENCE,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM
confidence=ConfidenceLevel.MEDIUM,
display=self._compose_display_text(subjects, predicate, match.group(1))
)
memories.append(memory)
@@ -759,7 +1012,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取事件信息"""
memories = []
@@ -770,13 +1023,14 @@ class MemoryBuilder:
if any(keyword in text for keyword in event_keywords):
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate="mentioned_event",
obj={"event_text": text, "timestamp": timestamp},
memory_type=MemoryType.EVENT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM
confidence=ConfidenceLevel.MEDIUM,
display=self._compose_display_text(subjects, "mentioned_event", text)
)
memories.append(memory)

View File

@@ -7,7 +7,7 @@
import time
import uuid
import orjson
from typing import Dict, List, Optional, Any, Union
from typing import Dict, List, Optional, Any, Union, Iterable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
@@ -52,17 +52,20 @@ class ImportanceLevel(Enum):
@dataclass
class ContentStructure:
"""主谓宾三元组结构"""
subject: str # 主语(通常为用户)
predicate: str # 谓语(动作、状态、关系)
object: Union[str, Dict] # 宾语(对象、属性、值)
"""主谓宾结构,包含自然语言描述"""
subject: Union[str, List[str]]
predicate: str
object: Union[str, Dict]
display: str = ""
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"subject": self.subject,
"predicate": self.predicate,
"object": self.object
"object": self.object,
"display": self.display
}
@classmethod
@@ -71,16 +74,25 @@ class ContentStructure:
return cls(
subject=data.get("subject", ""),
predicate=data.get("predicate", ""),
object=data.get("object", "")
object=data.get("object", ""),
display=data.get("display", "")
)
def to_subject_list(self) -> List[str]:
"""将主语转换为列表形式"""
if isinstance(self.subject, list):
return [s for s in self.subject if isinstance(s, str) and s.strip()]
if isinstance(self.subject, str) and self.subject.strip():
return [self.subject.strip()]
return []
def __str__(self) -> str:
"""字符串表示"""
if isinstance(self.object, dict):
object_str = str(self.object)
else:
object_str = str(self.object)
return f"{self.subject} {self.predicate} {object_str}"
if self.display:
return self.display
subjects = "".join(self.to_subject_list()) or str(self.subject)
object_str = self.object if isinstance(self.object, str) else str(self.object)
return f"{subjects} {self.predicate} {object_str}".strip()
@dataclass
@@ -236,9 +248,19 @@ class MemoryChunk:
@property
def text_content(self) -> str:
"""获取文本内容"""
"""获取文本内容优先使用display"""
return str(self.content)
@property
def display(self) -> str:
"""获取展示文本"""
return self.content.display or str(self.content)
@property
def subjects(self) -> List[str]:
"""获取主语列表"""
return self.content.to_subject_list()
def update_access(self):
"""更新访问信息"""
self.metadata.update_access()
@@ -415,16 +437,42 @@ class MemoryChunk:
confidence_icon = "" * self.metadata.confidence.value
importance_icon = "" * self.metadata.importance.value
return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}"
return f"{emoji} [{self.memory_type.value}] {self.display} {confidence_icon} {importance_icon}"
def __repr__(self) -> str:
"""调试表示"""
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
"""根据主谓宾生成自然语言描述"""
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
subject_part = "".join(subjects_clean) if subjects_clean else "对话参与者"
if isinstance(obj, dict):
object_candidates = []
for key, value in obj.items():
if isinstance(value, (str, int, float)):
object_candidates.append(f"{key}:{value}")
elif isinstance(value, list):
compact = "".join(str(item) for item in value[:3])
object_candidates.append(f"{key}:{compact}")
object_part = "".join(object_candidates) if object_candidates else str(obj)
else:
object_part = str(obj).strip()
predicate_clean = predicate.strip()
if not predicate_clean:
return f"{subject_part} {object_part}".strip()
if object_part:
return f"{subject_part}{predicate_clean}{object_part}".strip()
return f"{subject_part}{predicate_clean}".strip()
def create_memory_chunk(
user_id: str,
subject: str,
subject: Union[str, List[str]],
predicate: str,
obj: Union[str, Dict],
memory_type: MemoryType,
@@ -432,6 +480,7 @@ def create_memory_chunk(
source_context: Optional[str] = None,
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
display: Optional[str] = None,
**kwargs
) -> MemoryChunk:
"""便捷的内存块创建函数"""
@@ -447,10 +496,22 @@ def create_memory_chunk(
source_context=source_context
)
subjects: List[str]
if isinstance(subject, list):
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
subject_payload: Union[str, List[str]] = subjects
else:
cleaned = subject.strip() if isinstance(subject, str) else ""
subjects = [cleaned] if cleaned else []
subject_payload = cleaned
display_text = display or _build_display_text(subjects, predicate, obj)
content = ContentStructure(
subject=subject,
subject=subject_payload,
predicate=predicate,
object=obj
object=obj,
display=display_text
)
chunk = MemoryChunk(

View File

@@ -266,8 +266,12 @@ class MemoryFusionEngine:
consistency_score = 0.0
# 主语一致性
if mem1.content.subject == mem2.content.subject:
consistency_score += 0.4
subjects1 = set(mem1.subjects)
subjects2 = set(mem2.subjects)
if subjects1 or subjects2:
overlap = len(subjects1 & subjects2)
union_count = max(len(subjects1 | subjects2), 1)
consistency_score += (overlap / union_count) * 0.4
# 谓语相似性
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)

View File

@@ -282,9 +282,11 @@ class MemoryIntegrationHooks:
}
# 使用增强记忆系统处理对话
result = await process_conversation_with_enhanced_memory(
conversation_text, context, user_id
)
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
@@ -336,9 +338,11 @@ class MemoryIntegrationHooks:
}
# 使用增强记忆系统处理对话
result = await process_conversation_with_enhanced_memory(
conversation_text, context, user_id
)
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)

View File

@@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-
"""记忆检索查询规划器"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import orjson
from src.chat.memory_system.memory_chunk import MemoryType
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@dataclass
class MemoryQueryPlan:
"""查询规划结果"""
semantic_query: str
memory_types: List[MemoryType] = field(default_factory=list)
subject_includes: List[str] = field(default_factory=list)
object_includes: List[str] = field(default_factory=list)
required_keywords: List[str] = field(default_factory=list)
optional_keywords: List[str] = field(default_factory=list)
owner_filters: List[str] = field(default_factory=list)
recency_preference: str = "any"
limit: int = 10
emphasis: Optional[str] = None
raw_plan: Dict[str, Any] = field(default_factory=dict)
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
if not self.semantic_query:
self.semantic_query = fallback_query
if self.limit <= 0:
self.limit = default_limit
self.recency_preference = (self.recency_preference or "any").lower()
if self.recency_preference not in {"any", "recent", "historical"}:
self.recency_preference = "any"
self.emphasis = (self.emphasis or "balanced").lower()
class MemoryQueryPlanner:
"""基于小模型的记忆检索查询规划器"""
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
self.model = planner_model
self.default_limit = default_limit
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
if not self.model:
logger.debug("未提供查询规划模型,使用默认规划")
return self._default_plan(query_text)
prompt = self._build_prompt(query_text, context)
try:
response, _ = await self.model.generate_response_async(prompt, temperature=0.2)
payload = self._extract_json_payload(response)
if not payload:
logger.debug("查询规划模型未返回结构化结果,使用默认规划")
return self._default_plan(query_text)
try:
data = orjson.loads(payload)
except orjson.JSONDecodeError as exc:
preview = payload[:200]
logger.warning("解析查询规划JSON失败: %s,片段: %s", exc, preview)
return self._default_plan(query_text)
plan = self._parse_plan_dict(data, query_text)
plan.ensure_defaults(query_text, self.default_limit)
return plan
except Exception as exc:
logger.error("查询规划模型调用失败: %s", exc, exc_info=True)
return self._default_plan(query_text)
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
return MemoryQueryPlan(
semantic_query=query_text,
limit=self.default_limit
)
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
def _collect_list(key: str) -> List[str]:
value = data.get(key)
if isinstance(value, str):
return [value]
if isinstance(value, list):
return [self._safe_str(item) for item in value if self._safe_str(item)]
return []
memory_type_values = _collect_list("memory_types")
memory_types: List[MemoryType] = []
for item in memory_type_values:
if not item:
continue
try:
memory_types.append(MemoryType(item))
except ValueError:
# 尝试匹配value值
normalized = item.lower()
for mt in MemoryType:
if mt.value == normalized:
memory_types.append(mt)
break
plan = MemoryQueryPlan(
semantic_query=semantic_query,
memory_types=memory_types,
subject_includes=_collect_list("subject_includes"),
object_includes=_collect_list("object_includes"),
required_keywords=_collect_list("required_keywords"),
optional_keywords=_collect_list("optional_keywords"),
owner_filters=_collect_list("owner_filters"),
recency_preference=self._safe_str(data.get("recency")) or "any",
limit=self._safe_int(data.get("limit"), self.default_limit),
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
raw_plan=data
)
return plan
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
participants = context.get("participants") or context.get("speaker_names") or []
if isinstance(participants, str):
participants = [participants]
participants = [p for p in participants if isinstance(p, str) and p.strip()]
participant_preview = "".join(participants[:5]) or "未知"
persona = context.get("bot_personality") or context.get("bot_identity") or "未知"
return f"""
你是一名记忆检索分析师,将根据对话查询生成结构化的检索计划。
请结合提供的上下文输出一个JSON对象字段含义如下
- semantic_query: 提供给向量检索的自然语言查询,要求清晰具体;
- memory_types: 建议检索的记忆类型数组,取值范围参见 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual)
- subject_includes: 需要出现在记忆主语中的人或角色列表;
- object_includes: 记忆中需要提到的重要对象或主题关键词列表;
- required_keywords: 检索时必须包含的关键词;
- optional_keywords: 可以提升相关性的附加关键词;
- owner_filters: 如果需要限制检索所属主体请列出用户ID或其它标识
- recency: 建议的时间偏好,可选 recent/any/historical
- emphasis: 检索策略倾向,可选 precision/recall/balanced
- limit: 推荐的最大返回数量(1-15之间)
- notes: 额外说明,可选。
当前查询: "{query_text}"
已知的对话参与者: {participant_preview}
机器人设定: {persona}
请输出符合要求的JSON禁止添加额外说明或Markdown代码块。
"""
def _extract_json_payload(self, response: str) -> Optional[str]:
if not response:
return None
stripped = response.strip()
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
if code_block_match:
candidate = code_block_match.group(1).strip()
if candidate:
return candidate
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1]
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
@staticmethod
def _safe_str(value: Any) -> str:
if isinstance(value, str):
return value.strip()
if value is None:
return ""
return str(value).strip()
@staticmethod
def _safe_int(value: Any, default: int) -> int:
try:
number = int(value)
if number <= 0:
return default
return number
except (TypeError, ValueError):
return default

View File

@@ -25,6 +25,7 @@ class IndexType(Enum):
"""索引类型"""
MEMORY_TYPE = "memory_type" # 记忆类型索引
USER_ID = "user_id" # 用户ID索引
SUBJECT = "subject" # 主体索引
KEYWORD = "keyword" # 关键词索引
TAG = "tag" # 标签索引
CATEGORY = "category" # 分类索引
@@ -41,6 +42,7 @@ class IndexQuery:
"""索引查询条件"""
user_ids: Optional[List[str]] = None
memory_types: Optional[List[MemoryType]] = None
subjects: Optional[List[str]] = None
keywords: Optional[List[str]] = None
tags: Optional[List[str]] = None
categories: Optional[List[str]] = None
@@ -76,6 +78,7 @@ class MetadataIndexManager:
self.indices = {
IndexType.MEMORY_TYPE: defaultdict(set),
IndexType.USER_ID: defaultdict(set),
IndexType.SUBJECT: defaultdict(set),
IndexType.KEYWORD: defaultdict(set),
IndexType.TAG: defaultdict(set),
IndexType.CATEGORY: defaultdict(set),
@@ -110,6 +113,41 @@ class MetadataIndexManager:
self.auto_save_interval = 500 # 每500次操作自动保存
self._operation_count = 0
@staticmethod
def _serialize_index_key(index_type: IndexType, key: Any) -> str:
"""将索引键序列化为字符串以便存储"""
if isinstance(key, Enum):
value = key.value
else:
value = key
return str(value)
@staticmethod
def _deserialize_index_key(index_type: IndexType, key: str) -> Any:
"""根据索引类型反序列化索引键"""
try:
if index_type == IndexType.MEMORY_TYPE:
return MemoryType(key)
if index_type == IndexType.CONFIDENCE:
return ConfidenceLevel(int(key))
if index_type == IndexType.IMPORTANCE:
return ImportanceLevel(int(key))
# 其他索引键默认使用原始字符串可能已经是lower后的字符串
return key
except Exception:
logger.warning("无法反序列化索引键 %s 在索引 %s 中,使用原始字符串", key, index_type.value)
return key
@staticmethod
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
serialized = {}
for field_name, value in metadata.items():
if isinstance(value, Enum):
serialized[field_name] = value.value
else:
serialized[field_name] = value
return serialized
async def index_memories(self, memories: List[MemoryChunk]):
"""为记忆建立索引"""
if not memories:
@@ -142,6 +180,68 @@ class MetadataIndexManager:
except Exception as e:
logger.error(f"❌ 元数据索引失败: {e}", exc_info=True)
async def update_memory_entry(self, memory: MemoryChunk):
"""更新已存在记忆的索引信息"""
if not memory:
return
with self._lock:
entry = self.memory_metadata_cache.get(memory.memory_id)
if entry is None:
# 若不存在则作为新记忆索引
self._index_single_memory(memory)
return
old_confidence = entry.get("confidence")
old_importance = entry.get("importance")
old_semantic_hash = entry.get("semantic_hash")
entry.update(
{
"user_id": memory.user_id,
"memory_type": memory.memory_type,
"created_at": memory.metadata.created_at,
"last_accessed": memory.metadata.last_accessed,
"access_count": memory.metadata.access_count,
"confidence": memory.metadata.confidence,
"importance": memory.metadata.importance,
"relationship_score": memory.metadata.relationship_score,
"relevance_score": memory.metadata.relevance_score,
"semantic_hash": memory.semantic_hash,
"subjects": memory.subjects,
}
)
# 更新置信度/重要性索引
if isinstance(old_confidence, ConfidenceLevel):
self.indices[IndexType.CONFIDENCE][old_confidence].discard(memory.memory_id)
if isinstance(old_importance, ImportanceLevel):
self.indices[IndexType.IMPORTANCE][old_importance].discard(memory.memory_id)
if isinstance(old_semantic_hash, str):
self.indices[IndexType.SEMANTIC_HASH][old_semantic_hash].discard(memory.memory_id)
self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory.memory_id)
self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory.memory_id)
if memory.semantic_hash:
self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory.memory_id)
# 同步关键词/标签/分类索引
for keyword in memory.keywords:
if keyword:
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory.memory_id)
for tag in memory.tags:
if tag:
self.indices[IndexType.TAG][tag.lower()].add(memory.memory_id)
for category in memory.categories:
if category:
self.indices[IndexType.CATEGORY][category.lower()].add(memory.memory_id)
for subject in memory.subjects:
if subject:
self.indices[IndexType.SUBJECT][subject.strip().lower()].add(memory.memory_id)
def _index_single_memory(self, memory: MemoryChunk):
"""为单个记忆建立索引"""
memory_id = memory.memory_id
@@ -157,7 +257,8 @@ class MetadataIndexManager:
"importance": memory.metadata.importance,
"relationship_score": memory.metadata.relationship_score,
"relevance_score": memory.metadata.relevance_score,
"semantic_hash": memory.semantic_hash
"semantic_hash": memory.semantic_hash,
"subjects": memory.subjects
}
# 记忆类型索引
@@ -166,6 +267,12 @@ class MetadataIndexManager:
# 用户ID索引
self.indices[IndexType.USER_ID][memory.user_id].add(memory_id)
# 主体索引
for subject in memory.subjects:
normalized = subject.strip().lower()
if normalized:
self.indices[IndexType.SUBJECT][normalized].add(memory_id)
# 关键词索引
for keyword in memory.keywords:
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id)
@@ -282,13 +389,6 @@ class MetadataIndexManager:
# 应用最严格的过滤条件
applied_filters = []
if query.user_ids:
user_ids_set = set()
for user_id in query.user_ids:
user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set()))
candidate_ids.update(user_ids_set)
applied_filters.append("user_ids")
if query.memory_types:
memory_types_set = set()
for memory_type in query.memory_types:
@@ -302,7 +402,7 @@ class MetadataIndexManager:
if query.keywords:
keywords_set = set()
for keyword in query.keywords:
keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set()))
keywords_set.update(self._collect_index_matches(IndexType.KEYWORD, keyword))
if applied_filters:
candidate_ids &= keywords_set
else:
@@ -329,12 +429,55 @@ class MetadataIndexManager:
candidate_ids.update(categories_set)
applied_filters.append("categories")
if query.subjects:
subjects_set = set()
for subject in query.subjects:
subjects_set.update(self._collect_index_matches(IndexType.SUBJECT, subject))
if applied_filters:
candidate_ids &= subjects_set
else:
candidate_ids.update(subjects_set)
applied_filters.append("subjects")
# 如果没有应用任何过滤条件,返回所有记忆
if not applied_filters:
return all_memory_ids
return candidate_ids
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
"""根据给定token收集索引匹配支持部分匹配"""
mapping = self.indices.get(index_type)
if mapping is None:
return set()
key = ""
if isinstance(token, Enum):
key = str(token.value).strip().lower()
elif isinstance(token, str):
key = token.strip().lower()
elif token is not None:
key = str(token).strip().lower()
if not key:
return set()
matches: Set[str] = set(mapping.get(key, set()))
if matches:
return set(matches)
for existing_key, ids in mapping.items():
if not existing_key or not isinstance(existing_key, str):
continue
normalized = existing_key.strip().lower()
if not normalized:
continue
if key in normalized or normalized in key:
matches.update(ids)
return matches
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
"""应用过滤条件"""
filtered_ids = list(candidate_ids)
@@ -440,10 +583,10 @@ class MetadataIndexManager:
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
"""获取应用的过滤器列表"""
filters = []
if query.user_ids:
filters.append("user_ids")
if query.memory_types:
filters.append("memory_types")
if query.subjects:
filters.append("subjects")
if query.keywords:
filters.append("keywords")
if query.tags:
@@ -502,6 +645,18 @@ class MetadataIndexManager:
# 从各类索引中移除
self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id)
self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id)
subjects = metadata.get("subjects") or []
for subject in subjects:
if not isinstance(subject, str):
continue
normalized = subject.strip().lower()
if not normalized:
continue
subject_bucket = self.indices[IndexType.SUBJECT].get(normalized)
if subject_bucket is not None:
subject_bucket.discard(memory_id)
if not subject_bucket:
self.indices[IndexType.SUBJECT].pop(normalized, None)
# 从时间索引中移除
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id]
@@ -625,11 +780,13 @@ class MetadataIndexManager:
logger.info("正在保存元数据索引...")
# 保存各类索引
indices_data = {}
indices_data: Dict[str, Dict[str, List[str]]] = {}
for index_type, index_data in self.indices.items():
indices_data[index_type.value] = {
key: list(values) for key, values in index_data.items()
}
serialized_index = {}
for key, values in index_data.items():
serialized_key = self._serialize_index_key(index_type, key)
serialized_index[serialized_key] = list(values)
indices_data[index_type.value] = serialized_index
indices_file = self.index_path / "indices.json"
with open(indices_file, 'w', encoding='utf-8') as f:
@@ -652,8 +809,12 @@ class MetadataIndexManager:
# 保存元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
metadata_serialized = {
memory_id: self._serialize_metadata_entry(metadata)
for memory_id, metadata in self.memory_metadata_cache.items()
}
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存统计信息
stats_file = self.index_path / "index_stats.json"
@@ -679,9 +840,11 @@ class MetadataIndexManager:
for index_type_value, index_data in indices_data.items():
index_type = IndexType(index_type_value)
self.indices[index_type] = {
key: set(values) for key, values in index_data.items()
}
restored_index = defaultdict(set)
for key_str, values in index_data.items():
restored_key = self._deserialize_index_key(index_type, key_str)
restored_index[restored_key] = set(values)
self.indices[index_type] = restored_index
# 加载时间索引
time_index_file = self.index_path / "time_index.json"
@@ -709,10 +872,38 @@ class MetadataIndexManager:
# 转换置信度和重要性为枚举类型
for memory_id, metadata in cache_data.items():
if isinstance(metadata["confidence"], str):
metadata["confidence"] = ConfidenceLevel(metadata["confidence"])
if isinstance(metadata["importance"], str):
metadata["importance"] = ImportanceLevel(metadata["importance"])
memory_type_value = metadata.get("memory_type")
if isinstance(memory_type_value, str):
try:
metadata["memory_type"] = MemoryType(memory_type_value)
except ValueError:
logger.warning("无法解析memory_type %s", memory_type_value)
confidence_value = metadata.get("confidence")
if isinstance(confidence_value, (str, int)):
try:
metadata["confidence"] = ConfidenceLevel(int(confidence_value))
except ValueError:
logger.warning("无法解析confidence %s", confidence_value)
importance_value = metadata.get("importance")
if isinstance(importance_value, (str, int)):
try:
metadata["importance"] = ImportanceLevel(int(importance_value))
except ValueError:
logger.warning("无法解析importance %s", importance_value)
subjects_value = metadata.get("subjects")
if isinstance(subjects_value, str):
metadata["subjects"] = [subjects_value]
elif isinstance(subjects_value, list):
cleaned_subjects = []
for item in subjects_value:
if isinstance(item, str) and item.strip():
cleaned_subjects.append(item.strip())
metadata["subjects"] = cleaned_subjects
else:
metadata["subjects"] = []
self.memory_metadata_cache = cache_data

View File

@@ -203,11 +203,17 @@ class MultiStageRetrieval:
try:
from .metadata_index import IndexQuery
# 构建索引查询
query_plan = context.get("query_plan")
memory_types = self._extract_memory_types_from_context(context)
keywords = self._extract_keywords_from_query(query, query_plan)
subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
index_query = IndexQuery(
user_ids=[user_id],
memory_types=self._extract_memory_types_from_context(context),
keywords=self._extract_keywords_from_query(query),
user_ids=None,
memory_types=memory_types,
subjects=subjects,
keywords=keywords,
limit=self.config.metadata_filter_limit,
sort_by="last_accessed",
sort_order="desc"
@@ -215,13 +221,66 @@ class MultiStageRetrieval:
# 执行查询
result = await metadata_index.query_memories(index_query)
filtered_count = result.total_count - len(result.memory_ids)
result_ids = list(result.memory_ids)
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆")
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
if not result_ids:
sorted_ids = sorted(
(memory.memory_id for memory in all_memories_cache.values()),
key=lambda mid: all_memories_cache[mid].metadata.last_accessed if mid in all_memories_cache else 0,
reverse=True,
)
if memory_types:
type_filtered = [
mid for mid in sorted_ids
if all_memories_cache[mid].memory_type in memory_types
]
sorted_ids = type_filtered or sorted_ids
if subjects:
subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()]
if subject_candidates:
subject_filtered = [
mid for mid in sorted_ids
if any(
subj.strip().lower() in subject_candidates
for subj in all_memories_cache[mid].subjects
)
]
sorted_ids = subject_filtered or sorted_ids
if keywords:
keyword_pool = {kw.lower() for kw in keywords if isinstance(kw, str) and kw.strip()}
if keyword_pool:
keyword_filtered = []
for mid in sorted_ids:
memory_text = (
(all_memories_cache[mid].display or "")
+ "\n"
+ (all_memories_cache[mid].text_content or "")
).lower()
if any(kw in memory_text for kw in keyword_pool):
keyword_filtered.append(mid)
sorted_ids = keyword_filtered or sorted_ids
result_ids = sorted_ids[: self.config.metadata_filter_limit]
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
logger.debug(
"元数据过滤未命中索引,使用近似回退: types=%s, subjects=%s, keywords=%s",
bool(memory_types),
bool(subjects),
bool(keywords),
)
logger.debug(
"元数据过滤:候选=%d, 返回=%d",
len(all_memories_cache),
len(result_ids),
)
return StageResult(
stage=RetrievalStage.METADATA_FILTERING,
memory_ids=result.memory_ids,
memory_ids=result_ids,
processing_time=time.time() - start_time,
filtered_count=filtered_count,
score_threshold=0.0
@@ -251,7 +310,7 @@ class MultiStageRetrieval:
try:
# 生成查询向量
query_embedding = await self._generate_query_embedding(query, context)
query_embedding = await self._generate_query_embedding(query, context, vector_storage)
if not query_embedding:
return StageResult(
@@ -263,22 +322,24 @@ class MultiStageRetrieval:
)
# 执行向量搜索
search_result = await vector_storage.search_similar(
query_embedding,
search_result = await vector_storage.search_similar_memories(
query_vector=query_embedding,
limit=self.config.vector_search_limit
)
candidate_pool = candidate_ids or set(all_memories_cache.keys())
# 过滤候选记忆
filtered_memories = []
for memory_id, similarity in search_result:
if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold:
if memory_id in candidate_pool and similarity >= self.config.vector_similarity_threshold:
filtered_memories.append((memory_id, similarity))
# 按相似度排序
filtered_memories.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
filtered_count = len(candidate_ids) - len(result_ids)
filtered_count = max(0, len(candidate_pool) - len(result_ids))
logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
@@ -407,12 +468,20 @@ class MultiStageRetrieval:
score_threshold=0.0
)
async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]:
async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]:
"""生成查询向量"""
try:
# 这里应该调用embedding模型
# 由于我们可能没有直接的embedding模型返回None或使用简单的方法
# 在实际实现中这里应该调用与记忆存储相同的embedding模型
query_plan = context.get("query_plan")
query_text = query
if query_plan and getattr(query_plan, "semantic_query", None):
query_text = query_plan.semantic_query
if not query_text:
return None
if hasattr(vector_storage, "generate_query_embedding"):
return await vector_storage.generate_query_embedding(query_text)
return None
except Exception as e:
logger.warning(f"生成查询向量失败: {e}")
@@ -421,9 +490,15 @@ class MultiStageRetrieval:
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
"""计算语义相似度"""
try:
query_plan = context.get("query_plan")
query_text = query
if query_plan and getattr(query_plan, "semantic_query", None):
query_text = query_plan.semantic_query
# 简单的文本相似度计算
query_words = set(query.lower().split())
memory_words = set(memory.text_content.lower().split())
query_words = set(query_text.lower().split())
memory_text = (memory.display or memory.text_content or "").lower()
memory_words = set(memory_text.split())
if not query_words or not memory_words:
return 0.0
@@ -443,10 +518,15 @@ class MultiStageRetrieval:
try:
score = 0.0
query_plan = context.get("query_plan")
# 检查记忆类型是否匹配上下文
if context.get("expected_memory_types"):
if memory.memory_type in context["expected_memory_types"]:
score += 0.3
elif query_plan and getattr(query_plan, "memory_types", None):
if memory.memory_type in query_plan.memory_types:
score += 0.3
# 检查关键词匹配
if context.get("keywords"):
@@ -456,6 +536,35 @@ class MultiStageRetrieval:
if overlap:
score += len(overlap) / max(len(context_keywords), 1) * 0.4
if query_plan:
# 主体匹配
subject_score = self._calculate_subject_overlap(memory, getattr(query_plan, "subject_includes", []))
score += subject_score * 0.3
# 对象/描述匹配
object_keywords = getattr(query_plan, "object_includes", []) or []
if object_keywords:
display_text = (memory.display or memory.text_content or "").lower()
hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
if hits:
score += min(0.3, hits * 0.1)
optional_keywords = getattr(query_plan, "optional_keywords", []) or []
if optional_keywords:
display_text = (memory.display or memory.text_content or "").lower()
hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
if hits:
score += min(0.2, hits * 0.05)
# 时间偏好
recency_pref = getattr(query_plan, "recency_preference", "")
if recency_pref:
memory_age = time.time() - memory.metadata.created_at
if recency_pref == "recent" and memory_age < 7 * 24 * 3600:
score += 0.2
elif recency_pref == "historical" and memory_age > 30 * 24 * 3600:
score += 0.1
# 检查时效性
if context.get("recent_only", False):
memory_age = time.time() - memory.metadata.created_at
@@ -471,6 +580,8 @@ class MultiStageRetrieval:
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
"""计算最终评分"""
try:
query_plan = context.get("query_plan")
# 语义相似度
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
@@ -482,13 +593,29 @@ class MultiStageRetrieval:
# 时效性评分
recency_score = self._calculate_recency_score(memory.metadata.created_at)
if query_plan:
recency_pref = getattr(query_plan, "recency_preference", "")
if recency_pref == "recent":
recency_score = max(recency_score, 0.8)
elif recency_pref == "historical":
recency_score = min(recency_score, 0.5)
# 权重组合
vector_weight = self.config.vector_weight
semantic_weight = self.config.semantic_weight
context_weight = self.config.context_weight
recency_weight = self.config.recency_weight
if query_plan and getattr(query_plan, "emphasis", None) == "precision":
semantic_weight += 0.05
elif query_plan and getattr(query_plan, "emphasis", None) == "recall":
context_weight += 0.05
final_score = (
semantic_score * self.config.semantic_weight +
vector_score * self.config.vector_weight +
context_score * self.config.context_weight +
recency_score * self.config.recency_weight
semantic_score * semantic_weight +
vector_score * vector_weight +
context_score * context_weight +
recency_score * recency_weight
)
# 加入记忆重要性权重
@@ -501,6 +628,31 @@ class MultiStageRetrieval:
logger.warning(f"计算最终评分失败: {e}")
return 0.0
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
if not required_subjects:
return 0.0
memory_subjects = {subject.lower() for subject in memory.subjects if isinstance(subject, str)}
if not memory_subjects:
return 0.0
hit = 0
total = 0
for subject in required_subjects:
if not isinstance(subject, str):
continue
total += 1
normalized = subject.strip().lower()
if not normalized:
continue
if any(normalized in mem_subject for mem_subject in memory_subjects):
hit += 1
if total == 0:
return 0.0
return hit / total
def _calculate_recency_score(self, timestamp: float) -> float:
"""计算时效性评分"""
try:
@@ -524,6 +676,10 @@ class MultiStageRetrieval:
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
"""从上下文中提取记忆类型"""
try:
query_plan = context.get("query_plan")
if query_plan and getattr(query_plan, "memory_types", None):
return query_plan.memory_types
if "expected_memory_types" in context:
return context["expected_memory_types"]
@@ -544,15 +700,30 @@ class MultiStageRetrieval:
except Exception:
return []
def _extract_keywords_from_query(self, query: str) -> List[str]:
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
"""从查询中提取关键词"""
try:
extracted: List[str] = []
if query_plan and getattr(query_plan, "required_keywords", None):
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
# 简单的关键词提取
words = query.lower().split()
# 过滤停用词
stopwords = {"", "", "", "", "", "", "", "", "", "", "", "", "", ""}
keywords = [word for word in words if len(word) > 1 and word not in stopwords]
return keywords[:10] # 最多返回10个关键词
extracted.extend(word for word in words if len(word) > 1 and word not in stopwords)
# 去重并保留顺序
seen = set()
deduplicated = []
for word in extracted:
if word in seen or not word:
continue
seen.add(word)
deduplicated.append(word)
return deduplicated[:10]
except Exception:
return []

View File

@@ -20,6 +20,7 @@ from pathlib import Path
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.common.config_helpers import resolve_embedding_dimension
from src.chat.memory_system.memory_chunk import MemoryChunk
logger = get_logger(__name__)
@@ -36,12 +37,12 @@ except ImportError:
@dataclass
class VectorStorageConfig:
"""向量存储配置"""
dimension: int = 768
dimension: int = 1024
similarity_threshold: float = 0.8
index_type: str = "flat" # flat, ivf, hnsw
max_index_size: int = 100000
storage_path: str = "data/memory_vectors"
auto_save_interval: int = 100 # 每N次操作自动保存
auto_save_interval: int = 10 # 每N次操作自动保存
enable_compression: bool = True
@@ -50,6 +51,15 @@ class VectorStorageManager:
def __init__(self, config: Optional[VectorStorageConfig] = None):
self.config = config or VectorStorageConfig()
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
if resolved_dimension and resolved_dimension != self.config.dimension:
logger.info(
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
resolved_dimension,
self.config.dimension,
)
self.config.dimension = resolved_dimension
self.storage_path = Path(self.config.storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
@@ -117,6 +127,32 @@ class VectorStorageManager:
)
logger.info("✅ 嵌入模型初始化完成")
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
"""生成查询向量,用于记忆召回"""
if not query_text:
return None
try:
await self.initialize_embedding_model()
embedding, _ = await self.embedding_model.get_embedding(query_text)
if not embedding:
return None
if len(embedding) != self.config.dimension:
logger.warning(
"查询向量维度不匹配: 期望 %d, 实际 %d",
self.config.dimension,
len(embedding)
)
return None
return self._normalize_vector(embedding)
except Exception as exc:
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
return None
async def store_memories(self, memories: List[MemoryChunk]):
"""存储记忆向量"""
if not memories:
@@ -213,7 +249,7 @@ class VectorStorageManager:
results[memory_id] = embedding
else:
logger.warning(
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)",
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
self.config.dimension,
len(embedding) if embedding else 0,
memory_id,
@@ -299,14 +335,32 @@ class VectorStorageManager:
async def search_similar_memories(
self,
query_vector: List[float],
query_vector: Optional[List[float]] = None,
*,
query_text: Optional[str] = None,
limit: int = 10,
user_id: Optional[str] = None
scope_id: Optional[str] = None
) -> List[Tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
try:
if query_vector is None:
if not query_text:
return []
query_vector = await self.generate_query_embedding(query_text)
if not query_vector:
return []
scope_filter: Optional[str] = None
if isinstance(scope_id, str):
normalized_scope = scope_id.strip().lower()
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
scope_filter = scope_id
elif scope_id:
scope_filter = str(scope_id)
# 规范化查询向量
query_vector = self._normalize_vector(query_vector)
@@ -341,10 +395,9 @@ class VectorStorageManager:
memory_id = self.index_to_memory_id.get(index)
if memory_id:
# 应用用户过滤
if user_id:
if scope_filter:
memory = self.memory_cache.get(memory_id)
if memory and memory.user_id != user_id:
if memory and str(memory.user_id) != scope_filter:
continue
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
@@ -481,8 +534,14 @@ class VectorStorageManager:
# 保存映射关系
mapping_file = self.storage_path / "id_mapping.json"
mapping_data = {
"memory_id_to_index": self.memory_id_to_index,
"index_to_memory_id": self.index_to_memory_id
"memory_id_to_index": {
str(memory_id): int(index)
for memory_id, index in self.memory_id_to_index.items()
},
"index_to_memory_id": {
str(index): memory_id
for index, memory_id in self.index_to_memory_id.items()
}
}
with open(mapping_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
@@ -529,8 +588,17 @@ class VectorStorageManager:
if mapping_file.exists():
with open(mapping_file, 'r', encoding='utf-8') as f:
mapping_data = orjson.loads(f.read())
self.memory_id_to_index = mapping_data.get("memory_id_to_index", {})
self.index_to_memory_id = mapping_data.get("index_to_memory_id", {})
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
str(memory_id): int(index)
for memory_id, index in raw_memory_to_index.items()
}
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
self.index_to_memory_id = {
int(index): memory_id
for index, memory_id in raw_index_to_memory.items()
}
# 加载FAISS索引如果可用
if FAISS_AVAILABLE:

View File

@@ -473,14 +473,14 @@ class ChatBot:
async def preprocess():
# 存储消息到数据库
from .storage import MessageStorage
try:
await MessageStorage.store_message(message, message.chat_stream)
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
# 使用消息管理器处理消息(保持原有功能)
from src.common.data_models.database_data_model import DatabaseMessages

View File

@@ -377,12 +377,12 @@ class Prompt:
# 性能优化 - 为不同任务设置不同的超时时间
task_timeouts = {
"memory_block": 5.0, # 记忆系统可能较慢,单独设置超时
"tool_info": 3.0, # 工具信息中等速度
"relation_info": 2.0, # 关系信息通常较快
"knowledge_info": 3.0, # 知识库查询中等速度
"cross_context": 2.0, # 上下文处理通常较快
"expression_habits": 1.5, # 表达习惯处理很快
"memory_block": 15.0, # 记忆系统
"tool_info": 15.0, # 工具信息
"relation_info": 10.0, # 关系信息
"knowledge_info": 10.0, # 知识库查询
"cross_context": 10.0, # 上下文处理
"expression_habits": 10.0, # 表达习惯
}
# 分别处理每个任务,避免慢任务影响快任务
@@ -562,12 +562,8 @@ class Prompt:
)
]
# 等待所有记忆查询完成最多3秒
try:
running_memories, instant_memory = await asyncio.wait_for(
asyncio.gather(*memory_tasks, return_exceptions=True),
timeout=3.0
)
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
# 处理可能的异常结果
if isinstance(running_memories, Exception):