Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
139
scripts/rebuild_metadata_index.py
Normal file
139
scripts/rebuild_metadata_index.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
从现有ChromaDB数据重建JSON元数据索引
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
async def rebuild_metadata_index():
|
||||
"""从ChromaDB重建元数据索引"""
|
||||
print("="*80)
|
||||
print("重建JSON元数据索引")
|
||||
print("="*80)
|
||||
|
||||
# 初始化记忆系统
|
||||
print("\n🔧 初始化记忆系统...")
|
||||
ms = MemorySystem()
|
||||
await ms.initialize()
|
||||
print("✅ 记忆系统已初始化")
|
||||
|
||||
if not hasattr(ms.unified_storage, 'metadata_index'):
|
||||
print("❌ 元数据索引管理器未初始化")
|
||||
return
|
||||
|
||||
# 获取所有记忆
|
||||
print("\n📥 从ChromaDB获取所有记忆...")
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
try:
|
||||
# 获取集合中的所有记忆ID
|
||||
collection_name = ms.unified_storage.config.memory_collection
|
||||
result = vector_db_service.get(
|
||||
collection_name=collection_name,
|
||||
include=["documents", "metadatas", "embeddings"]
|
||||
)
|
||||
|
||||
if not result or not result.get("ids"):
|
||||
print("❌ ChromaDB中没有找到记忆数据")
|
||||
return
|
||||
|
||||
ids = result["ids"]
|
||||
metadatas = result.get("metadatas", [])
|
||||
|
||||
print(f"✅ 找到 {len(ids)} 条记忆")
|
||||
|
||||
# 重建元数据索引
|
||||
print("\n🔨 开始重建元数据索引...")
|
||||
entries = []
|
||||
success_count = 0
|
||||
|
||||
for i, (memory_id, metadata) in enumerate(zip(ids, metadatas), 1):
|
||||
try:
|
||||
# 从ChromaDB元数据重建索引条目
|
||||
import orjson
|
||||
|
||||
entry = MemoryMetadataIndexEntry(
|
||||
memory_id=memory_id,
|
||||
user_id=metadata.get("user_id", "unknown"),
|
||||
memory_type=metadata.get("memory_type", "general"),
|
||||
subjects=orjson.loads(metadata.get("subjects", "[]")),
|
||||
objects=[metadata.get("object")] if metadata.get("object") else [],
|
||||
keywords=orjson.loads(metadata.get("keywords", "[]")),
|
||||
tags=orjson.loads(metadata.get("tags", "[]")),
|
||||
importance=2, # 默认NORMAL
|
||||
confidence=2, # 默认MEDIUM
|
||||
created_at=metadata.get("created_at", 0.0),
|
||||
access_count=metadata.get("access_count", 0),
|
||||
chat_id=metadata.get("chat_id"),
|
||||
content_preview=None
|
||||
)
|
||||
|
||||
# 尝试解析importance和confidence的枚举名称
|
||||
if "importance" in metadata:
|
||||
imp_str = metadata["importance"]
|
||||
if imp_str == "LOW":
|
||||
entry.importance = 1
|
||||
elif imp_str == "NORMAL":
|
||||
entry.importance = 2
|
||||
elif imp_str == "HIGH":
|
||||
entry.importance = 3
|
||||
elif imp_str == "CRITICAL":
|
||||
entry.importance = 4
|
||||
|
||||
if "confidence" in metadata:
|
||||
conf_str = metadata["confidence"]
|
||||
if conf_str == "LOW":
|
||||
entry.confidence = 1
|
||||
elif conf_str == "MEDIUM":
|
||||
entry.confidence = 2
|
||||
elif conf_str == "HIGH":
|
||||
entry.confidence = 3
|
||||
elif conf_str == "VERIFIED":
|
||||
entry.confidence = 4
|
||||
|
||||
entries.append(entry)
|
||||
success_count += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理记忆 {memory_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据")
|
||||
|
||||
# 批量更新索引
|
||||
print("\n💾 保存元数据索引...")
|
||||
ms.unified_storage.metadata_index.batch_add_or_update(entries)
|
||||
ms.unified_storage.metadata_index.save()
|
||||
|
||||
# 显示统计信息
|
||||
stats = ms.unified_storage.metadata_index.get_stats()
|
||||
print(f"\n📊 重建后的索引统计:")
|
||||
print(f" - 总记忆数: {stats['total_memories']}")
|
||||
print(f" - 主语数量: {stats['subjects_count']}")
|
||||
print(f" - 关键词数量: {stats['keywords_count']}")
|
||||
print(f" - 标签数量: {stats['tags_count']}")
|
||||
print(f" - 类型分布:")
|
||||
for mtype, count in stats['types'].items():
|
||||
print(f" - {mtype}: {count}")
|
||||
|
||||
print("\n✅ 元数据索引重建完成!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重建索引失败: {e}", exc_info=True)
|
||||
print(f"❌ 重建索引失败: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(rebuild_metadata_index())
|
||||
23
scripts/run_multi_stage_smoke.py
Normal file
23
scripts/run_multi_stage_smoke.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
|
||||
async def main():
|
||||
ms = MemorySystem()
|
||||
await ms.initialize()
|
||||
results = await ms.retrieve_relevant_memories(query_text="测试查询:杰瑞喵喜欢什么?", limit=3)
|
||||
print(f"检索到 {len(results)} 条记忆(如果 >0 则表明运行成功)")
|
||||
for i, m in enumerate(results, 1):
|
||||
print(f"{i}. id={m.metadata.memory_id} source={getattr(m.metadata, 'source', None)}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -122,6 +122,13 @@ class ChatterManager:
|
||||
actions_count = result.get("actions_count", 0)
|
||||
logger.debug(f"流 {stream_id} 处理完成: 成功={success}, 动作数={actions_count}")
|
||||
|
||||
# 在处理完成后,清除该流的未读消息
|
||||
try:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
await message_manager.clear_stream_unread_messages(stream_id)
|
||||
except Exception as clear_e:
|
||||
logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
self.stats["failed_executions"] += 1
|
||||
|
||||
@@ -22,12 +22,11 @@ from .memory_forgetting_engine import (
|
||||
get_memory_forgetting_engine
|
||||
)
|
||||
|
||||
# 统一存储系统
|
||||
from .unified_memory_storage import (
|
||||
UnifiedMemoryStorage,
|
||||
UnifiedStorageConfig,
|
||||
get_unified_memory_storage,
|
||||
initialize_unified_memory_storage
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import (
|
||||
VectorMemoryStorage,
|
||||
VectorStorageConfig,
|
||||
get_vector_memory_storage
|
||||
)
|
||||
|
||||
# 记忆核心系统
|
||||
@@ -79,11 +78,10 @@ __all__ = [
|
||||
"ForgettingConfig",
|
||||
"get_memory_forgetting_engine",
|
||||
|
||||
# 统一存储
|
||||
"UnifiedMemoryStorage",
|
||||
"UnifiedStorageConfig",
|
||||
"get_unified_memory_storage",
|
||||
"initialize_unified_memory_storage",
|
||||
# Vector DB存储
|
||||
"VectorMemoryStorage",
|
||||
"VectorStorageConfig",
|
||||
"get_vector_memory_storage",
|
||||
|
||||
# 记忆系统
|
||||
"MemorySystem",
|
||||
|
||||
@@ -96,19 +96,8 @@ class MemoryBuilder:
|
||||
try:
|
||||
logger.debug(f"开始从对话构建记忆,文本长度: {len(conversation_text)}")
|
||||
|
||||
# 预处理文本
|
||||
processed_text = self._preprocess_text(conversation_text)
|
||||
|
||||
# 确定提取策略
|
||||
strategy = self._determine_extraction_strategy(processed_text, context)
|
||||
|
||||
# 根据策略提取记忆
|
||||
if strategy == ExtractionStrategy.LLM_BASED:
|
||||
memories = await self._extract_with_llm(processed_text, context, user_id, timestamp)
|
||||
elif strategy == ExtractionStrategy.RULE_BASED:
|
||||
memories = self._extract_with_rules(processed_text, context, user_id, timestamp)
|
||||
else: # HYBRID
|
||||
memories = await self._extract_with_hybrid(processed_text, context, user_id, timestamp)
|
||||
# 使用LLM提取记忆
|
||||
memories = await self._extract_with_llm(conversation_text, context, user_id, timestamp)
|
||||
|
||||
# 后处理和验证
|
||||
validated_memories = self._validate_and_enhance_memories(memories, context)
|
||||
@@ -129,41 +118,6 @@ class MemoryBuilder:
|
||||
self.extraction_stats["failed_extractions"] += 1
|
||||
raise
|
||||
|
||||
def _preprocess_text(self, text: str) -> str:
|
||||
"""预处理文本"""
|
||||
# 移除多余的空白字符
|
||||
text = re.sub(r'\s+', ' ', text.strip())
|
||||
|
||||
# 移除特殊字符,但保留基本标点
|
||||
text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?、;:""''()【】]', '', text)
|
||||
|
||||
# 截断过长的文本
|
||||
if len(text) > 2000:
|
||||
text = text[:2000] + "..."
|
||||
|
||||
return text
|
||||
|
||||
def _determine_extraction_strategy(self, text: str, context: Dict[str, Any]) -> ExtractionStrategy:
|
||||
"""确定提取策略"""
|
||||
text_length = len(text)
|
||||
has_structured_data = any(key in context for key in ["structured_data", "entities", "keywords"])
|
||||
message_type = context.get("message_type", "normal")
|
||||
|
||||
# 短文本使用规则提取
|
||||
if text_length < 50:
|
||||
return ExtractionStrategy.RULE_BASED
|
||||
|
||||
# 包含结构化数据使用混合策略
|
||||
if has_structured_data:
|
||||
return ExtractionStrategy.HYBRID
|
||||
|
||||
# 系统消息或命令使用规则提取
|
||||
if message_type in ["command", "system"]:
|
||||
return ExtractionStrategy.RULE_BASED
|
||||
|
||||
# 默认使用LLM提取
|
||||
return ExtractionStrategy.LLM_BASED
|
||||
|
||||
async def _extract_with_llm(
|
||||
self,
|
||||
text: str,
|
||||
@@ -190,79 +144,10 @@ class MemoryBuilder:
|
||||
logger.error(f"LLM提取失败: {e}")
|
||||
raise MemoryExtractionError(str(e)) from e
|
||||
|
||||
def _extract_with_rules(
|
||||
self,
|
||||
text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""使用规则提取记忆"""
|
||||
memories = []
|
||||
|
||||
subjects = self._resolve_conversation_participants(context, user_id)
|
||||
|
||||
# 规则1: 检测个人信息
|
||||
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(personal_info)
|
||||
|
||||
# 规则2: 检测偏好信息
|
||||
preferences = self._extract_preferences(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(preferences)
|
||||
|
||||
# 规则3: 检测事件信息
|
||||
events = self._extract_events(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(events)
|
||||
|
||||
return memories
|
||||
|
||||
async def _extract_with_hybrid(
|
||||
self,
|
||||
text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""混合策略提取记忆"""
|
||||
all_memories = []
|
||||
|
||||
# 首先使用规则提取
|
||||
rule_memories = self._extract_with_rules(text, context, user_id, timestamp)
|
||||
all_memories.extend(rule_memories)
|
||||
|
||||
# 然后使用LLM提取
|
||||
llm_memories = await self._extract_with_llm(text, context, user_id, timestamp)
|
||||
|
||||
# 合并和去重
|
||||
final_memories = self._merge_hybrid_results(all_memories, llm_memories)
|
||||
|
||||
return final_memories
|
||||
|
||||
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
|
||||
"""构建LLM提取提示"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
chat_id = context.get("chat_id", "unknown")
|
||||
message_type = context.get("message_type", "normal")
|
||||
target_user_id = context.get("user_id", "用户")
|
||||
target_user_id = str(target_user_id)
|
||||
|
||||
target_user_name = (
|
||||
context.get("user_display_name")
|
||||
or context.get("user_name")
|
||||
or context.get("nickname")
|
||||
or context.get("sender_name")
|
||||
)
|
||||
if isinstance(target_user_name, str):
|
||||
target_user_name = target_user_name.strip()
|
||||
else:
|
||||
target_user_name = ""
|
||||
|
||||
if not target_user_name or self._looks_like_system_identifier(target_user_name):
|
||||
target_user_name = "该用户"
|
||||
|
||||
target_user_id_display = target_user_id
|
||||
if self._looks_like_system_identifier(target_user_id_display):
|
||||
target_user_id_display = "(系统ID,勿写入记忆)"
|
||||
|
||||
bot_name = context.get("bot_name")
|
||||
bot_identity = context.get("bot_identity")
|
||||
@@ -966,145 +851,6 @@ class MemoryBuilder:
|
||||
return f"{subject_phrase}{predicate}".strip()
|
||||
return subject_phrase
|
||||
|
||||
def _extract_personal_info(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取个人信息"""
|
||||
memories = []
|
||||
|
||||
# 常见个人信息模式
|
||||
patterns = {
|
||||
r"我叫(\w+)": ("is_named", {"name": "$1"}),
|
||||
r"我今年(\d+)岁": ("is_age", {"age": "$1"}),
|
||||
r"我是(\w+)": ("is_profession", {"profession": "$1"}),
|
||||
r"我住在(\w+)": ("lives_in", {"location": "$1"}),
|
||||
r"我的电话是(\d+)": ("has_phone", {"phone": "$1"}),
|
||||
r"我的邮箱是(\w+@\w+\.\w+)": ("has_email", {"email": "$1"}),
|
||||
}
|
||||
|
||||
for pattern, (predicate, obj_template) in patterns.items():
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
obj = obj_template
|
||||
for i, group in enumerate(match.groups(), 1):
|
||||
obj = {k: v.replace(f"${i}", group) for k, v in obj.items()}
|
||||
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subjects,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
memory_type=MemoryType.PERSONAL_FACT,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.HIGH,
|
||||
confidence=ConfidenceLevel.HIGH,
|
||||
display=self._compose_display_text(subjects, predicate, obj)
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _extract_preferences(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取偏好信息"""
|
||||
memories = []
|
||||
|
||||
# 偏好模式
|
||||
preference_patterns = [
|
||||
(r"我喜欢(.+)", "likes"),
|
||||
(r"我不喜欢(.+)", "dislikes"),
|
||||
(r"我爱吃(.+)", "likes_food"),
|
||||
(r"我讨厌(.+)", "hates"),
|
||||
(r"我最喜欢的(.+)", "favorite_is"),
|
||||
]
|
||||
|
||||
for pattern, predicate in preference_patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subjects,
|
||||
predicate=predicate,
|
||||
obj=match.group(1),
|
||||
memory_type=MemoryType.PREFERENCE,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.NORMAL,
|
||||
confidence=ConfidenceLevel.MEDIUM,
|
||||
display=self._compose_display_text(subjects, predicate, match.group(1))
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _extract_events(
|
||||
self,
|
||||
text: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取事件信息"""
|
||||
memories = []
|
||||
|
||||
# 事件关键词
|
||||
event_keywords = ["明天", "今天", "昨天", "上周", "下周", "约会", "会议", "活动", "旅行", "生日"]
|
||||
|
||||
if any(keyword in text for keyword in event_keywords):
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subjects,
|
||||
predicate="mentioned_event",
|
||||
obj={"event_text": text, "timestamp": timestamp},
|
||||
memory_type=MemoryType.EVENT,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.NORMAL,
|
||||
confidence=ConfidenceLevel.MEDIUM,
|
||||
display=self._compose_display_text(subjects, "mentioned_event", text)
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def _merge_hybrid_results(
|
||||
self,
|
||||
rule_memories: List[MemoryChunk],
|
||||
llm_memories: List[MemoryChunk]
|
||||
) -> List[MemoryChunk]:
|
||||
"""合并混合策略结果"""
|
||||
all_memories = rule_memories.copy()
|
||||
|
||||
# 添加LLM记忆,避免重复
|
||||
for llm_memory in llm_memories:
|
||||
is_duplicate = False
|
||||
for rule_memory in rule_memories:
|
||||
if llm_memory.is_similar_to(rule_memory, threshold=0.7):
|
||||
is_duplicate = True
|
||||
# 合并置信度
|
||||
rule_memory.metadata.confidence = ConfidenceLevel(
|
||||
max(rule_memory.metadata.confidence.value, llm_memory.metadata.confidence.value)
|
||||
)
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
all_memories.append(llm_memory)
|
||||
|
||||
return all_memories
|
||||
|
||||
def _validate_and_enhance_memories(
|
||||
self,
|
||||
memories: List[MemoryChunk],
|
||||
|
||||
@@ -127,6 +127,8 @@ class MemoryMetadata:
|
||||
|
||||
# 来源信息
|
||||
source_context: Optional[str] = None # 来源上下文片段
|
||||
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
|
||||
source: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -150,6 +152,19 @@ class MemoryMetadata:
|
||||
if self.last_forgetting_check == 0:
|
||||
self.last_forgetting_check = current_time
|
||||
|
||||
# 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步
|
||||
if not getattr(self, 'source', None) and getattr(self, 'source_context', None):
|
||||
try:
|
||||
self.source = str(self.source_context)
|
||||
except Exception:
|
||||
self.source = None
|
||||
# 如果有 source 字段但 source_context 为空,也同步回去
|
||||
if not getattr(self, 'source_context', None) and getattr(self, 'source', None):
|
||||
try:
|
||||
self.source_context = str(self.source)
|
||||
except Exception:
|
||||
self.source_context = None
|
||||
|
||||
def update_access(self):
|
||||
"""更新访问信息"""
|
||||
current_time = time.time()
|
||||
|
||||
316
src/chat/memory_system/memory_metadata_index.py
Normal file
316
src/chat/memory_system/memory_metadata_index.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆元数据索引管理器
|
||||
使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryType, ImportanceLevel, ConfidenceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryMetadataIndexEntry:
|
||||
"""元数据索引条目(轻量级,只用于快速过滤)"""
|
||||
memory_id: str
|
||||
user_id: str
|
||||
|
||||
# 分类信息
|
||||
memory_type: str # MemoryType.value
|
||||
subjects: List[str] # 主语列表
|
||||
objects: List[str] # 宾语列表
|
||||
keywords: List[str] # 关键词列表
|
||||
tags: List[str] # 标签列表
|
||||
|
||||
# 数值字段(用于范围过滤)
|
||||
importance: int # ImportanceLevel.value (1-4)
|
||||
confidence: int # ConfidenceLevel.value (1-4)
|
||||
created_at: float # 创建时间戳
|
||||
access_count: int # 访问次数
|
||||
|
||||
# 可选字段
|
||||
chat_id: Optional[str] = None
|
||||
content_preview: Optional[str] = None # 内容预览(前100字符)
|
||||
|
||||
|
||||
class MemoryMetadataIndex:
|
||||
"""记忆元数据索引管理器"""
|
||||
|
||||
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
||||
self.index_file = Path(index_file)
|
||||
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
|
||||
# 倒排索引(用于快速查找)
|
||||
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
|
||||
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# 加载已有索引
|
||||
self._load_index()
|
||||
|
||||
def _load_index(self):
|
||||
"""从文件加载索引"""
|
||||
if not self.index_file.exists():
|
||||
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.index_file, 'rb') as f:
|
||||
data = orjson.loads(f.read())
|
||||
|
||||
# 重建内存索引
|
||||
for entry_data in data.get('entries', []):
|
||||
entry = MemoryMetadataIndexEntry(**entry_data)
|
||||
self.index[entry.memory_id] = entry
|
||||
self._update_inverted_indices(entry)
|
||||
|
||||
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
def _save_index(self):
|
||||
"""保存索引到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.index_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 序列化所有条目
|
||||
entries = [asdict(entry) for entry in self.index.values()]
|
||||
data = {
|
||||
'version': '1.0',
|
||||
'count': len(entries),
|
||||
'last_updated': datetime.now().isoformat(),
|
||||
'entries': entries
|
||||
}
|
||||
|
||||
# 写入文件(使用临时文件 + 原子重命名)
|
||||
temp_file = self.index_file.with_suffix('.tmp')
|
||||
with open(temp_file, 'wb') as f:
|
||||
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
|
||||
|
||||
temp_file.replace(self.index_file)
|
||||
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
|
||||
"""更新倒排索引"""
|
||||
memory_id = entry.memory_id
|
||||
|
||||
# 类型索引
|
||||
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
|
||||
|
||||
# 主语索引
|
||||
for subject in entry.subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm:
|
||||
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
|
||||
|
||||
# 关键词索引
|
||||
for keyword in entry.keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
if keyword_norm:
|
||||
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
|
||||
|
||||
# 标签索引
|
||||
for tag in entry.tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
if tag_norm:
|
||||
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
|
||||
|
||||
def add_or_update(self, entry: MemoryMetadataIndexEntry):
|
||||
"""添加或更新索引条目"""
|
||||
with self.lock:
|
||||
# 如果已存在,先从倒排索引中移除旧记录
|
||||
if entry.memory_id in self.index:
|
||||
self._remove_from_inverted_indices(entry.memory_id)
|
||||
|
||||
# 添加新记录
|
||||
self.index[entry.memory_id] = entry
|
||||
self._update_inverted_indices(entry)
|
||||
|
||||
def _remove_from_inverted_indices(self, memory_id: str):
|
||||
"""从倒排索引中移除记录"""
|
||||
if memory_id not in self.index:
|
||||
return
|
||||
|
||||
entry = self.index[memory_id]
|
||||
|
||||
# 从类型索引移除
|
||||
if entry.memory_type in self.type_index:
|
||||
self.type_index[entry.memory_type].discard(memory_id)
|
||||
|
||||
# 从主语索引移除
|
||||
for subject in entry.subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm in self.subject_index:
|
||||
self.subject_index[subject_norm].discard(memory_id)
|
||||
|
||||
# 从关键词索引移除
|
||||
for keyword in entry.keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
if keyword_norm in self.keyword_index:
|
||||
self.keyword_index[keyword_norm].discard(memory_id)
|
||||
|
||||
# 从标签索引移除
|
||||
for tag in entry.tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
if tag_norm in self.tag_index:
|
||||
self.tag_index[tag_norm].discard(memory_id)
|
||||
|
||||
def remove(self, memory_id: str):
|
||||
"""移除索引条目"""
|
||||
with self.lock:
|
||||
if memory_id in self.index:
|
||||
self._remove_from_inverted_indices(memory_id)
|
||||
del self.index[memory_id]
|
||||
|
||||
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
|
||||
"""批量添加或更新"""
|
||||
with self.lock:
|
||||
for entry in entries:
|
||||
self.add_or_update(entry)
|
||||
|
||||
def save(self):
|
||||
"""保存索引到磁盘"""
|
||||
with self.lock:
|
||||
self._save_index()
|
||||
|
||||
def search(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
keywords: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
importance_min: Optional[int] = None,
|
||||
importance_max: Optional[int] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
搜索符合条件的记忆ID列表(支持模糊匹配)
|
||||
|
||||
Returns:
|
||||
List[str]: 符合条件的 memory_id 列表
|
||||
"""
|
||||
with self.lock:
|
||||
# 初始候选集(所有记忆)
|
||||
candidate_ids: Optional[Set[str]] = None
|
||||
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
candidate_ids = {
|
||||
mid for mid, entry in self.index.items()
|
||||
if entry.user_id == user_id
|
||||
}
|
||||
else:
|
||||
candidate_ids = set(self.index.keys())
|
||||
|
||||
# 类型过滤(OR关系)
|
||||
if memory_types:
|
||||
type_ids = set()
|
||||
for mtype in memory_types:
|
||||
type_ids.update(self.type_index.get(mtype, set()))
|
||||
candidate_ids &= type_ids
|
||||
|
||||
# 主语过滤(OR关系,支持模糊匹配)
|
||||
if subjects:
|
||||
subject_ids = set()
|
||||
for subject in subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
# 精确匹配
|
||||
if subject_norm in self.subject_index:
|
||||
subject_ids.update(self.subject_index[subject_norm])
|
||||
# 模糊匹配(包含)
|
||||
for indexed_subject, ids in self.subject_index.items():
|
||||
if subject_norm in indexed_subject or indexed_subject in subject_norm:
|
||||
subject_ids.update(ids)
|
||||
candidate_ids &= subject_ids
|
||||
|
||||
# 关键词过滤(OR关系,支持模糊匹配)
|
||||
if keywords:
|
||||
keyword_ids = set()
|
||||
for keyword in keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
# 精确匹配
|
||||
if keyword_norm in self.keyword_index:
|
||||
keyword_ids.update(self.keyword_index[keyword_norm])
|
||||
# 模糊匹配(包含)
|
||||
for indexed_keyword, ids in self.keyword_index.items():
|
||||
if keyword_norm in indexed_keyword or indexed_keyword in keyword_norm:
|
||||
keyword_ids.update(ids)
|
||||
candidate_ids &= keyword_ids
|
||||
|
||||
# 标签过滤(OR关系)
|
||||
if tags:
|
||||
tag_ids = set()
|
||||
for tag in tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
tag_ids.update(self.tag_index.get(tag_norm, set()))
|
||||
candidate_ids &= tag_ids
|
||||
|
||||
# 重要性过滤
|
||||
if importance_min is not None or importance_max is not None:
|
||||
importance_ids = {
|
||||
mid for mid in candidate_ids
|
||||
if (importance_min is None or self.index[mid].importance >= importance_min)
|
||||
and (importance_max is None or self.index[mid].importance <= importance_max)
|
||||
}
|
||||
candidate_ids &= importance_ids
|
||||
|
||||
# 时间范围过滤
|
||||
if created_after is not None or created_before is not None:
|
||||
time_ids = {
|
||||
mid for mid in candidate_ids
|
||||
if (created_after is None or self.index[mid].created_at >= created_after)
|
||||
and (created_before is None or self.index[mid].created_at <= created_before)
|
||||
}
|
||||
candidate_ids &= time_ids
|
||||
|
||||
# 转换为列表并排序(按创建时间倒序)
|
||||
result_ids = sorted(
|
||||
candidate_ids,
|
||||
key=lambda mid: self.index[mid].created_at,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 限制数量
|
||||
if limit:
|
||||
result_ids = result_ids[:limit]
|
||||
|
||||
logger.debug(
|
||||
f"元数据索引搜索: types={memory_types}, subjects={subjects}, "
|
||||
f"keywords={keywords}, 返回={len(result_ids)}条"
|
||||
)
|
||||
|
||||
return result_ids
|
||||
|
||||
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
|
||||
"""获取单个索引条目"""
|
||||
return self.index.get(memory_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
with self.lock:
|
||||
return {
|
||||
'total_memories': len(self.index),
|
||||
'types': {mtype: len(ids) for mtype, ids in self.type_index.items()},
|
||||
'subjects_count': len(self.subject_index),
|
||||
'keywords_count': len(self.keyword_index),
|
||||
'tags_count': len(self.tag_index),
|
||||
}
|
||||
@@ -135,22 +135,76 @@ class MemoryQueryPlanner:
|
||||
|
||||
persona = context.get("bot_personality") or context.get("bot_identity") or "未知"
|
||||
|
||||
# 构建未读消息上下文信息
|
||||
context_section = ""
|
||||
if context.get("has_unread_context") and context.get("unread_messages_context"):
|
||||
unread_context = context["unread_messages_context"]
|
||||
unread_messages = unread_context.get("messages", [])
|
||||
unread_keywords = unread_context.get("keywords", [])
|
||||
unread_participants = unread_context.get("participants", [])
|
||||
context_summary = unread_context.get("context_summary", "")
|
||||
|
||||
if unread_messages:
|
||||
# 构建未读消息摘要
|
||||
message_previews = []
|
||||
for msg in unread_messages[:5]: # 最多显示5条
|
||||
sender = msg.get("sender", "未知")
|
||||
content = msg.get("content", "")[:100] # 限制每条消息长度
|
||||
message_previews.append(f"{sender}: {content}")
|
||||
|
||||
context_section = f"""
|
||||
|
||||
## 📋 未读消息上下文 (共{unread_context.get('total_count', 0)}条未读消息)
|
||||
### 最近消息预览:
|
||||
{chr(10).join(message_previews)}
|
||||
|
||||
### 上下文关键词:
|
||||
{', '.join(unread_keywords[:15]) if unread_keywords else '无'}
|
||||
|
||||
### 对话参与者:
|
||||
{', '.join(unread_participants) if unread_participants else '无'}
|
||||
|
||||
### 上下文摘要:
|
||||
{context_summary[:300] if context_summary else '无'}
|
||||
"""
|
||||
else:
|
||||
context_section = """
|
||||
|
||||
## 📋 未读消息上下文:
|
||||
无未读消息或上下文信息不可用
|
||||
"""
|
||||
|
||||
return f"""
|
||||
你是一名记忆检索规划助手,请基于输入生成一个简洁的 JSON 检索计划。
|
||||
你的任务是分析当前查询并结合未读消息的上下文,生成更精准的记忆检索策略。
|
||||
|
||||
仅需提供以下字段:
|
||||
- semantic_query: 用于向量召回的自然语言描述,要求具体且贴合当前查询;
|
||||
- semantic_query: 用于向量召回的自然语言描述,要求具体且贴合当前查询和上下文;
|
||||
- memory_types: 建议检索的记忆类型列表,取值范围来自 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual);
|
||||
- subject_includes: 建议出现在记忆主语中的人物或角色;
|
||||
- object_includes: 建议关注的对象、主题或关键信息;
|
||||
- required_keywords: 建议必须包含的关键词(从上下文中提取);
|
||||
- recency: 推荐的时间偏好,可选 recent/any/historical;
|
||||
- limit: 推荐的最大返回数量 (1-15);
|
||||
- notes: 额外补充说明(可选)。
|
||||
- emphasis: 检索重点,可选 balanced/contextual/recent/comprehensive。
|
||||
|
||||
请不要生成谓语字段,也不要额外补充其它参数。
|
||||
|
||||
当前查询: "{query_text}"
|
||||
已知的对话参与者: {participant_preview}
|
||||
机器人设定: {persona}
|
||||
## 当前查询:
|
||||
"{query_text}"
|
||||
|
||||
## 已知对话参与者:
|
||||
{participant_preview}
|
||||
|
||||
## 机器人设定:
|
||||
{persona}{context_section}
|
||||
|
||||
## 🎯 指导原则:
|
||||
1. **上下文关联**: 优先分析与当前查询相关的未读消息内容和关键词
|
||||
2. **语义理解**: 结合上下文理解查询的真实意图,而非字面意思
|
||||
3. **参与者感知**: 考虑未读消息中的参与者,检索与他们相关的记忆
|
||||
4. **主题延续**: 关注未读消息中讨论的主题,检索相关的历史记忆
|
||||
5. **时间相关性**: 如果未读消息讨论最近的事件,偏向检索相关时期的记忆
|
||||
|
||||
请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。
|
||||
"""
|
||||
|
||||
@@ -191,27 +191,27 @@ class MemorySystem:
|
||||
self.memory_builder = MemoryBuilder(self.memory_extraction_model)
|
||||
self.fusion_engine = MemoryFusionEngine(self.config.fusion_similarity_threshold)
|
||||
|
||||
# 初始化统一存储系统
|
||||
from src.chat.memory_system.unified_memory_storage import initialize_unified_memory_storage, UnifiedStorageConfig
|
||||
# 初始化Vector DB存储系统(替代旧的unified_memory_storage)
|
||||
from src.chat.memory_system.vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig
|
||||
|
||||
storage_config = UnifiedStorageConfig(
|
||||
dimension=self.config.vector_dimension,
|
||||
storage_config = VectorStorageConfig(
|
||||
memory_collection="unified_memory_v2",
|
||||
metadata_collection="memory_metadata_v2",
|
||||
similarity_threshold=self.config.similarity_threshold,
|
||||
storage_path=getattr(global_config.memory, 'unified_storage_path', 'data/unified_memory'),
|
||||
cache_size_limit=getattr(global_config.memory, 'unified_storage_cache_limit', 10000),
|
||||
auto_save_interval=getattr(global_config.memory, 'unified_storage_auto_save_interval', 50),
|
||||
enable_compression=getattr(global_config.memory, 'unified_storage_enable_compression', True),
|
||||
search_limit=getattr(global_config.memory, 'unified_storage_search_limit', 20),
|
||||
batch_size=getattr(global_config.memory, 'unified_storage_batch_size', 100),
|
||||
enable_caching=getattr(global_config.memory, 'unified_storage_enable_caching', True),
|
||||
cache_size_limit=getattr(global_config.memory, 'unified_storage_cache_limit', 1000),
|
||||
auto_cleanup_interval=getattr(global_config.memory, 'unified_storage_auto_cleanup_interval', 3600),
|
||||
enable_forgetting=getattr(global_config.memory, 'enable_memory_forgetting', True),
|
||||
forgetting_check_interval=getattr(global_config.memory, 'forgetting_check_interval_hours', 24)
|
||||
retention_hours=getattr(global_config.memory, 'memory_retention_hours', 720) # 30天
|
||||
)
|
||||
|
||||
try:
|
||||
self.unified_storage = await initialize_unified_memory_storage(storage_config)
|
||||
if self.unified_storage is None:
|
||||
raise RuntimeError("统一存储系统初始化返回None")
|
||||
logger.info("✅ 统一存储系统初始化成功")
|
||||
self.unified_storage = VectorMemoryStorage(storage_config)
|
||||
logger.info("✅ Vector DB存储系统初始化成功")
|
||||
except Exception as storage_error:
|
||||
logger.error(f"❌ 统一存储系统初始化失败: {storage_error}", exc_info=True)
|
||||
logger.error(f"❌ Vector DB存储系统初始化失败: {storage_error}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 初始化遗忘引擎
|
||||
@@ -380,11 +380,11 @@ class MemorySystem:
|
||||
self.status = original_status
|
||||
return []
|
||||
|
||||
# 2. 构建记忆块
|
||||
# 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享)
|
||||
memory_chunks = await self.memory_builder.build_memories(
|
||||
conversation_text,
|
||||
normalized_context,
|
||||
GLOBAL_MEMORY_SCOPE,
|
||||
GLOBAL_MEMORY_SCOPE, # 强制使用 global,不区分用户
|
||||
timestamp or time.time()
|
||||
)
|
||||
|
||||
@@ -609,7 +609,7 @@ class MemorySystem:
|
||||
limit: int = 5,
|
||||
**kwargs
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆(简化版,使用统一存储)"""
|
||||
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
|
||||
raw_query = query_text or kwargs.get("query")
|
||||
if not raw_query:
|
||||
raise ValueError("query_text 或 query 参数不能为空")
|
||||
@@ -619,6 +619,8 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
context = context or {}
|
||||
|
||||
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
|
||||
resolved_user_id = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
self.status = MemorySystemStatus.RETRIEVING
|
||||
@@ -626,50 +628,152 @@ class MemorySystem:
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
|
||||
|
||||
effective_limit = limit or self.config.final_recall_limit
|
||||
|
||||
# 构建过滤器
|
||||
filters = {
|
||||
"user_id": resolved_user_id
|
||||
effective_limit = self.config.final_recall_limit
|
||||
|
||||
# === 阶段一:元数据粗筛(软性过滤) ===
|
||||
coarse_filters = {
|
||||
"user_id": GLOBAL_MEMORY_SCOPE, # 必选:确保作用域正确
|
||||
}
|
||||
|
||||
# 应用查询规划结果
|
||||
# 应用查询规划(优化查询语句并构建元数据过滤)
|
||||
optimized_query = raw_query
|
||||
metadata_filters = {}
|
||||
|
||||
if self.query_planner:
|
||||
try:
|
||||
query_plan = await self.query_planner.plan_query(raw_query, normalized_context)
|
||||
if getattr(query_plan, "memory_types", None):
|
||||
filters["memory_types"] = [mt.value for mt in query_plan.memory_types]
|
||||
if getattr(query_plan, "subject_includes", None):
|
||||
filters["keywords"] = query_plan.subject_includes
|
||||
# 构建包含未读消息的增强上下文
|
||||
enhanced_context = await self._build_enhanced_query_context(raw_query, normalized_context)
|
||||
query_plan = await self.query_planner.plan_query(raw_query, enhanced_context)
|
||||
|
||||
# 使用LLM优化后的查询语句(更精确的语义表达)
|
||||
if getattr(query_plan, "semantic_query", None):
|
||||
raw_query = query_plan.semantic_query
|
||||
optimized_query = query_plan.semantic_query
|
||||
|
||||
# 构建JSON元数据过滤条件(用于阶段一粗筛)
|
||||
# 将查询规划的结果转换为元数据过滤条件
|
||||
if getattr(query_plan, "memory_types", None):
|
||||
metadata_filters['memory_types'] = [mt.value for mt in query_plan.memory_types]
|
||||
|
||||
if getattr(query_plan, "subject_includes", None):
|
||||
metadata_filters['subjects'] = query_plan.subject_includes
|
||||
|
||||
if getattr(query_plan, "required_keywords", None):
|
||||
metadata_filters['keywords'] = query_plan.required_keywords
|
||||
|
||||
# 时间范围过滤
|
||||
recency = getattr(query_plan, "recency_preference", "any")
|
||||
current_time = time.time()
|
||||
if recency == "recent":
|
||||
# 最近7天
|
||||
metadata_filters['created_after'] = current_time - (7 * 24 * 3600)
|
||||
elif recency == "historical":
|
||||
# 30天以前
|
||||
metadata_filters['created_before'] = current_time - (30 * 24 * 3600)
|
||||
|
||||
# 添加用户ID到元数据过滤
|
||||
metadata_filters['user_id'] = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
logger.debug(f"[阶段一] 查询优化: '{raw_query}' → '{optimized_query}'")
|
||||
logger.debug(f"[阶段一] 元数据过滤条件: {metadata_filters}")
|
||||
|
||||
except Exception as plan_exc:
|
||||
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
|
||||
logger.warning("查询规划失败,使用原始查询: %s", plan_exc, exc_info=True)
|
||||
# 即使查询规划失败,也保留基本的user_id过滤
|
||||
metadata_filters = {'user_id': GLOBAL_MEMORY_SCOPE}
|
||||
|
||||
# 使用统一存储搜索
|
||||
# === 阶段二:向量精筛 ===
|
||||
coarse_limit = self.config.coarse_recall_limit # 粗筛阶段返回更多候选
|
||||
|
||||
logger.debug(f"[阶段二] 开始向量搜索: query='{optimized_query[:60]}...', limit={coarse_limit}")
|
||||
|
||||
search_results = await self.unified_storage.search_similar_memories(
|
||||
query_text=raw_query,
|
||||
limit=effective_limit,
|
||||
filters=filters
|
||||
query_text=optimized_query,
|
||||
limit=coarse_limit,
|
||||
filters=coarse_filters, # ChromaDB where条件(保留兼容)
|
||||
metadata_filters=metadata_filters # JSON元数据索引过滤
|
||||
)
|
||||
|
||||
logger.info(f"[阶段二] 向量搜索完成: 返回 {len(search_results)} 条候选")
|
||||
|
||||
# 转换为记忆对象
|
||||
final_memories = []
|
||||
for memory_id, similarity_score in search_results:
|
||||
memory = self.unified_storage.get_memory_by_id(memory_id)
|
||||
if memory:
|
||||
memory.update_access()
|
||||
final_memories.append(memory)
|
||||
# === 阶段三:综合重排 ===
|
||||
scored_memories = []
|
||||
current_time = time.time()
|
||||
|
||||
for memory, vector_similarity in search_results:
|
||||
# 1. 向量相似度得分(已归一化到 0-1)
|
||||
vector_score = vector_similarity
|
||||
|
||||
# 2. 时效性得分(指数衰减,30天半衰期)
|
||||
age_seconds = current_time - memory.metadata.created_at
|
||||
age_days = age_seconds / (24 * 3600)
|
||||
# 使用 math.exp 而非 np.exp(避免依赖numpy)
|
||||
import math
|
||||
recency_score = math.exp(-age_days / 30)
|
||||
|
||||
# 3. 重要性得分(枚举值转换为归一化得分 0-1)
|
||||
# ImportanceLevel: LOW=1, NORMAL=2, HIGH=3, CRITICAL=4
|
||||
importance_enum = memory.metadata.importance
|
||||
if hasattr(importance_enum, 'value'):
|
||||
# 枚举类型,转换为0-1范围:(value - 1) / 3
|
||||
importance_score = (importance_enum.value - 1) / 3.0
|
||||
else:
|
||||
# 如果已经是数值,直接使用
|
||||
importance_score = float(importance_enum) if importance_enum else 0.5
|
||||
|
||||
# 4. 访问频率得分(归一化,访问10次以上得满分)
|
||||
access_count = memory.metadata.access_count
|
||||
frequency_score = min(access_count / 10.0, 1.0)
|
||||
|
||||
# 综合得分(加权平均)
|
||||
final_score = (
|
||||
self.config.vector_weight * vector_score +
|
||||
self.config.recency_weight * recency_score +
|
||||
self.config.context_weight * importance_score +
|
||||
0.1 * frequency_score # 访问频率权重(固定10%)
|
||||
)
|
||||
|
||||
scored_memories.append((memory, final_score, {
|
||||
"vector": vector_score,
|
||||
"recency": recency_score,
|
||||
"importance": importance_score,
|
||||
"frequency": frequency_score,
|
||||
"final": final_score
|
||||
}))
|
||||
|
||||
# 更新访问记录
|
||||
memory.update_access()
|
||||
|
||||
# 按综合得分排序
|
||||
scored_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 返回 Top-K
|
||||
final_memories = [mem for mem, score, details in scored_memories[:effective_limit]]
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
|
||||
# 详细日志
|
||||
if scored_memories:
|
||||
logger.info(f"[阶段三] 综合重排完成: Top 3 得分详情")
|
||||
for i, (mem, score, details) in enumerate(scored_memories[:3], 1):
|
||||
try:
|
||||
summary = mem.content[:60] if hasattr(mem, 'content') and mem.content else ""
|
||||
except:
|
||||
summary = ""
|
||||
logger.info(
|
||||
f" #{i} | final={details['final']:.3f} "
|
||||
f"(vec={details['vector']:.3f}, rec={details['recency']:.3f}, "
|
||||
f"imp={details['importance']:.3f}, freq={details['frequency']:.3f}) "
|
||||
f"| {summary}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"✅ 简化记忆检索完成"
|
||||
"✅ 三阶段记忆检索完成"
|
||||
f" | user={resolved_user_id}"
|
||||
f" | count={len(final_memories)}"
|
||||
f" | 粗筛={len(search_results)}"
|
||||
f" | 精筛={len(scored_memories)}"
|
||||
f" | 返回={len(final_memories)}"
|
||||
f" | duration={retrieval_time:.3f}s"
|
||||
f" | query='{raw_query}'"
|
||||
f" | query='{optimized_query[:60]}...'"
|
||||
)
|
||||
|
||||
self.last_retrieval_time = time.time()
|
||||
@@ -719,8 +823,8 @@ class MemorySystem:
|
||||
except Exception:
|
||||
context = dict(raw_context or {})
|
||||
|
||||
# 基础字段(统一使用全局作用域)
|
||||
context["user_id"] = GLOBAL_MEMORY_SCOPE
|
||||
# 基础字段:强制使用传入的 user_id 参数(已统一为 GLOBAL_MEMORY_SCOPE)
|
||||
context["user_id"] = user_id or GLOBAL_MEMORY_SCOPE
|
||||
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
|
||||
context["message_type"] = context.get("message_type") or "normal"
|
||||
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
|
||||
@@ -760,6 +864,150 @@ class MemorySystem:
|
||||
|
||||
return context
|
||||
|
||||
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""构建包含未读消息综合上下文的增强查询上下文
|
||||
|
||||
Args:
|
||||
raw_query: 原始查询文本
|
||||
normalized_context: 标准化后的基础上下文
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含未读消息综合信息的增强上下文
|
||||
"""
|
||||
enhanced_context = dict(normalized_context) # 复制基础上下文
|
||||
|
||||
try:
|
||||
# 获取stream_id以查找未读消息
|
||||
stream_id = normalized_context.get("stream_id")
|
||||
if not stream_id:
|
||||
logger.debug("未找到stream_id,使用基础上下文进行查询规划")
|
||||
return enhanced_context
|
||||
|
||||
# 获取未读消息作为上下文
|
||||
unread_messages_summary = await self._collect_unread_messages_context(stream_id)
|
||||
|
||||
if unread_messages_summary:
|
||||
enhanced_context["unread_messages_context"] = unread_messages_summary
|
||||
enhanced_context["has_unread_context"] = True
|
||||
|
||||
logger.debug(f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息")
|
||||
else:
|
||||
enhanced_context["has_unread_context"] = False
|
||||
logger.debug("未找到未读消息,使用基础上下文进行查询规划")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"构建增强查询上下文失败: {e}", exc_info=True)
|
||||
enhanced_context["has_unread_context"] = False
|
||||
|
||||
return enhanced_context
|
||||
|
||||
async def _collect_unread_messages_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""收集未读消息的综合上下文信息
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 未读消息的综合信息,包含消息列表、关键词、主题等
|
||||
"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
|
||||
if not chat_stream or not hasattr(chat_stream, "context_manager"):
|
||||
logger.debug(f"未找到stream_id={stream_id}的聊天流或上下文管理器")
|
||||
return None
|
||||
|
||||
# 获取未读消息
|
||||
context_manager = chat_stream.context_manager
|
||||
unread_messages = context_manager.get_unread_messages()
|
||||
|
||||
if not unread_messages:
|
||||
logger.debug(f"stream_id={stream_id}没有未读消息")
|
||||
return None
|
||||
|
||||
# 构建未读消息摘要
|
||||
messages_summary = []
|
||||
all_keywords = set()
|
||||
participant_names = set()
|
||||
|
||||
for msg in unread_messages[:10]: # 限制处理最近10条未读消息
|
||||
try:
|
||||
# 提取消息内容
|
||||
content = (getattr(msg, "processed_plain_text", None) or
|
||||
getattr(msg, "display_message", None) or "")
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 提取发送者信息
|
||||
sender_name = "未知用户"
|
||||
if hasattr(msg, "user_info") and msg.user_info:
|
||||
sender_name = (getattr(msg.user_info, "user_nickname", None) or
|
||||
getattr(msg.user_info, "user_cardname", None) or
|
||||
getattr(msg.user_info, "user_id", None) or "未知用户")
|
||||
|
||||
participant_names.add(sender_name)
|
||||
|
||||
# 添加到消息摘要
|
||||
messages_summary.append({
|
||||
"sender": sender_name,
|
||||
"content": content[:200], # 限制长度避免过长
|
||||
"timestamp": getattr(msg, "time", None)
|
||||
})
|
||||
|
||||
# 提取关键词(简单实现)
|
||||
content_lower = content.lower()
|
||||
# 这里可以添加更复杂的关键词提取逻辑
|
||||
words = [w.strip() for w in content_lower.split() if len(w.strip()) > 1]
|
||||
all_keywords.update(words[:5]) # 每条消息最多取5个词
|
||||
|
||||
except Exception as msg_e:
|
||||
logger.debug(f"处理未读消息时出错: {msg_e}")
|
||||
continue
|
||||
|
||||
if not messages_summary:
|
||||
return None
|
||||
|
||||
# 构建综合上下文信息
|
||||
unread_context = {
|
||||
"messages": messages_summary,
|
||||
"total_count": len(unread_messages),
|
||||
"processed_count": len(messages_summary),
|
||||
"keywords": list(all_keywords)[:20], # 最多20个关键词
|
||||
"participants": list(participant_names),
|
||||
"context_summary": self._build_unread_context_summary(messages_summary)
|
||||
}
|
||||
|
||||
logger.debug(f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者")
|
||||
return unread_context
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"收集未读消息上下文失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _build_unread_context_summary(self, messages_summary: List[Dict[str, Any]]) -> str:
|
||||
"""构建未读消息的文本摘要
|
||||
|
||||
Args:
|
||||
messages_summary: 未读消息摘要列表
|
||||
|
||||
Returns:
|
||||
str: 未读消息的文本摘要
|
||||
"""
|
||||
if not messages_summary:
|
||||
return ""
|
||||
|
||||
summary_parts = []
|
||||
for msg_info in messages_summary:
|
||||
sender = msg_info.get("sender", "未知")
|
||||
content = msg_info.get("content", "")
|
||||
if content:
|
||||
summary_parts.append(f"{sender}: {content}")
|
||||
|
||||
return " | ".join(summary_parts)
|
||||
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
|
||||
if not context:
|
||||
|
||||
@@ -1,577 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统一记忆存储系统
|
||||
简化后的记忆存储,整合向量存储和元数据索引
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 尝试导入FAISS
|
||||
try:
|
||||
import faiss
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
logger.warning("FAISS not available, using simple vector storage")
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedStorageConfig:
|
||||
"""统一存储配置"""
|
||||
# 向量存储配置
|
||||
dimension: int = 1024
|
||||
similarity_threshold: float = 0.8
|
||||
storage_path: str = "data/unified_memory"
|
||||
|
||||
# 性能配置
|
||||
cache_size_limit: int = 10000
|
||||
auto_save_interval: int = 50
|
||||
search_limit: int = 20
|
||||
enable_compression: bool = True
|
||||
|
||||
# 遗忘配置
|
||||
enable_forgetting: bool = True
|
||||
forgetting_check_interval: int = 24 # 小时
|
||||
|
||||
|
||||
class UnifiedMemoryStorage:
|
||||
"""统一记忆存储系统"""
|
||||
|
||||
def __init__(self, config: Optional[UnifiedStorageConfig] = None):
|
||||
self.config = config or UnifiedStorageConfig()
|
||||
|
||||
# 存储路径
|
||||
self.storage_path = Path(self.config.storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 向量索引
|
||||
self.vector_index = None
|
||||
self.memory_id_to_index: Dict[str, int] = {}
|
||||
self.index_to_memory_id: Dict[int, str] = {}
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: Dict[str, np.ndarray] = {}
|
||||
|
||||
# 元数据索引(简化版)
|
||||
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> memory_id set
|
||||
self.type_index: Dict[str, Set[str]] = {} # type -> memory_id set
|
||||
self.user_index: Dict[str, Set[str]] = {} # user_id -> memory_id set
|
||||
|
||||
# 遗忘引擎
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
if self.config.enable_forgetting:
|
||||
self.forgetting_engine = MemoryForgettingEngine()
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_memories": 0,
|
||||
"total_vectors": 0,
|
||||
"cache_size": 0,
|
||||
"last_save_time": 0.0,
|
||||
"total_searches": 0,
|
||||
"total_stores": 0,
|
||||
"forgetting_stats": {}
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
self._lock = threading.RLock()
|
||||
self._operation_count = 0
|
||||
|
||||
# 嵌入模型
|
||||
self.embedding_model: Optional[LLMRequest] = None
|
||||
|
||||
# 初始化
|
||||
self._initialize_storage()
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""初始化存储系统"""
|
||||
try:
|
||||
# 初始化向量索引
|
||||
if FAISS_AVAILABLE:
|
||||
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
|
||||
logger.info(f"FAISS向量索引初始化完成,维度: {self.config.dimension}")
|
||||
else:
|
||||
# 简单向量存储
|
||||
self.vector_index = {}
|
||||
logger.info("使用简单向量存储(FAISS不可用)")
|
||||
|
||||
# 尝试加载现有数据
|
||||
self._load_storage()
|
||||
|
||||
logger.info(f"统一记忆存储初始化完成,当前记忆数: {len(self.memory_cache)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储系统初始化失败: {e}", exc_info=True)
|
||||
|
||||
def set_embedding_model(self, model: LLMRequest):
|
||||
"""设置嵌入模型"""
|
||||
self.embedding_model = model
|
||||
|
||||
async def _generate_embedding(self, text: str) -> Optional[np.ndarray]:
|
||||
"""生成文本的向量表示"""
|
||||
if not self.embedding_model:
|
||||
logger.warning("未设置嵌入模型,无法生成向量")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 使用嵌入模型生成向量
|
||||
embedding, _ = await self.embedding_model.get_embedding(text)
|
||||
|
||||
if embedding is None:
|
||||
logger.warning(f"嵌入模型返回空向量,文本: {text[:50]}...")
|
||||
return None
|
||||
|
||||
# 转换为numpy数组
|
||||
embedding_array = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 归一化向量
|
||||
norm = np.linalg.norm(embedding_array)
|
||||
if norm > 0:
|
||||
embedding_array = embedding_array / norm
|
||||
|
||||
return embedding_array
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成向量失败: {e}")
|
||||
return None
|
||||
|
||||
def _add_to_keyword_index(self, memory: MemoryChunk):
|
||||
"""添加到关键词索引"""
|
||||
for keyword in memory.keywords:
|
||||
if keyword not in self.keyword_index:
|
||||
self.keyword_index[keyword] = set()
|
||||
self.keyword_index[keyword].add(memory.memory_id)
|
||||
|
||||
def _add_to_type_index(self, memory: MemoryChunk):
|
||||
"""添加到类型索引"""
|
||||
memory_type = memory.memory_type.value
|
||||
if memory_type not in self.type_index:
|
||||
self.type_index[memory_type] = set()
|
||||
self.type_index[memory_type].add(memory.memory_id)
|
||||
|
||||
def _add_to_user_index(self, memory: MemoryChunk):
|
||||
"""添加到用户索引"""
|
||||
user_id = memory.user_id
|
||||
if user_id not in self.user_index:
|
||||
self.user_index[user_id] = set()
|
||||
self.user_index[user_id].add(memory.memory_id)
|
||||
|
||||
def _remove_from_indexes(self, memory: MemoryChunk):
|
||||
"""从所有索引中移除记忆"""
|
||||
memory_id = memory.memory_id
|
||||
|
||||
# 从关键词索引移除
|
||||
for keyword, memory_ids in self.keyword_index.items():
|
||||
memory_ids.discard(memory_id)
|
||||
if not memory_ids:
|
||||
del self.keyword_index[keyword]
|
||||
|
||||
# 从类型索引移除
|
||||
memory_type = memory.memory_type.value
|
||||
if memory_type in self.type_index:
|
||||
self.type_index[memory_type].discard(memory_id)
|
||||
if not self.type_index[memory_type]:
|
||||
del self.type_index[memory_type]
|
||||
|
||||
# 从用户索引移除
|
||||
if memory.user_id in self.user_index:
|
||||
self.user_index[memory.user_id].discard(memory_id)
|
||||
if not self.user_index[memory.user_id]:
|
||||
del self.user_index[memory.user_id]
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
"""存储记忆列表"""
|
||||
if not memories:
|
||||
return 0
|
||||
|
||||
stored_count = 0
|
||||
|
||||
with self._lock:
|
||||
for memory in memories:
|
||||
try:
|
||||
# 生成向量
|
||||
vector = None
|
||||
if memory.display and memory.display.strip():
|
||||
vector = await self._generate_embedding(memory.display)
|
||||
elif memory.text_content and memory.text_content.strip():
|
||||
vector = await self._generate_embedding(memory.text_content)
|
||||
|
||||
# 存储到缓存
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
if vector is not None:
|
||||
self.vector_cache[memory.memory_id] = vector
|
||||
|
||||
# 添加到向量索引
|
||||
if FAISS_AVAILABLE:
|
||||
index_id = self.vector_index.ntotal
|
||||
self.vector_index.add(vector.reshape(1, -1))
|
||||
self.memory_id_to_index[memory.memory_id] = index_id
|
||||
self.index_to_memory_id[index_id] = memory.memory_id
|
||||
else:
|
||||
# 简单存储
|
||||
self.vector_index[memory.memory_id] = vector
|
||||
|
||||
# 更新元数据索引
|
||||
self._add_to_keyword_index(memory)
|
||||
self._add_to_type_index(memory)
|
||||
self._add_to_user_index(memory)
|
||||
|
||||
stored_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储记忆 {memory.memory_id[:8]} 失败: {e}")
|
||||
continue
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_memories"] = len(self.memory_cache)
|
||||
self.stats["total_vectors"] = len(self.vector_cache)
|
||||
self.stats["total_stores"] += stored_count
|
||||
|
||||
# 自动保存
|
||||
self._operation_count += stored_count
|
||||
if self._operation_count >= self.config.auto_save_interval:
|
||||
await self._save_storage()
|
||||
self._operation_count = 0
|
||||
|
||||
logger.debug(f"成功存储 {stored_count}/{len(memories)} 条记忆")
|
||||
return stored_count
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
scope_id: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
if not query_text or not self.vector_cache:
|
||||
return []
|
||||
|
||||
# 生成查询向量
|
||||
query_vector = await self._generate_embedding(query_text)
|
||||
if query_vector is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
results = []
|
||||
|
||||
if FAISS_AVAILABLE and self.vector_index.ntotal > 0:
|
||||
# 使用FAISS搜索
|
||||
query_vector = query_vector.reshape(1, -1)
|
||||
scores, indices = self.vector_index.search(
|
||||
query_vector,
|
||||
min(limit, self.vector_index.ntotal)
|
||||
)
|
||||
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx >= 0 and score >= self.config.similarity_threshold:
|
||||
memory_id = self.index_to_memory_id.get(idx)
|
||||
if memory_id and memory_id in self.memory_cache:
|
||||
# 应用过滤器
|
||||
if self._apply_filters(self.memory_cache[memory_id], filters):
|
||||
results.append((memory_id, float(score)))
|
||||
|
||||
else:
|
||||
# 简单余弦相似度搜索
|
||||
for memory_id, vector in self.vector_cache.items():
|
||||
if memory_id not in self.memory_cache:
|
||||
continue
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity = np.dot(query_vector, vector)
|
||||
if similarity >= self.config.similarity_threshold:
|
||||
# 应用过滤器
|
||||
if self._apply_filters(self.memory_cache[memory_id], filters):
|
||||
results.append((memory_id, float(similarity)))
|
||||
|
||||
# 排序并限制结果
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
self.stats["total_searches"] += 1
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索相似记忆失败: {e}")
|
||||
return []
|
||||
|
||||
def _apply_filters(self, memory: MemoryChunk, filters: Optional[Dict[str, Any]]) -> bool:
|
||||
"""应用搜索过滤器"""
|
||||
if not filters:
|
||||
return True
|
||||
|
||||
# 用户过滤器
|
||||
if "user_id" in filters and memory.user_id != filters["user_id"]:
|
||||
return False
|
||||
|
||||
# 类型过滤器
|
||||
if "memory_types" in filters and memory.memory_type.value not in filters["memory_types"]:
|
||||
return False
|
||||
|
||||
# 关键词过滤器
|
||||
if "keywords" in filters:
|
||||
memory_keywords = set(k.lower() for k in memory.keywords)
|
||||
filter_keywords = set(k.lower() for k in filters["keywords"])
|
||||
if not memory_keywords.intersection(filter_keywords):
|
||||
return False
|
||||
|
||||
# 重要性过滤器
|
||||
if "min_importance" in filters and memory.metadata.importance.value < filters["min_importance"]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
"""根据ID获取记忆"""
|
||||
return self.memory_cache.get(memory_id)
|
||||
|
||||
def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 50) -> List[MemoryChunk]:
|
||||
"""根据过滤器获取记忆"""
|
||||
results = []
|
||||
|
||||
for memory in self.memory_cache.values():
|
||||
if self._apply_filters(memory, filters):
|
||||
results.append(memory)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def forget_memories(self, memory_ids: List[str]) -> int:
|
||||
"""遗忘指定的记忆"""
|
||||
if not memory_ids:
|
||||
return 0
|
||||
|
||||
forgotten_count = 0
|
||||
|
||||
with self._lock:
|
||||
for memory_id in memory_ids:
|
||||
try:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if not memory:
|
||||
continue
|
||||
|
||||
# 从向量索引移除
|
||||
if FAISS_AVAILABLE and memory_id in self.memory_id_to_index:
|
||||
# FAISS不支持直接删除,这里简化处理
|
||||
# 在实际使用中,可能需要重建索引
|
||||
logger.debug(f"FAISS索引删除 {memory_id} (需要重建索引)")
|
||||
elif memory_id in self.vector_index:
|
||||
del self.vector_index[memory_id]
|
||||
|
||||
# 从缓存移除
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.vector_cache.pop(memory_id, None)
|
||||
|
||||
# 从索引移除
|
||||
self._remove_from_indexes(memory)
|
||||
|
||||
forgotten_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"遗忘记忆 {memory_id[:8]} 失败: {e}")
|
||||
continue
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_memories"] = len(self.memory_cache)
|
||||
self.stats["total_vectors"] = len(self.vector_cache)
|
||||
|
||||
logger.info(f"成功遗忘 {forgotten_count}/{len(memory_ids)} 条记忆")
|
||||
return forgotten_count
|
||||
|
||||
async def perform_forgetting_check(self) -> Dict[str, Any]:
|
||||
"""执行遗忘检查"""
|
||||
if not self.forgetting_engine:
|
||||
return {"error": "遗忘引擎未启用"}
|
||||
|
||||
try:
|
||||
# 执行遗忘检查
|
||||
result = await self.forgetting_engine.perform_forgetting_check(list(self.memory_cache.values()))
|
||||
|
||||
# 遗忘标记的记忆
|
||||
forgetting_ids = result["normal_forgetting"] + result["force_forgetting"]
|
||||
if forgetting_ids:
|
||||
forgotten_count = await self.forget_memories(forgetting_ids)
|
||||
result["forgotten_count"] = forgotten_count
|
||||
|
||||
# 更新统计
|
||||
self.stats["forgetting_stats"] = self.forgetting_engine.get_forgetting_stats()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行遗忘检查失败: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _load_storage(self):
|
||||
"""加载存储数据"""
|
||||
try:
|
||||
# 加载记忆缓存
|
||||
memory_file = self.storage_path / "memory_cache.json"
|
||||
if memory_file.exists():
|
||||
with open(memory_file, 'rb') as f:
|
||||
memory_data = orjson.loads(f.read())
|
||||
for memory_id, memory_dict in memory_data.items():
|
||||
self.memory_cache[memory_id] = MemoryChunk.from_dict(memory_dict)
|
||||
|
||||
# 加载向量缓存(如果启用压缩)
|
||||
if not self.config.enable_compression:
|
||||
vector_file = self.storage_path / "vectors.npz"
|
||||
if vector_file.exists():
|
||||
vectors = np.load(vector_file)
|
||||
self.vector_cache = {
|
||||
memory_id: vectors[memory_id]
|
||||
for memory_id in vectors.files
|
||||
if memory_id in self.memory_cache
|
||||
}
|
||||
|
||||
# 重建向量索引
|
||||
if FAISS_AVAILABLE and self.vector_cache:
|
||||
logger.info("重建FAISS向量索引...")
|
||||
vectors = []
|
||||
memory_ids = []
|
||||
|
||||
for memory_id, vector in self.vector_cache.items():
|
||||
vectors.append(vector)
|
||||
memory_ids.append(memory_id)
|
||||
|
||||
if vectors:
|
||||
vectors_array = np.vstack(vectors)
|
||||
self.vector_index.reset()
|
||||
self.vector_index.add(vectors_array)
|
||||
|
||||
# 重建映射
|
||||
for idx, memory_id in enumerate(memory_ids):
|
||||
self.memory_id_to_index[memory_id] = idx
|
||||
self.index_to_memory_id[idx] = memory_id
|
||||
|
||||
logger.info(f"存储数据加载完成,记忆数: {len(self.memory_cache)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载存储数据失败: {e}")
|
||||
|
||||
async def _save_storage(self):
|
||||
"""保存存储数据"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 保存记忆缓存
|
||||
memory_data = {
|
||||
memory_id: memory.to_dict()
|
||||
for memory_id, memory in self.memory_cache.items()
|
||||
}
|
||||
|
||||
memory_file = self.storage_path / "memory_cache.json"
|
||||
with open(memory_file, 'wb') as f:
|
||||
f.write(orjson.dumps(memory_data, option=orjson.OPT_INDENT_2))
|
||||
|
||||
# 保存向量缓存(如果启用压缩)
|
||||
if not self.config.enable_compression and self.vector_cache:
|
||||
vector_file = self.storage_path / "vectors.npz"
|
||||
np.savez_compressed(vector_file, **self.vector_cache)
|
||||
|
||||
save_time = time.time() - start_time
|
||||
self.stats["last_save_time"] = time.time()
|
||||
|
||||
logger.debug(f"存储数据保存完成,耗时: {save_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存存储数据失败: {e}")
|
||||
|
||||
async def save_storage(self):
|
||||
"""手动保存存储数据"""
|
||||
await self._save_storage()
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats.update({
|
||||
"cache_size": len(self.memory_cache),
|
||||
"vector_count": len(self.vector_cache),
|
||||
"keyword_index_size": len(self.keyword_index),
|
||||
"type_index_size": len(self.type_index),
|
||||
"user_index_size": len(self.user_index),
|
||||
"config": {
|
||||
"dimension": self.config.dimension,
|
||||
"similarity_threshold": self.config.similarity_threshold,
|
||||
"enable_forgetting": self.config.enable_forgetting
|
||||
}
|
||||
})
|
||||
return stats
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理存储系统"""
|
||||
try:
|
||||
logger.info("开始清理统一记忆存储...")
|
||||
|
||||
# 保存数据
|
||||
await self._save_storage()
|
||||
|
||||
# 清空缓存
|
||||
self.memory_cache.clear()
|
||||
self.vector_cache.clear()
|
||||
self.keyword_index.clear()
|
||||
self.type_index.clear()
|
||||
self.user_index.clear()
|
||||
|
||||
# 重置索引
|
||||
if FAISS_AVAILABLE:
|
||||
self.vector_index.reset()
|
||||
|
||||
self.memory_id_to_index.clear()
|
||||
self.index_to_memory_id.clear()
|
||||
|
||||
logger.info("统一记忆存储清理完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理存储系统失败: {e}")
|
||||
|
||||
|
||||
# 创建全局存储实例
|
||||
unified_memory_storage: Optional[UnifiedMemoryStorage] = None
|
||||
|
||||
|
||||
def get_unified_memory_storage() -> Optional[UnifiedMemoryStorage]:
|
||||
"""获取统一存储实例"""
|
||||
return unified_memory_storage
|
||||
|
||||
|
||||
async def initialize_unified_memory_storage(config: Optional[UnifiedStorageConfig] = None) -> UnifiedMemoryStorage:
|
||||
"""初始化统一记忆存储"""
|
||||
global unified_memory_storage
|
||||
|
||||
if unified_memory_storage is None:
|
||||
unified_memory_storage = UnifiedMemoryStorage(config)
|
||||
|
||||
# 设置嵌入模型
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task:
|
||||
unified_memory_storage.set_embedding_model(
|
||||
LLMRequest(model_set=embedding_task, request_type="memory.embedding")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"设置嵌入模型失败: {e}")
|
||||
|
||||
return unified_memory_storage
|
||||
908
src/chat/memory_system/vector_memory_storage_v2.py
Normal file
908
src/chat/memory_system/vector_memory_storage_v2.py
Normal file
@@ -0,0 +1,908 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于Vector DB的统一记忆存储系统 V2
|
||||
使用ChromaDB作为底层存储,替代JSON存储方式
|
||||
|
||||
主要特性:
|
||||
- 统一的向量存储接口
|
||||
- 高效的语义检索
|
||||
- 元数据过滤支持
|
||||
- 批量操作优化
|
||||
- 自动清理过期记忆
|
||||
"""
|
||||
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any, Union
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStorageConfig:
|
||||
"""Vector存储配置"""
|
||||
# 集合配置
|
||||
memory_collection: str = "unified_memory_v2"
|
||||
metadata_collection: str = "memory_metadata_v2"
|
||||
|
||||
# 检索配置
|
||||
similarity_threshold: float = 0.5 # 降低阈值以提高召回率(0.5-0.6 是合理范围)
|
||||
search_limit: int = 20
|
||||
batch_size: int = 100
|
||||
|
||||
# 性能配置
|
||||
enable_caching: bool = True
|
||||
cache_size_limit: int = 1000
|
||||
auto_cleanup_interval: int = 3600 # 1小时
|
||||
|
||||
# 遗忘配置
|
||||
enable_forgetting: bool = True
|
||||
retention_hours: int = 24 * 30 # 30天
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls):
|
||||
"""从全局配置创建实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
memory_cfg = global_config.memory
|
||||
|
||||
return cls(
|
||||
memory_collection=getattr(memory_cfg, 'vector_db_memory_collection', 'unified_memory_v2'),
|
||||
metadata_collection=getattr(memory_cfg, 'vector_db_metadata_collection', 'memory_metadata_v2'),
|
||||
similarity_threshold=getattr(memory_cfg, 'vector_db_similarity_threshold', 0.5),
|
||||
search_limit=getattr(memory_cfg, 'vector_db_search_limit', 20),
|
||||
batch_size=getattr(memory_cfg, 'vector_db_batch_size', 100),
|
||||
enable_caching=getattr(memory_cfg, 'vector_db_enable_caching', True),
|
||||
cache_size_limit=getattr(memory_cfg, 'vector_db_cache_size_limit', 1000),
|
||||
auto_cleanup_interval=getattr(memory_cfg, 'vector_db_auto_cleanup_interval', 3600),
|
||||
enable_forgetting=getattr(memory_cfg, 'enable_memory_forgetting', True),
|
||||
retention_hours=getattr(memory_cfg, 'vector_db_retention_hours', 720),
|
||||
)
|
||||
|
||||
|
||||
class VectorMemoryStorage:
|
||||
@property
|
||||
def keyword_index(self) -> dict:
|
||||
"""
|
||||
动态构建关键词倒排索引(仅兼容旧接口,基于当前缓存)
|
||||
返回: {keyword: [memory_id, ...]}
|
||||
"""
|
||||
index = {}
|
||||
for memory in self.memory_cache.values():
|
||||
for kw in getattr(memory, 'keywords', []):
|
||||
if not kw:
|
||||
continue
|
||||
kw_norm = kw.strip().lower()
|
||||
if kw_norm:
|
||||
index.setdefault(kw_norm, []).append(getattr(memory.metadata, 'memory_id', None))
|
||||
return index
|
||||
"""基于Vector DB的记忆存储系统"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
# 默认从全局配置读取,如果没有传入config
|
||||
if config is None:
|
||||
try:
|
||||
self.config = VectorStorageConfig.from_global_config()
|
||||
logger.info("✅ Vector存储配置已从全局配置加载")
|
||||
except Exception as e:
|
||||
logger.warning(f"从全局配置加载失败,使用默认配置: {e}")
|
||||
self.config = VectorStorageConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
# 从配置中获取批处理大小和集合名称
|
||||
self.batch_size = self.config.batch_size
|
||||
self.collection_name = self.config.memory_collection
|
||||
self.vector_db_service = vector_db_service
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.cache_timestamps: Dict[str, float] = {}
|
||||
self._cache = self.memory_cache # 别名,兼容旧代码
|
||||
|
||||
# 元数据索引管理器(JSON文件索引)
|
||||
self.metadata_index = MemoryMetadataIndex()
|
||||
|
||||
# 遗忘引擎
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
if self.config.enable_forgetting:
|
||||
self.forgetting_engine = MemoryForgettingEngine()
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_memories": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"total_searches": 0,
|
||||
"total_stores": 0,
|
||||
"last_cleanup_time": 0.0,
|
||||
"forgetting_stats": {}
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# 定时清理任务
|
||||
self._cleanup_task = None
|
||||
self._stop_cleanup = False
|
||||
|
||||
# 初始化系统
|
||||
self._initialize_storage()
|
||||
self._start_cleanup_task()
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""初始化Vector DB存储"""
|
||||
try:
|
||||
# 创建记忆集合
|
||||
vector_db_service.get_or_create_collection(
|
||||
name=self.config.memory_collection,
|
||||
metadata={
|
||||
"description": "统一记忆存储V2",
|
||||
"hnsw:space": "cosine",
|
||||
"version": "2.0"
|
||||
}
|
||||
)
|
||||
|
||||
# 创建元数据集合(用于复杂查询)
|
||||
vector_db_service.get_or_create_collection(
|
||||
name=self.config.metadata_collection,
|
||||
metadata={
|
||||
"description": "记忆元数据索引",
|
||||
"hnsw:space": "cosine",
|
||||
"version": "2.0"
|
||||
}
|
||||
)
|
||||
|
||||
# 获取当前记忆总数
|
||||
self.stats["total_memories"] = vector_db_service.count(self.config.memory_collection)
|
||||
|
||||
logger.info(f"Vector记忆存储初始化完成,当前记忆数: {self.stats['total_memories']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector存储系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动定时清理任务"""
|
||||
if self.config.auto_cleanup_interval > 0:
|
||||
def cleanup_worker():
|
||||
while not self._stop_cleanup:
|
||||
try:
|
||||
time.sleep(self.config.auto_cleanup_interval)
|
||||
if not self._stop_cleanup:
|
||||
asyncio.create_task(self._perform_auto_cleanup())
|
||||
except Exception as e:
|
||||
logger.error(f"定时清理任务出错: {e}")
|
||||
|
||||
self._cleanup_task = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
self._cleanup_task.start()
|
||||
logger.info(f"定时清理任务已启动,间隔: {self.config.auto_cleanup_interval}秒")
|
||||
|
||||
async def _perform_auto_cleanup(self):
|
||||
"""执行自动清理"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# 清理过期缓存
|
||||
if self.config.enable_caching:
|
||||
expired_keys = [
|
||||
memory_id for memory_id, timestamp in self.cache_timestamps.items()
|
||||
if current_time - timestamp > 3600 # 1小时过期
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
self.memory_cache.pop(key, None)
|
||||
self.cache_timestamps.pop(key, None)
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了 {len(expired_keys)} 个过期缓存项")
|
||||
|
||||
# 执行遗忘检查
|
||||
if self.forgetting_engine:
|
||||
await self.perform_forgetting_check()
|
||||
|
||||
self.stats["last_cleanup_time"] = current_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"自动清理失败: {e}")
|
||||
|
||||
def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]:
|
||||
"""将MemoryChunk转换为向量存储格式"""
|
||||
try:
|
||||
# 获取memory_id
|
||||
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
|
||||
|
||||
# 生成向量表示的文本
|
||||
display_text = getattr(memory, 'display', None) or getattr(memory, 'text_content', None) or str(memory.content)
|
||||
if not display_text.strip():
|
||||
logger.warning(f"记忆 {memory_id} 缺少有效的显示文本")
|
||||
display_text = f"{memory.memory_type.value}: {', '.join(memory.subjects)}"
|
||||
|
||||
# 构建元数据 - 修复枚举值和列表序列化
|
||||
metadata = {
|
||||
"memory_id": memory_id,
|
||||
"user_id": memory.metadata.user_id or "unknown",
|
||||
"memory_type": memory.memory_type.value,
|
||||
"importance": memory.metadata.importance.name, # 使用 .name 而不是枚举对象
|
||||
"confidence": memory.metadata.confidence.name, # 使用 .name 而不是枚举对象
|
||||
"created_at": memory.metadata.created_at,
|
||||
"last_accessed": memory.metadata.last_accessed or memory.metadata.created_at,
|
||||
"access_count": memory.metadata.access_count,
|
||||
"subjects": orjson.dumps(memory.subjects).decode("utf-8"), # 列表转JSON字符串
|
||||
"keywords": orjson.dumps(memory.keywords).decode("utf-8"), # 列表转JSON字符串
|
||||
"tags": orjson.dumps(memory.tags).decode("utf-8"), # 列表转JSON字符串
|
||||
"categories": orjson.dumps(memory.categories).decode("utf-8"), # 列表转JSON字符串
|
||||
"relevance_score": memory.metadata.relevance_score
|
||||
}
|
||||
|
||||
# 添加可选字段
|
||||
if memory.metadata.source_context:
|
||||
metadata["source_context"] = str(memory.metadata.source_context)
|
||||
|
||||
if memory.content.predicate:
|
||||
metadata["predicate"] = memory.content.predicate
|
||||
|
||||
if memory.content.object:
|
||||
if isinstance(memory.content.object, (dict, list)):
|
||||
metadata["object"] = orjson.dumps(memory.content.object).decode()
|
||||
else:
|
||||
metadata["object"] = str(memory.content.object)
|
||||
|
||||
return {
|
||||
"id": memory_id,
|
||||
"embedding": None, # 将由vector_db_service生成
|
||||
"metadata": metadata,
|
||||
"document": display_text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
|
||||
logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
|
||||
"""将Vector DB结果转换为MemoryChunk"""
|
||||
try:
|
||||
# 从元数据中恢复完整记忆
|
||||
if "memory_data" in metadata:
|
||||
memory_dict = orjson.loads(metadata["memory_data"])
|
||||
return MemoryChunk.from_dict(memory_dict)
|
||||
|
||||
# 兜底:从基础字段重建(使用新的结构化格式)
|
||||
logger.warning(f"未找到memory_data,使用兜底逻辑重建记忆 (id={metadata.get('memory_id', 'unknown')})")
|
||||
|
||||
# 构建符合MemoryChunk.from_dict期望的结构
|
||||
memory_dict = {
|
||||
"metadata": {
|
||||
"memory_id": metadata.get("memory_id", f"recovered_{int(time.time())}"),
|
||||
"user_id": metadata.get("user_id", "unknown"),
|
||||
"created_at": metadata.get("timestamp", time.time()),
|
||||
"last_accessed": metadata.get("last_access_time", time.time()),
|
||||
"last_modified": metadata.get("timestamp", time.time()),
|
||||
"access_count": metadata.get("access_count", 0),
|
||||
"relevance_score": 0.0,
|
||||
"confidence": int(metadata.get("confidence", 2)), # MEDIUM
|
||||
"importance": int(metadata.get("importance", 2)), # NORMAL
|
||||
"source_context": None,
|
||||
},
|
||||
"content": {
|
||||
"subject": "",
|
||||
"predicate": "",
|
||||
"object": "",
|
||||
"display": document # 使用document作为显示文本
|
||||
},
|
||||
"memory_type": metadata.get("memory_type", "contextual"),
|
||||
"keywords": orjson.loads(metadata.get("keywords", "[]")) if isinstance(metadata.get("keywords"), str) else metadata.get("keywords", []),
|
||||
"tags": [],
|
||||
"categories": [],
|
||||
"embedding": None,
|
||||
"semantic_hash": None,
|
||||
"related_memories": [],
|
||||
"temporal_context": None
|
||||
}
|
||||
|
||||
return MemoryChunk.from_dict(memory_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换Vector结果到MemoryChunk失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
"""从缓存获取记忆"""
|
||||
if not self.config.enable_caching:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
if memory_id in self.memory_cache:
|
||||
self.cache_timestamps[memory_id] = time.time()
|
||||
self.stats["cache_hits"] += 1
|
||||
return self.memory_cache[memory_id]
|
||||
|
||||
self.stats["cache_misses"] += 1
|
||||
return None
|
||||
|
||||
def _add_to_cache(self, memory: MemoryChunk):
|
||||
"""添加记忆到缓存"""
|
||||
if not self.config.enable_caching:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
# 检查缓存大小限制
|
||||
if len(self.memory_cache) >= self.config.cache_size_limit:
|
||||
# 移除最老的缓存项
|
||||
oldest_id = min(self.cache_timestamps.keys(),
|
||||
key=lambda k: self.cache_timestamps[k])
|
||||
self.memory_cache.pop(oldest_id, None)
|
||||
self.cache_timestamps.pop(oldest_id, None)
|
||||
|
||||
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
|
||||
if memory_id:
|
||||
self.memory_cache[memory_id] = memory
|
||||
self.cache_timestamps[memory_id] = time.time()
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
"""批量存储记忆"""
|
||||
if not memories:
|
||||
return 0
|
||||
|
||||
start_time = datetime.now()
|
||||
success_count = 0
|
||||
|
||||
try:
|
||||
# 转换为向量格式
|
||||
vector_data_list = []
|
||||
for memory in memories:
|
||||
try:
|
||||
vector_data = self._memory_to_vector_format(memory)
|
||||
vector_data_list.append(vector_data)
|
||||
except Exception as e:
|
||||
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
|
||||
logger.error(f"处理记忆 {memory_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
if not vector_data_list:
|
||||
logger.warning("没有有效的记忆数据可存储")
|
||||
return 0
|
||||
|
||||
# 批量存储到向量数据库
|
||||
for i in range(0, len(vector_data_list), self.batch_size):
|
||||
batch = vector_data_list[i:i + self.batch_size]
|
||||
|
||||
try:
|
||||
# 生成embeddings
|
||||
embeddings = []
|
||||
for item in batch:
|
||||
try:
|
||||
embedding = await get_embedding(item["document"])
|
||||
embeddings.append(embedding)
|
||||
except Exception as e:
|
||||
logger.error(f"生成embedding失败: {e}")
|
||||
# 使用零向量作为后备
|
||||
embeddings.append([0.0] * 768) # 默认维度
|
||||
|
||||
# vector_db_service.add 需要embeddings参数
|
||||
self.vector_db_service.add(
|
||||
collection_name=self.collection_name,
|
||||
embeddings=embeddings,
|
||||
ids=[item["id"] for item in batch],
|
||||
documents=[item["document"] for item in batch],
|
||||
metadatas=[item["metadata"] for item in batch]
|
||||
)
|
||||
success = True
|
||||
|
||||
if success:
|
||||
# 更新缓存和元数据索引
|
||||
metadata_entries = []
|
||||
for item in batch:
|
||||
memory_id = item["id"]
|
||||
# 从原始 memories 列表中找到对应的 MemoryChunk
|
||||
memory = next((m for m in memories if (getattr(m.metadata, 'memory_id', None) or getattr(m, 'memory_id', None)) == memory_id), None)
|
||||
if memory:
|
||||
# 更新缓存
|
||||
self._cache[memory_id] = memory
|
||||
success_count += 1
|
||||
|
||||
# 创建元数据索引条目
|
||||
try:
|
||||
index_entry = MemoryMetadataIndexEntry(
|
||||
memory_id=memory_id,
|
||||
user_id=memory.metadata.user_id or "unknown",
|
||||
memory_type=memory.memory_type.value,
|
||||
subjects=memory.subjects,
|
||||
objects=[str(memory.content.object)] if memory.content.object else [],
|
||||
keywords=memory.keywords,
|
||||
tags=memory.tags,
|
||||
importance=memory.metadata.importance.value,
|
||||
confidence=memory.metadata.confidence.value,
|
||||
created_at=memory.metadata.created_at,
|
||||
access_count=memory.metadata.access_count,
|
||||
chat_id=memory.metadata.chat_id,
|
||||
content_preview=str(memory.content)[:100] if memory.content else None
|
||||
)
|
||||
metadata_entries.append(index_entry)
|
||||
except Exception as e:
|
||||
logger.warning(f"创建元数据索引条目失败 (memory_id={memory_id}): {e}")
|
||||
|
||||
# 批量更新元数据索引
|
||||
if metadata_entries:
|
||||
try:
|
||||
self.metadata_index.batch_add_or_update(metadata_entries)
|
||||
logger.debug(f"更新元数据索引: {len(metadata_entries)} 条")
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新元数据索引失败: {e}")
|
||||
else:
|
||||
logger.warning(f"批次存储失败,跳过 {len(batch)} 条记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量存储失败: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
duration = (datetime.now() - start_time).total_seconds()
|
||||
logger.info(f"成功存储 {success_count}/{len(memories)} 条记忆,耗时 {duration:.2f}秒")
|
||||
|
||||
# 保存元数据索引到磁盘
|
||||
if success_count > 0:
|
||||
try:
|
||||
self.metadata_index.save()
|
||||
logger.debug("元数据索引已保存到磁盘")
|
||||
except Exception as e:
|
||||
logger.error(f"保存元数据索引失败: {e}")
|
||||
|
||||
return success_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量存储记忆失败: {e}", exc_info=True)
|
||||
return success_count
|
||||
|
||||
async def store_memory(self, memory: MemoryChunk) -> bool:
|
||||
"""存储单条记忆"""
|
||||
result = await self.store_memories([memory])
|
||||
return result > 0
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
# 新增:元数据过滤参数(用于JSON索引粗筛)
|
||||
metadata_filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[MemoryChunk, float]]:
|
||||
"""
|
||||
搜索相似记忆(混合索引模式)
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
limit: 返回数量限制
|
||||
similarity_threshold: 相似度阈值
|
||||
filters: ChromaDB where条件(保留用于兼容)
|
||||
metadata_filters: JSON元数据索引过滤条件,支持:
|
||||
- memory_types: List[str]
|
||||
- subjects: List[str]
|
||||
- keywords: List[str]
|
||||
- tags: List[str]
|
||||
- importance_min: int
|
||||
- importance_max: int
|
||||
- created_after: float
|
||||
- created_before: float
|
||||
- user_id: str
|
||||
"""
|
||||
if not query_text.strip():
|
||||
return []
|
||||
|
||||
try:
|
||||
# === 阶段一:JSON元数据粗筛(可选) ===
|
||||
candidate_ids: Optional[List[str]] = None
|
||||
if metadata_filters:
|
||||
logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}")
|
||||
candidate_ids = self.metadata_index.search(
|
||||
memory_types=metadata_filters.get('memory_types'),
|
||||
subjects=metadata_filters.get('subjects'),
|
||||
keywords=metadata_filters.get('keywords'),
|
||||
tags=metadata_filters.get('tags'),
|
||||
importance_min=metadata_filters.get('importance_min'),
|
||||
importance_max=metadata_filters.get('importance_max'),
|
||||
created_after=metadata_filters.get('created_after'),
|
||||
created_before=metadata_filters.get('created_before'),
|
||||
user_id=metadata_filters.get('user_id'),
|
||||
limit=self.config.search_limit * 2 # 粗筛返回更多候选
|
||||
)
|
||||
logger.info(f"[JSON元数据粗筛] 完成,筛选出 {len(candidate_ids)} 个候选ID")
|
||||
|
||||
# 如果粗筛后没有结果,直接返回
|
||||
if not candidate_ids:
|
||||
logger.warning("JSON元数据粗筛后无候选,返回空结果")
|
||||
return []
|
||||
|
||||
# === 阶段二:向量精筛 ===
|
||||
# 生成查询向量
|
||||
query_embedding = await get_embedding(query_text)
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
threshold = similarity_threshold or self.config.similarity_threshold
|
||||
|
||||
# 构建where条件
|
||||
where_conditions = filters or {}
|
||||
|
||||
# 如果有候选ID列表,添加到where条件
|
||||
if candidate_ids:
|
||||
# ChromaDB的where条件需要使用$in操作符
|
||||
where_conditions["memory_id"] = {"$in": candidate_ids}
|
||||
logger.debug(f"[向量精筛] 限制在 {len(candidate_ids)} 个候选ID内搜索")
|
||||
|
||||
# 查询Vector DB
|
||||
logger.debug(f"[向量精筛] 开始,limit={min(limit, self.config.search_limit)}")
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.config.memory_collection,
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=min(limit, self.config.search_limit),
|
||||
where=where_conditions if where_conditions else None
|
||||
)
|
||||
|
||||
# 处理结果
|
||||
similar_memories = []
|
||||
|
||||
if results.get("documents") and results["documents"][0]:
|
||||
documents = results["documents"][0]
|
||||
distances = results.get("distances", [[]])[0]
|
||||
metadatas = results.get("metadatas", [[]])[0]
|
||||
ids = results.get("ids", [[]])[0]
|
||||
|
||||
logger.info(f"向量检索返回原始结果:documents={len(documents)}, ids={len(ids)}, metadatas={len(metadatas)}")
|
||||
for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids)):
|
||||
# 计算相似度
|
||||
distance = distances[i] if i < len(distances) else 1.0
|
||||
similarity = 1 - distance # ChromaDB返回距离,转换为相似度
|
||||
|
||||
if similarity < threshold:
|
||||
continue
|
||||
|
||||
# 首先尝试从缓存获取
|
||||
memory = self._get_from_cache(memory_id)
|
||||
|
||||
if not memory:
|
||||
# 从Vector结果重建
|
||||
memory = self._vector_result_to_memory(doc, metadata)
|
||||
if memory:
|
||||
self._add_to_cache(memory)
|
||||
|
||||
if memory:
|
||||
similar_memories.append((memory, similarity))
|
||||
# 记录单条结果的关键日志(id,相似度,简短文本)
|
||||
try:
|
||||
short_text = (str(memory.content)[:120]) if hasattr(memory, 'content') else (doc[:120] if isinstance(doc, str) else '')
|
||||
except Exception:
|
||||
short_text = ''
|
||||
logger.info(f"检索结果 - id={memory_id}, similarity={similarity:.4f}, summary={short_text}")
|
||||
|
||||
# 按相似度排序
|
||||
similar_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
self.stats["total_searches"] += 1
|
||||
logger.info(f"搜索相似记忆: query='{query_text[:60]}...', limit={limit}, threshold={threshold}, filters={where_conditions}, 返回数={len(similar_memories)}")
|
||||
logger.debug(f"搜索相似记忆 详细结果数={len(similar_memories)}")
|
||||
|
||||
return similar_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索相似记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
"""根据ID获取记忆"""
|
||||
# 首先尝试从缓存获取
|
||||
memory = self._get_from_cache(memory_id)
|
||||
if memory:
|
||||
return memory
|
||||
|
||||
try:
|
||||
# 从Vector DB获取
|
||||
results = vector_db_service.get(
|
||||
collection_name=self.config.memory_collection,
|
||||
ids=[memory_id]
|
||||
)
|
||||
|
||||
if results.get("documents") and results["documents"]:
|
||||
document = results["documents"][0]
|
||||
metadata = results["metadatas"][0] if results.get("metadatas") else {}
|
||||
|
||||
memory = self._vector_result_to_memory(document, metadata)
|
||||
if memory:
|
||||
self._add_to_cache(memory)
|
||||
|
||||
return memory
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆 {memory_id} 失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def get_memories_by_filters(
|
||||
self,
|
||||
filters: Dict[str, Any],
|
||||
limit: int = 100
|
||||
) -> List[MemoryChunk]:
|
||||
"""根据过滤条件获取记忆"""
|
||||
try:
|
||||
results = vector_db_service.get(
|
||||
collection_name=self.config.memory_collection,
|
||||
where=filters,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
memories = []
|
||||
if results.get("documents"):
|
||||
documents = results["documents"]
|
||||
metadatas = results.get("metadatas", [{}] * len(documents))
|
||||
ids = results.get("ids", [])
|
||||
|
||||
logger.info(f"按过滤条件获取返回: docs={len(documents)}, ids={len(ids)}")
|
||||
for i, (doc, metadata) in enumerate(zip(documents, metadatas)):
|
||||
memory_id = ids[i] if i < len(ids) else None
|
||||
|
||||
# 首先尝试从缓存获取
|
||||
if memory_id:
|
||||
memory = self._get_from_cache(memory_id)
|
||||
if memory:
|
||||
memories.append(memory)
|
||||
logger.debug(f"过滤获取命中缓存: id={memory_id}")
|
||||
continue
|
||||
|
||||
# 从Vector结果重建
|
||||
memory = self._vector_result_to_memory(doc, metadata)
|
||||
if memory:
|
||||
memories.append(memory)
|
||||
if memory_id:
|
||||
self._add_to_cache(memory)
|
||||
logger.debug(f"过滤获取结果: id={memory_id}, meta_keys={list(metadata.keys())}")
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"根据过滤条件获取记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def update_memory(self, memory: MemoryChunk) -> bool:
|
||||
"""更新记忆"""
|
||||
try:
|
||||
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
|
||||
if not memory_id:
|
||||
logger.error("无法更新记忆:缺少memory_id")
|
||||
return False
|
||||
|
||||
# 先删除旧记忆
|
||||
await self.delete_memory(memory_id)
|
||||
|
||||
# 重新存储更新后的记忆
|
||||
return await self.store_memory(memory)
|
||||
|
||||
except Exception as e:
|
||||
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
|
||||
logger.error(f"更新记忆 {memory_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def delete_memory(self, memory_id: str) -> bool:
|
||||
"""删除记忆"""
|
||||
try:
|
||||
# 从Vector DB删除
|
||||
vector_db_service.delete(
|
||||
collection_name=self.config.memory_collection,
|
||||
ids=[memory_id]
|
||||
)
|
||||
|
||||
# 从缓存删除
|
||||
with self._lock:
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.cache_timestamps.pop(memory_id, None)
|
||||
|
||||
self.stats["total_memories"] = max(0, self.stats["total_memories"] - 1)
|
||||
logger.debug(f"删除记忆: {memory_id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除记忆 {memory_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int:
|
||||
"""根据过滤条件批量删除记忆"""
|
||||
try:
|
||||
# 先获取要删除的记忆ID
|
||||
results = vector_db_service.get(
|
||||
collection_name=self.config.memory_collection,
|
||||
where=filters,
|
||||
include=["metadatas"]
|
||||
)
|
||||
|
||||
if not results.get("ids"):
|
||||
return 0
|
||||
|
||||
memory_ids = results["ids"]
|
||||
|
||||
# 批量删除
|
||||
vector_db_service.delete(
|
||||
collection_name=self.config.memory_collection,
|
||||
where=filters
|
||||
)
|
||||
|
||||
# 从缓存删除
|
||||
with self._lock:
|
||||
for memory_id in memory_ids:
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.cache_timestamps.pop(memory_id, None)
|
||||
|
||||
deleted_count = len(memory_ids)
|
||||
self.stats["total_memories"] = max(0, self.stats["total_memories"] - deleted_count)
|
||||
logger.info(f"批量删除记忆: {deleted_count} 条")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量删除记忆失败: {e}")
|
||||
return 0
|
||||
|
||||
async def perform_forgetting_check(self) -> Dict[str, Any]:
|
||||
"""执行遗忘检查"""
|
||||
if not self.forgetting_engine:
|
||||
return {"error": "遗忘引擎未启用"}
|
||||
|
||||
try:
|
||||
# 获取所有记忆进行遗忘检查
|
||||
# 注意:对于大型数据集,这里应该分批处理
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - (self.config.retention_hours * 3600)
|
||||
|
||||
# 先删除明显过期的记忆
|
||||
expired_filters = {"timestamp": {"$lt": cutoff_time}}
|
||||
expired_count = await self.delete_memories_by_filters(expired_filters)
|
||||
|
||||
# 对剩余记忆执行智能遗忘检查
|
||||
# 这里为了性能考虑,只检查一部分记忆
|
||||
sample_memories = await self.get_memories_by_filters({}, limit=500)
|
||||
|
||||
if sample_memories:
|
||||
result = await self.forgetting_engine.perform_forgetting_check(sample_memories)
|
||||
|
||||
# 遗忘标记的记忆
|
||||
forgetting_ids = result.get("normal_forgetting", []) + result.get("force_forgetting", [])
|
||||
forgotten_count = 0
|
||||
|
||||
for memory_id in forgetting_ids:
|
||||
if await self.delete_memory(memory_id):
|
||||
forgotten_count += 1
|
||||
|
||||
result["forgotten_count"] = forgotten_count
|
||||
result["expired_count"] = expired_count
|
||||
|
||||
# 更新统计
|
||||
self.stats["forgetting_stats"] = self.forgetting_engine.get_forgetting_stats()
|
||||
|
||||
logger.info(f"遗忘检查完成: 过期删除 {expired_count}, 智能遗忘 {forgotten_count}")
|
||||
return result
|
||||
|
||||
return {"expired_count": expired_count, "forgotten_count": 0}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行遗忘检查失败: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
try:
|
||||
current_total = vector_db_service.count(self.config.memory_collection)
|
||||
self.stats["total_memories"] = current_total
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
"cache_size": len(self.memory_cache),
|
||||
"collection_name": self.config.memory_collection,
|
||||
"storage_type": "vector_db_v2",
|
||||
"uptime": time.time() - self.stats.get("start_time", time.time())
|
||||
}
|
||||
|
||||
def stop(self):
|
||||
"""停止存储系统"""
|
||||
self._stop_cleanup = True
|
||||
|
||||
if self._cleanup_task and self._cleanup_task.is_alive():
|
||||
logger.info("正在停止定时清理任务...")
|
||||
|
||||
# 清空缓存
|
||||
with self._lock:
|
||||
self.memory_cache.clear()
|
||||
self.cache_timestamps.clear()
|
||||
|
||||
logger.info("Vector记忆存储系统已停止")
|
||||
|
||||
|
||||
# 全局实例(可选)
|
||||
_global_vector_storage = None
|
||||
|
||||
|
||||
def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage:
|
||||
"""获取全局Vector记忆存储实例"""
|
||||
global _global_vector_storage
|
||||
|
||||
if _global_vector_storage is None:
|
||||
_global_vector_storage = VectorMemoryStorage(config)
|
||||
|
||||
return _global_vector_storage
|
||||
|
||||
|
||||
# 兼容性接口
|
||||
class VectorMemoryStorageAdapter:
|
||||
"""适配器类,提供与原UnifiedMemoryStorage兼容的接口"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
self.storage = VectorMemoryStorage(config)
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
return await self.storage.store_memories(memories)
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
scope_id: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
results = await self.storage.search_similar_memories(
|
||||
query_text, limit, filters=filters
|
||||
)
|
||||
# 转换为原格式:(memory_id, similarity)
|
||||
return [(getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown'), similarity) for memory, similarity in results]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.storage.get_storage_stats()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 简单测试
|
||||
async def test_vector_storage():
|
||||
storage = VectorMemoryStorage()
|
||||
|
||||
# 创建测试记忆
|
||||
from src.chat.memory_system.memory_chunk import MemoryType
|
||||
test_memory = MemoryChunk(
|
||||
memory_id="test_001",
|
||||
user_id="test_user",
|
||||
text_content="今天天气很好,适合出门散步",
|
||||
memory_type=MemoryType.FACT,
|
||||
keywords=["天气", "散步"],
|
||||
importance=0.7
|
||||
)
|
||||
|
||||
# 存储记忆
|
||||
success = await storage.store_memory(test_memory)
|
||||
print(f"存储结果: {success}")
|
||||
|
||||
# 搜索记忆
|
||||
results = await storage.search_similar_memories("天气怎么样", limit=5)
|
||||
print(f"搜索结果: {len(results)} 条")
|
||||
|
||||
for memory, similarity in results:
|
||||
print(f" - {memory.text_content[:50]}... (相似度: {similarity:.3f})")
|
||||
|
||||
# 获取统计信息
|
||||
stats = storage.get_storage_stats()
|
||||
print(f"存储统计: {stats}")
|
||||
|
||||
storage.stop()
|
||||
|
||||
asyncio.run(test_vector_storage())
|
||||
@@ -59,7 +59,7 @@ class SingleStreamContextManager:
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -275,7 +275,7 @@ class SingleStreamContextManager:
|
||||
self.last_access_time = time.time()
|
||||
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
|
||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度待计算)")
|
||||
return True
|
||||
|
||||
@@ -354,6 +354,25 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"清除未读消息时发生错误: {e}")
|
||||
|
||||
async def clear_stream_unread_messages(self, stream_id: str):
|
||||
"""清除指定聊天流的所有未读消息"""
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"clear_stream_unread_messages: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
context = chat_stream.context_manager.context
|
||||
if hasattr(context, 'unread_messages') and context.unread_messages:
|
||||
logger.debug(f"正在为流 {stream_id} 清除 {len(context.unread_messages)} 条未读消息")
|
||||
context.unread_messages.clear()
|
||||
else:
|
||||
logger.debug(f"流 {stream_id} 没有需要清除的未读消息")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清除流 {stream_id} 的未读消息时发生错误: {e}")
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
message_manager = MessageManager()
|
||||
|
||||
@@ -587,15 +587,25 @@ class DefaultReplyer:
|
||||
# 转换格式以兼容现有代码
|
||||
running_memories = []
|
||||
if enhanced_memories:
|
||||
for memory_chunk in enhanced_memories:
|
||||
logger.debug(f"[记忆转换] 收到 {len(enhanced_memories)} 条原始记忆")
|
||||
for idx, memory_chunk in enumerate(enhanced_memories, 1):
|
||||
# 获取结构化内容的字符串表示
|
||||
structure_display = str(memory_chunk.content) if hasattr(memory_chunk, 'content') else "unknown"
|
||||
|
||||
# 获取记忆内容,优先使用display
|
||||
content = memory_chunk.display or memory_chunk.text_content or ""
|
||||
|
||||
# 调试:记录每条记忆的内容获取情况
|
||||
logger.debug(f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}")
|
||||
|
||||
running_memories.append({
|
||||
"content": memory_chunk.display or memory_chunk.text_content or "",
|
||||
"content": content,
|
||||
"memory_type": memory_chunk.memory_type.value,
|
||||
"confidence": memory_chunk.metadata.confidence.value,
|
||||
"importance": memory_chunk.metadata.importance.value,
|
||||
"relevance": getattr(memory_chunk, 'relevance_score', 0.5),
|
||||
"relevance": getattr(memory_chunk.metadata, 'relevance_score', 0.5),
|
||||
"source": memory_chunk.metadata.source,
|
||||
"structure": memory_chunk.content_structure.value if memory_chunk.content_structure else "unknown",
|
||||
"structure": structure_display,
|
||||
})
|
||||
|
||||
# 构建瞬时记忆字符串
|
||||
@@ -604,7 +614,7 @@ class DefaultReplyer:
|
||||
if top_memory:
|
||||
instant_memory = top_memory[0].get("content", "")
|
||||
|
||||
logger.info(f"增强记忆系统检索到 {len(running_memories)} 条记忆")
|
||||
logger.info(f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统检索失败: {e}")
|
||||
@@ -640,10 +650,17 @@ class DefaultReplyer:
|
||||
# 调试相关度信息
|
||||
relevance_info = [(m.get('memory_type', 'unknown'), m.get('relevance', 0.0)) for m in sorted_memories]
|
||||
logger.debug(f"记忆相关度信息: {relevance_info}")
|
||||
logger.debug(f"[记忆构建] 准备将 {len(sorted_memories)} 条记忆添加到提示词")
|
||||
|
||||
for running_memory in sorted_memories:
|
||||
for idx, running_memory in enumerate(sorted_memories, 1):
|
||||
content = running_memory.get('content', '')
|
||||
memory_type = running_memory.get('memory_type', 'unknown')
|
||||
|
||||
# 跳过空内容
|
||||
if not content or not content.strip():
|
||||
logger.warning(f"[记忆构建] 跳过第 {idx} 条记忆:内容为空 (type={memory_type})")
|
||||
logger.debug(f"[记忆构建] 空记忆详情: {running_memory}")
|
||||
continue
|
||||
|
||||
# 映射记忆类型到中文标签
|
||||
type_mapping = {
|
||||
@@ -661,10 +678,12 @@ class DefaultReplyer:
|
||||
if "(类型:" in content and ")" in content:
|
||||
clean_content = content.split("(类型:")[0].strip()
|
||||
|
||||
logger.debug(f"[记忆构建] 添加第 {idx} 条记忆: [{chinese_type}] {clean_content[:50]}...")
|
||||
memory_parts.append(f"- **[{chinese_type}]** {clean_content}")
|
||||
|
||||
memory_str = "\n".join(memory_parts) + "\n"
|
||||
has_any_memory = True
|
||||
logger.debug(f"[记忆构建] 成功构建记忆字符串,包含 {len(memory_parts) - 2} 条记忆")
|
||||
|
||||
# 添加瞬时记忆
|
||||
if instant_memory:
|
||||
|
||||
@@ -98,13 +98,79 @@ class ChromaDBImpl(VectorDBBase):
|
||||
"n_results": n_results,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# 修复ChromaDB的where条件格式
|
||||
if where:
|
||||
query_params["where"] = where
|
||||
processed_where = self._process_where_condition(where)
|
||||
if processed_where:
|
||||
query_params["where"] = processed_where
|
||||
|
||||
return collection.query(**query_params)
|
||||
except Exception as e:
|
||||
logger.error(f"查询集合 '{collection_name}' 失败: {e}")
|
||||
# 如果查询失败,尝试不使用where条件重新查询
|
||||
try:
|
||||
fallback_params = {
|
||||
"query_embeddings": query_embeddings,
|
||||
"n_results": n_results,
|
||||
}
|
||||
logger.warning(f"使用回退查询模式(无where条件)")
|
||||
return collection.query(**fallback_params)
|
||||
except Exception as fallback_e:
|
||||
logger.error(f"回退查询也失败: {fallback_e}")
|
||||
return {}
|
||||
|
||||
def _process_where_condition(self, where: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理where条件,转换为ChromaDB支持的格式
|
||||
ChromaDB支持的格式:
|
||||
- 简单条件: {"field": "value"}
|
||||
- 操作符条件: {"field": {"$op": "value"}}
|
||||
- AND条件: {"$and": [condition1, condition2]}
|
||||
- OR条件: {"$or": [condition1, condition2]}
|
||||
"""
|
||||
if not where:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 如果只有一个字段,直接返回
|
||||
if len(where) == 1:
|
||||
key, value = next(iter(where.items()))
|
||||
|
||||
# 处理列表值(如memory_types)
|
||||
if isinstance(value, list):
|
||||
if len(value) == 1:
|
||||
return {key: value[0]}
|
||||
else:
|
||||
# 多个值使用 $in 操作符
|
||||
return {key: {"$in": value}}
|
||||
else:
|
||||
return {key: value}
|
||||
|
||||
# 多个字段使用 $and 操作符
|
||||
conditions = []
|
||||
for key, value in where.items():
|
||||
if isinstance(value, list):
|
||||
if len(value) == 1:
|
||||
conditions.append({key: value[0]})
|
||||
else:
|
||||
conditions.append({key: {"$in": value}})
|
||||
else:
|
||||
conditions.append({key: value})
|
||||
|
||||
return {"$and": conditions}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理where条件失败: {e}, 使用简化条件")
|
||||
# 回退到只使用第一个条件
|
||||
if where:
|
||||
key, value = next(iter(where.items()))
|
||||
if isinstance(value, list) and value:
|
||||
return {key: value[0]}
|
||||
elif not isinstance(value, list):
|
||||
return {key: value}
|
||||
return None
|
||||
|
||||
def get(
|
||||
self,
|
||||
collection_name: str,
|
||||
@@ -119,16 +185,33 @@ class ChromaDBImpl(VectorDBBase):
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
# 处理where条件
|
||||
processed_where = None
|
||||
if where:
|
||||
processed_where = self._process_where_condition(where)
|
||||
|
||||
return collection.get(
|
||||
ids=ids,
|
||||
where=where,
|
||||
where=processed_where,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
include=include or ["documents", "metadatas", "embeddings"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}")
|
||||
# 如果获取失败,尝试不使用where条件重新获取
|
||||
try:
|
||||
logger.warning(f"使用回退获取模式(无where条件)")
|
||||
return collection.get(
|
||||
ids=ids,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include or ["documents", "metadatas", "embeddings"],
|
||||
)
|
||||
except Exception as fallback_e:
|
||||
logger.error(f"回退获取也失败: {fallback_e}")
|
||||
return {}
|
||||
|
||||
def delete(
|
||||
|
||||
@@ -305,6 +305,48 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
cache_ttl_seconds: int = Field(default=300, description="缓存生存时间(秒)")
|
||||
max_cache_size: int = Field(default=1000, description="最大缓存大小")
|
||||
|
||||
# Vector DB记忆存储配置 (替代JSON存储)
|
||||
enable_vector_memory_storage: bool = Field(default=True, description="启用Vector DB记忆存储")
|
||||
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
|
||||
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
|
||||
|
||||
# Vector DB配置
|
||||
vector_db_memory_collection: str = Field(default="unified_memory_v2", description="Vector DB记忆集合名称")
|
||||
vector_db_metadata_collection: str = Field(default="memory_metadata_v2", description="Vector DB元数据集合名称")
|
||||
vector_db_similarity_threshold: float = Field(default=0.5, description="Vector DB相似度阈值(推荐0.5-0.6,过高会导致检索不到结果)")
|
||||
vector_db_search_limit: int = Field(default=20, description="Vector DB搜索限制")
|
||||
vector_db_batch_size: int = Field(default=100, description="批处理大小")
|
||||
vector_db_enable_caching: bool = Field(default=True, description="启用内存缓存")
|
||||
vector_db_cache_size_limit: int = Field(default=1000, description="缓存大小限制")
|
||||
vector_db_auto_cleanup_interval: int = Field(default=3600, description="自动清理间隔(秒)")
|
||||
vector_db_retention_hours: int = Field(default=720, description="记忆保留时间(小时,默认30天)")
|
||||
|
||||
# 遗忘引擎配置
|
||||
enable_memory_forgetting: bool = Field(default=True, description="启用智能遗忘机制")
|
||||
forgetting_check_interval_hours: int = Field(default=24, description="遗忘检查间隔(小时)")
|
||||
base_forgetting_days: float = Field(default=30.0, description="基础遗忘天数")
|
||||
min_forgetting_days: float = Field(default=7.0, description="最小遗忘天数")
|
||||
max_forgetting_days: float = Field(default=365.0, description="最大遗忘天数")
|
||||
|
||||
# 重要程度权重
|
||||
critical_importance_bonus: float = Field(default=45.0, description="关键重要性额外天数")
|
||||
high_importance_bonus: float = Field(default=30.0, description="高重要性额外天数")
|
||||
normal_importance_bonus: float = Field(default=15.0, description="一般重要性额外天数")
|
||||
low_importance_bonus: float = Field(default=0.0, description="低重要性额外天数")
|
||||
|
||||
# 置信度权重
|
||||
verified_confidence_bonus: float = Field(default=30.0, description="已验证置信度额外天数")
|
||||
high_confidence_bonus: float = Field(default=20.0, description="高置信度额外天数")
|
||||
medium_confidence_bonus: float = Field(default=10.0, description="中等置信度额外天数")
|
||||
low_confidence_bonus: float = Field(default=0.0, description="低置信度额外天数")
|
||||
|
||||
# 激活频率权重
|
||||
activation_frequency_weight: float = Field(default=0.5, description="每次激活增加的天数权重")
|
||||
max_frequency_bonus: float = Field(default=10.0, description="最大激活频率奖励天数")
|
||||
|
||||
# 休眠机制
|
||||
dormant_threshold_days: int = Field(default=90, description="休眠状态判定天数")
|
||||
|
||||
|
||||
class MoodConfig(ValidatedConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
@@ -795,7 +795,7 @@ class LLMRequest:
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
|
||||
reasoning = response.reasoning_content or reasoning
|
||||
|
||||
|
||||
@@ -382,19 +382,21 @@ class BaseAction(ABC):
|
||||
# 构造命令数据
|
||||
command_data = {"name": command_name, "args": args or {}}
|
||||
|
||||
success = await send_api.command_to_stream(
|
||||
command=command_data,
|
||||
response = await send_api.adapter_command_to_stream(
|
||||
action=command_name,
|
||||
params=args or {},
|
||||
stream_id=self.chat_id,
|
||||
storage_message=storage_message,
|
||||
display_message=display_message,
|
||||
platform=self.platform
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
|
||||
# 根据响应判断成功与否
|
||||
if response and response.get("status") == "ok":
|
||||
logger.info(f"{self.log_prefix} 成功执行适配器命令: {command_name}, 响应: {response.get('data')}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
|
||||
|
||||
return success
|
||||
error_message = response.get('message', '未知错误')
|
||||
logger.error(f"{self.log_prefix} 执行适配器命令失败: {command_name}, 错误: {error_message}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||
|
||||
@@ -127,7 +127,7 @@ class ChatterActionPlanner:
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}"
|
||||
)
|
||||
|
||||
|
||||
@@ -828,6 +828,63 @@ class MessageHandler:
|
||||
data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}",
|
||||
)
|
||||
|
||||
# 检查是否是音乐分享 (QQ音乐类型)
|
||||
if nested_data.get("view") == "music" and "com.tencent.music" in str(nested_data.get("app", "")):
|
||||
meta = nested_data.get("meta", {})
|
||||
music = meta.get("music", {})
|
||||
if music:
|
||||
tag = music.get("tag", "未知来源")
|
||||
logger.debug(f"检测到【{tag}】音乐分享消息 (music view),开始提取信息")
|
||||
|
||||
title = music.get("title", "未知歌曲")
|
||||
desc = music.get("desc", "未知艺术家")
|
||||
jump_url = music.get("jumpUrl", "")
|
||||
preview_url = music.get("preview", "")
|
||||
|
||||
artist = "未知艺术家"
|
||||
song_title = title
|
||||
|
||||
if "网易云音乐" in tag:
|
||||
artist = desc
|
||||
elif "QQ音乐" in tag:
|
||||
if " - " in title:
|
||||
parts = title.split(" - ", 1)
|
||||
song_title = parts[0]
|
||||
artist = parts[1]
|
||||
else:
|
||||
artist = desc
|
||||
|
||||
formatted_content = (
|
||||
f"这是一张来自【{tag}】的音乐分享卡片:\n"
|
||||
f"歌曲: {song_title}\n"
|
||||
f"艺术家: {artist}\n"
|
||||
f"跳转链接: {jump_url}\n"
|
||||
f"封面图: {preview_url}"
|
||||
)
|
||||
return Seg(type="text", data=formatted_content)
|
||||
|
||||
# 检查是否是新闻/图文分享 (网易云音乐可能伪装成这种)
|
||||
elif nested_data.get("view") == "news" and "com.tencent.tuwen" in str(nested_data.get("app", "")):
|
||||
meta = nested_data.get("meta", {})
|
||||
news = meta.get("news", {})
|
||||
if news and "网易云音乐" in news.get("tag", ""):
|
||||
tag = news.get("tag")
|
||||
logger.debug(f"检测到【{tag}】音乐分享消息 (news view),开始提取信息")
|
||||
|
||||
title = news.get("title", "未知歌曲")
|
||||
desc = news.get("desc", "未知艺术家")
|
||||
jump_url = news.get("jumpUrl", "")
|
||||
preview_url = news.get("preview", "")
|
||||
|
||||
formatted_content = (
|
||||
f"这是一张来自【{tag}】的音乐分享卡片:\n"
|
||||
f"标题: {title}\n"
|
||||
f"描述: {desc}\n"
|
||||
f"跳转链接: {jump_url}\n"
|
||||
f"封面图: {preview_url}"
|
||||
)
|
||||
return Seg(type="text", data=formatted_content)
|
||||
|
||||
# 如果没有提取到关键信息,返回None
|
||||
return None
|
||||
|
||||
|
||||
@@ -6,6 +6,33 @@ import urllib3
|
||||
import ssl
|
||||
import io
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from asyncio import Lock
|
||||
|
||||
_internal_cache = {}
|
||||
_cache_lock = Lock()
|
||||
CACHE_TIMEOUT = 300 # 缓存5分钟
|
||||
|
||||
|
||||
async def get_from_cache(key: str):
|
||||
async with _cache_lock:
|
||||
data = _internal_cache.get(key)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
result, timestamp = data
|
||||
if time.time() - timestamp < CACHE_TIMEOUT:
|
||||
logger.debug(f"从缓存命中: {key}")
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
async def set_to_cache(key: str, value: any):
|
||||
async with _cache_lock:
|
||||
_internal_cache[key] = (value, time.time())
|
||||
|
||||
|
||||
from .database import BanUser, db_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -27,11 +54,16 @@ class SSLAdapter(urllib3.PoolManager):
|
||||
|
||||
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||
"""
|
||||
获取群相关信息
|
||||
获取群相关信息 (带缓存)
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群聊信息中")
|
||||
cache_key = f"group_info:{group_id}"
|
||||
cached_data = await get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
|
||||
logger.debug(f"获取群聊信息中 (无缓存): {group_id}")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
try:
|
||||
@@ -43,8 +75,11 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
|
||||
except Exception as e:
|
||||
logger.error(f"获取群信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
data = socket_response.get("data")
|
||||
if data:
|
||||
await set_to_cache(cache_key, data)
|
||||
return data
|
||||
|
||||
|
||||
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||
@@ -71,11 +106,16 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in
|
||||
|
||||
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
|
||||
"""
|
||||
获取群成员信息
|
||||
获取群成员信息 (带缓存)
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群成员信息中")
|
||||
cache_key = f"member_info:{group_id}:{user_id}"
|
||||
cached_data = await get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
|
||||
logger.debug(f"获取群成员信息中 (无缓存): group={group_id}, user={user_id}")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
@@ -93,8 +133,11 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
|
||||
except Exception as e:
|
||||
logger.error(f"获取成员信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
data = socket_response.get("data")
|
||||
if data:
|
||||
await set_to_cache(cache_key, data)
|
||||
return data
|
||||
|
||||
|
||||
async def get_image_base64(url: str) -> str:
|
||||
@@ -137,13 +180,18 @@ def convert_image_to_gif(image_base64: str) -> str:
|
||||
|
||||
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||
"""
|
||||
获取自身信息
|
||||
获取自身信息 (带缓存)
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
Returns:
|
||||
data: dict: 返回的自身信息
|
||||
"""
|
||||
logger.debug("获取自身信息中")
|
||||
cache_key = "self_info"
|
||||
cached_data = await get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
|
||||
logger.debug("获取自身信息中 (无缓存)")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
|
||||
try:
|
||||
@@ -155,8 +203,11 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||
except Exception as e:
|
||||
logger.error(f"获取自身信息失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
data = response.get("data")
|
||||
if data:
|
||||
await set_to_cache(cache_key, data)
|
||||
return data
|
||||
|
||||
|
||||
def get_image_format(raw_data: str) -> str:
|
||||
|
||||
@@ -191,13 +191,21 @@ class PokeAction(BaseAction):
|
||||
|
||||
display_name = user_name or user_id
|
||||
|
||||
# 构建戳一戳的参数
|
||||
poke_args = {"user_id": str(user_id)}
|
||||
if self.is_group and self.chat_stream.group_info:
|
||||
poke_args["group_id"] = self.chat_stream.group_info.group_id
|
||||
logger.info(f"在群聊 {poke_args['group_id']} 中执行戳一戳")
|
||||
else:
|
||||
logger.info("在私聊中执行戳一戳")
|
||||
|
||||
for i in range(times):
|
||||
logger.info(f"正在向 {display_name} ({user_id}) 发送第 {i + 1}/{times} 次戳一戳...")
|
||||
await self.send_command(
|
||||
"SEND_POKE", args={"qq_id": user_id}, display_message=f"戳了戳 {display_name} ({i + 1}/{times})"
|
||||
"send_poke", args=poke_args, display_message=f"戳了戳 {display_name} ({i + 1}/{times})"
|
||||
)
|
||||
# 添加一个小的延迟,以避免发送过快
|
||||
await asyncio.sleep(0.5)
|
||||
# 添加一个延迟,避免因发送过快导致后续戳一戳失败
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
success_message = f"已向 {display_name} 发送 {times} 次戳一戳。"
|
||||
await self.store_action_info(
|
||||
@@ -212,138 +220,126 @@ class SetEmojiLikeAction(BaseAction):
|
||||
# === 基本信息(必须填写)===
|
||||
action_name = "set_emoji_like"
|
||||
action_description = "为某条已经存在的消息添加‘贴表情’回应(类似点赞),而不是发送新消息。可以在觉得某条消息非常有趣、值得赞同或者需要特殊情感回应时主动使用。"
|
||||
activation_type = ActionActivationType.ALWAYS # 消息接收时激活(?)
|
||||
activation_type = ActionActivationType.ALWAYS
|
||||
chat_type_allow = ChatType.GROUP
|
||||
parallel_action = True
|
||||
|
||||
# === 功能描述(必须填写)===
|
||||
# 从 qq_face 字典中提取所有表情名称用于提示
|
||||
emoji_options = []
|
||||
for name in qq_face.values():
|
||||
match = re.search(r"\[表情:(.+?)\]", name)
|
||||
if match:
|
||||
emoji_options.append(match.group(1))
|
||||
|
||||
action_parameters = {
|
||||
"set": "是否设置回应 (True/False)",
|
||||
}
|
||||
action_require = [
|
||||
"当需要对一个已存在消息进行‘贴表情’回应时使用",
|
||||
"这是一个对旧消息的操作,而不是发送新消息",
|
||||
"如果你想发送一个新的表情包消息,请使用 'emoji' 动作",
|
||||
]
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用贴表情动作的条件:
|
||||
1. 用户明确要求使用贴表情包
|
||||
2. 这是一个适合表达强烈情绪的场合
|
||||
3. 不要发送太多表情包,如果你已经发送过多个表情包则回答"否"
|
||||
|
||||
1. 这是一个适合表达强烈情绪的场合,例如非常有趣、赞同、惊讶等。
|
||||
2. 不要发送太多表情包,如果最近已经发送过表情包,请回答"否"。
|
||||
3. 仅在群聊中使用。
|
||||
|
||||
请回答"是"或"否"。
|
||||
"""
|
||||
associated_types = ["text"]
|
||||
|
||||
# 重新启用完整的表情库
|
||||
emoji_options = []
|
||||
for name in qq_face.values():
|
||||
match = re.search(r"\[表情:(.+?)\]", name)
|
||||
if match:
|
||||
emoji_options.append(match.group(1))
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行设置表情回应的动作"""
|
||||
message_id = None
|
||||
set_like = self.action_data.get("set", True)
|
||||
if self.has_action_message:
|
||||
logger.debug(str(self.action_message))
|
||||
if isinstance(self.action_message, dict):
|
||||
message_id = self.action_message.get("message_id")
|
||||
|
||||
if self.has_action_message and isinstance(self.action_message, dict):
|
||||
message_id = self.action_message.get("message_id")
|
||||
logger.info(f"获取到的消息ID: {message_id}")
|
||||
else:
|
||||
logger.error("未提供消息ID")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: 未提供消息ID",
|
||||
action_done=False,
|
||||
)
|
||||
logger.error("未提供有效的消息或消息ID")
|
||||
await self.store_action_info(action_prompt_display="贴表情失败: 未提供消息ID", action_done=False)
|
||||
return False, "未提供消息ID"
|
||||
|
||||
if not message_id:
|
||||
logger.error("消息ID为空")
|
||||
await self.store_action_info(action_prompt_display="贴表情失败: 消息ID为空", action_done=False)
|
||||
return False, "消息ID为空"
|
||||
|
||||
available_models = llm_api.get_available_models()
|
||||
if "utils_small" not in available_models:
|
||||
logger.error("未找到 'utils_small' 模型配置,无法选择表情")
|
||||
return False, "表情选择功能配置错误"
|
||||
logger.error("未找到 'utils_small' 模型配置,无法选择表情")
|
||||
return False, "表情选择功能配置错误"
|
||||
|
||||
model_to_use = available_models["utils_small"]
|
||||
|
||||
# 获取最近的对话历史作为上下文
|
||||
context_text = ""
|
||||
if self.action_message:
|
||||
context_text = self.action_message.get("processed_plain_text", "")
|
||||
else:
|
||||
logger.error("无法找到动作选择的原始消息")
|
||||
return False, "无法找到动作选择的原始消息"
|
||||
|
||||
|
||||
context_text = self.action_message.get("processed_plain_text", "")
|
||||
if not context_text:
|
||||
logger.error("无法找到动作选择的原始消息文本")
|
||||
return False, "无法找到动作选择的原始消息文本"
|
||||
|
||||
prompt = (
|
||||
f"根据以下这条消息,从列表中选择一个最合适的表情名称来回应这条消息。\n"
|
||||
f"消息内容: '{context_text}'\n"
|
||||
f"可用表情列表: {', '.join(self.emoji_options)}\n"
|
||||
f"你的任务是:只输出你选择的表情的名称,不要包含任何其他文字或标点。\n"
|
||||
f"例如,如果觉得应该用'赞',就只输出'赞'。"
|
||||
)
|
||||
f"**任务:**\n"
|
||||
f"根据以下消息,从“可用表情列表”中选择一个最合适的表情名称来回应。\n\n"
|
||||
f"**规则(必须严格遵守):**\n"
|
||||
f"1. **只能**从下面的“可用表情列表”中选择一个表情名称。\n"
|
||||
f"2. 你的回答**必须**只包含你选择的表情名称,**不能**有任何其他文字、标点、解释或空格。\n"
|
||||
f"3. 你的回答**不能**包含 `[表情:]` 或 `[]` 等符号。\n\n"
|
||||
f"**消息内容:**\n"
|
||||
f"'{context_text}'\n\n"
|
||||
f"**可用表情列表:**\n"
|
||||
f"{', '.join(self.emoji_options)}\n\n"
|
||||
f"**示例:**\n"
|
||||
f"- 如果认为“赞”最合适,你的回答**必须**是:`赞`\n"
|
||||
f"- 如果认为“笑哭”最合适,你的回答**必须**是:`笑哭`\n\n"
|
||||
f"**你的回答:**"
|
||||
)
|
||||
|
||||
success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=model_to_use, request_type="plugin.set_emoji_like.select_emoji"
|
||||
)
|
||||
prompt, model_config=model_to_use, request_type="plugin.set_emoji_like.select_emoji"
|
||||
)
|
||||
|
||||
if not success or not response:
|
||||
logger.error("二级LLM未能选择有效的表情。")
|
||||
return False, "无法找到合适的表情。"
|
||||
logger.error("表情选择模型未能返回有效的表情名称。")
|
||||
return False, "无法选择合适的表情。"
|
||||
|
||||
chosen_emoji_name = response.strip()
|
||||
logger.info(f"二级LLM选择的表情是: '{chosen_emoji_name}'")
|
||||
logger.info(f"模型选择的表情是: '{chosen_emoji_name}'")
|
||||
|
||||
emoji_id = get_emoji_id(chosen_emoji_name)
|
||||
|
||||
if not emoji_id:
|
||||
logger.error(f"二级LLM选择的表情 '{chosen_emoji_name}' 仍然无法匹配到有效的表情ID。")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: 找不到表情: '{chosen_emoji_name}'",
|
||||
action_done=False,
|
||||
)
|
||||
return False, f"找不到表情: '{chosen_emoji_name}'。"
|
||||
|
||||
# 4. 使用适配器API发送命令
|
||||
if not message_id:
|
||||
logger.error("未提供消息ID")
|
||||
logger.error(f"模型选择的表情 '{chosen_emoji_name}' 无法匹配到有效的表情ID。可能是模型违反了规则。")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: 未提供消息ID",
|
||||
action_prompt_display=f"贴表情失败: 找不到表情 '{chosen_emoji_name}'",
|
||||
action_done=False,
|
||||
)
|
||||
return False, "未提供消息ID"
|
||||
return False, f"找不到表情: '{chosen_emoji_name}'"
|
||||
|
||||
try:
|
||||
# 使用适配器API发送贴表情命令
|
||||
success = await self.send_command(
|
||||
command_name="set_emoji_like",
|
||||
command_name="set_msg_emoji_like",
|
||||
args={"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
|
||||
storage_message=False,
|
||||
)
|
||||
if success:
|
||||
logger.info("设置表情回应成功")
|
||||
display_message = f"贴上了表情: {chosen_emoji_name}"
|
||||
logger.info(display_message)
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了set_emoji_like动作,{chosen_emoji_name},设置表情回应: {emoji_id}, 是否设置: {set_like}",
|
||||
action_prompt_display=display_message,
|
||||
action_done=True,
|
||||
)
|
||||
return True, "成功设置表情回应"
|
||||
else:
|
||||
logger.error("设置表情回应失败")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败",
|
||||
action_done=False,
|
||||
)
|
||||
logger.error("通过适配器设置表情回应失败")
|
||||
await self.store_action_info(action_prompt_display="贴表情失败: 适配器返回失败", action_done=False)
|
||||
return False, "设置表情回应失败"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置表情回应失败: {e}")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: {e}",
|
||||
action_done=False,
|
||||
)
|
||||
logger.error(f"设置表情回应时发生异常: {e}", exc_info=True)
|
||||
await self.store_action_info(action_prompt_display=f"贴表情失败: {e}", action_done=False)
|
||||
return False, f"设置表情回应失败: {e}"
|
||||
|
||||
|
||||
|
||||
@@ -255,11 +255,46 @@ max_frequency_bonus = 10.0 # 最大激活频率奖励天数
|
||||
# 休眠机制
|
||||
dormant_threshold_days = 90 # 休眠状态判定天数(超过此天数未访问的记忆进入休眠状态)
|
||||
|
||||
# 统一存储配置 (新增)
|
||||
unified_storage_path = "data/unified_memory" # 统一存储数据路径
|
||||
unified_storage_cache_limit = 10000 # 内存缓存大小限制
|
||||
unified_storage_auto_save_interval = 50 # 自动保存间隔(记忆条数)
|
||||
unified_storage_enable_compression = true # 是否启用数据压缩
|
||||
# 统一存储配置 (已弃用 - 请使用Vector DB配置)
|
||||
# DEPRECATED: unified_storage_path = "data/unified_memory"
|
||||
# DEPRECATED: unified_storage_cache_limit = 10000
|
||||
# DEPRECATED: unified_storage_auto_save_interval = 50
|
||||
# DEPRECATED: unified_storage_enable_compression = true
|
||||
|
||||
# Vector DB存储配置 (新增 - 替代JSON存储)
|
||||
enable_vector_memory_storage = true # 启用Vector DB存储
|
||||
enable_llm_instant_memory = true # 启用基于LLM的瞬时记忆
|
||||
enable_vector_instant_memory = true # 启用基于向量的瞬时记忆
|
||||
|
||||
# Vector DB配置
|
||||
vector_db_memory_collection = "unified_memory_v2" # Vector DB主记忆集合名称
|
||||
vector_db_metadata_collection = "memory_metadata_v2" # Vector DB元数据集合名称
|
||||
vector_db_similarity_threshold = 0.5 # Vector DB相似度阈值 (推荐范围: 0.5-0.6, 过高会导致检索不到结果)
|
||||
vector_db_search_limit = 20 # Vector DB单次搜索返回的最大结果数
|
||||
vector_db_batch_size = 100 # 批处理大小 (批量存储记忆时每批处理的记忆条数)
|
||||
vector_db_enable_caching = true # 启用内存缓存
|
||||
vector_db_cache_size_limit = 1000 # 缓存大小限制 (内存缓存最多保存的记忆条数)
|
||||
vector_db_auto_cleanup_interval = 3600 # 自动清理间隔(秒)
|
||||
vector_db_retention_hours = 720 # 记忆保留时间(小时,默认30天)
|
||||
|
||||
# 多阶段召回配置(可选)
|
||||
# 取消注释以启用更严格的粗筛,适用于大规模记忆库(>10万条)
|
||||
# memory_importance_threshold = 0.3 # 重要性阈值(过滤低价值记忆,范围0.0-1.0)
|
||||
# memory_recency_days = 30 # 时间范围(只搜索最近N天的记忆,0表示不限制)
|
||||
|
||||
# Vector DB配置 (ChromaDB)
|
||||
[vector_db]
|
||||
type = "chromadb" # Vector DB类型
|
||||
path = "data/chroma_db" # Vector DB数据路径
|
||||
|
||||
[vector_db.settings]
|
||||
anonymized_telemetry = false # 禁用匿名遥测
|
||||
allow_reset = true # 允许重置
|
||||
|
||||
[vector_db.collections]
|
||||
unified_memory_v2 = { description = "统一记忆存储V2", hnsw_space = "cosine", version = "2.0" }
|
||||
memory_metadata_v2 = { description = "记忆元数据索引", hnsw_space = "cosine", version = "2.0" }
|
||||
semantic_cache = { description = "语义缓存", hnsw_space = "cosine" }
|
||||
|
||||
[voice]
|
||||
enable_asr = true # 是否启用语音识别,启用后MoFox-Bot可以识别语音消息,启用该功能需要配置语音识别模型[model.voice]
|
||||
|
||||
Reference in New Issue
Block a user