feat(memory): 实现三阶段记忆检索系统并简化提取策略

- 移除规则和混合提取策略,统一使用LLM提取
- 实现三阶段检索:元数据粗筛→向量精筛→综合重排
- 新增JSON元数据索引支持,提升检索效率
- 优化Vector DB配置管理和批处理机制
- 统一记忆作用域为全局,实现完全共享
- 增强查询规划和综合评分算法
This commit is contained in:
Windpicker-owo
2025-10-02 10:13:38 +08:00
parent 6f750e2bac
commit 59bda71f29
9 changed files with 814 additions and 297 deletions

View File

@@ -0,0 +1,139 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
从现有ChromaDB数据重建JSON元数据索引
"""
import asyncio
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
from src.common.logger import get_logger
logger = get_logger(__name__)
async def rebuild_metadata_index():
"""从ChromaDB重建元数据索引"""
print("="*80)
print("重建JSON元数据索引")
print("="*80)
# 初始化记忆系统
print("\n🔧 初始化记忆系统...")
ms = MemorySystem()
await ms.initialize()
print("✅ 记忆系统已初始化")
if not hasattr(ms.unified_storage, 'metadata_index'):
print("❌ 元数据索引管理器未初始化")
return
# 获取所有记忆
print("\n📥 从ChromaDB获取所有记忆...")
from src.common.vector_db import vector_db_service
try:
# 获取集合中的所有记忆ID
collection_name = ms.unified_storage.config.memory_collection
result = vector_db_service.get(
collection_name=collection_name,
include=["documents", "metadatas", "embeddings"]
)
if not result or not result.get("ids"):
print("❌ ChromaDB中没有找到记忆数据")
return
ids = result["ids"]
metadatas = result.get("metadatas", [])
print(f"✅ 找到 {len(ids)} 条记忆")
# 重建元数据索引
print("\n🔨 开始重建元数据索引...")
entries = []
success_count = 0
for i, (memory_id, metadata) in enumerate(zip(ids, metadatas), 1):
try:
# 从ChromaDB元数据重建索引条目
import orjson
entry = MemoryMetadataIndexEntry(
memory_id=memory_id,
user_id=metadata.get("user_id", "unknown"),
memory_type=metadata.get("memory_type", "general"),
subjects=orjson.loads(metadata.get("subjects", "[]")),
objects=[metadata.get("object")] if metadata.get("object") else [],
keywords=orjson.loads(metadata.get("keywords", "[]")),
tags=orjson.loads(metadata.get("tags", "[]")),
importance=2, # 默认NORMAL
confidence=2, # 默认MEDIUM
created_at=metadata.get("created_at", 0.0),
access_count=metadata.get("access_count", 0),
chat_id=metadata.get("chat_id"),
content_preview=None
)
# 尝试解析importance和confidence的枚举名称
if "importance" in metadata:
imp_str = metadata["importance"]
if imp_str == "LOW":
entry.importance = 1
elif imp_str == "NORMAL":
entry.importance = 2
elif imp_str == "HIGH":
entry.importance = 3
elif imp_str == "CRITICAL":
entry.importance = 4
if "confidence" in metadata:
conf_str = metadata["confidence"]
if conf_str == "LOW":
entry.confidence = 1
elif conf_str == "MEDIUM":
entry.confidence = 2
elif conf_str == "HIGH":
entry.confidence = 3
elif conf_str == "VERIFIED":
entry.confidence = 4
entries.append(entry)
success_count += 1
if i % 100 == 0:
print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)")
except Exception as e:
logger.warning(f"处理记忆 {memory_id} 失败: {e}")
continue
print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据")
# 批量更新索引
print("\n💾 保存元数据索引...")
ms.unified_storage.metadata_index.batch_add_or_update(entries)
ms.unified_storage.metadata_index.save()
# 显示统计信息
stats = ms.unified_storage.metadata_index.get_stats()
print(f"\n📊 重建后的索引统计:")
print(f" - 总记忆数: {stats['total_memories']}")
print(f" - 主语数量: {stats['subjects_count']}")
print(f" - 关键词数量: {stats['keywords_count']}")
print(f" - 标签数量: {stats['tags_count']}")
print(f" - 类型分布:")
for mtype, count in stats['types'].items():
print(f" - {mtype}: {count}")
print("\n✅ 元数据索引重建完成!")
except Exception as e:
logger.error(f"重建索引失败: {e}", exc_info=True)
print(f"❌ 重建索引失败: {e}")
if __name__ == '__main__':
asyncio.run(rebuild_metadata_index())

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错
"""
import asyncio
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.chat.memory_system.memory_system import MemorySystem
async def main():
ms = MemorySystem()
await ms.initialize()
results = await ms.retrieve_relevant_memories(query_text="测试查询:杰瑞喵喜欢什么?", limit=3)
print(f"检索到 {len(results)} 条记忆(如果 >0 则表明运行成功)")
for i, m in enumerate(results, 1):
print(f"{i}. id={m.metadata.memory_id} source={getattr(m.metadata, 'source', None)}")
if __name__ == '__main__':
asyncio.run(main())

View File

@@ -96,19 +96,8 @@ class MemoryBuilder:
try:
logger.debug(f"开始从对话构建记忆,文本长度: {len(conversation_text)}")
# 预处理文本
processed_text = self._preprocess_text(conversation_text)
# 确定提取策略
strategy = self._determine_extraction_strategy(processed_text, context)
# 根据策略提取记忆
if strategy == ExtractionStrategy.LLM_BASED:
memories = await self._extract_with_llm(processed_text, context, user_id, timestamp)
elif strategy == ExtractionStrategy.RULE_BASED:
memories = self._extract_with_rules(processed_text, context, user_id, timestamp)
else: # HYBRID
memories = await self._extract_with_hybrid(processed_text, context, user_id, timestamp)
# 使用LLM提取记忆
memories = await self._extract_with_llm(conversation_text, context, user_id, timestamp)
# 后处理和验证
validated_memories = self._validate_and_enhance_memories(memories, context)
@@ -129,41 +118,6 @@ class MemoryBuilder:
self.extraction_stats["failed_extractions"] += 1
raise
def _preprocess_text(self, text: str) -> str:
"""预处理文本"""
# 移除多余的空白字符
text = re.sub(r'\s+', ' ', text.strip())
# 移除特殊字符,但保留基本标点
text = re.sub(r'[^\w\s\u4e00-\u9fff""''()【】]', '', text)
# 截断过长的文本
if len(text) > 2000:
text = text[:2000] + "..."
return text
def _determine_extraction_strategy(self, text: str, context: Dict[str, Any]) -> ExtractionStrategy:
"""确定提取策略"""
text_length = len(text)
has_structured_data = any(key in context for key in ["structured_data", "entities", "keywords"])
message_type = context.get("message_type", "normal")
# 短文本使用规则提取
if text_length < 50:
return ExtractionStrategy.RULE_BASED
# 包含结构化数据使用混合策略
if has_structured_data:
return ExtractionStrategy.HYBRID
# 系统消息或命令使用规则提取
if message_type in ["command", "system"]:
return ExtractionStrategy.RULE_BASED
# 默认使用LLM提取
return ExtractionStrategy.LLM_BASED
async def _extract_with_llm(
self,
text: str,
@@ -190,79 +144,10 @@ class MemoryBuilder:
logger.error(f"LLM提取失败: {e}")
raise MemoryExtractionError(str(e)) from e
def _extract_with_rules(
self,
text: str,
context: Dict[str, Any],
user_id: str,
timestamp: float
) -> List[MemoryChunk]:
"""使用规则提取记忆"""
memories = []
subjects = self._resolve_conversation_participants(context, user_id)
# 规则1: 检测个人信息
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects)
memories.extend(personal_info)
# 规则2: 检测偏好信息
preferences = self._extract_preferences(text, user_id, timestamp, context, subjects)
memories.extend(preferences)
# 规则3: 检测事件信息
events = self._extract_events(text, user_id, timestamp, context, subjects)
memories.extend(events)
return memories
async def _extract_with_hybrid(
self,
text: str,
context: Dict[str, Any],
user_id: str,
timestamp: float
) -> List[MemoryChunk]:
"""混合策略提取记忆"""
all_memories = []
# 首先使用规则提取
rule_memories = self._extract_with_rules(text, context, user_id, timestamp)
all_memories.extend(rule_memories)
# 然后使用LLM提取
llm_memories = await self._extract_with_llm(text, context, user_id, timestamp)
# 合并和去重
final_memories = self._merge_hybrid_results(all_memories, llm_memories)
return final_memories
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
"""构建LLM提取提示"""
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_id = context.get("chat_id", "unknown")
message_type = context.get("message_type", "normal")
target_user_id = context.get("user_id", "用户")
target_user_id = str(target_user_id)
target_user_name = (
context.get("user_display_name")
or context.get("user_name")
or context.get("nickname")
or context.get("sender_name")
)
if isinstance(target_user_name, str):
target_user_name = target_user_name.strip()
else:
target_user_name = ""
if not target_user_name or self._looks_like_system_identifier(target_user_name):
target_user_name = "该用户"
target_user_id_display = target_user_id
if self._looks_like_system_identifier(target_user_id_display):
target_user_id_display = "系统ID勿写入记忆"
bot_name = context.get("bot_name")
bot_identity = context.get("bot_identity")
@@ -966,145 +851,6 @@ class MemoryBuilder:
return f"{subject_phrase}{predicate}".strip()
return subject_phrase
def _extract_personal_info(
self,
text: str,
user_id: str,
timestamp: float,
context: Dict[str, Any],
subjects: List[str]
) -> List[MemoryChunk]:
"""提取个人信息"""
memories = []
# 常见个人信息模式
patterns = {
r"我叫(\w+)": ("is_named", {"name": "$1"}),
r"我今年(\d+)岁": ("is_age", {"age": "$1"}),
r"我是(\w+)": ("is_profession", {"profession": "$1"}),
r"我住在(\w+)": ("lives_in", {"location": "$1"}),
r"我的电话是(\d+)": ("has_phone", {"phone": "$1"}),
r"我的邮箱是(\w+@\w+\.\w+)": ("has_email", {"email": "$1"}),
}
for pattern, (predicate, obj_template) in patterns.items():
match = re.search(pattern, text)
if match:
obj = obj_template
for i, group in enumerate(match.groups(), 1):
obj = {k: v.replace(f"${i}", group) for k, v in obj.items()}
memory = create_memory_chunk(
user_id=user_id,
subject=subjects,
predicate=predicate,
obj=obj,
memory_type=MemoryType.PERSONAL_FACT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.HIGH,
confidence=ConfidenceLevel.HIGH,
display=self._compose_display_text(subjects, predicate, obj)
)
memories.append(memory)
return memories
def _extract_preferences(
self,
text: str,
user_id: str,
timestamp: float,
context: Dict[str, Any],
subjects: List[str]
) -> List[MemoryChunk]:
"""提取偏好信息"""
memories = []
# 偏好模式
preference_patterns = [
(r"我喜欢(.+)", "likes"),
(r"我不喜欢(.+)", "dislikes"),
(r"我爱吃(.+)", "likes_food"),
(r"我讨厌(.+)", "hates"),
(r"我最喜欢的(.+)", "favorite_is"),
]
for pattern, predicate in preference_patterns:
match = re.search(pattern, text)
if match:
memory = create_memory_chunk(
user_id=user_id,
subject=subjects,
predicate=predicate,
obj=match.group(1),
memory_type=MemoryType.PREFERENCE,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM,
display=self._compose_display_text(subjects, predicate, match.group(1))
)
memories.append(memory)
return memories
def _extract_events(
self,
text: str,
user_id: str,
timestamp: float,
context: Dict[str, Any],
subjects: List[str]
) -> List[MemoryChunk]:
"""提取事件信息"""
memories = []
# 事件关键词
event_keywords = ["明天", "今天", "昨天", "上周", "下周", "约会", "会议", "活动", "旅行", "生日"]
if any(keyword in text for keyword in event_keywords):
memory = create_memory_chunk(
user_id=user_id,
subject=subjects,
predicate="mentioned_event",
obj={"event_text": text, "timestamp": timestamp},
memory_type=MemoryType.EVENT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM,
display=self._compose_display_text(subjects, "mentioned_event", text)
)
memories.append(memory)
return memories
def _merge_hybrid_results(
self,
rule_memories: List[MemoryChunk],
llm_memories: List[MemoryChunk]
) -> List[MemoryChunk]:
"""合并混合策略结果"""
all_memories = rule_memories.copy()
# 添加LLM记忆避免重复
for llm_memory in llm_memories:
is_duplicate = False
for rule_memory in rule_memories:
if llm_memory.is_similar_to(rule_memory, threshold=0.7):
is_duplicate = True
# 合并置信度
rule_memory.metadata.confidence = ConfidenceLevel(
max(rule_memory.metadata.confidence.value, llm_memory.metadata.confidence.value)
)
break
if not is_duplicate:
all_memories.append(llm_memory)
return all_memories
def _validate_and_enhance_memories(
self,
memories: List[MemoryChunk],

View File

@@ -127,6 +127,8 @@ class MemoryMetadata:
# 来源信息
source_context: Optional[str] = None # 来源上下文片段
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
source: Optional[str] = None
def __post_init__(self):
"""后初始化处理"""
@@ -150,6 +152,19 @@ class MemoryMetadata:
if self.last_forgetting_check == 0:
self.last_forgetting_check = current_time
# 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步
if not getattr(self, 'source', None) and getattr(self, 'source_context', None):
try:
self.source = str(self.source_context)
except Exception:
self.source = None
# 如果有 source 字段但 source_context 为空,也同步回去
if not getattr(self, 'source_context', None) and getattr(self, 'source', None):
try:
self.source_context = str(self.source)
except Exception:
self.source_context = None
def update_access(self):
"""更新访问信息"""
current_time = time.time()

View File

@@ -0,0 +1,316 @@
# -*- coding: utf-8 -*-
"""
记忆元数据索引管理器
使用JSON文件存储记忆元数据支持快速模糊搜索和过滤
"""
import orjson
import threading
from pathlib import Path
from typing import Dict, List, Optional, Set, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryType, ImportanceLevel, ConfidenceLevel
logger = get_logger(__name__)
@dataclass
class MemoryMetadataIndexEntry:
"""元数据索引条目(轻量级,只用于快速过滤)"""
memory_id: str
user_id: str
# 分类信息
memory_type: str # MemoryType.value
subjects: List[str] # 主语列表
objects: List[str] # 宾语列表
keywords: List[str] # 关键词列表
tags: List[str] # 标签列表
# 数值字段(用于范围过滤)
importance: int # ImportanceLevel.value (1-4)
confidence: int # ConfidenceLevel.value (1-4)
created_at: float # 创建时间戳
access_count: int # 访问次数
# 可选字段
chat_id: Optional[str] = None
content_preview: Optional[str] = None # 内容预览前100字符
class MemoryMetadataIndex:
"""记忆元数据索引管理器"""
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
self.index_file = Path(index_file)
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
# 倒排索引(用于快速查找)
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
self.lock = threading.RLock()
# 加载已有索引
self._load_index()
def _load_index(self):
"""从文件加载索引"""
if not self.index_file.exists():
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
return
try:
with open(self.index_file, 'rb') as f:
data = orjson.loads(f.read())
# 重建内存索引
for entry_data in data.get('entries', []):
entry = MemoryMetadataIndexEntry(**entry_data)
self.index[entry.memory_id] = entry
self._update_inverted_indices(entry)
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
except Exception as e:
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
def _save_index(self):
"""保存索引到文件"""
try:
# 确保目录存在
self.index_file.parent.mkdir(parents=True, exist_ok=True)
# 序列化所有条目
entries = [asdict(entry) for entry in self.index.values()]
data = {
'version': '1.0',
'count': len(entries),
'last_updated': datetime.now().isoformat(),
'entries': entries
}
# 写入文件(使用临时文件 + 原子重命名)
temp_file = self.index_file.with_suffix('.tmp')
with open(temp_file, 'wb') as f:
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
temp_file.replace(self.index_file)
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
except Exception as e:
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
"""更新倒排索引"""
memory_id = entry.memory_id
# 类型索引
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
# 主语索引
for subject in entry.subjects:
subject_norm = subject.strip().lower()
if subject_norm:
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
# 关键词索引
for keyword in entry.keywords:
keyword_norm = keyword.strip().lower()
if keyword_norm:
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
# 标签索引
for tag in entry.tags:
tag_norm = tag.strip().lower()
if tag_norm:
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
def add_or_update(self, entry: MemoryMetadataIndexEntry):
"""添加或更新索引条目"""
with self.lock:
# 如果已存在,先从倒排索引中移除旧记录
if entry.memory_id in self.index:
self._remove_from_inverted_indices(entry.memory_id)
# 添加新记录
self.index[entry.memory_id] = entry
self._update_inverted_indices(entry)
def _remove_from_inverted_indices(self, memory_id: str):
"""从倒排索引中移除记录"""
if memory_id not in self.index:
return
entry = self.index[memory_id]
# 从类型索引移除
if entry.memory_type in self.type_index:
self.type_index[entry.memory_type].discard(memory_id)
# 从主语索引移除
for subject in entry.subjects:
subject_norm = subject.strip().lower()
if subject_norm in self.subject_index:
self.subject_index[subject_norm].discard(memory_id)
# 从关键词索引移除
for keyword in entry.keywords:
keyword_norm = keyword.strip().lower()
if keyword_norm in self.keyword_index:
self.keyword_index[keyword_norm].discard(memory_id)
# 从标签索引移除
for tag in entry.tags:
tag_norm = tag.strip().lower()
if tag_norm in self.tag_index:
self.tag_index[tag_norm].discard(memory_id)
def remove(self, memory_id: str):
"""移除索引条目"""
with self.lock:
if memory_id in self.index:
self._remove_from_inverted_indices(memory_id)
del self.index[memory_id]
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
"""批量添加或更新"""
with self.lock:
for entry in entries:
self.add_or_update(entry)
def save(self):
"""保存索引到磁盘"""
with self.lock:
self._save_index()
def search(
self,
memory_types: Optional[List[str]] = None,
subjects: Optional[List[str]] = None,
keywords: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
importance_min: Optional[int] = None,
importance_max: Optional[int] = None,
created_after: Optional[float] = None,
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None
) -> List[str]:
"""
搜索符合条件的记忆ID列表支持模糊匹配
Returns:
List[str]: 符合条件的 memory_id 列表
"""
with self.lock:
# 初始候选集(所有记忆)
candidate_ids: Optional[Set[str]] = None
# 用户过滤(必选)
if user_id:
candidate_ids = {
mid for mid, entry in self.index.items()
if entry.user_id == user_id
}
else:
candidate_ids = set(self.index.keys())
# 类型过滤OR关系
if memory_types:
type_ids = set()
for mtype in memory_types:
type_ids.update(self.type_index.get(mtype, set()))
candidate_ids &= type_ids
# 主语过滤OR关系支持模糊匹配
if subjects:
subject_ids = set()
for subject in subjects:
subject_norm = subject.strip().lower()
# 精确匹配
if subject_norm in self.subject_index:
subject_ids.update(self.subject_index[subject_norm])
# 模糊匹配(包含)
for indexed_subject, ids in self.subject_index.items():
if subject_norm in indexed_subject or indexed_subject in subject_norm:
subject_ids.update(ids)
candidate_ids &= subject_ids
# 关键词过滤OR关系支持模糊匹配
if keywords:
keyword_ids = set()
for keyword in keywords:
keyword_norm = keyword.strip().lower()
# 精确匹配
if keyword_norm in self.keyword_index:
keyword_ids.update(self.keyword_index[keyword_norm])
# 模糊匹配(包含)
for indexed_keyword, ids in self.keyword_index.items():
if keyword_norm in indexed_keyword or indexed_keyword in keyword_norm:
keyword_ids.update(ids)
candidate_ids &= keyword_ids
# 标签过滤OR关系
if tags:
tag_ids = set()
for tag in tags:
tag_norm = tag.strip().lower()
tag_ids.update(self.tag_index.get(tag_norm, set()))
candidate_ids &= tag_ids
# 重要性过滤
if importance_min is not None or importance_max is not None:
importance_ids = {
mid for mid in candidate_ids
if (importance_min is None or self.index[mid].importance >= importance_min)
and (importance_max is None or self.index[mid].importance <= importance_max)
}
candidate_ids &= importance_ids
# 时间范围过滤
if created_after is not None or created_before is not None:
time_ids = {
mid for mid in candidate_ids
if (created_after is None or self.index[mid].created_at >= created_after)
and (created_before is None or self.index[mid].created_at <= created_before)
}
candidate_ids &= time_ids
# 转换为列表并排序(按创建时间倒序)
result_ids = sorted(
candidate_ids,
key=lambda mid: self.index[mid].created_at,
reverse=True
)
# 限制数量
if limit:
result_ids = result_ids[:limit]
logger.debug(
f"元数据索引搜索: types={memory_types}, subjects={subjects}, "
f"keywords={keywords}, 返回={len(result_ids)}"
)
return result_ids
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
"""获取单个索引条目"""
return self.index.get(memory_id)
def get_stats(self) -> Dict[str, Any]:
"""获取索引统计信息"""
with self.lock:
return {
'total_memories': len(self.index),
'types': {mtype: len(ids) for mtype, ids in self.type_index.items()},
'subjects_count': len(self.subject_index),
'keywords_count': len(self.keyword_index),
'tags_count': len(self.tag_index),
}

View File

@@ -380,11 +380,11 @@ class MemorySystem:
self.status = original_status
return []
# 2. 构建记忆块
# 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享)
memory_chunks = await self.memory_builder.build_memories(
conversation_text,
normalized_context,
GLOBAL_MEMORY_SCOPE,
GLOBAL_MEMORY_SCOPE, # 强制使用 global不区分用户
timestamp or time.time()
)
@@ -609,7 +609,7 @@ class MemorySystem:
limit: int = 5,
**kwargs
) -> List[MemoryChunk]:
"""检索相关记忆(简化版,使用统一存储"""
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排"""
raw_query = query_text or kwargs.get("query")
if not raw_query:
raise ValueError("query_text 或 query 参数不能为空")
@@ -619,6 +619,8 @@ class MemorySystem:
return []
context = context or {}
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
resolved_user_id = GLOBAL_MEMORY_SCOPE
self.status = MemorySystemStatus.RETRIEVING
@@ -626,48 +628,165 @@ class MemorySystem:
try:
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
effective_limit = limit or self.config.final_recall_limit
# 构建过滤器
filters = {
"user_id": resolved_user_id
# === 阶段一:元数据粗筛(软性过滤) ===
coarse_filters = {
"user_id": GLOBAL_MEMORY_SCOPE, # 必选:确保作用域正确
}
# 可选:添加重要性阈值(过滤低价值记忆)
if hasattr(self.config, 'memory_importance_threshold'):
importance_threshold = self.config.memory_importance_threshold
if importance_threshold > 0:
coarse_filters["importance"] = {"$gte": importance_threshold}
logger.debug(f"[阶段一] 启用重要性过滤: >= {importance_threshold}")
# 可选添加时间范围只搜索最近N天
if hasattr(self.config, 'memory_recency_days'):
recency_days = self.config.memory_recency_days
if recency_days > 0:
cutoff_time = time.time() - (recency_days * 24 * 3600)
coarse_filters["created_at"] = {"$gte": cutoff_time}
logger.debug(f"[阶段一] 启用时间过滤: 最近 {recency_days}")
# 应用查询规划结果
# 应用查询规划(优化查询语句并构建元数据过滤)
optimized_query = raw_query
metadata_filters = {}
if self.query_planner:
try:
query_plan = await self.query_planner.plan_query(raw_query, normalized_context)
if getattr(query_plan, "memory_types", None):
filters["memory_types"] = [mt.value for mt in query_plan.memory_types]
if getattr(query_plan, "subject_includes", None):
filters["keywords"] = query_plan.subject_includes
# 使用LLM优化后的查询语句更精确的语义表达
if getattr(query_plan, "semantic_query", None):
raw_query = query_plan.semantic_query
optimized_query = query_plan.semantic_query
# 构建JSON元数据过滤条件用于阶段一粗筛
# 将查询规划的结果转换为元数据过滤条件
if getattr(query_plan, "memory_types", None):
metadata_filters['memory_types'] = [mt.value for mt in query_plan.memory_types]
if getattr(query_plan, "subject_includes", None):
metadata_filters['subjects'] = query_plan.subject_includes
if getattr(query_plan, "required_keywords", None):
metadata_filters['keywords'] = query_plan.required_keywords
# 时间范围过滤
recency = getattr(query_plan, "recency_preference", "any")
current_time = time.time()
if recency == "recent":
# 最近7天
metadata_filters['created_after'] = current_time - (7 * 24 * 3600)
elif recency == "historical":
# 30天以前
metadata_filters['created_before'] = current_time - (30 * 24 * 3600)
# 添加用户ID到元数据过滤
metadata_filters['user_id'] = GLOBAL_MEMORY_SCOPE
logger.debug(f"[阶段一] 查询优化: '{raw_query}''{optimized_query}'")
logger.debug(f"[阶段一] 元数据过滤条件: {metadata_filters}")
except Exception as plan_exc:
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
logger.warning("查询规划失败,使用原始查询: %s", plan_exc, exc_info=True)
# 即使查询规划失败也保留基本的user_id过滤
metadata_filters = {'user_id': GLOBAL_MEMORY_SCOPE}
# 使用Vector DB存储搜索
# === 阶段二:向量精筛 ===
coarse_limit = self.config.coarse_recall_limit # 粗筛阶段返回更多候选
logger.debug(f"[阶段二] 开始向量搜索: query='{optimized_query[:60]}...', limit={coarse_limit}")
search_results = await self.unified_storage.search_similar_memories(
query_text=raw_query,
limit=effective_limit,
filters=filters
query_text=optimized_query,
limit=coarse_limit,
filters=coarse_filters, # ChromaDB where条件保留兼容
metadata_filters=metadata_filters # JSON元数据索引过滤
)
logger.info(f"[阶段二] 向量搜索完成: 返回 {len(search_results)} 条候选")
# 转换为记忆对象 - search_results 返回 List[Tuple[MemoryChunk, float]]
final_memories = []
for memory, similarity_score in search_results:
# === 阶段三:综合重排 ===
scored_memories = []
current_time = time.time()
for memory, vector_similarity in search_results:
# 1. 向量相似度得分(已归一化到 0-1
vector_score = vector_similarity
# 2. 时效性得分指数衰减30天半衰期
age_seconds = current_time - memory.metadata.created_at
age_days = age_seconds / (24 * 3600)
# 使用 math.exp 而非 np.exp避免依赖numpy
import math
recency_score = math.exp(-age_days / 30)
# 3. 重要性得分(枚举值转换为归一化得分 0-1
# ImportanceLevel: LOW=1, NORMAL=2, HIGH=3, CRITICAL=4
importance_enum = memory.metadata.importance
if hasattr(importance_enum, 'value'):
# 枚举类型转换为0-1范围(value - 1) / 3
importance_score = (importance_enum.value - 1) / 3.0
else:
# 如果已经是数值,直接使用
importance_score = float(importance_enum) if importance_enum else 0.5
# 4. 访问频率得分归一化访问10次以上得满分
access_count = memory.metadata.access_count
frequency_score = min(access_count / 10.0, 1.0)
# 综合得分(加权平均)
final_score = (
self.config.vector_weight * vector_score +
self.config.recency_weight * recency_score +
self.config.context_weight * importance_score +
0.1 * frequency_score # 访问频率权重固定10%
)
scored_memories.append((memory, final_score, {
"vector": vector_score,
"recency": recency_score,
"importance": importance_score,
"frequency": frequency_score,
"final": final_score
}))
# 更新访问记录
memory.update_access()
final_memories.append(memory)
# 按综合得分排序
scored_memories.sort(key=lambda x: x[1], reverse=True)
# 返回 Top-K
final_memories = [mem for mem, score, details in scored_memories[:effective_limit]]
retrieval_time = time.time() - start_time
# 详细日志
if scored_memories:
logger.info(f"[阶段三] 综合重排完成: Top 3 得分详情")
for i, (mem, score, details) in enumerate(scored_memories[:3], 1):
try:
summary = mem.content[:60] if hasattr(mem, 'content') and mem.content else ""
except:
summary = ""
logger.info(
f" #{i} | final={details['final']:.3f} "
f"(vec={details['vector']:.3f}, rec={details['recency']:.3f}, "
f"imp={details['importance']:.3f}, freq={details['frequency']:.3f}) "
f"| {summary}"
)
logger.info(
"简化记忆检索完成"
"三阶段记忆检索完成"
f" | user={resolved_user_id}"
f" | count={len(final_memories)}"
f" | 粗筛={len(search_results)}"
f" | 精筛={len(scored_memories)}"
f" | 返回={len(final_memories)}"
f" | duration={retrieval_time:.3f}s"
f" | query='{raw_query}'"
f" | query='{optimized_query[:60]}...'"
)
self.last_retrieval_time = time.time()
@@ -717,8 +836,8 @@ class MemorySystem:
except Exception:
context = dict(raw_context or {})
# 基础字段(统一使用全局作用域
context["user_id"] = GLOBAL_MEMORY_SCOPE
# 基础字段:强制使用传入的 user_id 参数(已统一为 GLOBAL_MEMORY_SCOPE
context["user_id"] = user_id or GLOBAL_MEMORY_SCOPE
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
context["message_type"] = context.get("message_type") or "normal"
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"

View File

@@ -26,6 +26,7 @@ from src.common.vector_db import vector_db_service
from src.chat.utils.utils import get_embedding
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry
logger = get_logger(__name__)
@@ -38,7 +39,7 @@ class VectorStorageConfig:
metadata_collection: str = "memory_metadata_v2"
# 检索配置
similarity_threshold: float = 0.8
similarity_threshold: float = 0.5 # 降低阈值以提高召回率0.5-0.6 是合理范围)
search_limit: int = 20
batch_size: int = 100
@@ -50,6 +51,26 @@ class VectorStorageConfig:
# 遗忘配置
enable_forgetting: bool = True
retention_hours: int = 24 * 30 # 30天
@classmethod
def from_global_config(cls):
"""从全局配置创建实例"""
from src.config.config import global_config
memory_cfg = global_config.memory
return cls(
memory_collection=getattr(memory_cfg, 'vector_db_memory_collection', 'unified_memory_v2'),
metadata_collection=getattr(memory_cfg, 'vector_db_metadata_collection', 'memory_metadata_v2'),
similarity_threshold=getattr(memory_cfg, 'vector_db_similarity_threshold', 0.5),
search_limit=getattr(memory_cfg, 'vector_db_search_limit', 20),
batch_size=getattr(memory_cfg, 'vector_db_batch_size', 100),
enable_caching=getattr(memory_cfg, 'vector_db_enable_caching', True),
cache_size_limit=getattr(memory_cfg, 'vector_db_cache_size_limit', 1000),
auto_cleanup_interval=getattr(memory_cfg, 'vector_db_auto_cleanup_interval', 3600),
enable_forgetting=getattr(memory_cfg, 'enable_memory_forgetting', True),
retention_hours=getattr(memory_cfg, 'vector_db_retention_hours', 720),
)
class VectorMemoryStorage:
@@ -71,7 +92,16 @@ class VectorMemoryStorage:
"""基于Vector DB的记忆存储系统"""
def __init__(self, config: Optional[VectorStorageConfig] = None):
self.config = config or VectorStorageConfig()
# 默认从全局配置读取如果没有传入config
if config is None:
try:
self.config = VectorStorageConfig.from_global_config()
logger.info("✅ Vector存储配置已从全局配置加载")
except Exception as e:
logger.warning(f"从全局配置加载失败,使用默认配置: {e}")
self.config = VectorStorageConfig()
else:
self.config = config
# 从配置中获取批处理大小和集合名称
self.batch_size = self.config.batch_size
@@ -83,6 +113,9 @@ class VectorMemoryStorage:
self.cache_timestamps: Dict[str, float] = {}
self._cache = self.memory_cache # 别名,兼容旧代码
# 元数据索引管理器JSON文件索引
self.metadata_index = MemoryMetadataIndex()
# 遗忘引擎
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
if self.config.enable_forgetting:
@@ -354,14 +387,45 @@ class VectorMemoryStorage:
success = True
if success:
# 更新缓存
# 更新缓存和元数据索引
metadata_entries = []
for item in batch:
memory_id = item["id"]
# 从原始 memories 列表中找到对应的 MemoryChunk
memory = next((m for m in memories if (getattr(m.metadata, 'memory_id', None) or getattr(m, 'memory_id', None)) == memory_id), None)
if memory:
# 更新缓存
self._cache[memory_id] = memory
success_count += 1
# 创建元数据索引条目
try:
index_entry = MemoryMetadataIndexEntry(
memory_id=memory_id,
user_id=memory.metadata.user_id or "unknown",
memory_type=memory.memory_type.value,
subjects=memory.subjects,
objects=[str(memory.content.object)] if memory.content.object else [],
keywords=memory.keywords,
tags=memory.tags,
importance=memory.metadata.importance.value,
confidence=memory.metadata.confidence.value,
created_at=memory.metadata.created_at,
access_count=memory.metadata.access_count,
chat_id=memory.metadata.chat_id,
content_preview=str(memory.content)[:100] if memory.content else None
)
metadata_entries.append(index_entry)
except Exception as e:
logger.warning(f"创建元数据索引条目失败 (memory_id={memory_id}): {e}")
# 批量更新元数据索引
if metadata_entries:
try:
self.metadata_index.batch_add_or_update(metadata_entries)
logger.debug(f"更新元数据索引: {len(metadata_entries)}")
except Exception as e:
logger.error(f"批量更新元数据索引失败: {e}")
else:
logger.warning(f"批次存储失败,跳过 {len(batch)} 条记忆")
@@ -372,6 +436,14 @@ class VectorMemoryStorage:
duration = (datetime.now() - start_time).total_seconds()
logger.info(f"成功存储 {success_count}/{len(memories)} 条记忆,耗时 {duration:.2f}")
# 保存元数据索引到磁盘
if success_count > 0:
try:
self.metadata_index.save()
logger.debug("元数据索引已保存到磁盘")
except Exception as e:
logger.error(f"保存元数据索引失败: {e}")
return success_count
except Exception as e:
@@ -388,13 +460,57 @@ class VectorMemoryStorage:
query_text: str,
limit: int = 10,
similarity_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None
filters: Optional[Dict[str, Any]] = None,
# 新增元数据过滤参数用于JSON索引粗筛
metadata_filters: Optional[Dict[str, Any]] = None
) -> List[Tuple[MemoryChunk, float]]:
"""搜索相似记忆"""
"""
搜索相似记忆(混合索引模式)
Args:
query_text: 查询文本
limit: 返回数量限制
similarity_threshold: 相似度阈值
filters: ChromaDB where条件保留用于兼容
metadata_filters: JSON元数据索引过滤条件支持:
- memory_types: List[str]
- subjects: List[str]
- keywords: List[str]
- tags: List[str]
- importance_min: int
- importance_max: int
- created_after: float
- created_before: float
- user_id: str
"""
if not query_text.strip():
return []
try:
# === 阶段一JSON元数据粗筛可选 ===
candidate_ids: Optional[List[str]] = None
if metadata_filters:
logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}")
candidate_ids = self.metadata_index.search(
memory_types=metadata_filters.get('memory_types'),
subjects=metadata_filters.get('subjects'),
keywords=metadata_filters.get('keywords'),
tags=metadata_filters.get('tags'),
importance_min=metadata_filters.get('importance_min'),
importance_max=metadata_filters.get('importance_max'),
created_after=metadata_filters.get('created_after'),
created_before=metadata_filters.get('created_before'),
user_id=metadata_filters.get('user_id'),
limit=self.config.search_limit * 2 # 粗筛返回更多候选
)
logger.info(f"[JSON元数据粗筛] 完成,筛选出 {len(candidate_ids)} 个候选ID")
# 如果粗筛后没有结果,直接返回
if not candidate_ids:
logger.warning("JSON元数据粗筛后无候选返回空结果")
return []
# === 阶段二:向量精筛 ===
# 生成查询向量
query_embedding = await get_embedding(query_text)
if not query_embedding:
@@ -405,7 +521,14 @@ class VectorMemoryStorage:
# 构建where条件
where_conditions = filters or {}
# 如果有候选ID列表添加到where条件
if candidate_ids:
# ChromaDB的where条件需要使用$in操作符
where_conditions["memory_id"] = {"$in": candidate_ids}
logger.debug(f"[向量精筛] 限制在 {len(candidate_ids)} 个候选ID内搜索")
# 查询Vector DB
logger.debug(f"[向量精筛] 开始limit={min(limit, self.config.search_limit)}")
results = vector_db_service.query(
collection_name=self.config.memory_collection,
query_embeddings=[query_embedding],

View File

@@ -311,11 +311,12 @@ class MemoryConfig(ValidatedConfigBase):
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
# Vector DB配置
vector_db_memory_collection: str = Field(default="unified_memory_v2", description="Vector DB集合名称")
vector_db_similarity_threshold: float = Field(default=0.8, description="Vector DB相似度阈值")
vector_db_memory_collection: str = Field(default="unified_memory_v2", description="Vector DB记忆集合名称")
vector_db_metadata_collection: str = Field(default="memory_metadata_v2", description="Vector DB元数据集合名称")
vector_db_similarity_threshold: float = Field(default=0.5, description="Vector DB相似度阈值推荐0.5-0.6,过高会导致检索不到结果)")
vector_db_search_limit: int = Field(default=20, description="Vector DB搜索限制")
vector_db_batch_size: int = Field(default=100, description="批处理大小")
vector_db_enable_caching: bool = Field(default=True, description="启用缓存")
vector_db_enable_caching: bool = Field(default=True, description="启用内存缓存")
vector_db_cache_size_limit: int = Field(default=1000, description="缓存大小限制")
vector_db_auto_cleanup_interval: int = Field(default=3600, description="自动清理间隔(秒)")
vector_db_retention_hours: int = Field(default=720, description="记忆保留时间小时默认30天")

View File

@@ -1,5 +1,5 @@
[inner]
version = "7.1.3"
version = "7.1.4"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -303,11 +303,46 @@ max_frequency_bonus = 10.0 # 最大激活频率奖励天数
# 休眠机制
dormant_threshold_days = 90 # 休眠状态判定天数(超过此天数未访问的记忆进入休眠状态)
# 统一存储配置 (新增)
unified_storage_path = "data/unified_memory" # 统一存储数据路径
unified_storage_cache_limit = 10000 # 内存缓存大小限制
unified_storage_auto_save_interval = 50 # 自动保存间隔(记忆条数)
unified_storage_enable_compression = true # 是否启用数据压缩
# 统一存储配置 (已弃用 - 请使用Vector DB配置)
# DEPRECATED: unified_storage_path = "data/unified_memory"
# DEPRECATED: unified_storage_cache_limit = 10000
# DEPRECATED: unified_storage_auto_save_interval = 50
# DEPRECATED: unified_storage_enable_compression = true
# Vector DB存储配置 (新增 - 替代JSON存储)
enable_vector_memory_storage = true # 启用Vector DB存储
enable_llm_instant_memory = true # 启用基于LLM的瞬时记忆
enable_vector_instant_memory = true # 启用基于向量的瞬时记忆
# Vector DB配置
vector_db_memory_collection = "unified_memory_v2" # Vector DB主记忆集合名称
vector_db_metadata_collection = "memory_metadata_v2" # Vector DB元数据集合名称
vector_db_similarity_threshold = 0.5 # Vector DB相似度阈值 (推荐范围: 0.5-0.6, 过高会导致检索不到结果)
vector_db_search_limit = 20 # Vector DB单次搜索返回的最大结果数
vector_db_batch_size = 100 # 批处理大小 (批量存储记忆时每批处理的记忆条数)
vector_db_enable_caching = true # 启用内存缓存
vector_db_cache_size_limit = 1000 # 缓存大小限制 (内存缓存最多保存的记忆条数)
vector_db_auto_cleanup_interval = 3600 # 自动清理间隔(秒)
vector_db_retention_hours = 720 # 记忆保留时间小时默认30天
# 多阶段召回配置(可选)
# 取消注释以启用更严格的粗筛,适用于大规模记忆库(>10万条
# memory_importance_threshold = 0.3 # 重要性阈值过滤低价值记忆范围0.0-1.0
# memory_recency_days = 30 # 时间范围只搜索最近N天的记忆0表示不限制
# Vector DB配置 (ChromaDB)
[vector_db]
type = "chromadb" # Vector DB类型
path = "data/chroma_db" # Vector DB数据路径
[vector_db.settings]
anonymized_telemetry = false # 禁用匿名遥测
allow_reset = true # 允许重置
[vector_db.collections]
unified_memory_v2 = { description = "统一记忆存储V2", hnsw_space = "cosine", version = "2.0" }
memory_metadata_v2 = { description = "记忆元数据索引", hnsw_space = "cosine", version = "2.0" }
semantic_cache = { description = "语义缓存", hnsw_space = "cosine" }
[voice]
enable_asr = true # 是否启用语音识别启用后MoFox-Bot可以识别语音消息启用该功能需要配置语音识别模型[model.voice]