refactor(memory): 移除废弃的记忆系统备份文件,优化消息管理器架构

移除了deprecated_backup目录下的所有废弃记忆系统文件,包括增强记忆适配器、钩子、集成层、重排序器、元数据索引、多阶段检索和向量存储等模块。同时优化了消息管理器,集成了批量数据库写入器、流缓存管理器和自适应流管理器,提升了系统性能和可维护性。
This commit is contained in:
Windpicker-owo
2025-10-04 01:38:41 +08:00
parent 80fc37fd02
commit a5a16971e9
16 changed files with 1975 additions and 5203 deletions

View File

@@ -1,363 +0,0 @@
"""
增强记忆系统适配器
将增强记忆系统集成到现有MoFox Bot架构中
"""
import time
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
MEMORY_TYPE_LABELS = {
MemoryType.PERSONAL_FACT: "个人事实",
MemoryType.EVENT: "事件",
MemoryType.PREFERENCE: "偏好",
MemoryType.OPINION: "观点",
MemoryType.RELATIONSHIP: "关系",
MemoryType.EMOTION: "情感",
MemoryType.KNOWLEDGE: "知识",
MemoryType.SKILL: "技能",
MemoryType.GOAL: "目标",
MemoryType.EXPERIENCE: "经验",
MemoryType.CONTEXTUAL: "上下文",
}
@dataclass
class AdapterConfig:
"""适配器配置"""
enable_enhanced_memory: bool = True
integration_mode: str = "enhanced_only" # replace, enhanced_only
auto_migration: bool = True
memory_value_threshold: float = 0.6
fusion_threshold: float = 0.85
max_retrieval_results: int = 10
class EnhancedMemoryAdapter:
"""增强记忆系统适配器"""
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
self.llm_model = llm_model
self.config = config or AdapterConfig()
self.integration_layer: MemoryIntegrationLayer | None = None
self._initialized = False
# 统计信息
self.adapter_stats = {
"total_processed": 0,
"enhanced_used": 0,
"legacy_used": 0,
"hybrid_used": 0,
"memories_created": 0,
"memories_retrieved": 0,
"average_processing_time": 0.0,
}
async def initialize(self):
"""初始化适配器"""
if self._initialized:
return
try:
logger.info("🚀 初始化增强记忆系统适配器...")
# 转换配置格式
integration_config = IntegrationConfig(
mode=IntegrationMode(self.config.integration_mode),
enable_enhanced_memory=self.config.enable_enhanced_memory,
memory_value_threshold=self.config.memory_value_threshold,
fusion_threshold=self.config.fusion_threshold,
max_retrieval_results=self.config.max_retrieval_results,
enable_learning=True, # 启用学习功能
)
# 创建集成层
self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config)
# 初始化集成层
await self.integration_layer.initialize()
self._initialized = True
logger.info("✅ 增强记忆系统适配器初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True)
# 如果初始化失败,禁用增强记忆功能
self.config.enable_enhanced_memory = False
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
start_time = time.time()
self.adapter_stats["total_processed"] += 1
try:
payload_context: dict[str, Any] = dict(context or {})
conversation_text = payload_context.get("conversation_text")
if not conversation_text:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
conversation_text = str(conversation_candidate)
payload_context["conversation_text"] = conversation_text
else:
conversation_text = ""
else:
conversation_text = str(conversation_text)
if "timestamp" not in payload_context:
payload_context["timestamp"] = time.time()
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
# 使用集成层处理对话
result = await self.integration_layer.process_conversation(payload_context)
# 更新统计
processing_time = time.time() - start_time
self._update_processing_stats(processing_time)
if result["success"]:
created_count = len(result.get("created_memories", []))
self.adapter_stats["memories_created"] += created_count
logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆")
return result
except Exception as e:
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.config.enable_enhanced_memory:
return []
try:
limit = limit or self.config.max_retrieval_results
memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit)
self.adapter_stats["memories_retrieved"] += len(memories)
logger.debug(f"检索到 {len(memories)} 条相关记忆")
return memories
except Exception as e:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_memory_context_for_prompt(
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
) -> str:
"""获取用于提示词的记忆上下文"""
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
if not memories:
return ""
# 使用新的记忆格式化器
formatter_config = FormatterConfig(
include_timestamps=True,
include_memory_types=True,
include_confidence=False,
use_emoji_icons=True,
group_by_type=False,
max_display_length=150,
)
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
"""获取增强记忆系统摘要"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"available": False, "reason": "Not initialized or disabled"}
try:
# 获取系统状态
status = await self.integration_layer.get_system_status()
# 获取适配器统计
adapter_stats = self.adapter_stats.copy()
# 获取集成统计
integration_stats = self.integration_layer.get_integration_stats()
return {
"available": True,
"system_status": status,
"adapter_stats": adapter_stats,
"integration_stats": integration_stats,
"total_memories_created": adapter_stats["memories_created"],
"total_memories_retrieved": adapter_stats["memories_retrieved"],
}
except Exception as e:
logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True)
return {"available": False, "error": str(e)}
def _update_processing_stats(self, processing_time: float):
"""更新处理统计"""
total_processed = self.adapter_stats["total_processed"]
if total_processed > 0:
current_avg = self.adapter_stats["average_processing_time"]
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
self.adapter_stats["average_processing_time"] = new_avg
def get_adapter_stats(self) -> dict[str, Any]:
"""获取适配器统计信息"""
return self.adapter_stats.copy()
async def maintenance(self):
"""维护操作"""
if not self._initialized:
return
try:
logger.info("🔧 增强记忆系统适配器维护...")
await self.integration_layer.maintenance()
logger.info("✅ 增强记忆系统适配器维护完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭适配器"""
if not self._initialized:
return
try:
logger.info("🔄 关闭增强记忆系统适配器...")
await self.integration_layer.shutdown()
self._initialized = False
logger.info("✅ 增强记忆系统适配器已关闭")
except Exception as e:
logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True)
# 全局适配器实例
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
"""获取全局增强记忆适配器实例"""
global _enhanced_memory_adapter
if _enhanced_memory_adapter is None:
# 从配置中获取适配器配置
from src.config.config import global_config
adapter_config = AdapterConfig(
enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True),
integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"),
auto_migration=getattr(global_config.memory, "enable_memory_migration", True),
memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6),
fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85),
max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10),
)
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
await _enhanced_memory_adapter.initialize()
return _enhanced_memory_adapter
async def initialize_enhanced_memory_system(llm_model: LLMRequest):
"""初始化增强记忆系统"""
try:
logger.info("🚀 初始化增强记忆系统...")
adapter = await get_enhanced_memory_adapter(llm_model)
logger.info("✅ 增强记忆系统初始化完成")
return adapter
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
return None
async def process_conversation_with_enhanced_memory(
context: dict[str, Any], llm_model: LLMRequest | None = None
) -> dict[str, Any]:
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
payload_context = dict(context or {})
if "conversation_text" not in payload_context:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
payload_context["conversation_text"] = str(conversation_candidate)
return await adapter.process_conversation_memory(payload_context)
except Exception as e:
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_memories_with_enhanced_system(
query: str,
user_id: str,
context: dict[str, Any] | None = None,
limit: int = 10,
llm_model: LLMRequest | None = None,
) -> list[MemoryChunk]:
"""使用增强记忆系统检索记忆"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.retrieve_relevant_memories(query, user_id, context, limit)
except Exception as e:
logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True)
return []
async def get_memory_context_for_prompt(
query: str,
user_id: str,
context: dict[str, Any] | None = None,
max_memories: int = 5,
llm_model: LLMRequest | None = None,
) -> str:
"""获取用于提示词的记忆上下文"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
except Exception as e:
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
return ""

View File

@@ -1,194 +0,0 @@
"""
增强记忆系统钩子
用于在消息处理过程中自动构建和检索记忆
"""
from datetime import datetime
from typing import Any
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger(__name__)
class EnhancedMemoryHooks:
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
def __init__(self):
self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory
self.processed_messages = set() # 避免重复处理
async def process_message_for_memory(
self,
message_content: str,
user_id: str,
chat_id: str,
message_id: str,
context: dict[str, Any] | None = None,
) -> bool:
"""
处理消息并构建记忆
Args:
message_content: 消息内容
user_id: 用户ID
chat_id: 聊天ID
message_id: 消息ID
context: 上下文信息
Returns:
bool: 是否成功处理
"""
if not self.enabled:
return False
if message_id in self.processed_messages:
return False
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 注入机器人基础人设,帮助记忆构建时避免记录自身信息
bot_config = getattr(global_config, "bot", None)
personality_config = getattr(global_config, "personality", None)
bot_context = {}
if bot_config is not None:
bot_context["bot_name"] = getattr(bot_config, "nickname", None)
bot_context["bot_aliases"] = list(getattr(bot_config, "alias_names", []) or [])
bot_context["bot_account"] = getattr(bot_config, "qq_account", None)
if personality_config is not None:
bot_context["bot_identity"] = getattr(personality_config, "identity", None)
bot_context["bot_personality"] = getattr(personality_config, "personality_core", None)
bot_context["bot_personality_side"] = getattr(personality_config, "personality_side", None)
# 构建上下文
memory_context = {
"chat_id": chat_id,
"message_id": message_id,
"timestamp": datetime.now().timestamp(),
"message_type": "user_message",
**bot_context,
**(context or {}),
}
# 处理对话并构建记忆
memory_chunks = await enhanced_memory_manager.process_conversation(
conversation_text=message_content,
context=memory_context,
user_id=user_id,
timestamp=memory_context["timestamp"],
)
# 标记消息已处理
self.processed_messages.add(message_id)
# 限制处理历史大小
if len(self.processed_messages) > 1000:
# 移除最旧的500个记录
self.processed_messages = set(list(self.processed_messages)[-500:])
logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆")
return len(memory_chunks) > 0
except Exception as e:
logger.error(f"处理消息记忆失败: {e}")
return False
async def get_memory_for_response(
self,
query_text: str,
user_id: str,
chat_id: str,
limit: int = 5,
extra_context: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""
为回复获取相关记忆
Args:
query_text: 查询文本
user_id: 用户ID
chat_id: 聊天ID
limit: 返回记忆数量限制
Returns:
List[Dict]: 相关记忆列表
"""
if not self.enabled:
return []
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 构建查询上下文
context = {
"chat_id": chat_id,
"query_intent": "response_generation",
"expected_memory_types": ["personal_fact", "event", "preference", "opinion"],
}
if extra_context:
context.update(extra_context)
# 获取相关记忆
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
query_text=query_text, user_id=user_id, context=context, limit=limit
)
# 转换为字典格式
results = []
for result in enhanced_results:
memory_dict = {
"content": result.content,
"type": result.memory_type,
"confidence": result.confidence,
"importance": result.importance,
"timestamp": result.timestamp,
"source": result.source,
"relevance": result.relevance_score,
"structure": result.structure,
}
results.append(memory_dict)
logger.debug(f"为回复查询到 {len(results)} 条相关记忆")
return results
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return []
async def cleanup_old_memories(self):
"""清理旧记忆"""
try:
if enhanced_memory_manager.is_initialized:
# 调用增强记忆系统的维护功能
await enhanced_memory_manager.enhanced_system.maintenance()
logger.debug("增强记忆系统维护完成")
except Exception as e:
logger.error(f"清理旧记忆失败: {e}")
def clear_processed_cache(self):
"""清除已处理消息的缓存"""
self.processed_messages.clear()
logger.debug("已清除消息处理缓存")
def enable(self):
"""启用记忆钩子"""
self.enabled = True
logger.info("增强记忆钩子已启用")
def disable(self):
"""禁用记忆钩子"""
self.enabled = False
logger.info("增强记忆钩子已禁用")
# 创建全局实例
enhanced_memory_hooks = EnhancedMemoryHooks()

View File

@@ -1,177 +0,0 @@
"""
增强记忆系统集成脚本
用于在现有系统中无缝集成增强记忆功能
"""
from typing import Any
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
from src.common.logger import get_logger
logger = get_logger(__name__)
async def process_user_message_memory(
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
) -> bool:
"""
处理用户消息并构建记忆
Args:
message_content: 消息内容
user_id: 用户ID
chat_id: 聊天ID
message_id: 消息ID
context: 额外的上下文信息
Returns:
bool: 是否成功构建记忆
"""
try:
success = await enhanced_memory_hooks.process_message_for_memory(
message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
)
if success:
logger.debug(f"成功为消息 {message_id} 构建记忆")
return success
except Exception as e:
logger.error(f"处理用户消息记忆失败: {e}")
return False
async def get_relevant_memories_for_response(
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
为回复获取相关记忆
Args:
query_text: 查询文本(通常是用户的当前消息)
user_id: 用户ID
chat_id: 聊天ID
limit: 返回记忆数量限制
extra_context: 额外上下文信息
Returns:
Dict: 包含记忆信息的字典
"""
try:
memories = await enhanced_memory_hooks.get_memory_for_response(
query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context
)
result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)}
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
return result
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return {"has_memories": False, "memories": [], "memory_count": 0}
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
"""
格式化记忆信息用于Prompt
Args:
memories: 记忆信息字典
Returns:
str: 格式化后的记忆文本
"""
if not memories["has_memories"]:
return ""
memory_lines = ["以下是相关的记忆信息:"]
for memory in memories["memories"]:
content = memory["content"]
memory_type = memory["type"]
confidence = memory["confidence"]
importance = memory["importance"]
# 根据重要性添加不同的标记
importance_marker = "🔥" if importance >= 3 else "" if importance >= 2 else "📝"
confidence_marker = "" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭"
memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)"
memory_lines.append(memory_line)
return "\n".join(memory_lines)
async def cleanup_memory_system():
"""清理记忆系统"""
try:
await enhanced_memory_hooks.cleanup_old_memories()
logger.info("记忆系统清理完成")
except Exception as e:
logger.error(f"记忆系统清理失败: {e}")
def get_memory_system_status() -> dict[str, Any]:
"""
获取记忆系统状态
Returns:
Dict: 系统状态信息
"""
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
return {
"enabled": enhanced_memory_hooks.enabled,
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
"system_type": "enhanced_memory_system",
}
# 便捷函数
async def remember_message(
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
) -> bool:
"""
便捷的记忆构建函数
Args:
message: 要记住的消息
user_id: 用户ID
chat_id: 聊天ID
Returns:
bool: 是否成功
"""
import uuid
message_id = str(uuid.uuid4())
return await process_user_message_memory(
message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
)
async def recall_memories(
query: str,
user_id: str = "default_user",
chat_id: str = "default_chat",
limit: int = 5,
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
便捷的记忆检索函数
Args:
query: 查询文本
user_id: 用户ID
chat_id: 聊天ID
limit: 返回数量限制
Returns:
Dict: 记忆信息
"""
return await get_relevant_memories_for_response(
query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context
)

View File

@@ -1,356 +0,0 @@
"""
增强重排序器
实现文档设计的多维度评分模型
"""
import math
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
logger = get_logger(__name__)
class IntentType(Enum):
"""对话意图类型"""
FACT_QUERY = "fact_query" # 事实查询
EVENT_RECALL = "event_recall" # 事件回忆
PREFERENCE_CHECK = "preference_check" # 偏好检查
GENERAL_CHAT = "general_chat" # 一般对话
UNKNOWN = "unknown" # 未知意图
@dataclass
class ReRankingConfig:
"""重排序配置"""
# 权重配置 (w1 + w2 + w3 + w4 = 1.0)
semantic_weight: float = 0.5 # 语义相似度权重
recency_weight: float = 0.2 # 时效性权重
usage_freq_weight: float = 0.2 # 使用频率权重
type_match_weight: float = 0.1 # 类型匹配权重
# 时效性衰减参数
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
# 使用频率计算参数
freq_log_base: float = 2.0 # 对数底数
freq_max_score: float = 5.0 # 最大频率得分
# 类型匹配权重映射
type_match_weights: dict[str, dict[str, float]] = None
def __post_init__(self):
"""初始化类型匹配权重"""
if self.type_match_weights is None:
self.type_match_weights = {
IntentType.FACT_QUERY.value: {
MemoryType.PERSONAL_FACT.value: 1.0,
MemoryType.KNOWLEDGE.value: 0.8,
MemoryType.PREFERENCE.value: 0.5,
MemoryType.EVENT.value: 0.3,
"default": 0.3,
},
IntentType.EVENT_RECALL.value: {
MemoryType.EVENT.value: 1.0,
MemoryType.EXPERIENCE.value: 0.8,
MemoryType.EMOTION.value: 0.6,
MemoryType.PERSONAL_FACT.value: 0.5,
"default": 0.5,
},
IntentType.PREFERENCE_CHECK.value: {
MemoryType.PREFERENCE.value: 1.0,
MemoryType.OPINION.value: 0.8,
MemoryType.GOAL.value: 0.6,
MemoryType.PERSONAL_FACT.value: 0.4,
"default": 0.4,
},
IntentType.GENERAL_CHAT.value: {"default": 0.8},
IntentType.UNKNOWN.value: {"default": 0.8},
}
class IntentClassifier:
"""轻量级意图识别器"""
def __init__(self):
# 关键词模式匹配规则
self.patterns = {
IntentType.FACT_QUERY: [
# 中文模式
"我是",
"我的",
"我叫",
"我在",
"我住在",
"我的职业",
"我的工作",
"什么时候",
"在哪里",
"是什么",
"多少",
"几岁",
"年龄",
# 英文模式
"what is",
"where is",
"when is",
"how old",
"my name",
"i am",
"i live",
],
IntentType.EVENT_RECALL: [
# 中文模式
"记得",
"想起",
"还记得",
"那次",
"上次",
"之前",
"以前",
"曾经",
"发生过",
"经历",
"做过",
"去过",
"见过",
# 英文模式
"remember",
"recall",
"last time",
"before",
"previously",
"happened",
"experience",
],
IntentType.PREFERENCE_CHECK: [
# 中文模式
"喜欢",
"不喜欢",
"偏好",
"爱好",
"兴趣",
"讨厌",
"最爱",
"最喜欢",
"习惯",
"通常",
"一般",
"倾向于",
"更喜欢",
# 英文模式
"like",
"love",
"hate",
"prefer",
"favorite",
"usually",
"tend to",
"interest",
],
}
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
"""识别对话意图"""
if not query:
return IntentType.UNKNOWN
query_lower = query.lower()
# 统计各意图的匹配分数
intent_scores = dict.fromkeys(IntentType, 0)
for intent, patterns in self.patterns.items():
for pattern in patterns:
if pattern in query_lower:
intent_scores[intent] += 1
# 返回得分最高的意图
max_score = max(intent_scores.values())
if max_score == 0:
return IntentType.GENERAL_CHAT
for intent, score in intent_scores.items():
if score == max_score:
return intent
return IntentType.GENERAL_CHAT
class EnhancedReRanker:
"""增强重排序器 - 实现文档设计的多维度评分模型"""
def __init__(self, config: ReRankingConfig | None = None):
self.config = config or ReRankingConfig()
self.intent_classifier = IntentClassifier()
# 验证权重和为1.0
total_weight = (
self.config.semantic_weight
+ self.config.recency_weight
+ self.config.usage_freq_weight
+ self.config.type_match_weight
)
if abs(total_weight - 1.0) > 0.01:
logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化")
# 归一化权重
self.config.semantic_weight /= total_weight
self.config.recency_weight /= total_weight
self.config.usage_freq_weight /= total_weight
self.config.type_match_weight /= total_weight
def rerank_memories(
self,
query: str,
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
context: dict[str, Any],
limit: int = 10,
) -> list[tuple[str, MemoryChunk, float]]:
"""
对候选记忆进行重排序
Args:
query: 查询文本
candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)]
context: 上下文信息
limit: 返回数量限制
Returns:
重排序后的记忆列表 [(memory_id, memory, final_score)]
"""
if not candidate_memories:
return []
# 识别查询意图
intent = self.intent_classifier.classify_intent(query, context)
logger.debug(f"识别到查询意图: {intent.value}")
# 计算每个候选记忆的最终得分
scored_memories = []
current_time = time.time()
for memory_id, memory, vector_sim in candidate_memories:
try:
# 1. 语义相似度得分 (已归一化到[0,1])
semantic_score = self._normalize_similarity(vector_sim)
# 2. 时效性得分
recency_score = self._calculate_recency_score(memory, current_time)
# 3. 使用频率得分
usage_freq_score = self._calculate_usage_frequency_score(memory)
# 4. 类型匹配得分
type_match_score = self._calculate_type_match_score(memory, intent)
# 计算最终得分
final_score = (
self.config.semantic_weight * semantic_score
+ self.config.recency_weight * recency_score
+ self.config.usage_freq_weight * usage_freq_score
+ self.config.type_match_weight * type_match_score
)
scored_memories.append((memory_id, memory, final_score))
# 记录调试信息
logger.debug(
f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, "
f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, "
f"type={type_match_score:.3f}, final={final_score:.3f}"
)
except Exception as e:
logger.error(f"计算记忆 {memory_id} 得分时出错: {e}")
# 使用向量相似度作为后备得分
scored_memories.append((memory_id, memory, vector_sim))
# 按最终得分降序排序
scored_memories.sort(key=lambda x: x[2], reverse=True)
# 返回前N个结果
result = scored_memories[:limit]
highest_score = result[0][2] if result else 0.0
logger.info(
f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, "
f"意图={intent.value}, 最高分={highest_score:.3f}"
)
return result
def _normalize_similarity(self, raw_similarity: float) -> float:
"""归一化相似度到[0,1]区间"""
# 假设原始相似度已经在[-1,1]或[0,1]区间
if raw_similarity < 0:
return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1]
return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间
def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float:
"""
计算时效性得分
公式: Recency = 1 / (1 + decay_rate * days_old)
"""
last_accessed = memory.metadata.last_accessed or memory.metadata.created_at
days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数
if days_old < 0:
days_old = 0 # 处理时间异常
score = 1 / (1 + self.config.recency_decay_rate * days_old)
return min(1.0, max(0.0, score))
def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float:
"""
计算使用频率得分
公式: Usage_Freq = min(1.0, log2(access_count + 1) / max_score)
"""
access_count = memory.metadata.access_count
if access_count <= 0:
return 0.0
log_count = math.log2(access_count + 1)
score = log_count / self.config.freq_max_score
return min(1.0, max(0.0, score))
def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float:
"""计算类型匹配得分"""
memory_type = memory.memory_type.value
intent_value = intent.value
# 获取对应意图的类型权重映射
type_weights = self.config.type_match_weights.get(intent_value, {})
# 查找具体类型的权重,如果没有则使用默认权重
score = type_weights.get(memory_type, type_weights.get("default", 0.8))
return min(1.0, max(0.0, score))
# 创建默认的重排序器实例
default_reranker = EnhancedReRanker()
def rerank_candidate_memories(
query: str,
candidate_memories: list[tuple[str, MemoryChunk, float]],
context: dict[str, Any],
limit: int = 10,
config: ReRankingConfig | None = None,
) -> list[tuple[str, MemoryChunk, float]]:
"""
便捷函数:对候选记忆进行重排序
"""
if config:
reranker = EnhancedReRanker(config)
else:
reranker = default_reranker
return reranker.rerank_memories(query, candidate_memories, context, limit)

View File

@@ -1,245 +0,0 @@
"""
增强记忆系统集成层
现在只管理新的增强记忆系统,旧系统已被完全移除
"""
import asyncio
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
class IntegrationMode(Enum):
"""集成模式"""
REPLACE = "replace" # 完全替换现有记忆系统
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
@dataclass
class IntegrationConfig:
"""集成配置"""
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
enable_enhanced_memory: bool = True
memory_value_threshold: float = 0.6
fusion_threshold: float = 0.85
max_retrieval_results: int = 10
enable_learning: bool = True
class MemoryIntegrationLayer:
"""记忆系统集成层 - 现在只管理增强记忆系统"""
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
self.llm_model = llm_model
self.config = config or IntegrationConfig()
# 只初始化增强记忆系统
self.enhanced_memory: EnhancedMemorySystem | None = None
# 集成统计
self.integration_stats = {
"total_queries": 0,
"enhanced_queries": 0,
"memory_creations": 0,
"average_response_time": 0.0,
"success_rate": 0.0,
}
# 初始化锁
self._initialization_lock = asyncio.Lock()
self._initialized = False
async def initialize(self):
"""初始化集成层"""
if self._initialized:
return
async with self._initialization_lock:
if self._initialized:
return
logger.info("🚀 开始初始化增强记忆系统集成层...")
try:
# 初始化增强记忆系统
if self.config.enable_enhanced_memory:
await self._initialize_enhanced_memory()
self._initialized = True
logger.info("✅ 增强记忆系统集成层初始化完成")
except Exception as e:
logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True)
raise
async def _initialize_enhanced_memory(self):
"""初始化增强记忆系统"""
try:
logger.debug("初始化增强记忆系统...")
# 创建增强记忆系统配置
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
memory_config = MemorySystemConfig.from_global_config()
# 使用集成配置覆盖部分值
memory_config.memory_value_threshold = self.config.memory_value_threshold
memory_config.fusion_similarity_threshold = self.config.fusion_threshold
memory_config.final_recall_limit = self.config.max_retrieval_results
# 创建增强记忆系统
self.enhanced_memory = EnhancedMemorySystem(config=memory_config)
# 如果外部提供了LLM模型注入到系统中
if self.llm_model is not None:
self.enhanced_memory.llm_model = self.llm_model
# 初始化系统
await self.enhanced_memory.initialize()
logger.info("✅ 增强记忆系统初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
raise
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
start_time = time.time()
self.integration_stats["total_queries"] += 1
self.integration_stats["enhanced_queries"] += 1
try:
payload_context = dict(context or {})
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
# 直接使用增强记忆系统处理
result = await self.enhanced_memory.process_conversation_memory(payload_context)
# 更新统计
processing_time = time.time() - start_time
self._update_response_stats(processing_time, result.get("success", False))
if result.get("success"):
created_count = len(result.get("created_memories", []))
self.integration_stats["memory_creations"] += created_count
logger.debug(f"对话处理完成,创建 {created_count} 条记忆")
return result
except Exception as e:
processing_time = time.time() - start_time
self._update_response_stats(processing_time, False)
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self,
query: str,
user_id: str | None = None,
context: dict[str, Any] | None = None,
limit: int | None = None,
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.enhanced_memory:
return []
try:
limit = limit or self.config.max_retrieval_results
memories = await self.enhanced_memory.retrieve_relevant_memories(
query=query, user_id=None, context=context or {}, limit=limit
)
memory_count = len(memories)
logger.debug(f"检索到 {memory_count} 条相关记忆")
return memories
except Exception as e:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_system_status(self) -> dict[str, Any]:
"""获取系统状态"""
if not self._initialized:
return {"status": "not_initialized"}
try:
enhanced_status = {}
if self.enhanced_memory:
enhanced_status = await self.enhanced_memory.get_system_status()
return {
"status": "initialized",
"mode": self.config.mode.value,
"enhanced_memory": enhanced_status,
"integration_stats": self.integration_stats.copy(),
}
except Exception as e:
logger.error(f"获取系统状态失败: {e}", exc_info=True)
return {"status": "error", "error": str(e)}
def get_integration_stats(self) -> dict[str, Any]:
"""获取集成统计信息"""
return self.integration_stats.copy()
def _update_response_stats(self, processing_time: float, success: bool):
"""更新响应统计"""
total_queries = self.integration_stats["total_queries"]
if total_queries > 0:
# 更新平均响应时间
current_avg = self.integration_stats["average_response_time"]
new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries
self.integration_stats["average_response_time"] = new_avg
# 更新成功率
if success:
current_success_rate = self.integration_stats["success_rate"]
new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries
self.integration_stats["success_rate"] = new_success_rate
async def maintenance(self):
"""执行维护操作"""
if not self._initialized:
return
try:
logger.info("🔧 执行记忆系统集成层维护...")
if self.enhanced_memory:
await self.enhanced_memory.maintenance()
logger.info("✅ 记忆系统集成层维护完成")
except Exception as e:
logger.error(f"❌ 集成层维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭集成层"""
if not self._initialized:
return
try:
logger.info("🔄 关闭记忆系统集成层...")
if self.enhanced_memory:
await self.enhanced_memory.shutdown()
self._initialized = False
logger.info("✅ 记忆系统集成层已关闭")
except Exception as e:
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)

View File

@@ -1,526 +0,0 @@
"""
记忆系统集成钩子
提供与现有MoFox Bot系统的无缝集成点
"""
import time
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.enhanced_memory_adapter import (
get_memory_context_for_prompt,
process_conversation_with_enhanced_memory,
retrieve_memories_with_enhanced_system,
)
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class HookResult:
"""钩子执行结果"""
success: bool
data: Any = None
error: str | None = None
processing_time: float = 0.0
class MemoryIntegrationHooks:
"""记忆系统集成钩子"""
def __init__(self):
self.hooks_registered = False
self.hook_stats = {
"message_processing_hooks": 0,
"memory_retrieval_hooks": 0,
"prompt_enhancement_hooks": 0,
"total_hook_executions": 0,
"average_hook_time": 0.0,
}
async def register_hooks(self):
"""注册所有集成钩子"""
if self.hooks_registered:
return
try:
logger.info("🔗 注册记忆系统集成钩子...")
# 注册消息处理钩子
await self._register_message_processing_hooks()
# 注册记忆检索钩子
await self._register_memory_retrieval_hooks()
# 注册提示词增强钩子
await self._register_prompt_enhancement_hooks()
# 注册系统维护钩子
await self._register_maintenance_hooks()
self.hooks_registered = True
logger.info("✅ 记忆系统集成钩子注册完成")
except Exception as e:
logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True)
async def _register_message_processing_hooks(self):
"""注册消息处理钩子"""
try:
# 钩子1: 在消息处理后创建记忆
await self._register_post_message_hook()
# 钩子2: 在聊天流保存时处理记忆
await self._register_chat_stream_hook()
logger.debug("消息处理钩子注册完成")
except Exception as e:
logger.error(f"注册消息处理钩子失败: {e}")
async def _register_memory_retrieval_hooks(self):
"""注册记忆检索钩子"""
try:
# 钩子1: 在生成回复前检索相关记忆
await self._register_pre_response_hook()
# 钩子2: 在知识库查询前增强上下文
await self._register_knowledge_query_hook()
logger.debug("记忆检索钩子注册完成")
except Exception as e:
logger.error(f"注册记忆检索钩子失败: {e}")
async def _register_prompt_enhancement_hooks(self):
"""注册提示词增强钩子"""
try:
# 钩子1: 增强提示词构建
await self._register_prompt_building_hook()
logger.debug("提示词增强钩子注册完成")
except Exception as e:
logger.error(f"注册提示词增强钩子失败: {e}")
async def _register_maintenance_hooks(self):
"""注册系统维护钩子"""
try:
# 钩子1: 系统维护时的记忆系统维护
await self._register_system_maintenance_hook()
logger.debug("系统维护钩子注册完成")
except Exception as e:
logger.error(f"注册系统维护钩子失败: {e}")
async def _register_post_message_hook(self):
"""注册消息后处理钩子"""
try:
# 这里需要根据实际的系统架构来注册钩子
# 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整
# 尝试注册到事件系统
try:
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
# 注册消息后处理事件
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
logger.debug("已注册到事件系统的消息处理钩子")
except ImportError:
logger.debug("事件系统不可用,跳过事件钩子注册")
# 尝试注册到消息管理器
try:
from src.chat.message_manager import message_manager
# 如果消息管理器支持钩子注册
if hasattr(message_manager, "register_post_process_hook"):
message_manager.register_post_process_hook(self._on_message_processed_hook)
logger.debug("已注册到消息管理器的处理钩子")
except ImportError:
logger.debug("消息管理器不可用,跳过消息管理器钩子注册")
except Exception as e:
logger.error(f"注册消息后处理钩子失败: {e}")
async def _register_chat_stream_hook(self):
"""注册聊天流钩子"""
try:
# 尝试注册到聊天流管理器
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if hasattr(chat_manager, "register_save_hook"):
chat_manager.register_save_hook(self._on_chat_stream_save_hook)
logger.debug("已注册到聊天流管理器的保存钩子")
except ImportError:
logger.debug("聊天流管理器不可用,跳过聊天流钩子注册")
except Exception as e:
logger.error(f"注册聊天流钩子失败: {e}")
async def _register_pre_response_hook(self):
"""注册回复前钩子"""
try:
# 尝试注册到回复生成器
try:
from src.chat.replyer.default_generator import default_generator
if hasattr(default_generator, "register_pre_generation_hook"):
default_generator.register_pre_generation_hook(self._on_pre_response_hook)
logger.debug("已注册到回复生成器的前置钩子")
except ImportError:
logger.debug("回复生成器不可用,跳过回复前钩子注册")
except Exception as e:
logger.error(f"注册回复前钩子失败: {e}")
async def _register_knowledge_query_hook(self):
"""注册知识库查询钩子"""
try:
# 尝试注册到知识库系统
try:
from src.chat.knowledge.knowledge_lib import knowledge_manager
if hasattr(knowledge_manager, "register_query_enhancer"):
knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook)
logger.debug("已注册到知识库的查询增强钩子")
except ImportError:
logger.debug("知识库系统不可用,跳过知识库钩子注册")
except Exception as e:
logger.error(f"注册知识库查询钩子失败: {e}")
async def _register_prompt_building_hook(self):
"""注册提示词构建钩子"""
try:
# 尝试注册到提示词系统
try:
from src.chat.utils.prompt import prompt_manager
if hasattr(prompt_manager, "register_enhancer"):
prompt_manager.register_enhancer(self._on_prompt_building_hook)
logger.debug("已注册到提示词管理器的增强钩子")
except ImportError:
logger.debug("提示词系统不可用,跳过提示词钩子注册")
except Exception as e:
logger.error(f"注册提示词构建钩子失败: {e}")
async def _register_system_maintenance_hook(self):
"""注册系统维护钩子"""
try:
# 尝试注册到系统维护器
try:
from src.manager.async_task_manager import async_task_manager
# 注册定期维护任务
async_task_manager.add_task(MemoryMaintenanceTask())
logger.debug("已注册到系统维护器的定期任务")
except ImportError:
logger.debug("异步任务管理器不可用,跳过系统维护钩子注册")
except Exception as e:
logger.error(f"注册系统维护钩子失败: {e}")
# 钩子处理器方法
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
"""事件系统的消息处理处理器"""
return await self._on_message_processed_hook(event_data)
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
"""消息后处理钩子"""
start_time = time.time()
try:
self.hook_stats["message_processing_hooks"] += 1
# 提取必要的信息
message_info = message_data.get("message_info", {})
user_info = message_info.get("user_info", {})
conversation_text = message_data.get("processed_plain_text", "")
if not conversation_text:
return HookResult(success=True, data="No conversation text")
user_id = str(user_info.get("user_id", "unknown"))
context = {
"chat_id": message_data.get("chat_id"),
"message_type": message_data.get("message_type", "normal"),
"platform": message_info.get("platform", "unknown"),
"interest_value": message_data.get("interest_value", 0.0),
"keywords": message_data.get("key_words", []),
"timestamp": message_data.get("time", time.time()),
}
# 使用增强记忆系统处理对话
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
if result["success"]:
logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
"""聊天流保存钩子"""
start_time = time.time()
try:
self.hook_stats["message_processing_hooks"] += 1
# 从聊天流数据中提取对话信息
stream_context = chat_stream_data.get("stream_context", {})
user_id = stream_context.get("user_id", "unknown")
messages = stream_context.get("messages", [])
if not messages:
return HookResult(success=True, data="No messages to process")
# 构建对话文本
conversation_parts = []
for msg in messages[-10:]: # 只处理最近10条消息
text = msg.get("processed_plain_text", "")
if text:
conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}")
conversation_text = "\n".join(conversation_parts)
if not conversation_text:
return HookResult(success=True, data="No conversation text")
context = {
"chat_id": chat_stream_data.get("chat_id"),
"stream_id": chat_stream_data.get("stream_id"),
"platform": chat_stream_data.get("platform", "unknown"),
"message_count": len(messages),
"timestamp": time.time(),
}
# 使用增强记忆系统处理对话
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
if result["success"]:
logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
"""回复前钩子"""
start_time = time.time()
try:
self.hook_stats["memory_retrieval_hooks"] += 1
# 提取查询信息
query = response_data.get("query", "")
user_id = response_data.get("user_id", "unknown")
context = response_data.get("context", {})
if not query:
return HookResult(success=True, data="No query provided")
# 检索相关记忆
memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 将记忆添加到响应数据中
response_data["enhanced_memories"] = memories
response_data["enhanced_memory_context"] = await get_memory_context_for_prompt(
query, user_id, context, max_memories=5
)
logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆")
return HookResult(success=True, data=memories, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
"""知识库查询钩子"""
start_time = time.time()
try:
self.hook_stats["memory_retrieval_hooks"] += 1
query = query_data.get("query", "")
user_id = query_data.get("user_id", "unknown")
context = query_data.get("context", {})
if not query:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文并增强查询
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 将记忆上下文添加到查询数据中
query_data["enhanced_memory_context"] = memory_context
logger.debug("知识库查询钩子执行成功")
return HookResult(success=True, data=memory_context, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
"""提示词构建钩子"""
start_time = time.time()
try:
self.hook_stats["prompt_enhancement_hooks"] += 1
query = prompt_data.get("query", "")
user_id = prompt_data.get("user_id", "unknown")
context = prompt_data.get("context", {})
base_prompt = prompt_data.get("base_prompt", "")
if not query:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 构建增强的提示词
enhanced_prompt = base_prompt
if memory_context:
enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n"
# 将增强的提示词添加到数据中
prompt_data["enhanced_prompt"] = enhanced_prompt
prompt_data["memory_context"] = memory_context
logger.debug("提示词构建钩子执行成功")
return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
def _update_hook_stats(self, processing_time: float):
"""更新钩子统计"""
self.hook_stats["total_hook_executions"] += 1
total_executions = self.hook_stats["total_hook_executions"]
if total_executions > 0:
current_avg = self.hook_stats["average_hook_time"]
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
self.hook_stats["average_hook_time"] = new_avg
def get_hook_stats(self) -> dict[str, Any]:
"""获取钩子统计信息"""
return self.hook_stats.copy()
class MemoryMaintenanceTask:
"""记忆系统维护任务"""
def __init__(self):
self.task_name = "enhanced_memory_maintenance"
self.interval = 3600 # 1小时执行一次
async def execute(self):
"""执行维护任务"""
try:
logger.info("🔧 执行增强记忆系统维护任务...")
# 获取适配器实例
try:
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
if _enhanced_memory_adapter:
await _enhanced_memory_adapter.maintenance()
logger.info("✅ 增强记忆系统维护任务完成")
else:
logger.debug("增强记忆适配器未初始化,跳过维护")
except Exception as e:
logger.error(f"增强记忆系统维护失败: {e}")
except Exception as e:
logger.error(f"执行维护任务时发生异常: {e}", exc_info=True)
def get_interval(self) -> int:
"""获取执行间隔"""
return self.interval
def get_task_name(self) -> str:
"""获取任务名称"""
return self.task_name
# 全局钩子实例
_memory_hooks: MemoryIntegrationHooks | None = None
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
"""获取全局记忆集成钩子实例"""
global _memory_hooks
if _memory_hooks is None:
_memory_hooks = MemoryIntegrationHooks()
await _memory_hooks.register_hooks()
return _memory_hooks
async def initialize_memory_integration_hooks():
"""初始化记忆集成钩子"""
try:
logger.info("🚀 初始化记忆集成钩子...")
hooks = await get_memory_integration_hooks()
logger.info("✅ 记忆集成钩子初始化完成")
return hooks
except Exception as e:
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
return None

File diff suppressed because it is too large Load Diff

View File

@@ -1,875 +0,0 @@
"""
向量数据库存储接口
为记忆系统提供高效的向量存储和语义搜索能力
"""
import asyncio
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import orjson
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.config_helpers import resolve_embedding_dimension
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
# 尝试导入FAISS如果不可用则使用简单替代
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger.warning("FAISS not available, using simple vector storage")
@dataclass
class VectorStorageConfig:
"""向量存储配置"""
dimension: int = 1024
similarity_threshold: float = 0.8
index_type: str = "flat" # flat, ivf, hnsw
max_index_size: int = 100000
storage_path: str = "data/memory_vectors"
auto_save_interval: int = 10 # 每N次操作自动保存
enable_compression: bool = True
class VectorStorageManager:
"""向量存储管理器"""
def __init__(self, config: VectorStorageConfig | None = None):
self.config = config or VectorStorageConfig()
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
if resolved_dimension and resolved_dimension != self.config.dimension:
logger.info(
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
resolved_dimension,
self.config.dimension,
)
self.config.dimension = resolved_dimension
self.storage_path = Path(self.config.storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
# 向量索引
self.vector_index = None
self.memory_id_to_index = {} # memory_id -> vector index
self.index_to_memory_id = {} # vector index -> memory_id
# 内存缓存
self.memory_cache: dict[str, MemoryChunk] = {}
self.vector_cache: dict[str, list[float]] = {}
# 统计信息
self.storage_stats = {
"total_vectors": 0,
"index_build_time": 0.0,
"average_search_time": 0.0,
"cache_hit_rate": 0.0,
"total_searches": 0,
"cache_hits": 0,
}
# 线程锁
self._lock = threading.RLock()
self._operation_count = 0
# 初始化索引
self._initialize_index()
# 嵌入模型
self.embedding_model: LLMRequest = None
def _initialize_index(self):
"""初始化向量索引"""
try:
if FAISS_AVAILABLE:
if self.config.index_type == "flat":
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
elif self.config.index_type == "ivf":
quantizer = faiss.IndexFlatIP(self.config.dimension)
nlist = min(100, max(1, self.config.max_index_size // 1000))
self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist)
elif self.config.index_type == "hnsw":
self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32)
self.vector_index.hnsw.efConstruction = 40
else:
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
else:
# 简单的向量存储实现
self.vector_index = SimpleVectorIndex(self.config.dimension)
logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}")
except Exception as e:
logger.error(f"❌ 向量索引初始化失败: {e}")
# 回退到简单实现
self.vector_index = SimpleVectorIndex(self.config.dimension)
async def initialize_embedding_model(self):
"""初始化嵌入模型"""
if self.embedding_model is None:
self.embedding_model = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="memory.embedding"
)
logger.info("✅ 嵌入模型初始化完成")
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
"""生成查询向量,用于记忆召回"""
if not query_text:
logger.warning("查询文本为空,无法生成向量")
return None
try:
await self.initialize_embedding_model()
logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'")
embedding, _ = await self.embedding_model.get_embedding(query_text)
if not embedding:
logger.warning("嵌入模型返回空向量")
return None
logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}")
if len(embedding) != self.config.dimension:
logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding))
return None
normalized_vector = self._normalize_vector(embedding)
logger.debug(f"查询向量生成成功,向量范围: [{min(normalized_vector):.4f}, {max(normalized_vector):.4f}]")
return normalized_vector
except Exception as exc:
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
return None
async def store_memories(self, memories: list[MemoryChunk]):
"""存储记忆向量"""
if not memories:
return
start_time = time.time()
try:
# 确保嵌入模型已初始化
await self.initialize_embedding_model()
# 批量获取嵌入向量
memory_texts = []
for memory in memories:
# 预先缓存记忆,确保后续流程可访问
self.memory_cache[memory.memory_id] = memory
if memory.embedding is None:
# 如果没有嵌入向量,需要生成
text = self._prepare_embedding_text(memory)
memory_texts.append((memory.memory_id, text))
else:
# 已有嵌入向量,直接使用
await self._add_single_memory(memory, memory.embedding)
# 批量生成缺失的嵌入向量
if memory_texts:
await self._batch_generate_and_store_embeddings(memory_texts)
# 自动保存检查
self._operation_count += len(memories)
if self._operation_count >= self.config.auto_save_interval:
await self.save_storage()
self._operation_count = 0
storage_time = time.time() - start_time
logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}")
except Exception as e:
logger.error(f"❌ 向量存储失败: {e}", exc_info=True)
def _prepare_embedding_text(self, memory: MemoryChunk) -> str:
"""准备用于嵌入的文本,仅使用自然语言展示内容"""
display_text = (memory.display or "").strip()
if display_text:
return display_text
fallback_text = (memory.text_content or "").strip()
if fallback_text:
return fallback_text
subjects = "".join(s.strip() for s in memory.subjects if s and isinstance(s, str))
predicate = (memory.content.predicate or "").strip()
obj = memory.content.object
if isinstance(obj, dict):
object_parts = []
for key, value in obj.items():
if value is None:
continue
if isinstance(value, (list, tuple)):
preview = "".join(str(item) for item in value[:3])
object_parts.append(f"{key}:{preview}")
else:
object_parts.append(f"{key}:{value}")
object_text = ", ".join(object_parts)
else:
object_text = str(obj or "").strip()
composite_parts = [part for part in [subjects, predicate, object_text] if part]
if composite_parts:
return " ".join(composite_parts)
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
return memory.memory_id
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
"""批量生成和存储嵌入向量"""
if not memory_texts:
return
try:
texts = [text for _, text in memory_texts]
memory_ids = [memory_id for memory_id, _ in memory_texts]
# 批量生成嵌入向量
embeddings = await self._batch_generate_embeddings(memory_ids, texts)
# 存储向量和记忆
for memory_id, embedding in embeddings.items():
if embedding and len(embedding) == self.config.dimension:
memory = self.memory_cache.get(memory_id)
if memory:
await self._add_single_memory(memory, embedding)
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
"""批量生成嵌入向量"""
if not texts:
return {}
results: dict[str, list[float]] = {}
try:
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
async def generate_embedding(memory_id: str, text: str) -> None:
async with semaphore:
try:
embedding, _ = await self.embedding_model.get_embedding(text)
if embedding and len(embedding) == self.config.dimension:
results[memory_id] = embedding
else:
logger.warning(
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
self.config.dimension,
len(embedding) if embedding else 0,
memory_id,
)
results[memory_id] = []
except Exception as exc:
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
results[memory_id] = []
tasks = [
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
]
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
for memory_id in memory_ids:
results.setdefault(memory_id, [])
return results
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
"""添加单个记忆到向量存储"""
with self._lock:
try:
# 规范化向量
if embedding:
embedding = self._normalize_vector(embedding)
# 添加到缓存
self.memory_cache[memory.memory_id] = memory
self.vector_cache[memory.memory_id] = embedding
# 更新记忆的嵌入向量
memory.set_embedding(embedding)
# 添加到向量索引
if hasattr(self.vector_index, "add"):
# FAISS索引
if isinstance(embedding, np.ndarray):
vector_array = embedding.reshape(1, -1).astype("float32")
else:
vector_array = np.array([embedding], dtype="float32")
# 特殊处理IVF索引
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
# IVF索引需要先训练
logger.debug("训练IVF索引...")
self.vector_index.train(vector_array)
self.vector_index.add(vector_array)
index_id = self.vector_index.ntotal - 1
else:
# 简单索引
index_id = self.vector_index.add_vector(embedding)
# 更新映射关系
self.memory_id_to_index[memory.memory_id] = index_id
self.index_to_memory_id[index_id] = memory.memory_id
# 更新统计
self.storage_stats["total_vectors"] += 1
except Exception as e:
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
def _normalize_vector(self, vector: list[float]) -> list[float]:
"""L2归一化向量"""
if not vector:
return vector
try:
vector_array = np.array(vector, dtype=np.float32)
norm = np.linalg.norm(vector_array)
if norm == 0:
return vector
normalized = vector_array / norm
return normalized.tolist()
except Exception as e:
logger.warning(f"向量归一化失败: {e}")
return vector
async def search_similar_memories(
self,
query_vector: list[float] | None = None,
*,
query_text: str | None = None,
limit: int = 10,
scope_id: str | None = None,
) -> list[tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
try:
logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}")
if query_vector is None:
if not query_text:
logger.warning("查询向量和查询文本都为空")
return []
query_vector = await self.generate_query_embedding(query_text)
if not query_vector:
logger.warning("查询向量生成失败")
return []
scope_filter: str | None = None
if isinstance(scope_id, str):
normalized_scope = scope_id.strip().lower()
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
scope_filter = scope_id
elif scope_id:
scope_filter = str(scope_id)
# 规范化查询向量
query_vector = self._normalize_vector(query_vector)
logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}")
# 检查向量索引状态
if not self.vector_index:
logger.error("向量索引未初始化")
return []
total_vectors = 0
if hasattr(self.vector_index, "ntotal"):
total_vectors = self.vector_index.ntotal
elif hasattr(self.vector_index, "vectors"):
total_vectors = len(self.vector_index.vectors)
logger.debug(f"向量索引中实际向量数: {total_vectors}")
if total_vectors == 0:
logger.warning("向量索引为空,无法执行搜索")
return []
# 执行向量搜索
with self._lock:
if hasattr(self.vector_index, "search"):
# FAISS索引
if isinstance(query_vector, np.ndarray):
query_array = query_vector.reshape(1, -1).astype("float32")
else:
query_array = np.array([query_vector], dtype="float32")
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
# 设置IVF搜索参数
nprobe = min(self.vector_index.nlist, 10)
self.vector_index.nprobe = nprobe
logger.debug(f"IVF搜索参数: nprobe={nprobe}")
search_limit = min(limit, total_vectors)
logger.debug(f"执行FAISS搜索搜索限制: {search_limit}")
distances, indices = self.vector_index.search(query_array, search_limit)
distances = distances.flatten().tolist()
indices = indices.flatten().tolist()
logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引")
else:
# 简单索引
logger.debug("使用简单向量索引执行搜索")
results = self.vector_index.search(query_vector, limit)
distances = [score for _, score in results]
indices = [idx for idx, _ in results]
logger.debug(f"简单索引搜索结果: {len(results)} 个结果")
# 处理搜索结果
results = []
valid_results = 0
invalid_indices = 0
filtered_by_scope = 0
for distance, index in zip(distances, indices, strict=False):
if index == -1: # FAISS的无效索引标记
invalid_indices += 1
continue
memory_id = self.index_to_memory_id.get(index)
if not memory_id:
logger.debug(f"索引 {index} 没有对应的记忆ID")
invalid_indices += 1
continue
if scope_filter:
memory = self.memory_cache.get(memory_id)
if memory and str(memory.user_id) != scope_filter:
filtered_by_scope += 1
continue
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
results.append((memory_id, similarity))
valid_results += 1
logger.debug(
f"搜索结果处理: 总距离={len(distances)}, 有效结果={valid_results}, "
f"无效索引={invalid_indices}, 作用域过滤={filtered_by_scope}"
)
# 更新统计
search_time = time.time() - start_time
self.storage_stats["total_searches"] += 1
self.storage_stats["average_search_time"] = (
self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time
) / self.storage_stats["total_searches"]
final_results = results[:limit]
logger.info(
f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' "
f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果"
)
return final_results
except Exception as e:
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
return []
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
"""根据ID获取记忆"""
# 先检查缓存
if memory_id in self.memory_cache:
self.storage_stats["cache_hits"] += 1
return self.memory_cache[memory_id]
self.storage_stats["total_searches"] += 1
return None
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
"""更新记忆的嵌入向量"""
with self._lock:
try:
if memory_id not in self.memory_id_to_index:
logger.warning(f"记忆 {memory_id} 不存在于向量索引中")
return
# 获取旧索引
old_index = self.memory_id_to_index[memory_id]
# 删除旧向量(如果支持)
if hasattr(self.vector_index, "remove_ids"):
try:
self.vector_index.remove_ids(np.array([old_index]))
except:
logger.warning("无法删除旧向量,将直接添加新向量")
# 规范化新向量
new_embedding = self._normalize_vector(new_embedding)
# 添加新向量
if hasattr(self.vector_index, "add"):
if isinstance(new_embedding, np.ndarray):
vector_array = new_embedding.reshape(1, -1).astype("float32")
else:
vector_array = np.array([new_embedding], dtype="float32")
self.vector_index.add(vector_array)
new_index = self.vector_index.ntotal - 1
else:
new_index = self.vector_index.add_vector(new_embedding)
# 更新映射关系
self.memory_id_to_index[memory_id] = new_index
self.index_to_memory_id[new_index] = memory_id
# 更新缓存
self.vector_cache[memory_id] = new_embedding
# 更新记忆对象
memory = self.memory_cache.get(memory_id)
if memory:
memory.set_embedding(new_embedding)
logger.debug(f"更新记忆 {memory_id} 的嵌入向量")
except Exception as e:
logger.error(f"❌ 更新记忆嵌入向量失败: {e}")
async def delete_memory(self, memory_id: str):
"""删除记忆"""
with self._lock:
try:
if memory_id not in self.memory_id_to_index:
return
# 获取索引
index = self.memory_id_to_index[memory_id]
# 从向量索引中删除(如果支持)
if hasattr(self.vector_index, "remove_ids"):
try:
self.vector_index.remove_ids(np.array([index]))
except:
logger.warning("无法从向量索引中删除,仅从缓存中移除")
# 删除映射关系
del self.memory_id_to_index[memory_id]
if index in self.index_to_memory_id:
del self.index_to_memory_id[index]
# 从缓存中删除
self.memory_cache.pop(memory_id, None)
self.vector_cache.pop(memory_id, None)
# 更新统计
self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1)
logger.debug(f"删除记忆 {memory_id}")
except Exception as e:
logger.error(f"❌ 删除记忆失败: {e}")
async def save_storage(self):
"""保存向量存储到文件"""
try:
logger.info("正在保存向量存储...")
# 保存记忆缓存
cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()}
cache_file = self.storage_path / "memory_cache.json"
with open(cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
with open(vector_cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存映射关系
mapping_file = self.storage_path / "id_mapping.json"
mapping_data = {
"memory_id_to_index": {
str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items()
},
"index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()},
}
with open(mapping_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存FAISS索引如果可用
if FAISS_AVAILABLE and hasattr(self.vector_index, "save"):
index_file = self.storage_path / "vector_index.faiss"
faiss.write_index(self.vector_index, str(index_file))
# 保存统计信息
stats_file = self.storage_path / "storage_stats.json"
with open(stats_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
logger.info("✅ 向量存储保存完成")
except Exception as e:
logger.error(f"❌ 保存向量存储失败: {e}")
async def load_storage(self):
"""从文件加载向量存储"""
try:
logger.info("正在加载向量存储...")
# 加载记忆缓存
cache_file = self.storage_path / "memory_cache.json"
if cache_file.exists():
with open(cache_file, encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
self.memory_cache = {
memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items()
}
# 加载向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
if vector_cache_file.exists():
with open(vector_cache_file, encoding="utf-8") as f:
self.vector_cache = orjson.loads(f.read())
# 加载映射关系
mapping_file = self.storage_path / "id_mapping.json"
if mapping_file.exists():
with open(mapping_file, encoding="utf-8") as f:
mapping_data = orjson.loads(f.read())
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items()
}
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()}
# 加载FAISS索引如果可用
index_loaded = False
if FAISS_AVAILABLE:
index_file = self.storage_path / "vector_index.faiss"
if index_file.exists():
try:
loaded_index = faiss.read_index(str(index_file))
# 如果索引类型匹配,则替换
if type(loaded_index) == type(self.vector_index):
self.vector_index = loaded_index
index_loaded = True
logger.info("✅ FAISS索引文件加载完成")
else:
logger.warning("索引类型不匹配,重新构建索引")
except Exception as e:
logger.warning(f"加载FAISS索引失败: {e},重新构建")
else:
logger.info("FAISS索引文件不存在将重新构建")
# 如果索引没有成功加载且有向量数据,则重建索引
if not index_loaded and self.vector_cache:
logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引")
await self._rebuild_index()
# 加载统计信息
stats_file = self.storage_path / "storage_stats.json"
if stats_file.exists():
with open(stats_file, encoding="utf-8") as f:
self.storage_stats = orjson.loads(f.read())
# 更新向量计数
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量")
except Exception as e:
logger.error(f"❌ 加载向量存储失败: {e}")
async def _rebuild_index(self):
"""重建向量索引"""
try:
logger.info(f"正在重建向量索引...向量数量: {len(self.vector_cache)}")
# 重新初始化索引
self._initialize_index()
# 清空映射关系
self.memory_id_to_index.clear()
self.index_to_memory_id.clear()
if not self.vector_cache:
logger.warning("没有向量缓存数据,跳过重建")
return
# 准备向量数据
memory_ids = []
vectors = []
for memory_id, embedding in self.vector_cache.items():
if embedding and len(embedding) == self.config.dimension:
memory_ids.append(memory_id)
vectors.append(self._normalize_vector(embedding))
else:
logger.debug(f"跳过无效向量: {memory_id}, 维度: {len(embedding) if embedding else 0}")
if not vectors:
logger.warning("没有有效的向量数据")
return
logger.info(f"准备重建 {len(vectors)} 个向量到索引")
# 批量添加向量到FAISS索引
if hasattr(self.vector_index, "add"):
# FAISS索引
vector_array = np.array(vectors, dtype="float32")
# 特殊处理IVF索引
if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"):
logger.info("训练IVF索引...")
self.vector_index.train(vector_array)
# 添加向量
self.vector_index.add(vector_array)
# 重建映射关系
for i, memory_id in enumerate(memory_ids):
self.memory_id_to_index[memory_id] = i
self.index_to_memory_id[i] = memory_id
else:
# 简单索引
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)):
index_id = self.vector_index.add_vector(vector)
self.memory_id_to_index[memory_id] = index_id
self.index_to_memory_id[index_id] = memory_id
# 更新统计
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index))
logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}")
except Exception as e:
logger.error(f"❌ 重建向量索引失败: {e}", exc_info=True)
async def optimize_storage(self):
"""优化存储"""
try:
logger.info("开始向量存储优化...")
# 清理无效引用
self._cleanup_invalid_references()
# 重新构建索引(如果碎片化严重)
if self.storage_stats["total_vectors"] > 1000:
await self._rebuild_index()
# 更新缓存命中率
if self.storage_stats["total_searches"] > 0:
self.storage_stats["cache_hit_rate"] = (
self.storage_stats["cache_hits"] / self.storage_stats["total_searches"]
)
logger.info("✅ 向量存储优化完成")
except Exception as e:
logger.error(f"❌ 向量存储优化失败: {e}")
def _cleanup_invalid_references(self):
"""清理无效引用"""
with self._lock:
# 清理无效的memory_id到index的映射
valid_memory_ids = set(self.memory_cache.keys())
invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids
for memory_id in invalid_memory_ids:
index = self.memory_id_to_index[memory_id]
del self.memory_id_to_index[memory_id]
if index in self.index_to_memory_id:
del self.index_to_memory_id[index]
if invalid_memory_ids:
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
def get_storage_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
stats = self.storage_stats.copy()
if stats["total_searches"] > 0:
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"]
else:
stats["cache_hit_rate"] = 0.0
return stats
class SimpleVectorIndex:
"""简单的向量索引实现当FAISS不可用时的替代方案"""
def __init__(self, dimension: int):
self.dimension = dimension
self.vectors: list[list[float]] = []
self.vector_ids: list[int] = []
self.next_id = 0
def add_vector(self, vector: list[float]) -> int:
"""添加向量"""
if len(vector) != self.dimension:
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
vector_id = self.next_id
self.vectors.append(vector.copy())
self.vector_ids.append(vector_id)
self.next_id += 1
return vector_id
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
"""搜索相似向量"""
if len(query_vector) != self.dimension:
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
results = []
for i, vector in enumerate(self.vectors):
similarity = self._calculate_cosine_similarity(query_vector, vector)
results.append((self.vector_ids[i], similarity))
# 按相似度排序
results.sort(key=lambda x: x[1], reverse=True)
return results[:limit]
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
"""计算余弦相似度"""
try:
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
norm1 = sum(x * x for x in v1) ** 0.5
norm2 = sum(x * x for x in v2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
except Exception:
return 0.0
@property
def ntotal(self) -> int:
"""向量总数"""
return len(self.vectors)