feat(memory): 实现三阶段记忆检索系统并简化提取策略
- 移除规则和混合提取策略,统一使用LLM提取 - 实现三阶段检索:元数据粗筛→向量精筛→综合重排 - 新增JSON元数据索引支持,提升检索效率 - 优化Vector DB配置管理和批处理机制 - 统一记忆作用域为全局,实现完全共享 - 增强查询规划和综合评分算法
This commit is contained in:
139
scripts/rebuild_metadata_index.py
Normal file
139
scripts/rebuild_metadata_index.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
从现有ChromaDB数据重建JSON元数据索引
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
async def rebuild_metadata_index():
|
||||
"""从ChromaDB重建元数据索引"""
|
||||
print("="*80)
|
||||
print("重建JSON元数据索引")
|
||||
print("="*80)
|
||||
|
||||
# 初始化记忆系统
|
||||
print("\n🔧 初始化记忆系统...")
|
||||
ms = MemorySystem()
|
||||
await ms.initialize()
|
||||
print("✅ 记忆系统已初始化")
|
||||
|
||||
if not hasattr(ms.unified_storage, 'metadata_index'):
|
||||
print("❌ 元数据索引管理器未初始化")
|
||||
return
|
||||
|
||||
# 获取所有记忆
|
||||
print("\n📥 从ChromaDB获取所有记忆...")
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
try:
|
||||
# 获取集合中的所有记忆ID
|
||||
collection_name = ms.unified_storage.config.memory_collection
|
||||
result = vector_db_service.get(
|
||||
collection_name=collection_name,
|
||||
include=["documents", "metadatas", "embeddings"]
|
||||
)
|
||||
|
||||
if not result or not result.get("ids"):
|
||||
print("❌ ChromaDB中没有找到记忆数据")
|
||||
return
|
||||
|
||||
ids = result["ids"]
|
||||
metadatas = result.get("metadatas", [])
|
||||
|
||||
print(f"✅ 找到 {len(ids)} 条记忆")
|
||||
|
||||
# 重建元数据索引
|
||||
print("\n🔨 开始重建元数据索引...")
|
||||
entries = []
|
||||
success_count = 0
|
||||
|
||||
for i, (memory_id, metadata) in enumerate(zip(ids, metadatas), 1):
|
||||
try:
|
||||
# 从ChromaDB元数据重建索引条目
|
||||
import orjson
|
||||
|
||||
entry = MemoryMetadataIndexEntry(
|
||||
memory_id=memory_id,
|
||||
user_id=metadata.get("user_id", "unknown"),
|
||||
memory_type=metadata.get("memory_type", "general"),
|
||||
subjects=orjson.loads(metadata.get("subjects", "[]")),
|
||||
objects=[metadata.get("object")] if metadata.get("object") else [],
|
||||
keywords=orjson.loads(metadata.get("keywords", "[]")),
|
||||
tags=orjson.loads(metadata.get("tags", "[]")),
|
||||
importance=2, # 默认NORMAL
|
||||
confidence=2, # 默认MEDIUM
|
||||
created_at=metadata.get("created_at", 0.0),
|
||||
access_count=metadata.get("access_count", 0),
|
||||
chat_id=metadata.get("chat_id"),
|
||||
content_preview=None
|
||||
)
|
||||
|
||||
# 尝试解析importance和confidence的枚举名称
|
||||
if "importance" in metadata:
|
||||
imp_str = metadata["importance"]
|
||||
if imp_str == "LOW":
|
||||
entry.importance = 1
|
||||
elif imp_str == "NORMAL":
|
||||
entry.importance = 2
|
||||
elif imp_str == "HIGH":
|
||||
entry.importance = 3
|
||||
elif imp_str == "CRITICAL":
|
||||
entry.importance = 4
|
||||
|
||||
if "confidence" in metadata:
|
||||
conf_str = metadata["confidence"]
|
||||
if conf_str == "LOW":
|
||||
entry.confidence = 1
|
||||
elif conf_str == "MEDIUM":
|
||||
entry.confidence = 2
|
||||
elif conf_str == "HIGH":
|
||||
entry.confidence = 3
|
||||
elif conf_str == "VERIFIED":
|
||||
entry.confidence = 4
|
||||
|
||||
entries.append(entry)
|
||||
success_count += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理记忆 {memory_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据")
|
||||
|
||||
# 批量更新索引
|
||||
print("\n💾 保存元数据索引...")
|
||||
ms.unified_storage.metadata_index.batch_add_or_update(entries)
|
||||
ms.unified_storage.metadata_index.save()
|
||||
|
||||
# 显示统计信息
|
||||
stats = ms.unified_storage.metadata_index.get_stats()
|
||||
print(f"\n📊 重建后的索引统计:")
|
||||
print(f" - 总记忆数: {stats['total_memories']}")
|
||||
print(f" - 主语数量: {stats['subjects_count']}")
|
||||
print(f" - 关键词数量: {stats['keywords_count']}")
|
||||
print(f" - 标签数量: {stats['tags_count']}")
|
||||
print(f" - 类型分布:")
|
||||
for mtype, count in stats['types'].items():
|
||||
print(f" - {mtype}: {count}")
|
||||
|
||||
print("\n✅ 元数据索引重建完成!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重建索引失败: {e}", exc_info=True)
|
||||
print(f"❌ 重建索引失败: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(rebuild_metadata_index())
|
||||
23
scripts/run_multi_stage_smoke.py
Normal file
23
scripts/run_multi_stage_smoke.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
|
||||
async def main():
|
||||
ms = MemorySystem()
|
||||
await ms.initialize()
|
||||
results = await ms.retrieve_relevant_memories(query_text="测试查询:杰瑞喵喜欢什么?", limit=3)
|
||||
print(f"检索到 {len(results)} 条记忆(如果 >0 则表明运行成功)")
|
||||
for i, m in enumerate(results, 1):
|
||||
print(f"{i}. id={m.metadata.memory_id} source={getattr(m.metadata, 'source', None)}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -96,19 +96,8 @@ class MemoryBuilder:
|
||||
try:
|
||||
logger.debug(f"开始从对话构建记忆,文本长度: {len(conversation_text)}")
|
||||
|
||||
# 预处理文本
|
||||
processed_text = self._preprocess_text(conversation_text)
|
||||
|
||||
# 确定提取策略
|
||||
strategy = self._determine_extraction_strategy(processed_text, context)
|
||||
|
||||
# 根据策略提取记忆
|
||||
if strategy == ExtractionStrategy.LLM_BASED:
|
||||
memories = await self._extract_with_llm(processed_text, context, user_id, timestamp)
|
||||
elif strategy == ExtractionStrategy.RULE_BASED:
|
||||
memories = self._extract_with_rules(processed_text, context, user_id, timestamp)
|
||||
else: # HYBRID
|
||||
memories = await self._extract_with_hybrid(processed_text, context, user_id, timestamp)
|
||||
# 使用LLM提取记忆
|
||||
memories = await self._extract_with_llm(conversation_text, context, user_id, timestamp)
|
||||
|
||||
# 后处理和验证
|
||||
validated_memories = self._validate_and_enhance_memories(memories, context)
|
||||
@@ -129,41 +118,6 @@ class MemoryBuilder:
|
||||
self.extraction_stats["failed_extractions"] += 1
|
||||
raise
|
||||
|
||||
def _preprocess_text(self, text: str) -> str:
|
||||
"""预处理文本"""
|
||||
# 移除多余的空白字符
|
||||
text = re.sub(r'\s+', ' ', text.strip())
|
||||
|
||||
# 移除特殊字符,但保留基本标点
|
||||
text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?、;:""''()【】]', '', text)
|
||||
|
||||
# 截断过长的文本
|
||||
if len(text) > 2000:
|
||||
text = text[:2000] + "..."
|
||||
|
||||
return text
|
||||
|
||||
def _determine_extraction_strategy(self, text: str, context: Dict[str, Any]) -> ExtractionStrategy:
|
||||
"""确定提取策略"""
|
||||
text_length = len(text)
|
||||
has_structured_data = any(key in context for key in ["structured_data", "entities", "keywords"])
|
||||
message_type = context.get("message_type", "normal")
|
||||
|
||||
# 短文本使用规则提取
|
||||
if text_length < 50:
|
||||
return ExtractionStrategy.RULE_BASED
|
||||
|
||||
# 包含结构化数据使用混合策略
|
||||
if has_structured_data:
|
||||
return ExtractionStrategy.HYBRID
|
||||
|
||||
# 系统消息或命令使用规则提取
|
||||
if message_type in ["command", "system"]:
|
||||
return ExtractionStrategy.RULE_BASED
|
||||
|
||||
# 默认使用LLM提取
|
||||
return ExtractionStrategy.LLM_BASED
|
||||
|
||||
async def _extract_with_llm(
|
||||
self,
|
||||
text: str,
|
||||
@@ -190,79 +144,10 @@ class MemoryBuilder:
|
||||
logger.error(f"LLM提取失败: {e}")
|
||||
raise MemoryExtractionError(str(e)) from e
|
||||
|
||||
def _extract_with_rules(
|
||||
self,
|
||||
text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""使用规则提取记忆"""
|
||||
memories = []
|
||||
|
||||
subjects = self._resolve_conversation_participants(context, user_id)
|
||||
|
||||
# 规则1: 检测个人信息
|
||||
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(personal_info)
|
||||
|
||||
# 规则2: 检测偏好信息
|
||||
preferences = self._extract_preferences(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(preferences)
|
||||
|
||||
# 规则3: 检测事件信息
|
||||
events = self._extract_events(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(events)
|
||||
|
||||
return memories
|
||||
|
||||
async def _extract_with_hybrid(
|
||||
self,
|
||||
text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""混合策略提取记忆"""
|
||||
all_memories = []
|
||||
|
||||
# 首先使用规则提取
|
||||
rule_memories = self._extract_with_rules(text, context, user_id, timestamp)
|
||||
all_memories.extend(rule_memories)
|
||||
|
||||
# 然后使用LLM提取
|
||||
llm_memories = await self._extract_with_llm(text, context, user_id, timestamp)
|
||||
|
||||
# 合并和去重
|
||||
final_memories = self._merge_hybrid_results(all_memories, llm_memories)
|
||||
|
||||
return final_memories
|
||||
|
||||
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
|
||||
"""构建LLM提取提示"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
chat_id = context.get("chat_id", "unknown")
|
||||
message_type = context.get("message_type", "normal")
|
||||
target_user_id = context.get("user_id", "用户")
|
||||
target_user_id = str(target_user_id)
|
||||
|
||||
target_user_name = (
|
||||
context.get("user_display_name")
|
||||
or context.get("user_name")
|
||||
or context.get("nickname")
|
||||
or context.get("sender_name")
|
||||
)
|
||||
if isinstance(target_user_name, str):
|
||||
target_user_name = target_user_name.strip()
|
||||
else:
|
||||
target_user_name = ""
|
||||
|
||||
if not target_user_name or self._looks_like_system_identifier(target_user_name):
|
||||
target_user_name = "该用户"
|
||||
|
||||
target_user_id_display = target_user_id
|
||||
if self._looks_like_system_identifier(target_user_id_display):
|
||||
target_user_id_display = "(系统ID,勿写入记忆)"
|
||||
|
||||
bot_name = context.get("bot_name")
|
||||
bot_identity = context.get("bot_identity")
|
||||
@@ -966,145 +851,6 @@ class MemoryBuilder:
|
||||
return f"{subject_phrase}{predicate}".strip()
|
||||
return subject_phrase
|
||||
|
||||
def _extract_personal_info(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取个人信息"""
|
||||
memories = []
|
||||
|
||||
# 常见个人信息模式
|
||||
patterns = {
|
||||
r"我叫(\w+)": ("is_named", {"name": "$1"}),
|
||||
r"我今年(\d+)岁": ("is_age", {"age": "$1"}),
|
||||
r"我是(\w+)": ("is_profession", {"profession": "$1"}),
|
||||
r"我住在(\w+)": ("lives_in", {"location": "$1"}),
|
||||
r"我的电话是(\d+)": ("has_phone", {"phone": "$1"}),
|
||||
r"我的邮箱是(\w+@\w+\.\w+)": ("has_email", {"email": "$1"}),
|
||||
}
|
||||
|
||||
for pattern, (predicate, obj_template) in patterns.items():
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
obj = obj_template
|
||||
for i, group in enumerate(match.groups(), 1):
|
||||
obj = {k: v.replace(f"${i}", group) for k, v in obj.items()}
|
||||
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subjects,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
memory_type=MemoryType.PERSONAL_FACT,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.HIGH,
|
||||
confidence=ConfidenceLevel.HIGH,
|
||||
display=self._compose_display_text(subjects, predicate, obj)
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _extract_preferences(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取偏好信息"""
|
||||
memories = []
|
||||
|
||||
# 偏好模式
|
||||
preference_patterns = [
|
||||
(r"我喜欢(.+)", "likes"),
|
||||
(r"我不喜欢(.+)", "dislikes"),
|
||||
(r"我爱吃(.+)", "likes_food"),
|
||||
(r"我讨厌(.+)", "hates"),
|
||||
(r"我最喜欢的(.+)", "favorite_is"),
|
||||
]
|
||||
|
||||
for pattern, predicate in preference_patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subjects,
|
||||
predicate=predicate,
|
||||
obj=match.group(1),
|
||||
memory_type=MemoryType.PREFERENCE,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.NORMAL,
|
||||
confidence=ConfidenceLevel.MEDIUM,
|
||||
display=self._compose_display_text(subjects, predicate, match.group(1))
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _extract_events(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取事件信息"""
|
||||
memories = []
|
||||
|
||||
# 事件关键词
|
||||
event_keywords = ["明天", "今天", "昨天", "上周", "下周", "约会", "会议", "活动", "旅行", "生日"]
|
||||
|
||||
if any(keyword in text for keyword in event_keywords):
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subjects,
|
||||
predicate="mentioned_event",
|
||||
obj={"event_text": text, "timestamp": timestamp},
|
||||
memory_type=MemoryType.EVENT,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.NORMAL,
|
||||
confidence=ConfidenceLevel.MEDIUM,
|
||||
display=self._compose_display_text(subjects, "mentioned_event", text)
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _merge_hybrid_results(
|
||||
self,
|
||||
rule_memories: List[MemoryChunk],
|
||||
llm_memories: List[MemoryChunk]
|
||||
) -> List[MemoryChunk]:
|
||||
"""合并混合策略结果"""
|
||||
all_memories = rule_memories.copy()
|
||||
|
||||
# 添加LLM记忆,避免重复
|
||||
for llm_memory in llm_memories:
|
||||
is_duplicate = False
|
||||
for rule_memory in rule_memories:
|
||||
if llm_memory.is_similar_to(rule_memory, threshold=0.7):
|
||||
is_duplicate = True
|
||||
# 合并置信度
|
||||
rule_memory.metadata.confidence = ConfidenceLevel(
|
||||
max(rule_memory.metadata.confidence.value, llm_memory.metadata.confidence.value)
|
||||
)
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
all_memories.append(llm_memory)
|
||||
|
||||
return all_memories
|
||||
|
||||
def _validate_and_enhance_memories(
|
||||
self,
|
||||
memories: List[MemoryChunk],
|
||||
|
||||
@@ -127,6 +127,8 @@ class MemoryMetadata:
|
||||
|
||||
# 来源信息
|
||||
source_context: Optional[str] = None # 来源上下文片段
|
||||
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
|
||||
source: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -150,6 +152,19 @@ class MemoryMetadata:
|
||||
if self.last_forgetting_check == 0:
|
||||
self.last_forgetting_check = current_time
|
||||
|
||||
# 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步
|
||||
if not getattr(self, 'source', None) and getattr(self, 'source_context', None):
|
||||
try:
|
||||
self.source = str(self.source_context)
|
||||
except Exception:
|
||||
self.source = None
|
||||
# 如果有 source 字段但 source_context 为空,也同步回去
|
||||
if not getattr(self, 'source_context', None) and getattr(self, 'source', None):
|
||||
try:
|
||||
self.source_context = str(self.source)
|
||||
except Exception:
|
||||
self.source_context = None
|
||||
|
||||
def update_access(self):
|
||||
"""更新访问信息"""
|
||||
current_time = time.time()
|
||||
|
||||
316
src/chat/memory_system/memory_metadata_index.py
Normal file
316
src/chat/memory_system/memory_metadata_index.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆元数据索引管理器
|
||||
使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryType, ImportanceLevel, ConfidenceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryMetadataIndexEntry:
|
||||
"""元数据索引条目(轻量级,只用于快速过滤)"""
|
||||
memory_id: str
|
||||
user_id: str
|
||||
|
||||
# 分类信息
|
||||
memory_type: str # MemoryType.value
|
||||
subjects: List[str] # 主语列表
|
||||
objects: List[str] # 宾语列表
|
||||
keywords: List[str] # 关键词列表
|
||||
tags: List[str] # 标签列表
|
||||
|
||||
# 数值字段(用于范围过滤)
|
||||
importance: int # ImportanceLevel.value (1-4)
|
||||
confidence: int # ConfidenceLevel.value (1-4)
|
||||
created_at: float # 创建时间戳
|
||||
access_count: int # 访问次数
|
||||
|
||||
# 可选字段
|
||||
chat_id: Optional[str] = None
|
||||
content_preview: Optional[str] = None # 内容预览(前100字符)
|
||||
|
||||
|
||||
class MemoryMetadataIndex:
|
||||
"""记忆元数据索引管理器"""
|
||||
|
||||
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
||||
self.index_file = Path(index_file)
|
||||
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
|
||||
# 倒排索引(用于快速查找)
|
||||
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
|
||||
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# 加载已有索引
|
||||
self._load_index()
|
||||
|
||||
def _load_index(self):
|
||||
"""从文件加载索引"""
|
||||
if not self.index_file.exists():
|
||||
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.index_file, 'rb') as f:
|
||||
data = orjson.loads(f.read())
|
||||
|
||||
# 重建内存索引
|
||||
for entry_data in data.get('entries', []):
|
||||
entry = MemoryMetadataIndexEntry(**entry_data)
|
||||
self.index[entry.memory_id] = entry
|
||||
self._update_inverted_indices(entry)
|
||||
|
||||
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
def _save_index(self):
|
||||
"""保存索引到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.index_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 序列化所有条目
|
||||
entries = [asdict(entry) for entry in self.index.values()]
|
||||
data = {
|
||||
'version': '1.0',
|
||||
'count': len(entries),
|
||||
'last_updated': datetime.now().isoformat(),
|
||||
'entries': entries
|
||||
}
|
||||
|
||||
# 写入文件(使用临时文件 + 原子重命名)
|
||||
temp_file = self.index_file.with_suffix('.tmp')
|
||||
with open(temp_file, 'wb') as f:
|
||||
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
|
||||
|
||||
temp_file.replace(self.index_file)
|
||||
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
|
||||
"""更新倒排索引"""
|
||||
memory_id = entry.memory_id
|
||||
|
||||
# 类型索引
|
||||
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
|
||||
|
||||
# 主语索引
|
||||
for subject in entry.subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm:
|
||||
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
|
||||
|
||||
# 关键词索引
|
||||
for keyword in entry.keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
if keyword_norm:
|
||||
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
|
||||
|
||||
# 标签索引
|
||||
for tag in entry.tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
if tag_norm:
|
||||
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
|
||||
|
||||
def add_or_update(self, entry: MemoryMetadataIndexEntry):
|
||||
"""添加或更新索引条目"""
|
||||
with self.lock:
|
||||
# 如果已存在,先从倒排索引中移除旧记录
|
||||
if entry.memory_id in self.index:
|
||||
self._remove_from_inverted_indices(entry.memory_id)
|
||||
|
||||
# 添加新记录
|
||||
self.index[entry.memory_id] = entry
|
||||
self._update_inverted_indices(entry)
|
||||
|
||||
def _remove_from_inverted_indices(self, memory_id: str):
|
||||
"""从倒排索引中移除记录"""
|
||||
if memory_id not in self.index:
|
||||
return
|
||||
|
||||
entry = self.index[memory_id]
|
||||
|
||||
# 从类型索引移除
|
||||
if entry.memory_type in self.type_index:
|
||||
self.type_index[entry.memory_type].discard(memory_id)
|
||||
|
||||
# 从主语索引移除
|
||||
for subject in entry.subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm in self.subject_index:
|
||||
self.subject_index[subject_norm].discard(memory_id)
|
||||
|
||||
# 从关键词索引移除
|
||||
for keyword in entry.keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
if keyword_norm in self.keyword_index:
|
||||
self.keyword_index[keyword_norm].discard(memory_id)
|
||||
|
||||
# 从标签索引移除
|
||||
for tag in entry.tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
if tag_norm in self.tag_index:
|
||||
self.tag_index[tag_norm].discard(memory_id)
|
||||
|
||||
def remove(self, memory_id: str):
|
||||
"""移除索引条目"""
|
||||
with self.lock:
|
||||
if memory_id in self.index:
|
||||
self._remove_from_inverted_indices(memory_id)
|
||||
del self.index[memory_id]
|
||||
|
||||
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
|
||||
"""批量添加或更新"""
|
||||
with self.lock:
|
||||
for entry in entries:
|
||||
self.add_or_update(entry)
|
||||
|
||||
def save(self):
|
||||
"""保存索引到磁盘"""
|
||||
with self.lock:
|
||||
self._save_index()
|
||||
|
||||
def search(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
keywords: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
importance_min: Optional[int] = None,
|
||||
importance_max: Optional[int] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
搜索符合条件的记忆ID列表(支持模糊匹配)
|
||||
|
||||
Returns:
|
||||
List[str]: 符合条件的 memory_id 列表
|
||||
"""
|
||||
with self.lock:
|
||||
# 初始候选集(所有记忆)
|
||||
candidate_ids: Optional[Set[str]] = None
|
||||
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
candidate_ids = {
|
||||
mid for mid, entry in self.index.items()
|
||||
if entry.user_id == user_id
|
||||
}
|
||||
else:
|
||||
candidate_ids = set(self.index.keys())
|
||||
|
||||
# 类型过滤(OR关系)
|
||||
if memory_types:
|
||||
type_ids = set()
|
||||
for mtype in memory_types:
|
||||
type_ids.update(self.type_index.get(mtype, set()))
|
||||
candidate_ids &= type_ids
|
||||
|
||||
# 主语过滤(OR关系,支持模糊匹配)
|
||||
if subjects:
|
||||
subject_ids = set()
|
||||
for subject in subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
# 精确匹配
|
||||
if subject_norm in self.subject_index:
|
||||
subject_ids.update(self.subject_index[subject_norm])
|
||||
# 模糊匹配(包含)
|
||||
for indexed_subject, ids in self.subject_index.items():
|
||||
if subject_norm in indexed_subject or indexed_subject in subject_norm:
|
||||
subject_ids.update(ids)
|
||||
candidate_ids &= subject_ids
|
||||
|
||||
# 关键词过滤(OR关系,支持模糊匹配)
|
||||
if keywords:
|
||||
keyword_ids = set()
|
||||
for keyword in keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
# 精确匹配
|
||||
if keyword_norm in self.keyword_index:
|
||||
keyword_ids.update(self.keyword_index[keyword_norm])
|
||||
# 模糊匹配(包含)
|
||||
for indexed_keyword, ids in self.keyword_index.items():
|
||||
if keyword_norm in indexed_keyword or indexed_keyword in keyword_norm:
|
||||
keyword_ids.update(ids)
|
||||
candidate_ids &= keyword_ids
|
||||
|
||||
# 标签过滤(OR关系)
|
||||
if tags:
|
||||
tag_ids = set()
|
||||
for tag in tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
tag_ids.update(self.tag_index.get(tag_norm, set()))
|
||||
candidate_ids &= tag_ids
|
||||
|
||||
# 重要性过滤
|
||||
if importance_min is not None or importance_max is not None:
|
||||
importance_ids = {
|
||||
mid for mid in candidate_ids
|
||||
if (importance_min is None or self.index[mid].importance >= importance_min)
|
||||
and (importance_max is None or self.index[mid].importance <= importance_max)
|
||||
}
|
||||
candidate_ids &= importance_ids
|
||||
|
||||
# 时间范围过滤
|
||||
if created_after is not None or created_before is not None:
|
||||
time_ids = {
|
||||
mid for mid in candidate_ids
|
||||
if (created_after is None or self.index[mid].created_at >= created_after)
|
||||
and (created_before is None or self.index[mid].created_at <= created_before)
|
||||
}
|
||||
candidate_ids &= time_ids
|
||||
|
||||
# 转换为列表并排序(按创建时间倒序)
|
||||
result_ids = sorted(
|
||||
candidate_ids,
|
||||
key=lambda mid: self.index[mid].created_at,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 限制数量
|
||||
if limit:
|
||||
result_ids = result_ids[:limit]
|
||||
|
||||
logger.debug(
|
||||
f"元数据索引搜索: types={memory_types}, subjects={subjects}, "
|
||||
f"keywords={keywords}, 返回={len(result_ids)}条"
|
||||
)
|
||||
|
||||
return result_ids
|
||||
|
||||
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
|
||||
"""获取单个索引条目"""
|
||||
return self.index.get(memory_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
with self.lock:
|
||||
return {
|
||||
'total_memories': len(self.index),
|
||||
'types': {mtype: len(ids) for mtype, ids in self.type_index.items()},
|
||||
'subjects_count': len(self.subject_index),
|
||||
'keywords_count': len(self.keyword_index),
|
||||
'tags_count': len(self.tag_index),
|
||||
}
|
||||
@@ -380,11 +380,11 @@ class MemorySystem:
|
||||
self.status = original_status
|
||||
return []
|
||||
|
||||
# 2. 构建记忆块
|
||||
# 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享)
|
||||
memory_chunks = await self.memory_builder.build_memories(
|
||||
conversation_text,
|
||||
normalized_context,
|
||||
GLOBAL_MEMORY_SCOPE,
|
||||
GLOBAL_MEMORY_SCOPE, # 强制使用 global,不区分用户
|
||||
timestamp or time.time()
|
||||
)
|
||||
|
||||
@@ -609,7 +609,7 @@ class MemorySystem:
|
||||
limit: int = 5,
|
||||
**kwargs
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆(简化版,使用统一存储)"""
|
||||
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
|
||||
raw_query = query_text or kwargs.get("query")
|
||||
if not raw_query:
|
||||
raise ValueError("query_text 或 query 参数不能为空")
|
||||
@@ -619,6 +619,8 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
context = context or {}
|
||||
|
||||
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
|
||||
resolved_user_id = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
self.status = MemorySystemStatus.RETRIEVING
|
||||
@@ -626,48 +628,165 @@ class MemorySystem:
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
|
||||
|
||||
effective_limit = limit or self.config.final_recall_limit
|
||||
|
||||
# 构建过滤器
|
||||
filters = {
|
||||
"user_id": resolved_user_id
|
||||
# === 阶段一:元数据粗筛(软性过滤) ===
|
||||
coarse_filters = {
|
||||
"user_id": GLOBAL_MEMORY_SCOPE, # 必选:确保作用域正确
|
||||
}
|
||||
|
||||
# 可选:添加重要性阈值(过滤低价值记忆)
|
||||
if hasattr(self.config, 'memory_importance_threshold'):
|
||||
importance_threshold = self.config.memory_importance_threshold
|
||||
if importance_threshold > 0:
|
||||
coarse_filters["importance"] = {"$gte": importance_threshold}
|
||||
logger.debug(f"[阶段一] 启用重要性过滤: >= {importance_threshold}")
|
||||
|
||||
# 可选:添加时间范围(只搜索最近N天)
|
||||
if hasattr(self.config, 'memory_recency_days'):
|
||||
recency_days = self.config.memory_recency_days
|
||||
if recency_days > 0:
|
||||
cutoff_time = time.time() - (recency_days * 24 * 3600)
|
||||
coarse_filters["created_at"] = {"$gte": cutoff_time}
|
||||
logger.debug(f"[阶段一] 启用时间过滤: 最近 {recency_days} 天")
|
||||
|
||||
# 应用查询规划结果
|
||||
# 应用查询规划(优化查询语句并构建元数据过滤)
|
||||
optimized_query = raw_query
|
||||
metadata_filters = {}
|
||||
|
||||
if self.query_planner:
|
||||
try:
|
||||
query_plan = await self.query_planner.plan_query(raw_query, normalized_context)
|
||||
if getattr(query_plan, "memory_types", None):
|
||||
filters["memory_types"] = [mt.value for mt in query_plan.memory_types]
|
||||
if getattr(query_plan, "subject_includes", None):
|
||||
filters["keywords"] = query_plan.subject_includes
|
||||
|
||||
# 使用LLM优化后的查询语句(更精确的语义表达)
|
||||
if getattr(query_plan, "semantic_query", None):
|
||||
raw_query = query_plan.semantic_query
|
||||
optimized_query = query_plan.semantic_query
|
||||
|
||||
# 构建JSON元数据过滤条件(用于阶段一粗筛)
|
||||
# 将查询规划的结果转换为元数据过滤条件
|
||||
if getattr(query_plan, "memory_types", None):
|
||||
metadata_filters['memory_types'] = [mt.value for mt in query_plan.memory_types]
|
||||
|
||||
if getattr(query_plan, "subject_includes", None):
|
||||
metadata_filters['subjects'] = query_plan.subject_includes
|
||||
|
||||
if getattr(query_plan, "required_keywords", None):
|
||||
metadata_filters['keywords'] = query_plan.required_keywords
|
||||
|
||||
# 时间范围过滤
|
||||
recency = getattr(query_plan, "recency_preference", "any")
|
||||
current_time = time.time()
|
||||
if recency == "recent":
|
||||
# 最近7天
|
||||
metadata_filters['created_after'] = current_time - (7 * 24 * 3600)
|
||||
elif recency == "historical":
|
||||
# 30天以前
|
||||
metadata_filters['created_before'] = current_time - (30 * 24 * 3600)
|
||||
|
||||
# 添加用户ID到元数据过滤
|
||||
metadata_filters['user_id'] = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
logger.debug(f"[阶段一] 查询优化: '{raw_query}' → '{optimized_query}'")
|
||||
logger.debug(f"[阶段一] 元数据过滤条件: {metadata_filters}")
|
||||
|
||||
except Exception as plan_exc:
|
||||
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
|
||||
logger.warning("查询规划失败,使用原始查询: %s", plan_exc, exc_info=True)
|
||||
# 即使查询规划失败,也保留基本的user_id过滤
|
||||
metadata_filters = {'user_id': GLOBAL_MEMORY_SCOPE}
|
||||
|
||||
# 使用Vector DB存储搜索
|
||||
# === 阶段二:向量精筛 ===
|
||||
coarse_limit = self.config.coarse_recall_limit # 粗筛阶段返回更多候选
|
||||
|
||||
logger.debug(f"[阶段二] 开始向量搜索: query='{optimized_query[:60]}...', limit={coarse_limit}")
|
||||
|
||||
search_results = await self.unified_storage.search_similar_memories(
|
||||
query_text=raw_query,
|
||||
limit=effective_limit,
|
||||
filters=filters
|
||||
query_text=optimized_query,
|
||||
limit=coarse_limit,
|
||||
filters=coarse_filters, # ChromaDB where条件(保留兼容)
|
||||
metadata_filters=metadata_filters # JSON元数据索引过滤
|
||||
)
|
||||
|
||||
logger.info(f"[阶段二] 向量搜索完成: 返回 {len(search_results)} 条候选")
|
||||
|
||||
# 转换为记忆对象 - search_results 返回 List[Tuple[MemoryChunk, float]]
|
||||
final_memories = []
|
||||
for memory, similarity_score in search_results:
|
||||
# === 阶段三:综合重排 ===
|
||||
scored_memories = []
|
||||
current_time = time.time()
|
||||
|
||||
for memory, vector_similarity in search_results:
|
||||
# 1. 向量相似度得分(已归一化到 0-1)
|
||||
vector_score = vector_similarity
|
||||
|
||||
# 2. 时效性得分(指数衰减,30天半衰期)
|
||||
age_seconds = current_time - memory.metadata.created_at
|
||||
age_days = age_seconds / (24 * 3600)
|
||||
# 使用 math.exp 而非 np.exp(避免依赖numpy)
|
||||
import math
|
||||
recency_score = math.exp(-age_days / 30)
|
||||
|
||||
# 3. 重要性得分(枚举值转换为归一化得分 0-1)
|
||||
# ImportanceLevel: LOW=1, NORMAL=2, HIGH=3, CRITICAL=4
|
||||
importance_enum = memory.metadata.importance
|
||||
if hasattr(importance_enum, 'value'):
|
||||
# 枚举类型,转换为0-1范围:(value - 1) / 3
|
||||
importance_score = (importance_enum.value - 1) / 3.0
|
||||
else:
|
||||
# 如果已经是数值,直接使用
|
||||
importance_score = float(importance_enum) if importance_enum else 0.5
|
||||
|
||||
# 4. 访问频率得分(归一化,访问10次以上得满分)
|
||||
access_count = memory.metadata.access_count
|
||||
frequency_score = min(access_count / 10.0, 1.0)
|
||||
|
||||
# 综合得分(加权平均)
|
||||
final_score = (
|
||||
self.config.vector_weight * vector_score +
|
||||
self.config.recency_weight * recency_score +
|
||||
self.config.context_weight * importance_score +
|
||||
0.1 * frequency_score # 访问频率权重(固定10%)
|
||||
)
|
||||
|
||||
scored_memories.append((memory, final_score, {
|
||||
"vector": vector_score,
|
||||
"recency": recency_score,
|
||||
"importance": importance_score,
|
||||
"frequency": frequency_score,
|
||||
"final": final_score
|
||||
}))
|
||||
|
||||
# 更新访问记录
|
||||
memory.update_access()
|
||||
final_memories.append(memory)
|
||||
|
||||
# 按综合得分排序
|
||||
scored_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 返回 Top-K
|
||||
final_memories = [mem for mem, score, details in scored_memories[:effective_limit]]
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
|
||||
# 详细日志
|
||||
if scored_memories:
|
||||
logger.info(f"[阶段三] 综合重排完成: Top 3 得分详情")
|
||||
for i, (mem, score, details) in enumerate(scored_memories[:3], 1):
|
||||
try:
|
||||
summary = mem.content[:60] if hasattr(mem, 'content') and mem.content else ""
|
||||
except:
|
||||
summary = ""
|
||||
logger.info(
|
||||
f" #{i} | final={details['final']:.3f} "
|
||||
f"(vec={details['vector']:.3f}, rec={details['recency']:.3f}, "
|
||||
f"imp={details['importance']:.3f}, freq={details['frequency']:.3f}) "
|
||||
f"| {summary}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"✅ 简化记忆检索完成"
|
||||
"✅ 三阶段记忆检索完成"
|
||||
f" | user={resolved_user_id}"
|
||||
f" | count={len(final_memories)}"
|
||||
f" | 粗筛={len(search_results)}"
|
||||
f" | 精筛={len(scored_memories)}"
|
||||
f" | 返回={len(final_memories)}"
|
||||
f" | duration={retrieval_time:.3f}s"
|
||||
f" | query='{raw_query}'"
|
||||
f" | query='{optimized_query[:60]}...'"
|
||||
)
|
||||
|
||||
self.last_retrieval_time = time.time()
|
||||
@@ -717,8 +836,8 @@ class MemorySystem:
|
||||
except Exception:
|
||||
context = dict(raw_context or {})
|
||||
|
||||
# 基础字段(统一使用全局作用域)
|
||||
context["user_id"] = GLOBAL_MEMORY_SCOPE
|
||||
# 基础字段:强制使用传入的 user_id 参数(已统一为 GLOBAL_MEMORY_SCOPE)
|
||||
context["user_id"] = user_id or GLOBAL_MEMORY_SCOPE
|
||||
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
|
||||
context["message_type"] = context.get("message_type") or "normal"
|
||||
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
|
||||
|
||||
@@ -26,6 +26,7 @@ from src.common.vector_db import vector_db_service
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -38,7 +39,7 @@ class VectorStorageConfig:
|
||||
metadata_collection: str = "memory_metadata_v2"
|
||||
|
||||
# 检索配置
|
||||
similarity_threshold: float = 0.8
|
||||
similarity_threshold: float = 0.5 # 降低阈值以提高召回率(0.5-0.6 是合理范围)
|
||||
search_limit: int = 20
|
||||
batch_size: int = 100
|
||||
|
||||
@@ -50,6 +51,26 @@ class VectorStorageConfig:
|
||||
# 遗忘配置
|
||||
enable_forgetting: bool = True
|
||||
retention_hours: int = 24 * 30 # 30天
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls):
|
||||
"""从全局配置创建实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
memory_cfg = global_config.memory
|
||||
|
||||
return cls(
|
||||
memory_collection=getattr(memory_cfg, 'vector_db_memory_collection', 'unified_memory_v2'),
|
||||
metadata_collection=getattr(memory_cfg, 'vector_db_metadata_collection', 'memory_metadata_v2'),
|
||||
similarity_threshold=getattr(memory_cfg, 'vector_db_similarity_threshold', 0.5),
|
||||
search_limit=getattr(memory_cfg, 'vector_db_search_limit', 20),
|
||||
batch_size=getattr(memory_cfg, 'vector_db_batch_size', 100),
|
||||
enable_caching=getattr(memory_cfg, 'vector_db_enable_caching', True),
|
||||
cache_size_limit=getattr(memory_cfg, 'vector_db_cache_size_limit', 1000),
|
||||
auto_cleanup_interval=getattr(memory_cfg, 'vector_db_auto_cleanup_interval', 3600),
|
||||
enable_forgetting=getattr(memory_cfg, 'enable_memory_forgetting', True),
|
||||
retention_hours=getattr(memory_cfg, 'vector_db_retention_hours', 720),
|
||||
)
|
||||
|
||||
|
||||
class VectorMemoryStorage:
|
||||
@@ -71,7 +92,16 @@ class VectorMemoryStorage:
|
||||
"""基于Vector DB的记忆存储系统"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
# 默认从全局配置读取,如果没有传入config
|
||||
if config is None:
|
||||
try:
|
||||
self.config = VectorStorageConfig.from_global_config()
|
||||
logger.info("✅ Vector存储配置已从全局配置加载")
|
||||
except Exception as e:
|
||||
logger.warning(f"从全局配置加载失败,使用默认配置: {e}")
|
||||
self.config = VectorStorageConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
# 从配置中获取批处理大小和集合名称
|
||||
self.batch_size = self.config.batch_size
|
||||
@@ -83,6 +113,9 @@ class VectorMemoryStorage:
|
||||
self.cache_timestamps: Dict[str, float] = {}
|
||||
self._cache = self.memory_cache # 别名,兼容旧代码
|
||||
|
||||
# 元数据索引管理器(JSON文件索引)
|
||||
self.metadata_index = MemoryMetadataIndex()
|
||||
|
||||
# 遗忘引擎
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
if self.config.enable_forgetting:
|
||||
@@ -354,14 +387,45 @@ class VectorMemoryStorage:
|
||||
success = True
|
||||
|
||||
if success:
|
||||
# 更新缓存
|
||||
# 更新缓存和元数据索引
|
||||
metadata_entries = []
|
||||
for item in batch:
|
||||
memory_id = item["id"]
|
||||
# 从原始 memories 列表中找到对应的 MemoryChunk
|
||||
memory = next((m for m in memories if (getattr(m.metadata, 'memory_id', None) or getattr(m, 'memory_id', None)) == memory_id), None)
|
||||
if memory:
|
||||
# 更新缓存
|
||||
self._cache[memory_id] = memory
|
||||
success_count += 1
|
||||
|
||||
# 创建元数据索引条目
|
||||
try:
|
||||
index_entry = MemoryMetadataIndexEntry(
|
||||
memory_id=memory_id,
|
||||
user_id=memory.metadata.user_id or "unknown",
|
||||
memory_type=memory.memory_type.value,
|
||||
subjects=memory.subjects,
|
||||
objects=[str(memory.content.object)] if memory.content.object else [],
|
||||
keywords=memory.keywords,
|
||||
tags=memory.tags,
|
||||
importance=memory.metadata.importance.value,
|
||||
confidence=memory.metadata.confidence.value,
|
||||
created_at=memory.metadata.created_at,
|
||||
access_count=memory.metadata.access_count,
|
||||
chat_id=memory.metadata.chat_id,
|
||||
content_preview=str(memory.content)[:100] if memory.content else None
|
||||
)
|
||||
metadata_entries.append(index_entry)
|
||||
except Exception as e:
|
||||
logger.warning(f"创建元数据索引条目失败 (memory_id={memory_id}): {e}")
|
||||
|
||||
# 批量更新元数据索引
|
||||
if metadata_entries:
|
||||
try:
|
||||
self.metadata_index.batch_add_or_update(metadata_entries)
|
||||
logger.debug(f"更新元数据索引: {len(metadata_entries)} 条")
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新元数据索引失败: {e}")
|
||||
else:
|
||||
logger.warning(f"批次存储失败,跳过 {len(batch)} 条记忆")
|
||||
|
||||
@@ -372,6 +436,14 @@ class VectorMemoryStorage:
|
||||
duration = (datetime.now() - start_time).total_seconds()
|
||||
logger.info(f"成功存储 {success_count}/{len(memories)} 条记忆,耗时 {duration:.2f}秒")
|
||||
|
||||
# 保存元数据索引到磁盘
|
||||
if success_count > 0:
|
||||
try:
|
||||
self.metadata_index.save()
|
||||
logger.debug("元数据索引已保存到磁盘")
|
||||
except Exception as e:
|
||||
logger.error(f"保存元数据索引失败: {e}")
|
||||
|
||||
return success_count
|
||||
|
||||
except Exception as e:
|
||||
@@ -388,13 +460,57 @@ class VectorMemoryStorage:
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
# 新增:元数据过滤参数(用于JSON索引粗筛)
|
||||
metadata_filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[MemoryChunk, float]]:
|
||||
"""搜索相似记忆"""
|
||||
"""
|
||||
搜索相似记忆(混合索引模式)
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
limit: 返回数量限制
|
||||
similarity_threshold: 相似度阈值
|
||||
filters: ChromaDB where条件(保留用于兼容)
|
||||
metadata_filters: JSON元数据索引过滤条件,支持:
|
||||
- memory_types: List[str]
|
||||
- subjects: List[str]
|
||||
- keywords: List[str]
|
||||
- tags: List[str]
|
||||
- importance_min: int
|
||||
- importance_max: int
|
||||
- created_after: float
|
||||
- created_before: float
|
||||
- user_id: str
|
||||
"""
|
||||
if not query_text.strip():
|
||||
return []
|
||||
|
||||
try:
|
||||
# === 阶段一:JSON元数据粗筛(可选) ===
|
||||
candidate_ids: Optional[List[str]] = None
|
||||
if metadata_filters:
|
||||
logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}")
|
||||
candidate_ids = self.metadata_index.search(
|
||||
memory_types=metadata_filters.get('memory_types'),
|
||||
subjects=metadata_filters.get('subjects'),
|
||||
keywords=metadata_filters.get('keywords'),
|
||||
tags=metadata_filters.get('tags'),
|
||||
importance_min=metadata_filters.get('importance_min'),
|
||||
importance_max=metadata_filters.get('importance_max'),
|
||||
created_after=metadata_filters.get('created_after'),
|
||||
created_before=metadata_filters.get('created_before'),
|
||||
user_id=metadata_filters.get('user_id'),
|
||||
limit=self.config.search_limit * 2 # 粗筛返回更多候选
|
||||
)
|
||||
logger.info(f"[JSON元数据粗筛] 完成,筛选出 {len(candidate_ids)} 个候选ID")
|
||||
|
||||
# 如果粗筛后没有结果,直接返回
|
||||
if not candidate_ids:
|
||||
logger.warning("JSON元数据粗筛后无候选,返回空结果")
|
||||
return []
|
||||
|
||||
# === 阶段二:向量精筛 ===
|
||||
# 生成查询向量
|
||||
query_embedding = await get_embedding(query_text)
|
||||
if not query_embedding:
|
||||
@@ -405,7 +521,14 @@ class VectorMemoryStorage:
|
||||
# 构建where条件
|
||||
where_conditions = filters or {}
|
||||
|
||||
# 如果有候选ID列表,添加到where条件
|
||||
if candidate_ids:
|
||||
# ChromaDB的where条件需要使用$in操作符
|
||||
where_conditions["memory_id"] = {"$in": candidate_ids}
|
||||
logger.debug(f"[向量精筛] 限制在 {len(candidate_ids)} 个候选ID内搜索")
|
||||
|
||||
# 查询Vector DB
|
||||
logger.debug(f"[向量精筛] 开始,limit={min(limit, self.config.search_limit)}")
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.config.memory_collection,
|
||||
query_embeddings=[query_embedding],
|
||||
|
||||
@@ -311,11 +311,12 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
|
||||
|
||||
# Vector DB配置
|
||||
vector_db_memory_collection: str = Field(default="unified_memory_v2", description="Vector DB集合名称")
|
||||
vector_db_similarity_threshold: float = Field(default=0.8, description="Vector DB相似度阈值")
|
||||
vector_db_memory_collection: str = Field(default="unified_memory_v2", description="Vector DB记忆集合名称")
|
||||
vector_db_metadata_collection: str = Field(default="memory_metadata_v2", description="Vector DB元数据集合名称")
|
||||
vector_db_similarity_threshold: float = Field(default=0.5, description="Vector DB相似度阈值(推荐0.5-0.6,过高会导致检索不到结果)")
|
||||
vector_db_search_limit: int = Field(default=20, description="Vector DB搜索限制")
|
||||
vector_db_batch_size: int = Field(default=100, description="批处理大小")
|
||||
vector_db_enable_caching: bool = Field(default=True, description="启用缓存")
|
||||
vector_db_enable_caching: bool = Field(default=True, description="启用内存缓存")
|
||||
vector_db_cache_size_limit: int = Field(default=1000, description="缓存大小限制")
|
||||
vector_db_auto_cleanup_interval: int = Field(default=3600, description="自动清理间隔(秒)")
|
||||
vector_db_retention_hours: int = Field(default=720, description="记忆保留时间(小时,默认30天)")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "7.1.3"
|
||||
version = "7.1.4"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -303,11 +303,46 @@ max_frequency_bonus = 10.0 # 最大激活频率奖励天数
|
||||
# 休眠机制
|
||||
dormant_threshold_days = 90 # 休眠状态判定天数(超过此天数未访问的记忆进入休眠状态)
|
||||
|
||||
# 统一存储配置 (新增)
|
||||
unified_storage_path = "data/unified_memory" # 统一存储数据路径
|
||||
unified_storage_cache_limit = 10000 # 内存缓存大小限制
|
||||
unified_storage_auto_save_interval = 50 # 自动保存间隔(记忆条数)
|
||||
unified_storage_enable_compression = true # 是否启用数据压缩
|
||||
# 统一存储配置 (已弃用 - 请使用Vector DB配置)
|
||||
# DEPRECATED: unified_storage_path = "data/unified_memory"
|
||||
# DEPRECATED: unified_storage_cache_limit = 10000
|
||||
# DEPRECATED: unified_storage_auto_save_interval = 50
|
||||
# DEPRECATED: unified_storage_enable_compression = true
|
||||
|
||||
# Vector DB存储配置 (新增 - 替代JSON存储)
|
||||
enable_vector_memory_storage = true # 启用Vector DB存储
|
||||
enable_llm_instant_memory = true # 启用基于LLM的瞬时记忆
|
||||
enable_vector_instant_memory = true # 启用基于向量的瞬时记忆
|
||||
|
||||
# Vector DB配置
|
||||
vector_db_memory_collection = "unified_memory_v2" # Vector DB主记忆集合名称
|
||||
vector_db_metadata_collection = "memory_metadata_v2" # Vector DB元数据集合名称
|
||||
vector_db_similarity_threshold = 0.5 # Vector DB相似度阈值 (推荐范围: 0.5-0.6, 过高会导致检索不到结果)
|
||||
vector_db_search_limit = 20 # Vector DB单次搜索返回的最大结果数
|
||||
vector_db_batch_size = 100 # 批处理大小 (批量存储记忆时每批处理的记忆条数)
|
||||
vector_db_enable_caching = true # 启用内存缓存
|
||||
vector_db_cache_size_limit = 1000 # 缓存大小限制 (内存缓存最多保存的记忆条数)
|
||||
vector_db_auto_cleanup_interval = 3600 # 自动清理间隔(秒)
|
||||
vector_db_retention_hours = 720 # 记忆保留时间(小时,默认30天)
|
||||
|
||||
# 多阶段召回配置(可选)
|
||||
# 取消注释以启用更严格的粗筛,适用于大规模记忆库(>10万条)
|
||||
# memory_importance_threshold = 0.3 # 重要性阈值(过滤低价值记忆,范围0.0-1.0)
|
||||
# memory_recency_days = 30 # 时间范围(只搜索最近N天的记忆,0表示不限制)
|
||||
|
||||
# Vector DB配置 (ChromaDB)
|
||||
[vector_db]
|
||||
type = "chromadb" # Vector DB类型
|
||||
path = "data/chroma_db" # Vector DB数据路径
|
||||
|
||||
[vector_db.settings]
|
||||
anonymized_telemetry = false # 禁用匿名遥测
|
||||
allow_reset = true # 允许重置
|
||||
|
||||
[vector_db.collections]
|
||||
unified_memory_v2 = { description = "统一记忆存储V2", hnsw_space = "cosine", version = "2.0" }
|
||||
memory_metadata_v2 = { description = "记忆元数据索引", hnsw_space = "cosine", version = "2.0" }
|
||||
semantic_cache = { description = "语义缓存", hnsw_space = "cosine" }
|
||||
|
||||
[voice]
|
||||
enable_asr = true # 是否启用语音识别,启用后MoFox-Bot可以识别语音消息,启用该功能需要配置语音识别模型[model.voice]
|
||||
|
||||
Reference in New Issue
Block a user