refactor(memory): 移除废弃的记忆系统备份文件,优化消息管理器架构
移除了deprecated_backup目录下的所有废弃记忆系统文件,包括增强记忆适配器、钩子、集成层、重排序器、元数据索引、多阶段检索和向量存储等模块。同时优化了消息管理器,集成了批量数据库写入器、流缓存管理器和自适应流管理器,提升了系统性能和可维护性。
This commit is contained in:
@@ -1,363 +0,0 @@
|
||||
"""
|
||||
增强记忆系统适配器
|
||||
将增强记忆系统集成到现有MoFox Bot架构中
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
MEMORY_TYPE_LABELS = {
|
||||
MemoryType.PERSONAL_FACT: "个人事实",
|
||||
MemoryType.EVENT: "事件",
|
||||
MemoryType.PREFERENCE: "偏好",
|
||||
MemoryType.OPINION: "观点",
|
||||
MemoryType.RELATIONSHIP: "关系",
|
||||
MemoryType.EMOTION: "情感",
|
||||
MemoryType.KNOWLEDGE: "知识",
|
||||
MemoryType.SKILL: "技能",
|
||||
MemoryType.GOAL: "目标",
|
||||
MemoryType.EXPERIENCE: "经验",
|
||||
MemoryType.CONTEXTUAL: "上下文",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig:
|
||||
"""适配器配置"""
|
||||
|
||||
enable_enhanced_memory: bool = True
|
||||
integration_mode: str = "enhanced_only" # replace, enhanced_only
|
||||
auto_migration: bool = True
|
||||
memory_value_threshold: float = 0.6
|
||||
fusion_threshold: float = 0.85
|
||||
max_retrieval_results: int = 10
|
||||
|
||||
|
||||
class EnhancedMemoryAdapter:
|
||||
"""增强记忆系统适配器"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or AdapterConfig()
|
||||
self.integration_layer: MemoryIntegrationLayer | None = None
|
||||
self._initialized = False
|
||||
|
||||
# 统计信息
|
||||
self.adapter_stats = {
|
||||
"total_processed": 0,
|
||||
"enhanced_used": 0,
|
||||
"legacy_used": 0,
|
||||
"hybrid_used": 0,
|
||||
"memories_created": 0,
|
||||
"memories_retrieved": 0,
|
||||
"average_processing_time": 0.0,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化适配器"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🚀 初始化增强记忆系统适配器...")
|
||||
|
||||
# 转换配置格式
|
||||
integration_config = IntegrationConfig(
|
||||
mode=IntegrationMode(self.config.integration_mode),
|
||||
enable_enhanced_memory=self.config.enable_enhanced_memory,
|
||||
memory_value_threshold=self.config.memory_value_threshold,
|
||||
fusion_threshold=self.config.fusion_threshold,
|
||||
max_retrieval_results=self.config.max_retrieval_results,
|
||||
enable_learning=True, # 启用学习功能
|
||||
)
|
||||
|
||||
# 创建集成层
|
||||
self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config)
|
||||
|
||||
# 初始化集成层
|
||||
await self.integration_layer.initialize()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ 增强记忆系统适配器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True)
|
||||
# 如果初始化失败,禁用增强记忆功能
|
||||
self.config.enable_enhanced_memory = False
|
||||
|
||||
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""处理对话记忆,以上下文为唯一输入"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"success": False, "error": "Enhanced memory not available"}
|
||||
|
||||
start_time = time.time()
|
||||
self.adapter_stats["total_processed"] += 1
|
||||
|
||||
try:
|
||||
payload_context: dict[str, Any] = dict(context or {})
|
||||
|
||||
conversation_text = payload_context.get("conversation_text")
|
||||
if not conversation_text:
|
||||
conversation_candidate = (
|
||||
payload_context.get("message_content")
|
||||
or payload_context.get("latest_message")
|
||||
or payload_context.get("raw_text")
|
||||
)
|
||||
if conversation_candidate is not None:
|
||||
conversation_text = str(conversation_candidate)
|
||||
payload_context["conversation_text"] = conversation_text
|
||||
else:
|
||||
conversation_text = ""
|
||||
else:
|
||||
conversation_text = str(conversation_text)
|
||||
|
||||
if "timestamp" not in payload_context:
|
||||
payload_context["timestamp"] = time.time()
|
||||
|
||||
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
|
||||
|
||||
# 使用集成层处理对话
|
||||
result = await self.integration_layer.process_conversation(payload_context)
|
||||
|
||||
# 更新统计
|
||||
processing_time = time.time() - start_time
|
||||
self._update_processing_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
created_count = len(result.get("created_memories", []))
|
||||
self.adapter_stats["memories_created"] += created_count
|
||||
logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return []
|
||||
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit)
|
||||
|
||||
self.adapter_stats["memories_retrieved"] += len(memories)
|
||||
logger.debug(f"检索到 {len(memories)} 条相关记忆")
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
|
||||
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
# 使用新的记忆格式化器
|
||||
formatter_config = FormatterConfig(
|
||||
include_timestamps=True,
|
||||
include_memory_types=True,
|
||||
include_confidence=False,
|
||||
use_emoji_icons=True,
|
||||
group_by_type=False,
|
||||
max_display_length=150,
|
||||
)
|
||||
|
||||
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
|
||||
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
|
||||
"""获取增强记忆系统摘要"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"available": False, "reason": "Not initialized or disabled"}
|
||||
|
||||
try:
|
||||
# 获取系统状态
|
||||
status = await self.integration_layer.get_system_status()
|
||||
|
||||
# 获取适配器统计
|
||||
adapter_stats = self.adapter_stats.copy()
|
||||
|
||||
# 获取集成统计
|
||||
integration_stats = self.integration_layer.get_integration_stats()
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"system_status": status,
|
||||
"adapter_stats": adapter_stats,
|
||||
"integration_stats": integration_stats,
|
||||
"total_memories_created": adapter_stats["memories_created"],
|
||||
"total_memories_retrieved": adapter_stats["memories_retrieved"],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True)
|
||||
return {"available": False, "error": str(e)}
|
||||
|
||||
def _update_processing_stats(self, processing_time: float):
|
||||
"""更新处理统计"""
|
||||
total_processed = self.adapter_stats["total_processed"]
|
||||
if total_processed > 0:
|
||||
current_avg = self.adapter_stats["average_processing_time"]
|
||||
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
|
||||
self.adapter_stats["average_processing_time"] = new_avg
|
||||
|
||||
def get_adapter_stats(self) -> dict[str, Any]:
|
||||
"""获取适配器统计信息"""
|
||||
return self.adapter_stats.copy()
|
||||
|
||||
async def maintenance(self):
|
||||
"""维护操作"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔧 增强记忆系统适配器维护...")
|
||||
await self.integration_layer.maintenance()
|
||||
logger.info("✅ 增强记忆系统适配器维护完成")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭适配器"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔄 关闭增强记忆系统适配器...")
|
||||
await self.integration_layer.shutdown()
|
||||
self._initialized = False
|
||||
logger.info("✅ 增强记忆系统适配器已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 全局适配器实例
|
||||
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
|
||||
|
||||
|
||||
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
|
||||
"""获取全局增强记忆适配器实例"""
|
||||
global _enhanced_memory_adapter
|
||||
|
||||
if _enhanced_memory_adapter is None:
|
||||
# 从配置中获取适配器配置
|
||||
from src.config.config import global_config
|
||||
|
||||
adapter_config = AdapterConfig(
|
||||
enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True),
|
||||
integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"),
|
||||
auto_migration=getattr(global_config.memory, "enable_memory_migration", True),
|
||||
memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6),
|
||||
fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85),
|
||||
max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10),
|
||||
)
|
||||
|
||||
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
|
||||
await _enhanced_memory_adapter.initialize()
|
||||
|
||||
return _enhanced_memory_adapter
|
||||
|
||||
|
||||
async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
||||
"""初始化增强记忆系统"""
|
||||
try:
|
||||
logger.info("🚀 初始化增强记忆系统...")
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
logger.info("✅ 增强记忆系统初始化完成")
|
||||
return adapter
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def process_conversation_with_enhanced_memory(
|
||||
context: dict[str, Any], llm_model: LLMRequest | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
payload_context = dict(context or {})
|
||||
|
||||
if "conversation_text" not in payload_context:
|
||||
conversation_candidate = (
|
||||
payload_context.get("message_content")
|
||||
or payload_context.get("latest_message")
|
||||
or payload_context.get("raw_text")
|
||||
)
|
||||
if conversation_candidate is not None:
|
||||
payload_context["conversation_text"] = str(conversation_candidate)
|
||||
|
||||
return await adapter.process_conversation_memory(payload_context)
|
||||
except Exception as e:
|
||||
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def retrieve_memories_with_enhanced_system(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""使用增强记忆系统检索记忆"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
return await adapter.retrieve_relevant_memories(query, user_id, context, limit)
|
||||
except Exception as e:
|
||||
logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
max_memories: int = 5,
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
|
||||
return ""
|
||||
@@ -1,194 +0,0 @@
|
||||
"""
|
||||
增强记忆系统钩子
|
||||
用于在消息处理过程中自动构建和检索记忆
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EnhancedMemoryHooks:
|
||||
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory
|
||||
self.processed_messages = set() # 避免重复处理
|
||||
|
||||
async def process_message_for_memory(
|
||||
self,
|
||||
message_content: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理消息并构建记忆
|
||||
|
||||
Args:
|
||||
message_content: 消息内容
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
message_id: 消息ID
|
||||
context: 上下文信息
|
||||
|
||||
Returns:
|
||||
bool: 是否成功处理
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
if message_id in self.processed_messages:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 确保增强记忆管理器已初始化
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
# 注入机器人基础人设,帮助记忆构建时避免记录自身信息
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
personality_config = getattr(global_config, "personality", None)
|
||||
bot_context = {}
|
||||
if bot_config is not None:
|
||||
bot_context["bot_name"] = getattr(bot_config, "nickname", None)
|
||||
bot_context["bot_aliases"] = list(getattr(bot_config, "alias_names", []) or [])
|
||||
bot_context["bot_account"] = getattr(bot_config, "qq_account", None)
|
||||
|
||||
if personality_config is not None:
|
||||
bot_context["bot_identity"] = getattr(personality_config, "identity", None)
|
||||
bot_context["bot_personality"] = getattr(personality_config, "personality_core", None)
|
||||
bot_context["bot_personality_side"] = getattr(personality_config, "personality_side", None)
|
||||
|
||||
# 构建上下文
|
||||
memory_context = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"timestamp": datetime.now().timestamp(),
|
||||
"message_type": "user_message",
|
||||
**bot_context,
|
||||
**(context or {}),
|
||||
}
|
||||
|
||||
# 处理对话并构建记忆
|
||||
memory_chunks = await enhanced_memory_manager.process_conversation(
|
||||
conversation_text=message_content,
|
||||
context=memory_context,
|
||||
user_id=user_id,
|
||||
timestamp=memory_context["timestamp"],
|
||||
)
|
||||
|
||||
# 标记消息已处理
|
||||
self.processed_messages.add(message_id)
|
||||
|
||||
# 限制处理历史大小
|
||||
if len(self.processed_messages) > 1000:
|
||||
# 移除最旧的500个记录
|
||||
self.processed_messages = set(list(self.processed_messages)[-500:])
|
||||
|
||||
logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆")
|
||||
return len(memory_chunks) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记忆失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_memory_for_response(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5,
|
||||
extra_context: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回记忆数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 相关记忆列表
|
||||
"""
|
||||
if not self.enabled:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 确保增强记忆管理器已初始化
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
# 构建查询上下文
|
||||
context = {
|
||||
"chat_id": chat_id,
|
||||
"query_intent": "response_generation",
|
||||
"expected_memory_types": ["personal_fact", "event", "preference", "opinion"],
|
||||
}
|
||||
|
||||
if extra_context:
|
||||
context.update(extra_context)
|
||||
|
||||
# 获取相关记忆
|
||||
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
|
||||
query_text=query_text, user_id=user_id, context=context, limit=limit
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
results = []
|
||||
for result in enhanced_results:
|
||||
memory_dict = {
|
||||
"content": result.content,
|
||||
"type": result.memory_type,
|
||||
"confidence": result.confidence,
|
||||
"importance": result.importance,
|
||||
"timestamp": result.timestamp,
|
||||
"source": result.source,
|
||||
"relevance": result.relevance_score,
|
||||
"structure": result.structure,
|
||||
}
|
||||
results.append(memory_dict)
|
||||
|
||||
logger.debug(f"为回复查询到 {len(results)} 条相关记忆")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取回复记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def cleanup_old_memories(self):
|
||||
"""清理旧记忆"""
|
||||
try:
|
||||
if enhanced_memory_manager.is_initialized:
|
||||
# 调用增强记忆系统的维护功能
|
||||
await enhanced_memory_manager.enhanced_system.maintenance()
|
||||
logger.debug("增强记忆系统维护完成")
|
||||
except Exception as e:
|
||||
logger.error(f"清理旧记忆失败: {e}")
|
||||
|
||||
def clear_processed_cache(self):
|
||||
"""清除已处理消息的缓存"""
|
||||
self.processed_messages.clear()
|
||||
logger.debug("已清除消息处理缓存")
|
||||
|
||||
def enable(self):
|
||||
"""启用记忆钩子"""
|
||||
self.enabled = True
|
||||
logger.info("增强记忆钩子已启用")
|
||||
|
||||
def disable(self):
|
||||
"""禁用记忆钩子"""
|
||||
self.enabled = False
|
||||
logger.info("增强记忆钩子已禁用")
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
enhanced_memory_hooks = EnhancedMemoryHooks()
|
||||
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
增强记忆系统集成脚本
|
||||
用于在现有系统中无缝集成增强记忆功能
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def process_user_message_memory(
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
处理用户消息并构建记忆
|
||||
|
||||
Args:
|
||||
message_content: 消息内容
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
message_id: 消息ID
|
||||
context: 额外的上下文信息
|
||||
|
||||
Returns:
|
||||
bool: 是否成功构建记忆
|
||||
"""
|
||||
try:
|
||||
success = await enhanced_memory_hooks.process_message_for_memory(
|
||||
message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"成功为消息 {message_id} 构建记忆")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户消息记忆失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_relevant_memories_for_response(
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
Args:
|
||||
query_text: 查询文本(通常是用户的当前消息)
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回记忆数量限制
|
||||
extra_context: 额外上下文信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含记忆信息的字典
|
||||
"""
|
||||
try:
|
||||
memories = await enhanced_memory_hooks.get_memory_for_response(
|
||||
query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context
|
||||
)
|
||||
|
||||
result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)}
|
||||
|
||||
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取回复记忆失败: {e}")
|
||||
return {"has_memories": False, "memories": [], "memory_count": 0}
|
||||
|
||||
|
||||
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化记忆信息用于Prompt
|
||||
|
||||
Args:
|
||||
memories: 记忆信息字典
|
||||
|
||||
Returns:
|
||||
str: 格式化后的记忆文本
|
||||
"""
|
||||
if not memories["has_memories"]:
|
||||
return ""
|
||||
|
||||
memory_lines = ["以下是相关的记忆信息:"]
|
||||
|
||||
for memory in memories["memories"]:
|
||||
content = memory["content"]
|
||||
memory_type = memory["type"]
|
||||
confidence = memory["confidence"]
|
||||
importance = memory["importance"]
|
||||
|
||||
# 根据重要性添加不同的标记
|
||||
importance_marker = "🔥" if importance >= 3 else "⭐" if importance >= 2 else "📝"
|
||||
confidence_marker = "✅" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭"
|
||||
|
||||
memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)"
|
||||
memory_lines.append(memory_line)
|
||||
|
||||
return "\n".join(memory_lines)
|
||||
|
||||
|
||||
async def cleanup_memory_system():
|
||||
"""清理记忆系统"""
|
||||
try:
|
||||
await enhanced_memory_hooks.cleanup_old_memories()
|
||||
logger.info("记忆系统清理完成")
|
||||
except Exception as e:
|
||||
logger.error(f"记忆系统清理失败: {e}")
|
||||
|
||||
|
||||
def get_memory_system_status() -> dict[str, Any]:
|
||||
"""
|
||||
获取记忆系统状态
|
||||
|
||||
Returns:
|
||||
Dict: 系统状态信息
|
||||
"""
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
return {
|
||||
"enabled": enhanced_memory_hooks.enabled,
|
||||
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
|
||||
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
|
||||
"system_type": "enhanced_memory_system",
|
||||
}
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def remember_message(
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
便捷的记忆构建函数
|
||||
|
||||
Args:
|
||||
message: 要记住的消息
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
import uuid
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
return await process_user_message_memory(
|
||||
message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
|
||||
)
|
||||
|
||||
|
||||
async def recall_memories(
|
||||
query: str,
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
limit: int = 5,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
便捷的记忆检索函数
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
Dict: 记忆信息
|
||||
"""
|
||||
return await get_relevant_memories_for_response(
|
||||
query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context
|
||||
)
|
||||
@@ -1,356 +0,0 @@
|
||||
"""
|
||||
增强重排序器
|
||||
实现文档设计的多维度评分模型
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class IntentType(Enum):
|
||||
"""对话意图类型"""
|
||||
|
||||
FACT_QUERY = "fact_query" # 事实查询
|
||||
EVENT_RECALL = "event_recall" # 事件回忆
|
||||
PREFERENCE_CHECK = "preference_check" # 偏好检查
|
||||
GENERAL_CHAT = "general_chat" # 一般对话
|
||||
UNKNOWN = "unknown" # 未知意图
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReRankingConfig:
|
||||
"""重排序配置"""
|
||||
|
||||
# 权重配置 (w1 + w2 + w3 + w4 = 1.0)
|
||||
semantic_weight: float = 0.5 # 语义相似度权重
|
||||
recency_weight: float = 0.2 # 时效性权重
|
||||
usage_freq_weight: float = 0.2 # 使用频率权重
|
||||
type_match_weight: float = 0.1 # 类型匹配权重
|
||||
|
||||
# 时效性衰减参数
|
||||
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
|
||||
|
||||
# 使用频率计算参数
|
||||
freq_log_base: float = 2.0 # 对数底数
|
||||
freq_max_score: float = 5.0 # 最大频率得分
|
||||
|
||||
# 类型匹配权重映射
|
||||
type_match_weights: dict[str, dict[str, float]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化类型匹配权重"""
|
||||
if self.type_match_weights is None:
|
||||
self.type_match_weights = {
|
||||
IntentType.FACT_QUERY.value: {
|
||||
MemoryType.PERSONAL_FACT.value: 1.0,
|
||||
MemoryType.KNOWLEDGE.value: 0.8,
|
||||
MemoryType.PREFERENCE.value: 0.5,
|
||||
MemoryType.EVENT.value: 0.3,
|
||||
"default": 0.3,
|
||||
},
|
||||
IntentType.EVENT_RECALL.value: {
|
||||
MemoryType.EVENT.value: 1.0,
|
||||
MemoryType.EXPERIENCE.value: 0.8,
|
||||
MemoryType.EMOTION.value: 0.6,
|
||||
MemoryType.PERSONAL_FACT.value: 0.5,
|
||||
"default": 0.5,
|
||||
},
|
||||
IntentType.PREFERENCE_CHECK.value: {
|
||||
MemoryType.PREFERENCE.value: 1.0,
|
||||
MemoryType.OPINION.value: 0.8,
|
||||
MemoryType.GOAL.value: 0.6,
|
||||
MemoryType.PERSONAL_FACT.value: 0.4,
|
||||
"default": 0.4,
|
||||
},
|
||||
IntentType.GENERAL_CHAT.value: {"default": 0.8},
|
||||
IntentType.UNKNOWN.value: {"default": 0.8},
|
||||
}
|
||||
|
||||
|
||||
class IntentClassifier:
|
||||
"""轻量级意图识别器"""
|
||||
|
||||
def __init__(self):
|
||||
# 关键词模式匹配规则
|
||||
self.patterns = {
|
||||
IntentType.FACT_QUERY: [
|
||||
# 中文模式
|
||||
"我是",
|
||||
"我的",
|
||||
"我叫",
|
||||
"我在",
|
||||
"我住在",
|
||||
"我的职业",
|
||||
"我的工作",
|
||||
"什么时候",
|
||||
"在哪里",
|
||||
"是什么",
|
||||
"多少",
|
||||
"几岁",
|
||||
"年龄",
|
||||
# 英文模式
|
||||
"what is",
|
||||
"where is",
|
||||
"when is",
|
||||
"how old",
|
||||
"my name",
|
||||
"i am",
|
||||
"i live",
|
||||
],
|
||||
IntentType.EVENT_RECALL: [
|
||||
# 中文模式
|
||||
"记得",
|
||||
"想起",
|
||||
"还记得",
|
||||
"那次",
|
||||
"上次",
|
||||
"之前",
|
||||
"以前",
|
||||
"曾经",
|
||||
"发生过",
|
||||
"经历",
|
||||
"做过",
|
||||
"去过",
|
||||
"见过",
|
||||
# 英文模式
|
||||
"remember",
|
||||
"recall",
|
||||
"last time",
|
||||
"before",
|
||||
"previously",
|
||||
"happened",
|
||||
"experience",
|
||||
],
|
||||
IntentType.PREFERENCE_CHECK: [
|
||||
# 中文模式
|
||||
"喜欢",
|
||||
"不喜欢",
|
||||
"偏好",
|
||||
"爱好",
|
||||
"兴趣",
|
||||
"讨厌",
|
||||
"最爱",
|
||||
"最喜欢",
|
||||
"习惯",
|
||||
"通常",
|
||||
"一般",
|
||||
"倾向于",
|
||||
"更喜欢",
|
||||
# 英文模式
|
||||
"like",
|
||||
"love",
|
||||
"hate",
|
||||
"prefer",
|
||||
"favorite",
|
||||
"usually",
|
||||
"tend to",
|
||||
"interest",
|
||||
],
|
||||
}
|
||||
|
||||
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
|
||||
"""识别对话意图"""
|
||||
if not query:
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
query_lower = query.lower()
|
||||
|
||||
# 统计各意图的匹配分数
|
||||
intent_scores = dict.fromkeys(IntentType, 0)
|
||||
|
||||
for intent, patterns in self.patterns.items():
|
||||
for pattern in patterns:
|
||||
if pattern in query_lower:
|
||||
intent_scores[intent] += 1
|
||||
|
||||
# 返回得分最高的意图
|
||||
max_score = max(intent_scores.values())
|
||||
if max_score == 0:
|
||||
return IntentType.GENERAL_CHAT
|
||||
|
||||
for intent, score in intent_scores.items():
|
||||
if score == max_score:
|
||||
return intent
|
||||
|
||||
return IntentType.GENERAL_CHAT
|
||||
|
||||
|
||||
class EnhancedReRanker:
|
||||
"""增强重排序器 - 实现文档设计的多维度评分模型"""
|
||||
|
||||
def __init__(self, config: ReRankingConfig | None = None):
|
||||
self.config = config or ReRankingConfig()
|
||||
self.intent_classifier = IntentClassifier()
|
||||
|
||||
# 验证权重和为1.0
|
||||
total_weight = (
|
||||
self.config.semantic_weight
|
||||
+ self.config.recency_weight
|
||||
+ self.config.usage_freq_weight
|
||||
+ self.config.type_match_weight
|
||||
)
|
||||
|
||||
if abs(total_weight - 1.0) > 0.01:
|
||||
logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化")
|
||||
# 归一化权重
|
||||
self.config.semantic_weight /= total_weight
|
||||
self.config.recency_weight /= total_weight
|
||||
self.config.usage_freq_weight /= total_weight
|
||||
self.config.type_match_weight /= total_weight
|
||||
|
||||
def rerank_memories(
|
||||
self,
|
||||
query: str,
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
对候选记忆进行重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)]
|
||||
context: 上下文信息
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
重排序后的记忆列表 [(memory_id, memory, final_score)]
|
||||
"""
|
||||
if not candidate_memories:
|
||||
return []
|
||||
|
||||
# 识别查询意图
|
||||
intent = self.intent_classifier.classify_intent(query, context)
|
||||
logger.debug(f"识别到查询意图: {intent.value}")
|
||||
|
||||
# 计算每个候选记忆的最终得分
|
||||
scored_memories = []
|
||||
current_time = time.time()
|
||||
|
||||
for memory_id, memory, vector_sim in candidate_memories:
|
||||
try:
|
||||
# 1. 语义相似度得分 (已归一化到[0,1])
|
||||
semantic_score = self._normalize_similarity(vector_sim)
|
||||
|
||||
# 2. 时效性得分
|
||||
recency_score = self._calculate_recency_score(memory, current_time)
|
||||
|
||||
# 3. 使用频率得分
|
||||
usage_freq_score = self._calculate_usage_frequency_score(memory)
|
||||
|
||||
# 4. 类型匹配得分
|
||||
type_match_score = self._calculate_type_match_score(memory, intent)
|
||||
|
||||
# 计算最终得分
|
||||
final_score = (
|
||||
self.config.semantic_weight * semantic_score
|
||||
+ self.config.recency_weight * recency_score
|
||||
+ self.config.usage_freq_weight * usage_freq_score
|
||||
+ self.config.type_match_weight * type_match_score
|
||||
)
|
||||
|
||||
scored_memories.append((memory_id, memory, final_score))
|
||||
|
||||
# 记录调试信息
|
||||
logger.debug(
|
||||
f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, "
|
||||
f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, "
|
||||
f"type={type_match_score:.3f}, final={final_score:.3f}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算记忆 {memory_id} 得分时出错: {e}")
|
||||
# 使用向量相似度作为后备得分
|
||||
scored_memories.append((memory_id, memory, vector_sim))
|
||||
|
||||
# 按最终得分降序排序
|
||||
scored_memories.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# 返回前N个结果
|
||||
result = scored_memories[:limit]
|
||||
|
||||
highest_score = result[0][2] if result else 0.0
|
||||
logger.info(
|
||||
f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, "
|
||||
f"意图={intent.value}, 最高分={highest_score:.3f}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _normalize_similarity(self, raw_similarity: float) -> float:
|
||||
"""归一化相似度到[0,1]区间"""
|
||||
# 假设原始相似度已经在[-1,1]或[0,1]区间
|
||||
if raw_similarity < 0:
|
||||
return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1]
|
||||
return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间
|
||||
|
||||
def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float:
|
||||
"""
|
||||
计算时效性得分
|
||||
公式: Recency = 1 / (1 + decay_rate * days_old)
|
||||
"""
|
||||
last_accessed = memory.metadata.last_accessed or memory.metadata.created_at
|
||||
days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数
|
||||
|
||||
if days_old < 0:
|
||||
days_old = 0 # 处理时间异常
|
||||
|
||||
score = 1 / (1 + self.config.recency_decay_rate * days_old)
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float:
|
||||
"""
|
||||
计算使用频率得分
|
||||
公式: Usage_Freq = min(1.0, log2(access_count + 1) / max_score)
|
||||
"""
|
||||
access_count = memory.metadata.access_count
|
||||
if access_count <= 0:
|
||||
return 0.0
|
||||
|
||||
log_count = math.log2(access_count + 1)
|
||||
score = log_count / self.config.freq_max_score
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float:
|
||||
"""计算类型匹配得分"""
|
||||
memory_type = memory.memory_type.value
|
||||
intent_value = intent.value
|
||||
|
||||
# 获取对应意图的类型权重映射
|
||||
type_weights = self.config.type_match_weights.get(intent_value, {})
|
||||
|
||||
# 查找具体类型的权重,如果没有则使用默认权重
|
||||
score = type_weights.get(memory_type, type_weights.get("default", 0.8))
|
||||
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
|
||||
# 创建默认的重排序器实例
|
||||
default_reranker = EnhancedReRanker()
|
||||
|
||||
|
||||
def rerank_candidate_memories(
|
||||
query: str,
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]],
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
config: ReRankingConfig | None = None,
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
便捷函数:对候选记忆进行重排序
|
||||
"""
|
||||
if config:
|
||||
reranker = EnhancedReRanker(config)
|
||||
else:
|
||||
reranker = default_reranker
|
||||
|
||||
return reranker.rerank_memories(query, candidate_memories, context, limit)
|
||||
@@ -1,245 +0,0 @@
|
||||
"""
|
||||
增强记忆系统集成层
|
||||
现在只管理新的增强记忆系统,旧系统已被完全移除
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class IntegrationMode(Enum):
|
||||
"""集成模式"""
|
||||
|
||||
REPLACE = "replace" # 完全替换现有记忆系统
|
||||
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntegrationConfig:
|
||||
"""集成配置"""
|
||||
|
||||
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
|
||||
enable_enhanced_memory: bool = True
|
||||
memory_value_threshold: float = 0.6
|
||||
fusion_threshold: float = 0.85
|
||||
max_retrieval_results: int = 10
|
||||
enable_learning: bool = True
|
||||
|
||||
|
||||
class MemoryIntegrationLayer:
|
||||
"""记忆系统集成层 - 现在只管理增强记忆系统"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or IntegrationConfig()
|
||||
|
||||
# 只初始化增强记忆系统
|
||||
self.enhanced_memory: EnhancedMemorySystem | None = None
|
||||
|
||||
# 集成统计
|
||||
self.integration_stats = {
|
||||
"total_queries": 0,
|
||||
"enhanced_queries": 0,
|
||||
"memory_creations": 0,
|
||||
"average_response_time": 0.0,
|
||||
"success_rate": 0.0,
|
||||
}
|
||||
|
||||
# 初始化锁
|
||||
self._initialization_lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化集成层"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._initialization_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("🚀 开始初始化增强记忆系统集成层...")
|
||||
|
||||
try:
|
||||
# 初始化增强记忆系统
|
||||
if self.config.enable_enhanced_memory:
|
||||
await self._initialize_enhanced_memory()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ 增强记忆系统集成层初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _initialize_enhanced_memory(self):
|
||||
"""初始化增强记忆系统"""
|
||||
try:
|
||||
logger.debug("初始化增强记忆系统...")
|
||||
|
||||
# 创建增强记忆系统配置
|
||||
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
|
||||
|
||||
memory_config = MemorySystemConfig.from_global_config()
|
||||
|
||||
# 使用集成配置覆盖部分值
|
||||
memory_config.memory_value_threshold = self.config.memory_value_threshold
|
||||
memory_config.fusion_similarity_threshold = self.config.fusion_threshold
|
||||
memory_config.final_recall_limit = self.config.max_retrieval_results
|
||||
|
||||
# 创建增强记忆系统
|
||||
self.enhanced_memory = EnhancedMemorySystem(config=memory_config)
|
||||
|
||||
# 如果外部提供了LLM模型,注入到系统中
|
||||
if self.llm_model is not None:
|
||||
self.enhanced_memory.llm_model = self.llm_model
|
||||
|
||||
# 初始化系统
|
||||
await self.enhanced_memory.initialize()
|
||||
logger.info("✅ 增强记忆系统初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""处理对话记忆,仅使用上下文信息"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return {"success": False, "error": "Memory system not available"}
|
||||
|
||||
start_time = time.time()
|
||||
self.integration_stats["total_queries"] += 1
|
||||
self.integration_stats["enhanced_queries"] += 1
|
||||
|
||||
try:
|
||||
payload_context = dict(context or {})
|
||||
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
|
||||
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
|
||||
|
||||
# 直接使用增强记忆系统处理
|
||||
result = await self.enhanced_memory.process_conversation_memory(payload_context)
|
||||
|
||||
# 更新统计
|
||||
processing_time = time.time() - start_time
|
||||
self._update_response_stats(processing_time, result.get("success", False))
|
||||
|
||||
if result.get("success"):
|
||||
created_count = len(result.get("created_memories", []))
|
||||
self.integration_stats["memory_creations"] += created_count
|
||||
logger.debug(f"对话处理完成,创建 {created_count} 条记忆")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
self._update_response_stats(processing_time, False)
|
||||
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return []
|
||||
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.enhanced_memory.retrieve_relevant_memories(
|
||||
query=query, user_id=None, context=context or {}, limit=limit
|
||||
)
|
||||
|
||||
memory_count = len(memories)
|
||||
logger.debug(f"检索到 {memory_count} 条相关记忆")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_system_status(self) -> dict[str, Any]:
|
||||
"""获取系统状态"""
|
||||
if not self._initialized:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
try:
|
||||
enhanced_status = {}
|
||||
if self.enhanced_memory:
|
||||
enhanced_status = await self.enhanced_memory.get_system_status()
|
||||
|
||||
return {
|
||||
"status": "initialized",
|
||||
"mode": self.config.mode.value,
|
||||
"enhanced_memory": enhanced_status,
|
||||
"integration_stats": self.integration_stats.copy(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取系统状态失败: {e}", exc_info=True)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_integration_stats(self) -> dict[str, Any]:
|
||||
"""获取集成统计信息"""
|
||||
return self.integration_stats.copy()
|
||||
|
||||
def _update_response_stats(self, processing_time: float, success: bool):
|
||||
"""更新响应统计"""
|
||||
total_queries = self.integration_stats["total_queries"]
|
||||
if total_queries > 0:
|
||||
# 更新平均响应时间
|
||||
current_avg = self.integration_stats["average_response_time"]
|
||||
new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries
|
||||
self.integration_stats["average_response_time"] = new_avg
|
||||
|
||||
# 更新成功率
|
||||
if success:
|
||||
current_success_rate = self.integration_stats["success_rate"]
|
||||
new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries
|
||||
self.integration_stats["success_rate"] = new_success_rate
|
||||
|
||||
async def maintenance(self):
|
||||
"""执行维护操作"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔧 执行记忆系统集成层维护...")
|
||||
|
||||
if self.enhanced_memory:
|
||||
await self.enhanced_memory.maintenance()
|
||||
|
||||
logger.info("✅ 记忆系统集成层维护完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 集成层维护失败: {e}", exc_info=True)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭集成层"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔄 关闭记忆系统集成层...")
|
||||
|
||||
if self.enhanced_memory:
|
||||
await self.enhanced_memory.shutdown()
|
||||
|
||||
self._initialized = False
|
||||
logger.info("✅ 记忆系统集成层已关闭")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)
|
||||
@@ -1,526 +0,0 @@
|
||||
"""
|
||||
记忆系统集成钩子
|
||||
提供与现有MoFox Bot系统的无缝集成点
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_adapter import (
|
||||
get_memory_context_for_prompt,
|
||||
process_conversation_with_enhanced_memory,
|
||||
retrieve_memories_with_enhanced_system,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""钩子执行结果"""
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: str | None = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
|
||||
class MemoryIntegrationHooks:
|
||||
"""记忆系统集成钩子"""
|
||||
|
||||
def __init__(self):
|
||||
self.hooks_registered = False
|
||||
self.hook_stats = {
|
||||
"message_processing_hooks": 0,
|
||||
"memory_retrieval_hooks": 0,
|
||||
"prompt_enhancement_hooks": 0,
|
||||
"total_hook_executions": 0,
|
||||
"average_hook_time": 0.0,
|
||||
}
|
||||
|
||||
async def register_hooks(self):
|
||||
"""注册所有集成钩子"""
|
||||
if self.hooks_registered:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔗 注册记忆系统集成钩子...")
|
||||
|
||||
# 注册消息处理钩子
|
||||
await self._register_message_processing_hooks()
|
||||
|
||||
# 注册记忆检索钩子
|
||||
await self._register_memory_retrieval_hooks()
|
||||
|
||||
# 注册提示词增强钩子
|
||||
await self._register_prompt_enhancement_hooks()
|
||||
|
||||
# 注册系统维护钩子
|
||||
await self._register_maintenance_hooks()
|
||||
|
||||
self.hooks_registered = True
|
||||
logger.info("✅ 记忆系统集成钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True)
|
||||
|
||||
async def _register_message_processing_hooks(self):
|
||||
"""注册消息处理钩子"""
|
||||
try:
|
||||
# 钩子1: 在消息处理后创建记忆
|
||||
await self._register_post_message_hook()
|
||||
|
||||
# 钩子2: 在聊天流保存时处理记忆
|
||||
await self._register_chat_stream_hook()
|
||||
|
||||
logger.debug("消息处理钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册消息处理钩子失败: {e}")
|
||||
|
||||
async def _register_memory_retrieval_hooks(self):
|
||||
"""注册记忆检索钩子"""
|
||||
try:
|
||||
# 钩子1: 在生成回复前检索相关记忆
|
||||
await self._register_pre_response_hook()
|
||||
|
||||
# 钩子2: 在知识库查询前增强上下文
|
||||
await self._register_knowledge_query_hook()
|
||||
|
||||
logger.debug("记忆检索钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册记忆检索钩子失败: {e}")
|
||||
|
||||
async def _register_prompt_enhancement_hooks(self):
|
||||
"""注册提示词增强钩子"""
|
||||
try:
|
||||
# 钩子1: 增强提示词构建
|
||||
await self._register_prompt_building_hook()
|
||||
|
||||
logger.debug("提示词增强钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册提示词增强钩子失败: {e}")
|
||||
|
||||
async def _register_maintenance_hooks(self):
|
||||
"""注册系统维护钩子"""
|
||||
try:
|
||||
# 钩子1: 系统维护时的记忆系统维护
|
||||
await self._register_system_maintenance_hook()
|
||||
|
||||
logger.debug("系统维护钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册系统维护钩子失败: {e}")
|
||||
|
||||
async def _register_post_message_hook(self):
|
||||
"""注册消息后处理钩子"""
|
||||
try:
|
||||
# 这里需要根据实际的系统架构来注册钩子
|
||||
# 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整
|
||||
|
||||
# 尝试注册到事件系统
|
||||
try:
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
# 注册消息后处理事件
|
||||
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
|
||||
logger.debug("已注册到事件系统的消息处理钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("事件系统不可用,跳过事件钩子注册")
|
||||
|
||||
# 尝试注册到消息管理器
|
||||
try:
|
||||
from src.chat.message_manager import message_manager
|
||||
|
||||
# 如果消息管理器支持钩子注册
|
||||
if hasattr(message_manager, "register_post_process_hook"):
|
||||
message_manager.register_post_process_hook(self._on_message_processed_hook)
|
||||
logger.debug("已注册到消息管理器的处理钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("消息管理器不可用,跳过消息管理器钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册消息后处理钩子失败: {e}")
|
||||
|
||||
async def _register_chat_stream_hook(self):
|
||||
"""注册聊天流钩子"""
|
||||
try:
|
||||
# 尝试注册到聊天流管理器
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if hasattr(chat_manager, "register_save_hook"):
|
||||
chat_manager.register_save_hook(self._on_chat_stream_save_hook)
|
||||
logger.debug("已注册到聊天流管理器的保存钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("聊天流管理器不可用,跳过聊天流钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册聊天流钩子失败: {e}")
|
||||
|
||||
async def _register_pre_response_hook(self):
|
||||
"""注册回复前钩子"""
|
||||
try:
|
||||
# 尝试注册到回复生成器
|
||||
try:
|
||||
from src.chat.replyer.default_generator import default_generator
|
||||
|
||||
if hasattr(default_generator, "register_pre_generation_hook"):
|
||||
default_generator.register_pre_generation_hook(self._on_pre_response_hook)
|
||||
logger.debug("已注册到回复生成器的前置钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("回复生成器不可用,跳过回复前钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册回复前钩子失败: {e}")
|
||||
|
||||
async def _register_knowledge_query_hook(self):
|
||||
"""注册知识库查询钩子"""
|
||||
try:
|
||||
# 尝试注册到知识库系统
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import knowledge_manager
|
||||
|
||||
if hasattr(knowledge_manager, "register_query_enhancer"):
|
||||
knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook)
|
||||
logger.debug("已注册到知识库的查询增强钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("知识库系统不可用,跳过知识库钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册知识库查询钩子失败: {e}")
|
||||
|
||||
async def _register_prompt_building_hook(self):
|
||||
"""注册提示词构建钩子"""
|
||||
try:
|
||||
# 尝试注册到提示词系统
|
||||
try:
|
||||
from src.chat.utils.prompt import prompt_manager
|
||||
|
||||
if hasattr(prompt_manager, "register_enhancer"):
|
||||
prompt_manager.register_enhancer(self._on_prompt_building_hook)
|
||||
logger.debug("已注册到提示词管理器的增强钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("提示词系统不可用,跳过提示词钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册提示词构建钩子失败: {e}")
|
||||
|
||||
async def _register_system_maintenance_hook(self):
|
||||
"""注册系统维护钩子"""
|
||||
try:
|
||||
# 尝试注册到系统维护器
|
||||
try:
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
|
||||
# 注册定期维护任务
|
||||
async_task_manager.add_task(MemoryMaintenanceTask())
|
||||
logger.debug("已注册到系统维护器的定期任务")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("异步任务管理器不可用,跳过系统维护钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册系统维护钩子失败: {e}")
|
||||
|
||||
# 钩子处理器方法
|
||||
|
||||
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
|
||||
"""事件系统的消息处理处理器"""
|
||||
return await self._on_message_processed_hook(event_data)
|
||||
|
||||
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
|
||||
"""消息后处理钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["message_processing_hooks"] += 1
|
||||
|
||||
# 提取必要的信息
|
||||
message_info = message_data.get("message_info", {})
|
||||
user_info = message_info.get("user_info", {})
|
||||
conversation_text = message_data.get("processed_plain_text", "")
|
||||
|
||||
if not conversation_text:
|
||||
return HookResult(success=True, data="No conversation text")
|
||||
|
||||
user_id = str(user_info.get("user_id", "unknown"))
|
||||
context = {
|
||||
"chat_id": message_data.get("chat_id"),
|
||||
"message_type": message_data.get("message_type", "normal"),
|
||||
"platform": message_info.get("platform", "unknown"),
|
||||
"interest_value": message_data.get("interest_value", 0.0),
|
||||
"keywords": message_data.get("key_words", []),
|
||||
"timestamp": message_data.get("time", time.time()),
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
memory_context = dict(context)
|
||||
memory_context["conversation_text"] = conversation_text
|
||||
memory_context["user_id"] = user_id
|
||||
|
||||
result = await process_conversation_with_enhanced_memory(memory_context)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
|
||||
return HookResult(success=True, data=result, processing_time=processing_time)
|
||||
else:
|
||||
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
|
||||
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
|
||||
"""聊天流保存钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["message_processing_hooks"] += 1
|
||||
|
||||
# 从聊天流数据中提取对话信息
|
||||
stream_context = chat_stream_data.get("stream_context", {})
|
||||
user_id = stream_context.get("user_id", "unknown")
|
||||
messages = stream_context.get("messages", [])
|
||||
|
||||
if not messages:
|
||||
return HookResult(success=True, data="No messages to process")
|
||||
|
||||
# 构建对话文本
|
||||
conversation_parts = []
|
||||
for msg in messages[-10:]: # 只处理最近10条消息
|
||||
text = msg.get("processed_plain_text", "")
|
||||
if text:
|
||||
conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}")
|
||||
|
||||
conversation_text = "\n".join(conversation_parts)
|
||||
if not conversation_text:
|
||||
return HookResult(success=True, data="No conversation text")
|
||||
|
||||
context = {
|
||||
"chat_id": chat_stream_data.get("chat_id"),
|
||||
"stream_id": chat_stream_data.get("stream_id"),
|
||||
"platform": chat_stream_data.get("platform", "unknown"),
|
||||
"message_count": len(messages),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
memory_context = dict(context)
|
||||
memory_context["conversation_text"] = conversation_text
|
||||
memory_context["user_id"] = user_id
|
||||
|
||||
result = await process_conversation_with_enhanced_memory(memory_context)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
|
||||
return HookResult(success=True, data=result, processing_time=processing_time)
|
||||
else:
|
||||
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
|
||||
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
|
||||
"""回复前钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["memory_retrieval_hooks"] += 1
|
||||
|
||||
# 提取查询信息
|
||||
query = response_data.get("query", "")
|
||||
user_id = response_data.get("user_id", "unknown")
|
||||
context = response_data.get("context", {})
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 检索相关记忆
|
||||
memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 将记忆添加到响应数据中
|
||||
response_data["enhanced_memories"] = memories
|
||||
response_data["enhanced_memory_context"] = await get_memory_context_for_prompt(
|
||||
query, user_id, context, max_memories=5
|
||||
)
|
||||
|
||||
logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆")
|
||||
return HookResult(success=True, data=memories, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
|
||||
"""知识库查询钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["memory_retrieval_hooks"] += 1
|
||||
|
||||
query = query_data.get("query", "")
|
||||
user_id = query_data.get("user_id", "unknown")
|
||||
context = query_data.get("context", {})
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 获取记忆上下文并增强查询
|
||||
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 将记忆上下文添加到查询数据中
|
||||
query_data["enhanced_memory_context"] = memory_context
|
||||
|
||||
logger.debug("知识库查询钩子执行成功")
|
||||
return HookResult(success=True, data=memory_context, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
|
||||
"""提示词构建钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["prompt_enhancement_hooks"] += 1
|
||||
|
||||
query = prompt_data.get("query", "")
|
||||
user_id = prompt_data.get("user_id", "unknown")
|
||||
context = prompt_data.get("context", {})
|
||||
base_prompt = prompt_data.get("base_prompt", "")
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 获取记忆上下文
|
||||
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 构建增强的提示词
|
||||
enhanced_prompt = base_prompt
|
||||
if memory_context:
|
||||
enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n"
|
||||
|
||||
# 将增强的提示词添加到数据中
|
||||
prompt_data["enhanced_prompt"] = enhanced_prompt
|
||||
prompt_data["memory_context"] = memory_context
|
||||
|
||||
logger.debug("提示词构建钩子执行成功")
|
||||
return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
def _update_hook_stats(self, processing_time: float):
|
||||
"""更新钩子统计"""
|
||||
self.hook_stats["total_hook_executions"] += 1
|
||||
|
||||
total_executions = self.hook_stats["total_hook_executions"]
|
||||
if total_executions > 0:
|
||||
current_avg = self.hook_stats["average_hook_time"]
|
||||
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
|
||||
self.hook_stats["average_hook_time"] = new_avg
|
||||
|
||||
def get_hook_stats(self) -> dict[str, Any]:
|
||||
"""获取钩子统计信息"""
|
||||
return self.hook_stats.copy()
|
||||
|
||||
|
||||
class MemoryMaintenanceTask:
|
||||
"""记忆系统维护任务"""
|
||||
|
||||
def __init__(self):
|
||||
self.task_name = "enhanced_memory_maintenance"
|
||||
self.interval = 3600 # 1小时执行一次
|
||||
|
||||
async def execute(self):
|
||||
"""执行维护任务"""
|
||||
try:
|
||||
logger.info("🔧 执行增强记忆系统维护任务...")
|
||||
|
||||
# 获取适配器实例
|
||||
try:
|
||||
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
|
||||
|
||||
if _enhanced_memory_adapter:
|
||||
await _enhanced_memory_adapter.maintenance()
|
||||
logger.info("✅ 增强记忆系统维护任务完成")
|
||||
else:
|
||||
logger.debug("增强记忆适配器未初始化,跳过维护")
|
||||
except Exception as e:
|
||||
logger.error(f"增强记忆系统维护失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行维护任务时发生异常: {e}", exc_info=True)
|
||||
|
||||
def get_interval(self) -> int:
|
||||
"""获取执行间隔"""
|
||||
return self.interval
|
||||
|
||||
def get_task_name(self) -> str:
|
||||
"""获取任务名称"""
|
||||
return self.task_name
|
||||
|
||||
|
||||
# 全局钩子实例
|
||||
_memory_hooks: MemoryIntegrationHooks | None = None
|
||||
|
||||
|
||||
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
|
||||
"""获取全局记忆集成钩子实例"""
|
||||
global _memory_hooks
|
||||
|
||||
if _memory_hooks is None:
|
||||
_memory_hooks = MemoryIntegrationHooks()
|
||||
await _memory_hooks.register_hooks()
|
||||
|
||||
return _memory_hooks
|
||||
|
||||
|
||||
async def initialize_memory_integration_hooks():
|
||||
"""初始化记忆集成钩子"""
|
||||
try:
|
||||
logger.info("🚀 初始化记忆集成钩子...")
|
||||
hooks = await get_memory_integration_hooks()
|
||||
logger.info("✅ 记忆集成钩子初始化完成")
|
||||
return hooks
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,875 +0,0 @@
|
||||
"""
|
||||
向量数据库存储接口
|
||||
为记忆系统提供高效的向量存储和语义搜索能力
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 尝试导入FAISS,如果不可用则使用简单替代
|
||||
try:
|
||||
import faiss
|
||||
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
logger.warning("FAISS not available, using simple vector storage")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStorageConfig:
|
||||
"""向量存储配置"""
|
||||
|
||||
dimension: int = 1024
|
||||
similarity_threshold: float = 0.8
|
||||
index_type: str = "flat" # flat, ivf, hnsw
|
||||
max_index_size: int = 100000
|
||||
storage_path: str = "data/memory_vectors"
|
||||
auto_save_interval: int = 10 # 每N次操作自动保存
|
||||
enable_compression: bool = True
|
||||
|
||||
|
||||
class VectorStorageManager:
|
||||
"""向量存储管理器"""
|
||||
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
|
||||
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
|
||||
if resolved_dimension and resolved_dimension != self.config.dimension:
|
||||
logger.info(
|
||||
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
|
||||
resolved_dimension,
|
||||
self.config.dimension,
|
||||
)
|
||||
self.config.dimension = resolved_dimension
|
||||
self.storage_path = Path(self.config.storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 向量索引
|
||||
self.vector_index = None
|
||||
self.memory_id_to_index = {} # memory_id -> vector index
|
||||
self.index_to_memory_id = {} # vector index -> memory_id
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: dict[str, list[float]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.storage_stats = {
|
||||
"total_vectors": 0,
|
||||
"index_build_time": 0.0,
|
||||
"average_search_time": 0.0,
|
||||
"cache_hit_rate": 0.0,
|
||||
"total_searches": 0,
|
||||
"cache_hits": 0,
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
self._lock = threading.RLock()
|
||||
self._operation_count = 0
|
||||
|
||||
# 初始化索引
|
||||
self._initialize_index()
|
||||
|
||||
# 嵌入模型
|
||||
self.embedding_model: LLMRequest = None
|
||||
|
||||
def _initialize_index(self):
|
||||
"""初始化向量索引"""
|
||||
try:
|
||||
if FAISS_AVAILABLE:
|
||||
if self.config.index_type == "flat":
|
||||
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
|
||||
elif self.config.index_type == "ivf":
|
||||
quantizer = faiss.IndexFlatIP(self.config.dimension)
|
||||
nlist = min(100, max(1, self.config.max_index_size // 1000))
|
||||
self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist)
|
||||
elif self.config.index_type == "hnsw":
|
||||
self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32)
|
||||
self.vector_index.hnsw.efConstruction = 40
|
||||
else:
|
||||
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
|
||||
else:
|
||||
# 简单的向量存储实现
|
||||
self.vector_index = SimpleVectorIndex(self.config.dimension)
|
||||
|
||||
logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量索引初始化失败: {e}")
|
||||
# 回退到简单实现
|
||||
self.vector_index = SimpleVectorIndex(self.config.dimension)
|
||||
|
||||
async def initialize_embedding_model(self):
|
||||
"""初始化嵌入模型"""
|
||||
if self.embedding_model is None:
|
||||
self.embedding_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding, request_type="memory.embedding"
|
||||
)
|
||||
logger.info("✅ 嵌入模型初始化完成")
|
||||
|
||||
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
|
||||
"""生成查询向量,用于记忆召回"""
|
||||
if not query_text:
|
||||
logger.warning("查询文本为空,无法生成向量")
|
||||
return None
|
||||
|
||||
try:
|
||||
await self.initialize_embedding_model()
|
||||
|
||||
logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'")
|
||||
|
||||
embedding, _ = await self.embedding_model.get_embedding(query_text)
|
||||
if not embedding:
|
||||
logger.warning("嵌入模型返回空向量")
|
||||
return None
|
||||
|
||||
logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}")
|
||||
|
||||
if len(embedding) != self.config.dimension:
|
||||
logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding))
|
||||
return None
|
||||
|
||||
normalized_vector = self._normalize_vector(embedding)
|
||||
logger.debug(f"查询向量生成成功,向量范围: [{min(normalized_vector):.4f}, {max(normalized_vector):.4f}]")
|
||||
return normalized_vector
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def store_memories(self, memories: list[MemoryChunk]):
|
||||
"""存储记忆向量"""
|
||||
if not memories:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保嵌入模型已初始化
|
||||
await self.initialize_embedding_model()
|
||||
|
||||
# 批量获取嵌入向量
|
||||
memory_texts = []
|
||||
|
||||
for memory in memories:
|
||||
# 预先缓存记忆,确保后续流程可访问
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
if memory.embedding is None:
|
||||
# 如果没有嵌入向量,需要生成
|
||||
text = self._prepare_embedding_text(memory)
|
||||
memory_texts.append((memory.memory_id, text))
|
||||
else:
|
||||
# 已有嵌入向量,直接使用
|
||||
await self._add_single_memory(memory, memory.embedding)
|
||||
|
||||
# 批量生成缺失的嵌入向量
|
||||
if memory_texts:
|
||||
await self._batch_generate_and_store_embeddings(memory_texts)
|
||||
|
||||
# 自动保存检查
|
||||
self._operation_count += len(memories)
|
||||
if self._operation_count >= self.config.auto_save_interval:
|
||||
await self.save_storage()
|
||||
self._operation_count = 0
|
||||
|
||||
storage_time = time.time() - start_time
|
||||
logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}秒")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量存储失败: {e}", exc_info=True)
|
||||
|
||||
def _prepare_embedding_text(self, memory: MemoryChunk) -> str:
|
||||
"""准备用于嵌入的文本,仅使用自然语言展示内容"""
|
||||
display_text = (memory.display or "").strip()
|
||||
if display_text:
|
||||
return display_text
|
||||
|
||||
fallback_text = (memory.text_content or "").strip()
|
||||
if fallback_text:
|
||||
return fallback_text
|
||||
|
||||
subjects = "、".join(s.strip() for s in memory.subjects if s and isinstance(s, str))
|
||||
predicate = (memory.content.predicate or "").strip()
|
||||
|
||||
obj = memory.content.object
|
||||
if isinstance(obj, dict):
|
||||
object_parts = []
|
||||
for key, value in obj.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, (list, tuple)):
|
||||
preview = "、".join(str(item) for item in value[:3])
|
||||
object_parts.append(f"{key}:{preview}")
|
||||
else:
|
||||
object_parts.append(f"{key}:{value}")
|
||||
object_text = ", ".join(object_parts)
|
||||
else:
|
||||
object_text = str(obj or "").strip()
|
||||
|
||||
composite_parts = [part for part in [subjects, predicate, object_text] if part]
|
||||
if composite_parts:
|
||||
return " ".join(composite_parts)
|
||||
|
||||
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
|
||||
return memory.memory_id
|
||||
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
|
||||
"""批量生成和存储嵌入向量"""
|
||||
if not memory_texts:
|
||||
return
|
||||
|
||||
try:
|
||||
texts = [text for _, text in memory_texts]
|
||||
memory_ids = [memory_id for memory_id, _ in memory_texts]
|
||||
|
||||
# 批量生成嵌入向量
|
||||
embeddings = await self._batch_generate_embeddings(memory_ids, texts)
|
||||
|
||||
# 存储向量和记忆
|
||||
for memory_id, embedding in embeddings.items():
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory:
|
||||
await self._add_single_memory(memory, embedding)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
|
||||
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
|
||||
"""批量生成嵌入向量"""
|
||||
if not texts:
|
||||
return {}
|
||||
|
||||
results: dict[str, list[float]] = {}
|
||||
|
||||
try:
|
||||
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
|
||||
|
||||
async def generate_embedding(memory_id: str, text: str) -> None:
|
||||
async with semaphore:
|
||||
try:
|
||||
embedding, _ = await self.embedding_model.get_embedding(text)
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
results[memory_id] = embedding
|
||||
else:
|
||||
logger.warning(
|
||||
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
|
||||
self.config.dimension,
|
||||
len(embedding) if embedding else 0,
|
||||
memory_id,
|
||||
)
|
||||
results[memory_id] = []
|
||||
except Exception as exc:
|
||||
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
|
||||
results[memory_id] = []
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
for memory_id in memory_ids:
|
||||
results.setdefault(memory_id, [])
|
||||
|
||||
return results
|
||||
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
|
||||
"""添加单个记忆到向量存储"""
|
||||
with self._lock:
|
||||
try:
|
||||
# 规范化向量
|
||||
if embedding:
|
||||
embedding = self._normalize_vector(embedding)
|
||||
|
||||
# 添加到缓存
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
self.vector_cache[memory.memory_id] = embedding
|
||||
|
||||
# 更新记忆的嵌入向量
|
||||
memory.set_embedding(embedding)
|
||||
|
||||
# 添加到向量索引
|
||||
if hasattr(self.vector_index, "add"):
|
||||
# FAISS索引
|
||||
if isinstance(embedding, np.ndarray):
|
||||
vector_array = embedding.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
vector_array = np.array([embedding], dtype="float32")
|
||||
|
||||
# 特殊处理IVF索引
|
||||
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
|
||||
# IVF索引需要先训练
|
||||
logger.debug("训练IVF索引...")
|
||||
self.vector_index.train(vector_array)
|
||||
|
||||
self.vector_index.add(vector_array)
|
||||
index_id = self.vector_index.ntotal - 1
|
||||
|
||||
else:
|
||||
# 简单索引
|
||||
index_id = self.vector_index.add_vector(embedding)
|
||||
|
||||
# 更新映射关系
|
||||
self.memory_id_to_index[memory.memory_id] = index_id
|
||||
self.index_to_memory_id[index_id] = memory.memory_id
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
|
||||
|
||||
def _normalize_vector(self, vector: list[float]) -> list[float]:
|
||||
"""L2归一化向量"""
|
||||
if not vector:
|
||||
return vector
|
||||
|
||||
try:
|
||||
vector_array = np.array(vector, dtype=np.float32)
|
||||
norm = np.linalg.norm(vector_array)
|
||||
if norm == 0:
|
||||
return vector
|
||||
|
||||
normalized = vector_array / norm
|
||||
return normalized.tolist()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"向量归一化失败: {e}")
|
||||
return vector
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_vector: list[float] | None = None,
|
||||
*,
|
||||
query_text: str | None = None,
|
||||
limit: int = 10,
|
||||
scope_id: str | None = None,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}")
|
||||
|
||||
if query_vector is None:
|
||||
if not query_text:
|
||||
logger.warning("查询向量和查询文本都为空")
|
||||
return []
|
||||
|
||||
query_vector = await self.generate_query_embedding(query_text)
|
||||
if not query_vector:
|
||||
logger.warning("查询向量生成失败")
|
||||
return []
|
||||
|
||||
scope_filter: str | None = None
|
||||
if isinstance(scope_id, str):
|
||||
normalized_scope = scope_id.strip().lower()
|
||||
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
|
||||
scope_filter = scope_id
|
||||
elif scope_id:
|
||||
scope_filter = str(scope_id)
|
||||
|
||||
# 规范化查询向量
|
||||
query_vector = self._normalize_vector(query_vector)
|
||||
|
||||
logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}")
|
||||
|
||||
# 检查向量索引状态
|
||||
if not self.vector_index:
|
||||
logger.error("向量索引未初始化")
|
||||
return []
|
||||
|
||||
total_vectors = 0
|
||||
if hasattr(self.vector_index, "ntotal"):
|
||||
total_vectors = self.vector_index.ntotal
|
||||
elif hasattr(self.vector_index, "vectors"):
|
||||
total_vectors = len(self.vector_index.vectors)
|
||||
|
||||
logger.debug(f"向量索引中实际向量数: {total_vectors}")
|
||||
|
||||
if total_vectors == 0:
|
||||
logger.warning("向量索引为空,无法执行搜索")
|
||||
return []
|
||||
|
||||
# 执行向量搜索
|
||||
with self._lock:
|
||||
if hasattr(self.vector_index, "search"):
|
||||
# FAISS索引
|
||||
if isinstance(query_vector, np.ndarray):
|
||||
query_array = query_vector.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
query_array = np.array([query_vector], dtype="float32")
|
||||
|
||||
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
|
||||
# 设置IVF搜索参数
|
||||
nprobe = min(self.vector_index.nlist, 10)
|
||||
self.vector_index.nprobe = nprobe
|
||||
logger.debug(f"IVF搜索参数: nprobe={nprobe}")
|
||||
|
||||
search_limit = min(limit, total_vectors)
|
||||
logger.debug(f"执行FAISS搜索,搜索限制: {search_limit}")
|
||||
|
||||
distances, indices = self.vector_index.search(query_array, search_limit)
|
||||
distances = distances.flatten().tolist()
|
||||
indices = indices.flatten().tolist()
|
||||
|
||||
logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引")
|
||||
else:
|
||||
# 简单索引
|
||||
logger.debug("使用简单向量索引执行搜索")
|
||||
results = self.vector_index.search(query_vector, limit)
|
||||
distances = [score for _, score in results]
|
||||
indices = [idx for idx, _ in results]
|
||||
logger.debug(f"简单索引搜索结果: {len(results)} 个结果")
|
||||
|
||||
# 处理搜索结果
|
||||
results = []
|
||||
valid_results = 0
|
||||
invalid_indices = 0
|
||||
filtered_by_scope = 0
|
||||
|
||||
for distance, index in zip(distances, indices, strict=False):
|
||||
if index == -1: # FAISS的无效索引标记
|
||||
invalid_indices += 1
|
||||
continue
|
||||
|
||||
memory_id = self.index_to_memory_id.get(index)
|
||||
if not memory_id:
|
||||
logger.debug(f"索引 {index} 没有对应的记忆ID")
|
||||
invalid_indices += 1
|
||||
continue
|
||||
|
||||
if scope_filter:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory and str(memory.user_id) != scope_filter:
|
||||
filtered_by_scope += 1
|
||||
continue
|
||||
|
||||
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
|
||||
results.append((memory_id, similarity))
|
||||
valid_results += 1
|
||||
|
||||
logger.debug(
|
||||
f"搜索结果处理: 总距离={len(distances)}, 有效结果={valid_results}, "
|
||||
f"无效索引={invalid_indices}, 作用域过滤={filtered_by_scope}"
|
||||
)
|
||||
|
||||
# 更新统计
|
||||
search_time = time.time() - start_time
|
||||
self.storage_stats["total_searches"] += 1
|
||||
self.storage_stats["average_search_time"] = (
|
||||
self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time
|
||||
) / self.storage_stats["total_searches"]
|
||||
|
||||
final_results = results[:limit]
|
||||
logger.info(
|
||||
f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' "
|
||||
f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果"
|
||||
)
|
||||
|
||||
return final_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""根据ID获取记忆"""
|
||||
# 先检查缓存
|
||||
if memory_id in self.memory_cache:
|
||||
self.storage_stats["cache_hits"] += 1
|
||||
return self.memory_cache[memory_id]
|
||||
|
||||
self.storage_stats["total_searches"] += 1
|
||||
return None
|
||||
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
|
||||
"""更新记忆的嵌入向量"""
|
||||
with self._lock:
|
||||
try:
|
||||
if memory_id not in self.memory_id_to_index:
|
||||
logger.warning(f"记忆 {memory_id} 不存在于向量索引中")
|
||||
return
|
||||
|
||||
# 获取旧索引
|
||||
old_index = self.memory_id_to_index[memory_id]
|
||||
|
||||
# 删除旧向量(如果支持)
|
||||
if hasattr(self.vector_index, "remove_ids"):
|
||||
try:
|
||||
self.vector_index.remove_ids(np.array([old_index]))
|
||||
except:
|
||||
logger.warning("无法删除旧向量,将直接添加新向量")
|
||||
|
||||
# 规范化新向量
|
||||
new_embedding = self._normalize_vector(new_embedding)
|
||||
|
||||
# 添加新向量
|
||||
if hasattr(self.vector_index, "add"):
|
||||
if isinstance(new_embedding, np.ndarray):
|
||||
vector_array = new_embedding.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
vector_array = np.array([new_embedding], dtype="float32")
|
||||
|
||||
self.vector_index.add(vector_array)
|
||||
new_index = self.vector_index.ntotal - 1
|
||||
else:
|
||||
new_index = self.vector_index.add_vector(new_embedding)
|
||||
|
||||
# 更新映射关系
|
||||
self.memory_id_to_index[memory_id] = new_index
|
||||
self.index_to_memory_id[new_index] = memory_id
|
||||
|
||||
# 更新缓存
|
||||
self.vector_cache[memory_id] = new_embedding
|
||||
|
||||
# 更新记忆对象
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory:
|
||||
memory.set_embedding(new_embedding)
|
||||
|
||||
logger.debug(f"更新记忆 {memory_id} 的嵌入向量")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 更新记忆嵌入向量失败: {e}")
|
||||
|
||||
async def delete_memory(self, memory_id: str):
|
||||
"""删除记忆"""
|
||||
with self._lock:
|
||||
try:
|
||||
if memory_id not in self.memory_id_to_index:
|
||||
return
|
||||
|
||||
# 获取索引
|
||||
index = self.memory_id_to_index[memory_id]
|
||||
|
||||
# 从向量索引中删除(如果支持)
|
||||
if hasattr(self.vector_index, "remove_ids"):
|
||||
try:
|
||||
self.vector_index.remove_ids(np.array([index]))
|
||||
except:
|
||||
logger.warning("无法从向量索引中删除,仅从缓存中移除")
|
||||
|
||||
# 删除映射关系
|
||||
del self.memory_id_to_index[memory_id]
|
||||
if index in self.index_to_memory_id:
|
||||
del self.index_to_memory_id[index]
|
||||
|
||||
# 从缓存中删除
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.vector_cache.pop(memory_id, None)
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1)
|
||||
|
||||
logger.debug(f"删除记忆 {memory_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 删除记忆失败: {e}")
|
||||
|
||||
async def save_storage(self):
|
||||
"""保存向量存储到文件"""
|
||||
try:
|
||||
logger.info("正在保存向量存储...")
|
||||
|
||||
# 保存记忆缓存
|
||||
cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()}
|
||||
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
with open(vector_cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
mapping_data = {
|
||||
"memory_id_to_index": {
|
||||
str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items()
|
||||
},
|
||||
"index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()},
|
||||
}
|
||||
with open(mapping_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存FAISS索引(如果可用)
|
||||
if FAISS_AVAILABLE and hasattr(self.vector_index, "save"):
|
||||
index_file = self.storage_path / "vector_index.faiss"
|
||||
faiss.write_index(self.vector_index, str(index_file))
|
||||
|
||||
# 保存统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
with open(stats_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
logger.info("✅ 向量存储保存完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 保存向量存储失败: {e}")
|
||||
|
||||
async def load_storage(self):
|
||||
"""从文件加载向量存储"""
|
||||
try:
|
||||
logger.info("正在加载向量存储...")
|
||||
|
||||
# 加载记忆缓存
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
self.memory_cache = {
|
||||
memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items()
|
||||
}
|
||||
|
||||
# 加载向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
if vector_cache_file.exists():
|
||||
with open(vector_cache_file, encoding="utf-8") as f:
|
||||
self.vector_cache = orjson.loads(f.read())
|
||||
|
||||
# 加载映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
if mapping_file.exists():
|
||||
with open(mapping_file, encoding="utf-8") as f:
|
||||
mapping_data = orjson.loads(f.read())
|
||||
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.memory_id_to_index = {
|
||||
str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items()
|
||||
}
|
||||
|
||||
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
|
||||
self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()}
|
||||
|
||||
# 加载FAISS索引(如果可用)
|
||||
index_loaded = False
|
||||
if FAISS_AVAILABLE:
|
||||
index_file = self.storage_path / "vector_index.faiss"
|
||||
if index_file.exists():
|
||||
try:
|
||||
loaded_index = faiss.read_index(str(index_file))
|
||||
# 如果索引类型匹配,则替换
|
||||
if type(loaded_index) == type(self.vector_index):
|
||||
self.vector_index = loaded_index
|
||||
index_loaded = True
|
||||
logger.info("✅ FAISS索引文件加载完成")
|
||||
else:
|
||||
logger.warning("索引类型不匹配,重新构建索引")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载FAISS索引失败: {e},重新构建")
|
||||
else:
|
||||
logger.info("FAISS索引文件不存在,将重新构建")
|
||||
|
||||
# 如果索引没有成功加载且有向量数据,则重建索引
|
||||
if not index_loaded and self.vector_cache:
|
||||
logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引")
|
||||
await self._rebuild_index()
|
||||
|
||||
# 加载统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, encoding="utf-8") as f:
|
||||
self.storage_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新向量计数
|
||||
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
|
||||
|
||||
logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载向量存储失败: {e}")
|
||||
|
||||
async def _rebuild_index(self):
|
||||
"""重建向量索引"""
|
||||
try:
|
||||
logger.info(f"正在重建向量索引...向量数量: {len(self.vector_cache)}")
|
||||
|
||||
# 重新初始化索引
|
||||
self._initialize_index()
|
||||
|
||||
# 清空映射关系
|
||||
self.memory_id_to_index.clear()
|
||||
self.index_to_memory_id.clear()
|
||||
|
||||
if not self.vector_cache:
|
||||
logger.warning("没有向量缓存数据,跳过重建")
|
||||
return
|
||||
|
||||
# 准备向量数据
|
||||
memory_ids = []
|
||||
vectors = []
|
||||
|
||||
for memory_id, embedding in self.vector_cache.items():
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
memory_ids.append(memory_id)
|
||||
vectors.append(self._normalize_vector(embedding))
|
||||
else:
|
||||
logger.debug(f"跳过无效向量: {memory_id}, 维度: {len(embedding) if embedding else 0}")
|
||||
|
||||
if not vectors:
|
||||
logger.warning("没有有效的向量数据")
|
||||
return
|
||||
|
||||
logger.info(f"准备重建 {len(vectors)} 个向量到索引")
|
||||
|
||||
# 批量添加向量到FAISS索引
|
||||
if hasattr(self.vector_index, "add"):
|
||||
# FAISS索引
|
||||
vector_array = np.array(vectors, dtype="float32")
|
||||
|
||||
# 特殊处理IVF索引
|
||||
if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"):
|
||||
logger.info("训练IVF索引...")
|
||||
self.vector_index.train(vector_array)
|
||||
|
||||
# 添加向量
|
||||
self.vector_index.add(vector_array)
|
||||
|
||||
# 重建映射关系
|
||||
for i, memory_id in enumerate(memory_ids):
|
||||
self.memory_id_to_index[memory_id] = i
|
||||
self.index_to_memory_id[i] = memory_id
|
||||
|
||||
else:
|
||||
# 简单索引
|
||||
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)):
|
||||
index_id = self.vector_index.add_vector(vector)
|
||||
self.memory_id_to_index[memory_id] = index_id
|
||||
self.index_to_memory_id[index_id] = memory_id
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
|
||||
|
||||
final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index))
|
||||
logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重建向量索引失败: {e}", exc_info=True)
|
||||
|
||||
async def optimize_storage(self):
|
||||
"""优化存储"""
|
||||
try:
|
||||
logger.info("开始向量存储优化...")
|
||||
|
||||
# 清理无效引用
|
||||
self._cleanup_invalid_references()
|
||||
|
||||
# 重新构建索引(如果碎片化严重)
|
||||
if self.storage_stats["total_vectors"] > 1000:
|
||||
await self._rebuild_index()
|
||||
|
||||
# 更新缓存命中率
|
||||
if self.storage_stats["total_searches"] > 0:
|
||||
self.storage_stats["cache_hit_rate"] = (
|
||||
self.storage_stats["cache_hits"] / self.storage_stats["total_searches"]
|
||||
)
|
||||
|
||||
logger.info("✅ 向量存储优化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量存储优化失败: {e}")
|
||||
|
||||
def _cleanup_invalid_references(self):
|
||||
"""清理无效引用"""
|
||||
with self._lock:
|
||||
# 清理无效的memory_id到index的映射
|
||||
valid_memory_ids = set(self.memory_cache.keys())
|
||||
invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids
|
||||
|
||||
for memory_id in invalid_memory_ids:
|
||||
index = self.memory_id_to_index[memory_id]
|
||||
del self.memory_id_to_index[memory_id]
|
||||
if index in self.index_to_memory_id:
|
||||
del self.index_to_memory_id[index]
|
||||
|
||||
if invalid_memory_ids:
|
||||
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
|
||||
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
stats = self.storage_stats.copy()
|
||||
if stats["total_searches"] > 0:
|
||||
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"]
|
||||
else:
|
||||
stats["cache_hit_rate"] = 0.0
|
||||
return stats
|
||||
|
||||
|
||||
class SimpleVectorIndex:
|
||||
"""简单的向量索引实现(当FAISS不可用时的替代方案)"""
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
self.dimension = dimension
|
||||
self.vectors: list[list[float]] = []
|
||||
self.vector_ids: list[int] = []
|
||||
self.next_id = 0
|
||||
|
||||
def add_vector(self, vector: list[float]) -> int:
|
||||
"""添加向量"""
|
||||
if len(vector) != self.dimension:
|
||||
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
|
||||
|
||||
vector_id = self.next_id
|
||||
self.vectors.append(vector.copy())
|
||||
self.vector_ids.append(vector_id)
|
||||
self.next_id += 1
|
||||
|
||||
return vector_id
|
||||
|
||||
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
|
||||
"""搜索相似向量"""
|
||||
if len(query_vector) != self.dimension:
|
||||
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
|
||||
|
||||
results = []
|
||||
|
||||
for i, vector in enumerate(self.vectors):
|
||||
similarity = self._calculate_cosine_similarity(query_vector, vector)
|
||||
results.append((self.vector_ids[i], similarity))
|
||||
|
||||
# 按相似度排序
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
|
||||
norm1 = sum(x * x for x in v1) ** 0.5
|
||||
norm2 = sum(x * x for x in v2) ** 0.5
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def ntotal(self) -> int:
|
||||
"""向量总数"""
|
||||
return len(self.vectors)
|
||||
Reference in New Issue
Block a user