refactor(memory): 移除废弃的记忆系统备份文件,优化消息管理器架构

移除了deprecated_backup目录下的所有废弃记忆系统文件,包括增强记忆适配器、钩子、集成层、重排序器、元数据索引、多阶段检索和向量存储等模块。同时优化了消息管理器,集成了批量数据库写入器、流缓存管理器和自适应流管理器,提升了系统性能和可维护性。
This commit is contained in:
Windpicker-owo
2025-10-04 01:38:41 +08:00
parent 80fc37fd02
commit a5a16971e9
16 changed files with 1975 additions and 5203 deletions

View File

@@ -1,363 +0,0 @@
"""
增强记忆系统适配器
将增强记忆系统集成到现有MoFox Bot架构中
"""
import time
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
MEMORY_TYPE_LABELS = {
MemoryType.PERSONAL_FACT: "个人事实",
MemoryType.EVENT: "事件",
MemoryType.PREFERENCE: "偏好",
MemoryType.OPINION: "观点",
MemoryType.RELATIONSHIP: "关系",
MemoryType.EMOTION: "情感",
MemoryType.KNOWLEDGE: "知识",
MemoryType.SKILL: "技能",
MemoryType.GOAL: "目标",
MemoryType.EXPERIENCE: "经验",
MemoryType.CONTEXTUAL: "上下文",
}
@dataclass
class AdapterConfig:
"""适配器配置"""
enable_enhanced_memory: bool = True
integration_mode: str = "enhanced_only" # replace, enhanced_only
auto_migration: bool = True
memory_value_threshold: float = 0.6
fusion_threshold: float = 0.85
max_retrieval_results: int = 10
class EnhancedMemoryAdapter:
"""增强记忆系统适配器"""
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
self.llm_model = llm_model
self.config = config or AdapterConfig()
self.integration_layer: MemoryIntegrationLayer | None = None
self._initialized = False
# 统计信息
self.adapter_stats = {
"total_processed": 0,
"enhanced_used": 0,
"legacy_used": 0,
"hybrid_used": 0,
"memories_created": 0,
"memories_retrieved": 0,
"average_processing_time": 0.0,
}
async def initialize(self):
"""初始化适配器"""
if self._initialized:
return
try:
logger.info("🚀 初始化增强记忆系统适配器...")
# 转换配置格式
integration_config = IntegrationConfig(
mode=IntegrationMode(self.config.integration_mode),
enable_enhanced_memory=self.config.enable_enhanced_memory,
memory_value_threshold=self.config.memory_value_threshold,
fusion_threshold=self.config.fusion_threshold,
max_retrieval_results=self.config.max_retrieval_results,
enable_learning=True, # 启用学习功能
)
# 创建集成层
self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config)
# 初始化集成层
await self.integration_layer.initialize()
self._initialized = True
logger.info("✅ 增强记忆系统适配器初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True)
# 如果初始化失败,禁用增强记忆功能
self.config.enable_enhanced_memory = False
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
start_time = time.time()
self.adapter_stats["total_processed"] += 1
try:
payload_context: dict[str, Any] = dict(context or {})
conversation_text = payload_context.get("conversation_text")
if not conversation_text:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
conversation_text = str(conversation_candidate)
payload_context["conversation_text"] = conversation_text
else:
conversation_text = ""
else:
conversation_text = str(conversation_text)
if "timestamp" not in payload_context:
payload_context["timestamp"] = time.time()
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
# 使用集成层处理对话
result = await self.integration_layer.process_conversation(payload_context)
# 更新统计
processing_time = time.time() - start_time
self._update_processing_stats(processing_time)
if result["success"]:
created_count = len(result.get("created_memories", []))
self.adapter_stats["memories_created"] += created_count
logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆")
return result
except Exception as e:
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.config.enable_enhanced_memory:
return []
try:
limit = limit or self.config.max_retrieval_results
memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit)
self.adapter_stats["memories_retrieved"] += len(memories)
logger.debug(f"检索到 {len(memories)} 条相关记忆")
return memories
except Exception as e:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_memory_context_for_prompt(
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
) -> str:
"""获取用于提示词的记忆上下文"""
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
if not memories:
return ""
# 使用新的记忆格式化器
formatter_config = FormatterConfig(
include_timestamps=True,
include_memory_types=True,
include_confidence=False,
use_emoji_icons=True,
group_by_type=False,
max_display_length=150,
)
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
"""获取增强记忆系统摘要"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"available": False, "reason": "Not initialized or disabled"}
try:
# 获取系统状态
status = await self.integration_layer.get_system_status()
# 获取适配器统计
adapter_stats = self.adapter_stats.copy()
# 获取集成统计
integration_stats = self.integration_layer.get_integration_stats()
return {
"available": True,
"system_status": status,
"adapter_stats": adapter_stats,
"integration_stats": integration_stats,
"total_memories_created": adapter_stats["memories_created"],
"total_memories_retrieved": adapter_stats["memories_retrieved"],
}
except Exception as e:
logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True)
return {"available": False, "error": str(e)}
def _update_processing_stats(self, processing_time: float):
"""更新处理统计"""
total_processed = self.adapter_stats["total_processed"]
if total_processed > 0:
current_avg = self.adapter_stats["average_processing_time"]
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
self.adapter_stats["average_processing_time"] = new_avg
def get_adapter_stats(self) -> dict[str, Any]:
"""获取适配器统计信息"""
return self.adapter_stats.copy()
async def maintenance(self):
"""维护操作"""
if not self._initialized:
return
try:
logger.info("🔧 增强记忆系统适配器维护...")
await self.integration_layer.maintenance()
logger.info("✅ 增强记忆系统适配器维护完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭适配器"""
if not self._initialized:
return
try:
logger.info("🔄 关闭增强记忆系统适配器...")
await self.integration_layer.shutdown()
self._initialized = False
logger.info("✅ 增强记忆系统适配器已关闭")
except Exception as e:
logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True)
# 全局适配器实例
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
"""获取全局增强记忆适配器实例"""
global _enhanced_memory_adapter
if _enhanced_memory_adapter is None:
# 从配置中获取适配器配置
from src.config.config import global_config
adapter_config = AdapterConfig(
enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True),
integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"),
auto_migration=getattr(global_config.memory, "enable_memory_migration", True),
memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6),
fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85),
max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10),
)
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
await _enhanced_memory_adapter.initialize()
return _enhanced_memory_adapter
async def initialize_enhanced_memory_system(llm_model: LLMRequest):
"""初始化增强记忆系统"""
try:
logger.info("🚀 初始化增强记忆系统...")
adapter = await get_enhanced_memory_adapter(llm_model)
logger.info("✅ 增强记忆系统初始化完成")
return adapter
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
return None
async def process_conversation_with_enhanced_memory(
context: dict[str, Any], llm_model: LLMRequest | None = None
) -> dict[str, Any]:
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
payload_context = dict(context or {})
if "conversation_text" not in payload_context:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
payload_context["conversation_text"] = str(conversation_candidate)
return await adapter.process_conversation_memory(payload_context)
except Exception as e:
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_memories_with_enhanced_system(
query: str,
user_id: str,
context: dict[str, Any] | None = None,
limit: int = 10,
llm_model: LLMRequest | None = None,
) -> list[MemoryChunk]:
"""使用增强记忆系统检索记忆"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.retrieve_relevant_memories(query, user_id, context, limit)
except Exception as e:
logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True)
return []
async def get_memory_context_for_prompt(
query: str,
user_id: str,
context: dict[str, Any] | None = None,
max_memories: int = 5,
llm_model: LLMRequest | None = None,
) -> str:
"""获取用于提示词的记忆上下文"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
except Exception as e:
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
return ""

View File

@@ -1,194 +0,0 @@
"""
增强记忆系统钩子
用于在消息处理过程中自动构建和检索记忆
"""
from datetime import datetime
from typing import Any
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger(__name__)
class EnhancedMemoryHooks:
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
def __init__(self):
self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory
self.processed_messages = set() # 避免重复处理
async def process_message_for_memory(
self,
message_content: str,
user_id: str,
chat_id: str,
message_id: str,
context: dict[str, Any] | None = None,
) -> bool:
"""
处理消息并构建记忆
Args:
message_content: 消息内容
user_id: 用户ID
chat_id: 聊天ID
message_id: 消息ID
context: 上下文信息
Returns:
bool: 是否成功处理
"""
if not self.enabled:
return False
if message_id in self.processed_messages:
return False
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 注入机器人基础人设,帮助记忆构建时避免记录自身信息
bot_config = getattr(global_config, "bot", None)
personality_config = getattr(global_config, "personality", None)
bot_context = {}
if bot_config is not None:
bot_context["bot_name"] = getattr(bot_config, "nickname", None)
bot_context["bot_aliases"] = list(getattr(bot_config, "alias_names", []) or [])
bot_context["bot_account"] = getattr(bot_config, "qq_account", None)
if personality_config is not None:
bot_context["bot_identity"] = getattr(personality_config, "identity", None)
bot_context["bot_personality"] = getattr(personality_config, "personality_core", None)
bot_context["bot_personality_side"] = getattr(personality_config, "personality_side", None)
# 构建上下文
memory_context = {
"chat_id": chat_id,
"message_id": message_id,
"timestamp": datetime.now().timestamp(),
"message_type": "user_message",
**bot_context,
**(context or {}),
}
# 处理对话并构建记忆
memory_chunks = await enhanced_memory_manager.process_conversation(
conversation_text=message_content,
context=memory_context,
user_id=user_id,
timestamp=memory_context["timestamp"],
)
# 标记消息已处理
self.processed_messages.add(message_id)
# 限制处理历史大小
if len(self.processed_messages) > 1000:
# 移除最旧的500个记录
self.processed_messages = set(list(self.processed_messages)[-500:])
logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆")
return len(memory_chunks) > 0
except Exception as e:
logger.error(f"处理消息记忆失败: {e}")
return False
async def get_memory_for_response(
self,
query_text: str,
user_id: str,
chat_id: str,
limit: int = 5,
extra_context: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""
为回复获取相关记忆
Args:
query_text: 查询文本
user_id: 用户ID
chat_id: 聊天ID
limit: 返回记忆数量限制
Returns:
List[Dict]: 相关记忆列表
"""
if not self.enabled:
return []
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 构建查询上下文
context = {
"chat_id": chat_id,
"query_intent": "response_generation",
"expected_memory_types": ["personal_fact", "event", "preference", "opinion"],
}
if extra_context:
context.update(extra_context)
# 获取相关记忆
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
query_text=query_text, user_id=user_id, context=context, limit=limit
)
# 转换为字典格式
results = []
for result in enhanced_results:
memory_dict = {
"content": result.content,
"type": result.memory_type,
"confidence": result.confidence,
"importance": result.importance,
"timestamp": result.timestamp,
"source": result.source,
"relevance": result.relevance_score,
"structure": result.structure,
}
results.append(memory_dict)
logger.debug(f"为回复查询到 {len(results)} 条相关记忆")
return results
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return []
async def cleanup_old_memories(self):
"""清理旧记忆"""
try:
if enhanced_memory_manager.is_initialized:
# 调用增强记忆系统的维护功能
await enhanced_memory_manager.enhanced_system.maintenance()
logger.debug("增强记忆系统维护完成")
except Exception as e:
logger.error(f"清理旧记忆失败: {e}")
def clear_processed_cache(self):
"""清除已处理消息的缓存"""
self.processed_messages.clear()
logger.debug("已清除消息处理缓存")
def enable(self):
"""启用记忆钩子"""
self.enabled = True
logger.info("增强记忆钩子已启用")
def disable(self):
"""禁用记忆钩子"""
self.enabled = False
logger.info("增强记忆钩子已禁用")
# 创建全局实例
enhanced_memory_hooks = EnhancedMemoryHooks()

View File

@@ -1,177 +0,0 @@
"""
增强记忆系统集成脚本
用于在现有系统中无缝集成增强记忆功能
"""
from typing import Any
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
from src.common.logger import get_logger
logger = get_logger(__name__)
async def process_user_message_memory(
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
) -> bool:
"""
处理用户消息并构建记忆
Args:
message_content: 消息内容
user_id: 用户ID
chat_id: 聊天ID
message_id: 消息ID
context: 额外的上下文信息
Returns:
bool: 是否成功构建记忆
"""
try:
success = await enhanced_memory_hooks.process_message_for_memory(
message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
)
if success:
logger.debug(f"成功为消息 {message_id} 构建记忆")
return success
except Exception as e:
logger.error(f"处理用户消息记忆失败: {e}")
return False
async def get_relevant_memories_for_response(
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
为回复获取相关记忆
Args:
query_text: 查询文本(通常是用户的当前消息)
user_id: 用户ID
chat_id: 聊天ID
limit: 返回记忆数量限制
extra_context: 额外上下文信息
Returns:
Dict: 包含记忆信息的字典
"""
try:
memories = await enhanced_memory_hooks.get_memory_for_response(
query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context
)
result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)}
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
return result
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return {"has_memories": False, "memories": [], "memory_count": 0}
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
"""
格式化记忆信息用于Prompt
Args:
memories: 记忆信息字典
Returns:
str: 格式化后的记忆文本
"""
if not memories["has_memories"]:
return ""
memory_lines = ["以下是相关的记忆信息:"]
for memory in memories["memories"]:
content = memory["content"]
memory_type = memory["type"]
confidence = memory["confidence"]
importance = memory["importance"]
# 根据重要性添加不同的标记
importance_marker = "🔥" if importance >= 3 else "" if importance >= 2 else "📝"
confidence_marker = "" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭"
memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)"
memory_lines.append(memory_line)
return "\n".join(memory_lines)
async def cleanup_memory_system():
"""清理记忆系统"""
try:
await enhanced_memory_hooks.cleanup_old_memories()
logger.info("记忆系统清理完成")
except Exception as e:
logger.error(f"记忆系统清理失败: {e}")
def get_memory_system_status() -> dict[str, Any]:
"""
获取记忆系统状态
Returns:
Dict: 系统状态信息
"""
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
return {
"enabled": enhanced_memory_hooks.enabled,
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
"system_type": "enhanced_memory_system",
}
# 便捷函数
async def remember_message(
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
) -> bool:
"""
便捷的记忆构建函数
Args:
message: 要记住的消息
user_id: 用户ID
chat_id: 聊天ID
Returns:
bool: 是否成功
"""
import uuid
message_id = str(uuid.uuid4())
return await process_user_message_memory(
message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
)
async def recall_memories(
query: str,
user_id: str = "default_user",
chat_id: str = "default_chat",
limit: int = 5,
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
便捷的记忆检索函数
Args:
query: 查询文本
user_id: 用户ID
chat_id: 聊天ID
limit: 返回数量限制
Returns:
Dict: 记忆信息
"""
return await get_relevant_memories_for_response(
query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context
)

View File

@@ -1,356 +0,0 @@
"""
增强重排序器
实现文档设计的多维度评分模型
"""
import math
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
logger = get_logger(__name__)
class IntentType(Enum):
"""对话意图类型"""
FACT_QUERY = "fact_query" # 事实查询
EVENT_RECALL = "event_recall" # 事件回忆
PREFERENCE_CHECK = "preference_check" # 偏好检查
GENERAL_CHAT = "general_chat" # 一般对话
UNKNOWN = "unknown" # 未知意图
@dataclass
class ReRankingConfig:
"""重排序配置"""
# 权重配置 (w1 + w2 + w3 + w4 = 1.0)
semantic_weight: float = 0.5 # 语义相似度权重
recency_weight: float = 0.2 # 时效性权重
usage_freq_weight: float = 0.2 # 使用频率权重
type_match_weight: float = 0.1 # 类型匹配权重
# 时效性衰减参数
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
# 使用频率计算参数
freq_log_base: float = 2.0 # 对数底数
freq_max_score: float = 5.0 # 最大频率得分
# 类型匹配权重映射
type_match_weights: dict[str, dict[str, float]] = None
def __post_init__(self):
"""初始化类型匹配权重"""
if self.type_match_weights is None:
self.type_match_weights = {
IntentType.FACT_QUERY.value: {
MemoryType.PERSONAL_FACT.value: 1.0,
MemoryType.KNOWLEDGE.value: 0.8,
MemoryType.PREFERENCE.value: 0.5,
MemoryType.EVENT.value: 0.3,
"default": 0.3,
},
IntentType.EVENT_RECALL.value: {
MemoryType.EVENT.value: 1.0,
MemoryType.EXPERIENCE.value: 0.8,
MemoryType.EMOTION.value: 0.6,
MemoryType.PERSONAL_FACT.value: 0.5,
"default": 0.5,
},
IntentType.PREFERENCE_CHECK.value: {
MemoryType.PREFERENCE.value: 1.0,
MemoryType.OPINION.value: 0.8,
MemoryType.GOAL.value: 0.6,
MemoryType.PERSONAL_FACT.value: 0.4,
"default": 0.4,
},
IntentType.GENERAL_CHAT.value: {"default": 0.8},
IntentType.UNKNOWN.value: {"default": 0.8},
}
class IntentClassifier:
"""轻量级意图识别器"""
def __init__(self):
# 关键词模式匹配规则
self.patterns = {
IntentType.FACT_QUERY: [
# 中文模式
"我是",
"我的",
"我叫",
"我在",
"我住在",
"我的职业",
"我的工作",
"什么时候",
"在哪里",
"是什么",
"多少",
"几岁",
"年龄",
# 英文模式
"what is",
"where is",
"when is",
"how old",
"my name",
"i am",
"i live",
],
IntentType.EVENT_RECALL: [
# 中文模式
"记得",
"想起",
"还记得",
"那次",
"上次",
"之前",
"以前",
"曾经",
"发生过",
"经历",
"做过",
"去过",
"见过",
# 英文模式
"remember",
"recall",
"last time",
"before",
"previously",
"happened",
"experience",
],
IntentType.PREFERENCE_CHECK: [
# 中文模式
"喜欢",
"不喜欢",
"偏好",
"爱好",
"兴趣",
"讨厌",
"最爱",
"最喜欢",
"习惯",
"通常",
"一般",
"倾向于",
"更喜欢",
# 英文模式
"like",
"love",
"hate",
"prefer",
"favorite",
"usually",
"tend to",
"interest",
],
}
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
"""识别对话意图"""
if not query:
return IntentType.UNKNOWN
query_lower = query.lower()
# 统计各意图的匹配分数
intent_scores = dict.fromkeys(IntentType, 0)
for intent, patterns in self.patterns.items():
for pattern in patterns:
if pattern in query_lower:
intent_scores[intent] += 1
# 返回得分最高的意图
max_score = max(intent_scores.values())
if max_score == 0:
return IntentType.GENERAL_CHAT
for intent, score in intent_scores.items():
if score == max_score:
return intent
return IntentType.GENERAL_CHAT
class EnhancedReRanker:
"""增强重排序器 - 实现文档设计的多维度评分模型"""
def __init__(self, config: ReRankingConfig | None = None):
self.config = config or ReRankingConfig()
self.intent_classifier = IntentClassifier()
# 验证权重和为1.0
total_weight = (
self.config.semantic_weight
+ self.config.recency_weight
+ self.config.usage_freq_weight
+ self.config.type_match_weight
)
if abs(total_weight - 1.0) > 0.01:
logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化")
# 归一化权重
self.config.semantic_weight /= total_weight
self.config.recency_weight /= total_weight
self.config.usage_freq_weight /= total_weight
self.config.type_match_weight /= total_weight
def rerank_memories(
self,
query: str,
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
context: dict[str, Any],
limit: int = 10,
) -> list[tuple[str, MemoryChunk, float]]:
"""
对候选记忆进行重排序
Args:
query: 查询文本
candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)]
context: 上下文信息
limit: 返回数量限制
Returns:
重排序后的记忆列表 [(memory_id, memory, final_score)]
"""
if not candidate_memories:
return []
# 识别查询意图
intent = self.intent_classifier.classify_intent(query, context)
logger.debug(f"识别到查询意图: {intent.value}")
# 计算每个候选记忆的最终得分
scored_memories = []
current_time = time.time()
for memory_id, memory, vector_sim in candidate_memories:
try:
# 1. 语义相似度得分 (已归一化到[0,1])
semantic_score = self._normalize_similarity(vector_sim)
# 2. 时效性得分
recency_score = self._calculate_recency_score(memory, current_time)
# 3. 使用频率得分
usage_freq_score = self._calculate_usage_frequency_score(memory)
# 4. 类型匹配得分
type_match_score = self._calculate_type_match_score(memory, intent)
# 计算最终得分
final_score = (
self.config.semantic_weight * semantic_score
+ self.config.recency_weight * recency_score
+ self.config.usage_freq_weight * usage_freq_score
+ self.config.type_match_weight * type_match_score
)
scored_memories.append((memory_id, memory, final_score))
# 记录调试信息
logger.debug(
f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, "
f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, "
f"type={type_match_score:.3f}, final={final_score:.3f}"
)
except Exception as e:
logger.error(f"计算记忆 {memory_id} 得分时出错: {e}")
# 使用向量相似度作为后备得分
scored_memories.append((memory_id, memory, vector_sim))
# 按最终得分降序排序
scored_memories.sort(key=lambda x: x[2], reverse=True)
# 返回前N个结果
result = scored_memories[:limit]
highest_score = result[0][2] if result else 0.0
logger.info(
f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, "
f"意图={intent.value}, 最高分={highest_score:.3f}"
)
return result
def _normalize_similarity(self, raw_similarity: float) -> float:
"""归一化相似度到[0,1]区间"""
# 假设原始相似度已经在[-1,1]或[0,1]区间
if raw_similarity < 0:
return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1]
return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间
def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float:
"""
计算时效性得分
公式: Recency = 1 / (1 + decay_rate * days_old)
"""
last_accessed = memory.metadata.last_accessed or memory.metadata.created_at
days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数
if days_old < 0:
days_old = 0 # 处理时间异常
score = 1 / (1 + self.config.recency_decay_rate * days_old)
return min(1.0, max(0.0, score))
def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float:
"""
计算使用频率得分
公式: Usage_Freq = min(1.0, log2(access_count + 1) / max_score)
"""
access_count = memory.metadata.access_count
if access_count <= 0:
return 0.0
log_count = math.log2(access_count + 1)
score = log_count / self.config.freq_max_score
return min(1.0, max(0.0, score))
def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float:
"""计算类型匹配得分"""
memory_type = memory.memory_type.value
intent_value = intent.value
# 获取对应意图的类型权重映射
type_weights = self.config.type_match_weights.get(intent_value, {})
# 查找具体类型的权重,如果没有则使用默认权重
score = type_weights.get(memory_type, type_weights.get("default", 0.8))
return min(1.0, max(0.0, score))
# 创建默认的重排序器实例
default_reranker = EnhancedReRanker()
def rerank_candidate_memories(
query: str,
candidate_memories: list[tuple[str, MemoryChunk, float]],
context: dict[str, Any],
limit: int = 10,
config: ReRankingConfig | None = None,
) -> list[tuple[str, MemoryChunk, float]]:
"""
便捷函数:对候选记忆进行重排序
"""
if config:
reranker = EnhancedReRanker(config)
else:
reranker = default_reranker
return reranker.rerank_memories(query, candidate_memories, context, limit)

View File

@@ -1,245 +0,0 @@
"""
增强记忆系统集成层
现在只管理新的增强记忆系统,旧系统已被完全移除
"""
import asyncio
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
class IntegrationMode(Enum):
"""集成模式"""
REPLACE = "replace" # 完全替换现有记忆系统
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
@dataclass
class IntegrationConfig:
"""集成配置"""
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
enable_enhanced_memory: bool = True
memory_value_threshold: float = 0.6
fusion_threshold: float = 0.85
max_retrieval_results: int = 10
enable_learning: bool = True
class MemoryIntegrationLayer:
"""记忆系统集成层 - 现在只管理增强记忆系统"""
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
self.llm_model = llm_model
self.config = config or IntegrationConfig()
# 只初始化增强记忆系统
self.enhanced_memory: EnhancedMemorySystem | None = None
# 集成统计
self.integration_stats = {
"total_queries": 0,
"enhanced_queries": 0,
"memory_creations": 0,
"average_response_time": 0.0,
"success_rate": 0.0,
}
# 初始化锁
self._initialization_lock = asyncio.Lock()
self._initialized = False
async def initialize(self):
"""初始化集成层"""
if self._initialized:
return
async with self._initialization_lock:
if self._initialized:
return
logger.info("🚀 开始初始化增强记忆系统集成层...")
try:
# 初始化增强记忆系统
if self.config.enable_enhanced_memory:
await self._initialize_enhanced_memory()
self._initialized = True
logger.info("✅ 增强记忆系统集成层初始化完成")
except Exception as e:
logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True)
raise
async def _initialize_enhanced_memory(self):
"""初始化增强记忆系统"""
try:
logger.debug("初始化增强记忆系统...")
# 创建增强记忆系统配置
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
memory_config = MemorySystemConfig.from_global_config()
# 使用集成配置覆盖部分值
memory_config.memory_value_threshold = self.config.memory_value_threshold
memory_config.fusion_similarity_threshold = self.config.fusion_threshold
memory_config.final_recall_limit = self.config.max_retrieval_results
# 创建增强记忆系统
self.enhanced_memory = EnhancedMemorySystem(config=memory_config)
# 如果外部提供了LLM模型注入到系统中
if self.llm_model is not None:
self.enhanced_memory.llm_model = self.llm_model
# 初始化系统
await self.enhanced_memory.initialize()
logger.info("✅ 增强记忆系统初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
raise
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
start_time = time.time()
self.integration_stats["total_queries"] += 1
self.integration_stats["enhanced_queries"] += 1
try:
payload_context = dict(context or {})
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
# 直接使用增强记忆系统处理
result = await self.enhanced_memory.process_conversation_memory(payload_context)
# 更新统计
processing_time = time.time() - start_time
self._update_response_stats(processing_time, result.get("success", False))
if result.get("success"):
created_count = len(result.get("created_memories", []))
self.integration_stats["memory_creations"] += created_count
logger.debug(f"对话处理完成,创建 {created_count} 条记忆")
return result
except Exception as e:
processing_time = time.time() - start_time
self._update_response_stats(processing_time, False)
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self,
query: str,
user_id: str | None = None,
context: dict[str, Any] | None = None,
limit: int | None = None,
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.enhanced_memory:
return []
try:
limit = limit or self.config.max_retrieval_results
memories = await self.enhanced_memory.retrieve_relevant_memories(
query=query, user_id=None, context=context or {}, limit=limit
)
memory_count = len(memories)
logger.debug(f"检索到 {memory_count} 条相关记忆")
return memories
except Exception as e:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_system_status(self) -> dict[str, Any]:
"""获取系统状态"""
if not self._initialized:
return {"status": "not_initialized"}
try:
enhanced_status = {}
if self.enhanced_memory:
enhanced_status = await self.enhanced_memory.get_system_status()
return {
"status": "initialized",
"mode": self.config.mode.value,
"enhanced_memory": enhanced_status,
"integration_stats": self.integration_stats.copy(),
}
except Exception as e:
logger.error(f"获取系统状态失败: {e}", exc_info=True)
return {"status": "error", "error": str(e)}
def get_integration_stats(self) -> dict[str, Any]:
"""获取集成统计信息"""
return self.integration_stats.copy()
def _update_response_stats(self, processing_time: float, success: bool):
"""更新响应统计"""
total_queries = self.integration_stats["total_queries"]
if total_queries > 0:
# 更新平均响应时间
current_avg = self.integration_stats["average_response_time"]
new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries
self.integration_stats["average_response_time"] = new_avg
# 更新成功率
if success:
current_success_rate = self.integration_stats["success_rate"]
new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries
self.integration_stats["success_rate"] = new_success_rate
async def maintenance(self):
"""执行维护操作"""
if not self._initialized:
return
try:
logger.info("🔧 执行记忆系统集成层维护...")
if self.enhanced_memory:
await self.enhanced_memory.maintenance()
logger.info("✅ 记忆系统集成层维护完成")
except Exception as e:
logger.error(f"❌ 集成层维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭集成层"""
if not self._initialized:
return
try:
logger.info("🔄 关闭记忆系统集成层...")
if self.enhanced_memory:
await self.enhanced_memory.shutdown()
self._initialized = False
logger.info("✅ 记忆系统集成层已关闭")
except Exception as e:
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)

View File

@@ -1,526 +0,0 @@
"""
记忆系统集成钩子
提供与现有MoFox Bot系统的无缝集成点
"""
import time
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.enhanced_memory_adapter import (
get_memory_context_for_prompt,
process_conversation_with_enhanced_memory,
retrieve_memories_with_enhanced_system,
)
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class HookResult:
"""钩子执行结果"""
success: bool
data: Any = None
error: str | None = None
processing_time: float = 0.0
class MemoryIntegrationHooks:
"""记忆系统集成钩子"""
def __init__(self):
self.hooks_registered = False
self.hook_stats = {
"message_processing_hooks": 0,
"memory_retrieval_hooks": 0,
"prompt_enhancement_hooks": 0,
"total_hook_executions": 0,
"average_hook_time": 0.0,
}
async def register_hooks(self):
"""注册所有集成钩子"""
if self.hooks_registered:
return
try:
logger.info("🔗 注册记忆系统集成钩子...")
# 注册消息处理钩子
await self._register_message_processing_hooks()
# 注册记忆检索钩子
await self._register_memory_retrieval_hooks()
# 注册提示词增强钩子
await self._register_prompt_enhancement_hooks()
# 注册系统维护钩子
await self._register_maintenance_hooks()
self.hooks_registered = True
logger.info("✅ 记忆系统集成钩子注册完成")
except Exception as e:
logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True)
async def _register_message_processing_hooks(self):
"""注册消息处理钩子"""
try:
# 钩子1: 在消息处理后创建记忆
await self._register_post_message_hook()
# 钩子2: 在聊天流保存时处理记忆
await self._register_chat_stream_hook()
logger.debug("消息处理钩子注册完成")
except Exception as e:
logger.error(f"注册消息处理钩子失败: {e}")
async def _register_memory_retrieval_hooks(self):
"""注册记忆检索钩子"""
try:
# 钩子1: 在生成回复前检索相关记忆
await self._register_pre_response_hook()
# 钩子2: 在知识库查询前增强上下文
await self._register_knowledge_query_hook()
logger.debug("记忆检索钩子注册完成")
except Exception as e:
logger.error(f"注册记忆检索钩子失败: {e}")
async def _register_prompt_enhancement_hooks(self):
"""注册提示词增强钩子"""
try:
# 钩子1: 增强提示词构建
await self._register_prompt_building_hook()
logger.debug("提示词增强钩子注册完成")
except Exception as e:
logger.error(f"注册提示词增强钩子失败: {e}")
async def _register_maintenance_hooks(self):
"""注册系统维护钩子"""
try:
# 钩子1: 系统维护时的记忆系统维护
await self._register_system_maintenance_hook()
logger.debug("系统维护钩子注册完成")
except Exception as e:
logger.error(f"注册系统维护钩子失败: {e}")
async def _register_post_message_hook(self):
"""注册消息后处理钩子"""
try:
# 这里需要根据实际的系统架构来注册钩子
# 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整
# 尝试注册到事件系统
try:
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
# 注册消息后处理事件
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
logger.debug("已注册到事件系统的消息处理钩子")
except ImportError:
logger.debug("事件系统不可用,跳过事件钩子注册")
# 尝试注册到消息管理器
try:
from src.chat.message_manager import message_manager
# 如果消息管理器支持钩子注册
if hasattr(message_manager, "register_post_process_hook"):
message_manager.register_post_process_hook(self._on_message_processed_hook)
logger.debug("已注册到消息管理器的处理钩子")
except ImportError:
logger.debug("消息管理器不可用,跳过消息管理器钩子注册")
except Exception as e:
logger.error(f"注册消息后处理钩子失败: {e}")
async def _register_chat_stream_hook(self):
"""注册聊天流钩子"""
try:
# 尝试注册到聊天流管理器
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if hasattr(chat_manager, "register_save_hook"):
chat_manager.register_save_hook(self._on_chat_stream_save_hook)
logger.debug("已注册到聊天流管理器的保存钩子")
except ImportError:
logger.debug("聊天流管理器不可用,跳过聊天流钩子注册")
except Exception as e:
logger.error(f"注册聊天流钩子失败: {e}")
async def _register_pre_response_hook(self):
"""注册回复前钩子"""
try:
# 尝试注册到回复生成器
try:
from src.chat.replyer.default_generator import default_generator
if hasattr(default_generator, "register_pre_generation_hook"):
default_generator.register_pre_generation_hook(self._on_pre_response_hook)
logger.debug("已注册到回复生成器的前置钩子")
except ImportError:
logger.debug("回复生成器不可用,跳过回复前钩子注册")
except Exception as e:
logger.error(f"注册回复前钩子失败: {e}")
async def _register_knowledge_query_hook(self):
"""注册知识库查询钩子"""
try:
# 尝试注册到知识库系统
try:
from src.chat.knowledge.knowledge_lib import knowledge_manager
if hasattr(knowledge_manager, "register_query_enhancer"):
knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook)
logger.debug("已注册到知识库的查询增强钩子")
except ImportError:
logger.debug("知识库系统不可用,跳过知识库钩子注册")
except Exception as e:
logger.error(f"注册知识库查询钩子失败: {e}")
async def _register_prompt_building_hook(self):
"""注册提示词构建钩子"""
try:
# 尝试注册到提示词系统
try:
from src.chat.utils.prompt import prompt_manager
if hasattr(prompt_manager, "register_enhancer"):
prompt_manager.register_enhancer(self._on_prompt_building_hook)
logger.debug("已注册到提示词管理器的增强钩子")
except ImportError:
logger.debug("提示词系统不可用,跳过提示词钩子注册")
except Exception as e:
logger.error(f"注册提示词构建钩子失败: {e}")
async def _register_system_maintenance_hook(self):
"""注册系统维护钩子"""
try:
# 尝试注册到系统维护器
try:
from src.manager.async_task_manager import async_task_manager
# 注册定期维护任务
async_task_manager.add_task(MemoryMaintenanceTask())
logger.debug("已注册到系统维护器的定期任务")
except ImportError:
logger.debug("异步任务管理器不可用,跳过系统维护钩子注册")
except Exception as e:
logger.error(f"注册系统维护钩子失败: {e}")
# 钩子处理器方法
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
"""事件系统的消息处理处理器"""
return await self._on_message_processed_hook(event_data)
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
"""消息后处理钩子"""
start_time = time.time()
try:
self.hook_stats["message_processing_hooks"] += 1
# 提取必要的信息
message_info = message_data.get("message_info", {})
user_info = message_info.get("user_info", {})
conversation_text = message_data.get("processed_plain_text", "")
if not conversation_text:
return HookResult(success=True, data="No conversation text")
user_id = str(user_info.get("user_id", "unknown"))
context = {
"chat_id": message_data.get("chat_id"),
"message_type": message_data.get("message_type", "normal"),
"platform": message_info.get("platform", "unknown"),
"interest_value": message_data.get("interest_value", 0.0),
"keywords": message_data.get("key_words", []),
"timestamp": message_data.get("time", time.time()),
}
# 使用增强记忆系统处理对话
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
if result["success"]:
logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
"""聊天流保存钩子"""
start_time = time.time()
try:
self.hook_stats["message_processing_hooks"] += 1
# 从聊天流数据中提取对话信息
stream_context = chat_stream_data.get("stream_context", {})
user_id = stream_context.get("user_id", "unknown")
messages = stream_context.get("messages", [])
if not messages:
return HookResult(success=True, data="No messages to process")
# 构建对话文本
conversation_parts = []
for msg in messages[-10:]: # 只处理最近10条消息
text = msg.get("processed_plain_text", "")
if text:
conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}")
conversation_text = "\n".join(conversation_parts)
if not conversation_text:
return HookResult(success=True, data="No conversation text")
context = {
"chat_id": chat_stream_data.get("chat_id"),
"stream_id": chat_stream_data.get("stream_id"),
"platform": chat_stream_data.get("platform", "unknown"),
"message_count": len(messages),
"timestamp": time.time(),
}
# 使用增强记忆系统处理对话
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
if result["success"]:
logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
"""回复前钩子"""
start_time = time.time()
try:
self.hook_stats["memory_retrieval_hooks"] += 1
# 提取查询信息
query = response_data.get("query", "")
user_id = response_data.get("user_id", "unknown")
context = response_data.get("context", {})
if not query:
return HookResult(success=True, data="No query provided")
# 检索相关记忆
memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 将记忆添加到响应数据中
response_data["enhanced_memories"] = memories
response_data["enhanced_memory_context"] = await get_memory_context_for_prompt(
query, user_id, context, max_memories=5
)
logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆")
return HookResult(success=True, data=memories, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
"""知识库查询钩子"""
start_time = time.time()
try:
self.hook_stats["memory_retrieval_hooks"] += 1
query = query_data.get("query", "")
user_id = query_data.get("user_id", "unknown")
context = query_data.get("context", {})
if not query:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文并增强查询
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 将记忆上下文添加到查询数据中
query_data["enhanced_memory_context"] = memory_context
logger.debug("知识库查询钩子执行成功")
return HookResult(success=True, data=memory_context, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
"""提示词构建钩子"""
start_time = time.time()
try:
self.hook_stats["prompt_enhancement_hooks"] += 1
query = prompt_data.get("query", "")
user_id = prompt_data.get("user_id", "unknown")
context = prompt_data.get("context", {})
base_prompt = prompt_data.get("base_prompt", "")
if not query:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 构建增强的提示词
enhanced_prompt = base_prompt
if memory_context:
enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n"
# 将增强的提示词添加到数据中
prompt_data["enhanced_prompt"] = enhanced_prompt
prompt_data["memory_context"] = memory_context
logger.debug("提示词构建钩子执行成功")
return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
def _update_hook_stats(self, processing_time: float):
"""更新钩子统计"""
self.hook_stats["total_hook_executions"] += 1
total_executions = self.hook_stats["total_hook_executions"]
if total_executions > 0:
current_avg = self.hook_stats["average_hook_time"]
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
self.hook_stats["average_hook_time"] = new_avg
def get_hook_stats(self) -> dict[str, Any]:
"""获取钩子统计信息"""
return self.hook_stats.copy()
class MemoryMaintenanceTask:
"""记忆系统维护任务"""
def __init__(self):
self.task_name = "enhanced_memory_maintenance"
self.interval = 3600 # 1小时执行一次
async def execute(self):
"""执行维护任务"""
try:
logger.info("🔧 执行增强记忆系统维护任务...")
# 获取适配器实例
try:
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
if _enhanced_memory_adapter:
await _enhanced_memory_adapter.maintenance()
logger.info("✅ 增强记忆系统维护任务完成")
else:
logger.debug("增强记忆适配器未初始化,跳过维护")
except Exception as e:
logger.error(f"增强记忆系统维护失败: {e}")
except Exception as e:
logger.error(f"执行维护任务时发生异常: {e}", exc_info=True)
def get_interval(self) -> int:
"""获取执行间隔"""
return self.interval
def get_task_name(self) -> str:
"""获取任务名称"""
return self.task_name
# 全局钩子实例
_memory_hooks: MemoryIntegrationHooks | None = None
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
"""获取全局记忆集成钩子实例"""
global _memory_hooks
if _memory_hooks is None:
_memory_hooks = MemoryIntegrationHooks()
await _memory_hooks.register_hooks()
return _memory_hooks
async def initialize_memory_integration_hooks():
"""初始化记忆集成钩子"""
try:
logger.info("🚀 初始化记忆集成钩子...")
hooks = await get_memory_integration_hooks()
logger.info("✅ 记忆集成钩子初始化完成")
return hooks
except Exception as e:
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
return None

File diff suppressed because it is too large Load Diff

View File

@@ -1,875 +0,0 @@
"""
向量数据库存储接口
为记忆系统提供高效的向量存储和语义搜索能力
"""
import asyncio
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import orjson
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.config_helpers import resolve_embedding_dimension
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
# 尝试导入FAISS如果不可用则使用简单替代
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger.warning("FAISS not available, using simple vector storage")
@dataclass
class VectorStorageConfig:
"""向量存储配置"""
dimension: int = 1024
similarity_threshold: float = 0.8
index_type: str = "flat" # flat, ivf, hnsw
max_index_size: int = 100000
storage_path: str = "data/memory_vectors"
auto_save_interval: int = 10 # 每N次操作自动保存
enable_compression: bool = True
class VectorStorageManager:
"""向量存储管理器"""
def __init__(self, config: VectorStorageConfig | None = None):
self.config = config or VectorStorageConfig()
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
if resolved_dimension and resolved_dimension != self.config.dimension:
logger.info(
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
resolved_dimension,
self.config.dimension,
)
self.config.dimension = resolved_dimension
self.storage_path = Path(self.config.storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
# 向量索引
self.vector_index = None
self.memory_id_to_index = {} # memory_id -> vector index
self.index_to_memory_id = {} # vector index -> memory_id
# 内存缓存
self.memory_cache: dict[str, MemoryChunk] = {}
self.vector_cache: dict[str, list[float]] = {}
# 统计信息
self.storage_stats = {
"total_vectors": 0,
"index_build_time": 0.0,
"average_search_time": 0.0,
"cache_hit_rate": 0.0,
"total_searches": 0,
"cache_hits": 0,
}
# 线程锁
self._lock = threading.RLock()
self._operation_count = 0
# 初始化索引
self._initialize_index()
# 嵌入模型
self.embedding_model: LLMRequest = None
def _initialize_index(self):
"""初始化向量索引"""
try:
if FAISS_AVAILABLE:
if self.config.index_type == "flat":
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
elif self.config.index_type == "ivf":
quantizer = faiss.IndexFlatIP(self.config.dimension)
nlist = min(100, max(1, self.config.max_index_size // 1000))
self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist)
elif self.config.index_type == "hnsw":
self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32)
self.vector_index.hnsw.efConstruction = 40
else:
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
else:
# 简单的向量存储实现
self.vector_index = SimpleVectorIndex(self.config.dimension)
logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}")
except Exception as e:
logger.error(f"❌ 向量索引初始化失败: {e}")
# 回退到简单实现
self.vector_index = SimpleVectorIndex(self.config.dimension)
async def initialize_embedding_model(self):
"""初始化嵌入模型"""
if self.embedding_model is None:
self.embedding_model = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="memory.embedding"
)
logger.info("✅ 嵌入模型初始化完成")
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
"""生成查询向量,用于记忆召回"""
if not query_text:
logger.warning("查询文本为空,无法生成向量")
return None
try:
await self.initialize_embedding_model()
logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'")
embedding, _ = await self.embedding_model.get_embedding(query_text)
if not embedding:
logger.warning("嵌入模型返回空向量")
return None
logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}")
if len(embedding) != self.config.dimension:
logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding))
return None
normalized_vector = self._normalize_vector(embedding)
logger.debug(f"查询向量生成成功,向量范围: [{min(normalized_vector):.4f}, {max(normalized_vector):.4f}]")
return normalized_vector
except Exception as exc:
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
return None
async def store_memories(self, memories: list[MemoryChunk]):
"""存储记忆向量"""
if not memories:
return
start_time = time.time()
try:
# 确保嵌入模型已初始化
await self.initialize_embedding_model()
# 批量获取嵌入向量
memory_texts = []
for memory in memories:
# 预先缓存记忆,确保后续流程可访问
self.memory_cache[memory.memory_id] = memory
if memory.embedding is None:
# 如果没有嵌入向量,需要生成
text = self._prepare_embedding_text(memory)
memory_texts.append((memory.memory_id, text))
else:
# 已有嵌入向量,直接使用
await self._add_single_memory(memory, memory.embedding)
# 批量生成缺失的嵌入向量
if memory_texts:
await self._batch_generate_and_store_embeddings(memory_texts)
# 自动保存检查
self._operation_count += len(memories)
if self._operation_count >= self.config.auto_save_interval:
await self.save_storage()
self._operation_count = 0
storage_time = time.time() - start_time
logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}")
except Exception as e:
logger.error(f"❌ 向量存储失败: {e}", exc_info=True)
def _prepare_embedding_text(self, memory: MemoryChunk) -> str:
"""准备用于嵌入的文本,仅使用自然语言展示内容"""
display_text = (memory.display or "").strip()
if display_text:
return display_text
fallback_text = (memory.text_content or "").strip()
if fallback_text:
return fallback_text
subjects = "".join(s.strip() for s in memory.subjects if s and isinstance(s, str))
predicate = (memory.content.predicate or "").strip()
obj = memory.content.object
if isinstance(obj, dict):
object_parts = []
for key, value in obj.items():
if value is None:
continue
if isinstance(value, (list, tuple)):
preview = "".join(str(item) for item in value[:3])
object_parts.append(f"{key}:{preview}")
else:
object_parts.append(f"{key}:{value}")
object_text = ", ".join(object_parts)
else:
object_text = str(obj or "").strip()
composite_parts = [part for part in [subjects, predicate, object_text] if part]
if composite_parts:
return " ".join(composite_parts)
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
return memory.memory_id
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
"""批量生成和存储嵌入向量"""
if not memory_texts:
return
try:
texts = [text for _, text in memory_texts]
memory_ids = [memory_id for memory_id, _ in memory_texts]
# 批量生成嵌入向量
embeddings = await self._batch_generate_embeddings(memory_ids, texts)
# 存储向量和记忆
for memory_id, embedding in embeddings.items():
if embedding and len(embedding) == self.config.dimension:
memory = self.memory_cache.get(memory_id)
if memory:
await self._add_single_memory(memory, embedding)
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
"""批量生成嵌入向量"""
if not texts:
return {}
results: dict[str, list[float]] = {}
try:
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
async def generate_embedding(memory_id: str, text: str) -> None:
async with semaphore:
try:
embedding, _ = await self.embedding_model.get_embedding(text)
if embedding and len(embedding) == self.config.dimension:
results[memory_id] = embedding
else:
logger.warning(
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
self.config.dimension,
len(embedding) if embedding else 0,
memory_id,
)
results[memory_id] = []
except Exception as exc:
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
results[memory_id] = []
tasks = [
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
]
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
for memory_id in memory_ids:
results.setdefault(memory_id, [])
return results
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
"""添加单个记忆到向量存储"""
with self._lock:
try:
# 规范化向量
if embedding:
embedding = self._normalize_vector(embedding)
# 添加到缓存
self.memory_cache[memory.memory_id] = memory
self.vector_cache[memory.memory_id] = embedding
# 更新记忆的嵌入向量
memory.set_embedding(embedding)
# 添加到向量索引
if hasattr(self.vector_index, "add"):
# FAISS索引
if isinstance(embedding, np.ndarray):
vector_array = embedding.reshape(1, -1).astype("float32")
else:
vector_array = np.array([embedding], dtype="float32")
# 特殊处理IVF索引
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
# IVF索引需要先训练
logger.debug("训练IVF索引...")
self.vector_index.train(vector_array)
self.vector_index.add(vector_array)
index_id = self.vector_index.ntotal - 1
else:
# 简单索引
index_id = self.vector_index.add_vector(embedding)
# 更新映射关系
self.memory_id_to_index[memory.memory_id] = index_id
self.index_to_memory_id[index_id] = memory.memory_id
# 更新统计
self.storage_stats["total_vectors"] += 1
except Exception as e:
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
def _normalize_vector(self, vector: list[float]) -> list[float]:
"""L2归一化向量"""
if not vector:
return vector
try:
vector_array = np.array(vector, dtype=np.float32)
norm = np.linalg.norm(vector_array)
if norm == 0:
return vector
normalized = vector_array / norm
return normalized.tolist()
except Exception as e:
logger.warning(f"向量归一化失败: {e}")
return vector
async def search_similar_memories(
self,
query_vector: list[float] | None = None,
*,
query_text: str | None = None,
limit: int = 10,
scope_id: str | None = None,
) -> list[tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
try:
logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}")
if query_vector is None:
if not query_text:
logger.warning("查询向量和查询文本都为空")
return []
query_vector = await self.generate_query_embedding(query_text)
if not query_vector:
logger.warning("查询向量生成失败")
return []
scope_filter: str | None = None
if isinstance(scope_id, str):
normalized_scope = scope_id.strip().lower()
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
scope_filter = scope_id
elif scope_id:
scope_filter = str(scope_id)
# 规范化查询向量
query_vector = self._normalize_vector(query_vector)
logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}")
# 检查向量索引状态
if not self.vector_index:
logger.error("向量索引未初始化")
return []
total_vectors = 0
if hasattr(self.vector_index, "ntotal"):
total_vectors = self.vector_index.ntotal
elif hasattr(self.vector_index, "vectors"):
total_vectors = len(self.vector_index.vectors)
logger.debug(f"向量索引中实际向量数: {total_vectors}")
if total_vectors == 0:
logger.warning("向量索引为空,无法执行搜索")
return []
# 执行向量搜索
with self._lock:
if hasattr(self.vector_index, "search"):
# FAISS索引
if isinstance(query_vector, np.ndarray):
query_array = query_vector.reshape(1, -1).astype("float32")
else:
query_array = np.array([query_vector], dtype="float32")
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
# 设置IVF搜索参数
nprobe = min(self.vector_index.nlist, 10)
self.vector_index.nprobe = nprobe
logger.debug(f"IVF搜索参数: nprobe={nprobe}")
search_limit = min(limit, total_vectors)
logger.debug(f"执行FAISS搜索搜索限制: {search_limit}")
distances, indices = self.vector_index.search(query_array, search_limit)
distances = distances.flatten().tolist()
indices = indices.flatten().tolist()
logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引")
else:
# 简单索引
logger.debug("使用简单向量索引执行搜索")
results = self.vector_index.search(query_vector, limit)
distances = [score for _, score in results]
indices = [idx for idx, _ in results]
logger.debug(f"简单索引搜索结果: {len(results)} 个结果")
# 处理搜索结果
results = []
valid_results = 0
invalid_indices = 0
filtered_by_scope = 0
for distance, index in zip(distances, indices, strict=False):
if index == -1: # FAISS的无效索引标记
invalid_indices += 1
continue
memory_id = self.index_to_memory_id.get(index)
if not memory_id:
logger.debug(f"索引 {index} 没有对应的记忆ID")
invalid_indices += 1
continue
if scope_filter:
memory = self.memory_cache.get(memory_id)
if memory and str(memory.user_id) != scope_filter:
filtered_by_scope += 1
continue
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
results.append((memory_id, similarity))
valid_results += 1
logger.debug(
f"搜索结果处理: 总距离={len(distances)}, 有效结果={valid_results}, "
f"无效索引={invalid_indices}, 作用域过滤={filtered_by_scope}"
)
# 更新统计
search_time = time.time() - start_time
self.storage_stats["total_searches"] += 1
self.storage_stats["average_search_time"] = (
self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time
) / self.storage_stats["total_searches"]
final_results = results[:limit]
logger.info(
f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' "
f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果"
)
return final_results
except Exception as e:
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
return []
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
"""根据ID获取记忆"""
# 先检查缓存
if memory_id in self.memory_cache:
self.storage_stats["cache_hits"] += 1
return self.memory_cache[memory_id]
self.storage_stats["total_searches"] += 1
return None
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
"""更新记忆的嵌入向量"""
with self._lock:
try:
if memory_id not in self.memory_id_to_index:
logger.warning(f"记忆 {memory_id} 不存在于向量索引中")
return
# 获取旧索引
old_index = self.memory_id_to_index[memory_id]
# 删除旧向量(如果支持)
if hasattr(self.vector_index, "remove_ids"):
try:
self.vector_index.remove_ids(np.array([old_index]))
except:
logger.warning("无法删除旧向量,将直接添加新向量")
# 规范化新向量
new_embedding = self._normalize_vector(new_embedding)
# 添加新向量
if hasattr(self.vector_index, "add"):
if isinstance(new_embedding, np.ndarray):
vector_array = new_embedding.reshape(1, -1).astype("float32")
else:
vector_array = np.array([new_embedding], dtype="float32")
self.vector_index.add(vector_array)
new_index = self.vector_index.ntotal - 1
else:
new_index = self.vector_index.add_vector(new_embedding)
# 更新映射关系
self.memory_id_to_index[memory_id] = new_index
self.index_to_memory_id[new_index] = memory_id
# 更新缓存
self.vector_cache[memory_id] = new_embedding
# 更新记忆对象
memory = self.memory_cache.get(memory_id)
if memory:
memory.set_embedding(new_embedding)
logger.debug(f"更新记忆 {memory_id} 的嵌入向量")
except Exception as e:
logger.error(f"❌ 更新记忆嵌入向量失败: {e}")
async def delete_memory(self, memory_id: str):
"""删除记忆"""
with self._lock:
try:
if memory_id not in self.memory_id_to_index:
return
# 获取索引
index = self.memory_id_to_index[memory_id]
# 从向量索引中删除(如果支持)
if hasattr(self.vector_index, "remove_ids"):
try:
self.vector_index.remove_ids(np.array([index]))
except:
logger.warning("无法从向量索引中删除,仅从缓存中移除")
# 删除映射关系
del self.memory_id_to_index[memory_id]
if index in self.index_to_memory_id:
del self.index_to_memory_id[index]
# 从缓存中删除
self.memory_cache.pop(memory_id, None)
self.vector_cache.pop(memory_id, None)
# 更新统计
self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1)
logger.debug(f"删除记忆 {memory_id}")
except Exception as e:
logger.error(f"❌ 删除记忆失败: {e}")
async def save_storage(self):
"""保存向量存储到文件"""
try:
logger.info("正在保存向量存储...")
# 保存记忆缓存
cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()}
cache_file = self.storage_path / "memory_cache.json"
with open(cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
with open(vector_cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存映射关系
mapping_file = self.storage_path / "id_mapping.json"
mapping_data = {
"memory_id_to_index": {
str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items()
},
"index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()},
}
with open(mapping_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存FAISS索引如果可用
if FAISS_AVAILABLE and hasattr(self.vector_index, "save"):
index_file = self.storage_path / "vector_index.faiss"
faiss.write_index(self.vector_index, str(index_file))
# 保存统计信息
stats_file = self.storage_path / "storage_stats.json"
with open(stats_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
logger.info("✅ 向量存储保存完成")
except Exception as e:
logger.error(f"❌ 保存向量存储失败: {e}")
async def load_storage(self):
"""从文件加载向量存储"""
try:
logger.info("正在加载向量存储...")
# 加载记忆缓存
cache_file = self.storage_path / "memory_cache.json"
if cache_file.exists():
with open(cache_file, encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
self.memory_cache = {
memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items()
}
# 加载向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
if vector_cache_file.exists():
with open(vector_cache_file, encoding="utf-8") as f:
self.vector_cache = orjson.loads(f.read())
# 加载映射关系
mapping_file = self.storage_path / "id_mapping.json"
if mapping_file.exists():
with open(mapping_file, encoding="utf-8") as f:
mapping_data = orjson.loads(f.read())
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items()
}
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()}
# 加载FAISS索引如果可用
index_loaded = False
if FAISS_AVAILABLE:
index_file = self.storage_path / "vector_index.faiss"
if index_file.exists():
try:
loaded_index = faiss.read_index(str(index_file))
# 如果索引类型匹配,则替换
if type(loaded_index) == type(self.vector_index):
self.vector_index = loaded_index
index_loaded = True
logger.info("✅ FAISS索引文件加载完成")
else:
logger.warning("索引类型不匹配,重新构建索引")
except Exception as e:
logger.warning(f"加载FAISS索引失败: {e},重新构建")
else:
logger.info("FAISS索引文件不存在将重新构建")
# 如果索引没有成功加载且有向量数据,则重建索引
if not index_loaded and self.vector_cache:
logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引")
await self._rebuild_index()
# 加载统计信息
stats_file = self.storage_path / "storage_stats.json"
if stats_file.exists():
with open(stats_file, encoding="utf-8") as f:
self.storage_stats = orjson.loads(f.read())
# 更新向量计数
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量")
except Exception as e:
logger.error(f"❌ 加载向量存储失败: {e}")
async def _rebuild_index(self):
"""重建向量索引"""
try:
logger.info(f"正在重建向量索引...向量数量: {len(self.vector_cache)}")
# 重新初始化索引
self._initialize_index()
# 清空映射关系
self.memory_id_to_index.clear()
self.index_to_memory_id.clear()
if not self.vector_cache:
logger.warning("没有向量缓存数据,跳过重建")
return
# 准备向量数据
memory_ids = []
vectors = []
for memory_id, embedding in self.vector_cache.items():
if embedding and len(embedding) == self.config.dimension:
memory_ids.append(memory_id)
vectors.append(self._normalize_vector(embedding))
else:
logger.debug(f"跳过无效向量: {memory_id}, 维度: {len(embedding) if embedding else 0}")
if not vectors:
logger.warning("没有有效的向量数据")
return
logger.info(f"准备重建 {len(vectors)} 个向量到索引")
# 批量添加向量到FAISS索引
if hasattr(self.vector_index, "add"):
# FAISS索引
vector_array = np.array(vectors, dtype="float32")
# 特殊处理IVF索引
if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"):
logger.info("训练IVF索引...")
self.vector_index.train(vector_array)
# 添加向量
self.vector_index.add(vector_array)
# 重建映射关系
for i, memory_id in enumerate(memory_ids):
self.memory_id_to_index[memory_id] = i
self.index_to_memory_id[i] = memory_id
else:
# 简单索引
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)):
index_id = self.vector_index.add_vector(vector)
self.memory_id_to_index[memory_id] = index_id
self.index_to_memory_id[index_id] = memory_id
# 更新统计
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index))
logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}")
except Exception as e:
logger.error(f"❌ 重建向量索引失败: {e}", exc_info=True)
async def optimize_storage(self):
"""优化存储"""
try:
logger.info("开始向量存储优化...")
# 清理无效引用
self._cleanup_invalid_references()
# 重新构建索引(如果碎片化严重)
if self.storage_stats["total_vectors"] > 1000:
await self._rebuild_index()
# 更新缓存命中率
if self.storage_stats["total_searches"] > 0:
self.storage_stats["cache_hit_rate"] = (
self.storage_stats["cache_hits"] / self.storage_stats["total_searches"]
)
logger.info("✅ 向量存储优化完成")
except Exception as e:
logger.error(f"❌ 向量存储优化失败: {e}")
def _cleanup_invalid_references(self):
"""清理无效引用"""
with self._lock:
# 清理无效的memory_id到index的映射
valid_memory_ids = set(self.memory_cache.keys())
invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids
for memory_id in invalid_memory_ids:
index = self.memory_id_to_index[memory_id]
del self.memory_id_to_index[memory_id]
if index in self.index_to_memory_id:
del self.index_to_memory_id[index]
if invalid_memory_ids:
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
def get_storage_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
stats = self.storage_stats.copy()
if stats["total_searches"] > 0:
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"]
else:
stats["cache_hit_rate"] = 0.0
return stats
class SimpleVectorIndex:
"""简单的向量索引实现当FAISS不可用时的替代方案"""
def __init__(self, dimension: int):
self.dimension = dimension
self.vectors: list[list[float]] = []
self.vector_ids: list[int] = []
self.next_id = 0
def add_vector(self, vector: list[float]) -> int:
"""添加向量"""
if len(vector) != self.dimension:
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
vector_id = self.next_id
self.vectors.append(vector.copy())
self.vector_ids.append(vector_id)
self.next_id += 1
return vector_id
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
"""搜索相似向量"""
if len(query_vector) != self.dimension:
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
results = []
for i, vector in enumerate(self.vectors):
similarity = self._calculate_cosine_similarity(query_vector, vector)
results.append((self.vector_ids[i], similarity))
# 按相似度排序
results.sort(key=lambda x: x[1], reverse=True)
return results[:limit]
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
"""计算余弦相似度"""
try:
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
norm1 = sum(x * x for x in v1) ** 0.5
norm2 = sum(x * x for x in v2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
except Exception:
return 0.0
@property
def ntotal(self) -> int:
"""向量总数"""
return len(self.vectors)

View File

@@ -0,0 +1,489 @@
"""
自适应流管理器 - 动态并发限制和异步流池管理
根据系统负载和流优先级动态调整并发限制
"""
import asyncio
import psutil
import time
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
from enum import Enum
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("adaptive_stream_manager")
class StreamPriority(Enum):
"""流优先级"""
LOW = 1
NORMAL = 2
HIGH = 3
CRITICAL = 4
@dataclass
class SystemMetrics:
"""系统指标"""
cpu_usage: float = 0.0
memory_usage: float = 0.0
active_coroutines: int = 0
event_loop_lag: float = 0.0
timestamp: float = field(default_factory=time.time)
@dataclass
class StreamMetrics:
"""流指标"""
stream_id: str
priority: StreamPriority
message_rate: float = 0.0 # 消息速率(消息/分钟)
response_time: float = 0.0 # 平均响应时间
last_activity: float = field(default_factory=time.time)
consecutive_failures: int = 0
is_active: bool = True
class AdaptiveStreamManager:
"""自适应流管理器"""
def __init__(
self,
base_concurrent_limit: int = 50,
max_concurrent_limit: int = 200,
min_concurrent_limit: int = 10,
metrics_window: float = 60.0, # 指标窗口时间
adjustment_interval: float = 30.0, # 调整间隔
cpu_threshold_high: float = 0.8, # CPU高负载阈值
cpu_threshold_low: float = 0.3, # CPU低负载阈值
memory_threshold_high: float = 0.85, # 内存高负载阈值
):
self.base_concurrent_limit = base_concurrent_limit
self.max_concurrent_limit = max_concurrent_limit
self.min_concurrent_limit = min_concurrent_limit
self.metrics_window = metrics_window
self.adjustment_interval = adjustment_interval
self.cpu_threshold_high = cpu_threshold_high
self.cpu_threshold_low = cpu_threshold_low
self.memory_threshold_high = memory_threshold_high
# 当前状态
self.current_limit = base_concurrent_limit
self.active_streams: Set[str] = set()
self.pending_streams: Set[str] = set()
self.stream_metrics: Dict[str, StreamMetrics] = {}
# 异步信号量
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
# 系统监控
self.system_metrics: List[SystemMetrics] = []
self.last_adjustment_time = 0.0
# 统计信息
self.stats = {
"total_requests": 0,
"accepted_requests": 0,
"rejected_requests": 0,
"priority_accepts": 0,
"limit_adjustments": 0,
"avg_concurrent_streams": 0,
"peak_concurrent_streams": 0,
}
# 监控任务
self.monitor_task: Optional[asyncio.Task] = None
self.adjustment_task: Optional[asyncio.Task] = None
self.is_running = False
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
async def start(self):
"""启动自适应管理器"""
if self.is_running:
logger.warning("自适应流管理器已经在运行")
return
self.is_running = True
self.monitor_task = asyncio.create_task(self._system_monitor_loop(), name="system_monitor")
self.adjustment_task = asyncio.create_task(self._adjustment_loop(), name="limit_adjustment")
logger.info("自适应流管理器已启动")
async def stop(self):
"""停止自适应管理器"""
if not self.is_running:
return
self.is_running = False
# 停止监控任务
if self.monitor_task and not self.monitor_task.done():
self.monitor_task.cancel()
try:
await asyncio.wait_for(self.monitor_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("系统监控任务停止超时")
except Exception as e:
logger.error(f"停止系统监控任务时出错: {e}")
if self.adjustment_task and not self.adjustment_task.done():
self.adjustment_task.cancel()
try:
await asyncio.wait_for(self.adjustment_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("限制调整任务停止超时")
except Exception as e:
logger.error(f"停止限制调整任务时出错: {e}")
logger.info("自适应流管理器已停止")
async def acquire_stream_slot(
self,
stream_id: str,
priority: StreamPriority = StreamPriority.NORMAL,
force: bool = False
) -> bool:
"""
获取流处理槽位
Args:
stream_id: 流ID
priority: 优先级
force: 是否强制获取(突破限制)
Returns:
bool: 是否成功获取槽位
"""
# 检查管理器是否已启动
if not self.is_running:
logger.warning(f"自适应流管理器未运行,直接允许流 {stream_id}")
return True
self.stats["total_requests"] += 1
current_time = time.time()
# 更新流指标
if stream_id not in self.stream_metrics:
self.stream_metrics[stream_id] = StreamMetrics(
stream_id=stream_id,
priority=priority
)
self.stream_metrics[stream_id].last_activity = current_time
# 检查是否已经活跃
if stream_id in self.active_streams:
logger.debug(f"{stream_id} 已经在活跃列表中")
return True
# 优先级处理
if priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
return await self._acquire_priority_slot(stream_id, priority, force)
# 检查是否需要强制分发(消息积压)
if not force and self._should_force_dispatch(stream_id):
force = True
logger.info(f"{stream_id} 消息积压严重,强制分发")
# 尝试获取常规信号量
try:
# 使用wait_for实现非阻塞获取
acquired = await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
if acquired:
self.active_streams.add(stream_id)
self.stats["accepted_requests"] += 1
logger.debug(f"{stream_id} 获取常规槽位成功 (当前活跃: {len(self.active_streams)})")
return True
except asyncio.TimeoutError:
logger.debug(f"常规信号量已满: {stream_id}")
except Exception as e:
logger.warning(f"获取常规槽位时出错: {e}")
# 如果强制分发,尝试突破限制
if force:
return await self._force_acquire_slot(stream_id)
# 无法获取槽位
self.stats["rejected_requests"] += 1
logger.debug(f"{stream_id} 获取槽位失败,当前限制: {self.current_limit}, 活跃流: {len(self.active_streams)}")
return False
async def _acquire_priority_slot(self, stream_id: str, priority: StreamPriority, force: bool) -> bool:
"""获取优先级槽位"""
try:
# 优先级信号量有少量槽位
acquired = await asyncio.wait_for(self.priority_semaphore.acquire(), timeout=0.001)
if acquired:
self.active_streams.add(stream_id)
self.stats["priority_accepts"] += 1
self.stats["accepted_requests"] += 1
logger.debug(f"{stream_id} 获取优先级槽位成功 (优先级: {priority.name})")
return True
except asyncio.TimeoutError:
logger.debug(f"优先级信号量已满: {stream_id}")
except Exception as e:
logger.warning(f"获取优先级槽位时出错: {e}")
# 如果优先级槽位也满了,检查是否强制
if force or priority == StreamPriority.CRITICAL:
return await self._force_acquire_slot(stream_id)
return False
async def _force_acquire_slot(self, stream_id: str) -> bool:
"""强制获取槽位(突破限制)"""
# 检查是否超过最大限制
if len(self.active_streams) >= self.max_concurrent_limit:
logger.warning(f"达到最大并发限制 {self.max_concurrent_limit},无法为流 {stream_id} 强制分发")
return False
# 强制添加到活跃列表
self.active_streams.add(stream_id)
self.stats["accepted_requests"] += 1
logger.warning(f"{stream_id} 突破并发限制强制分发 (当前活跃: {len(self.active_streams)})")
return True
def release_stream_slot(self, stream_id: str):
"""释放流处理槽位"""
if stream_id in self.active_streams:
self.active_streams.remove(stream_id)
# 释放相应的信号量
metrics = self.stream_metrics.get(stream_id)
if metrics and metrics.priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
self.priority_semaphore.release()
else:
self.semaphore.release()
logger.debug(f"{stream_id} 释放槽位 (当前活跃: {len(self.active_streams)})")
def _should_force_dispatch(self, stream_id: str) -> bool:
"""判断是否应该强制分发"""
# 这里可以实现基于消息积压的判断逻辑
# 简化版本:基于流的历史活跃度和优先级
metrics = self.stream_metrics.get(stream_id)
if not metrics:
return False
# 如果是高优先级流,更容易强制分发
if metrics.priority == StreamPriority.HIGH:
return True
# 如果最近有活跃且响应时间较长,可能需要强制分发
current_time = time.time()
if (current_time - metrics.last_activity < 300 and # 5分钟内有活动
metrics.response_time > 5.0): # 响应时间超过5秒
return True
return False
async def _system_monitor_loop(self):
"""系统监控循环"""
logger.info("系统监控循环启动")
while self.is_running:
try:
await asyncio.sleep(5.0) # 每5秒监控一次
await self._collect_system_metrics()
except asyncio.CancelledError:
logger.info("系统监控循环被取消")
break
except Exception as e:
logger.error(f"系统监控出错: {e}")
logger.info("系统监控循环结束")
async def _collect_system_metrics(self):
"""收集系统指标"""
try:
# CPU使用率
cpu_usage = psutil.cpu_percent(interval=None) / 100.0
# 内存使用率
memory = psutil.virtual_memory()
memory_usage = memory.percent / 100.0
# 活跃协程数量
try:
active_coroutines = len(asyncio.all_tasks())
except:
active_coroutines = 0
# 事件循环延迟
event_loop_lag = 0.0
try:
loop = asyncio.get_running_loop()
start_time = time.time()
await asyncio.sleep(0)
event_loop_lag = time.time() - start_time
except:
pass
metrics = SystemMetrics(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
active_coroutines=active_coroutines,
event_loop_lag=event_loop_lag,
timestamp=time.time()
)
self.system_metrics.append(metrics)
# 保持指标窗口大小
cutoff_time = time.time() - self.metrics_window
self.system_metrics = [
m for m in self.system_metrics
if m.timestamp > cutoff_time
]
# 更新统计信息
self.stats["avg_concurrent_streams"] = (
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
)
self.stats["peak_concurrent_streams"] = max(
self.stats["peak_concurrent_streams"],
len(self.active_streams)
)
except Exception as e:
logger.error(f"收集系统指标失败: {e}")
async def _adjustment_loop(self):
"""限制调整循环"""
logger.info("限制调整循环启动")
while self.is_running:
try:
await asyncio.sleep(self.adjustment_interval)
await self._adjust_concurrent_limit()
except asyncio.CancelledError:
logger.info("限制调整循环被取消")
break
except Exception as e:
logger.error(f"限制调整出错: {e}")
logger.info("限制调整循环结束")
async def _adjust_concurrent_limit(self):
"""调整并发限制"""
if not self.system_metrics:
return
current_time = time.time()
if current_time - self.last_adjustment_time < self.adjustment_interval:
return
# 计算平均系统指标
recent_metrics = self.system_metrics[-10:] if len(self.system_metrics) >= 10 else self.system_metrics
if not recent_metrics:
return
avg_cpu = sum(m.cpu_usage for m in recent_metrics) / len(recent_metrics)
avg_memory = sum(m.memory_usage for m in recent_metrics) / len(recent_metrics)
avg_coroutines = sum(m.active_coroutines for m in recent_metrics) / len(recent_metrics)
# 调整策略
old_limit = self.current_limit
adjustment_factor = 1.0
# CPU负载调整
if avg_cpu > self.cpu_threshold_high:
adjustment_factor *= 0.8 # 减少20%
elif avg_cpu < self.cpu_threshold_low:
adjustment_factor *= 1.2 # 增加20%
# 内存负载调整
if avg_memory > self.memory_threshold_high:
adjustment_factor *= 0.7 # 减少30%
# 协程数量调整
if avg_coroutines > 1000:
adjustment_factor *= 0.9 # 减少10%
# 应用调整
new_limit = int(self.current_limit * adjustment_factor)
new_limit = max(self.min_concurrent_limit, min(self.max_concurrent_limit, new_limit))
# 检查是否需要调整信号量
if new_limit != self.current_limit:
await self._adjust_semaphore(self.current_limit, new_limit)
self.current_limit = new_limit
self.stats["limit_adjustments"] += 1
self.last_adjustment_time = current_time
logger.info(
f"并发限制调整: {old_limit} -> {new_limit} "
f"(CPU: {avg_cpu:.2f}, 内存: {avg_memory:.2f}, 协程: {avg_coroutines:.0f})"
)
async def _adjust_semaphore(self, old_limit: int, new_limit: int):
"""调整信号量大小"""
if new_limit > old_limit:
# 增加信号量槽位
for _ in range(new_limit - old_limit):
self.semaphore.release()
elif new_limit < old_limit:
# 减少信号量槽位(通过等待槽位被释放)
reduction = old_limit - new_limit
for _ in range(reduction):
try:
await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
except:
# 如果无法立即获取,说明当前使用量接近限制
break
def update_stream_metrics(self, stream_id: str, **kwargs):
"""更新流指标"""
if stream_id not in self.stream_metrics:
return
metrics = self.stream_metrics[stream_id]
for key, value in kwargs.items():
if hasattr(metrics, key):
setattr(metrics, key, value)
def get_stats(self) -> Dict:
"""获取统计信息"""
stats = self.stats.copy()
stats.update({
"current_limit": self.current_limit,
"active_streams": len(self.active_streams),
"pending_streams": len(self.pending_streams),
"is_running": self.is_running,
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
})
# 计算接受率
if stats["total_requests"] > 0:
stats["acceptance_rate"] = stats["accepted_requests"] / stats["total_requests"]
else:
stats["acceptance_rate"] = 0
return stats
# 全局自适应管理器实例
_adaptive_manager: Optional[AdaptiveStreamManager] = None
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
"""获取自适应流管理器实例"""
global _adaptive_manager
if _adaptive_manager is None:
_adaptive_manager = AdaptiveStreamManager()
return _adaptive_manager
async def init_adaptive_stream_manager():
"""初始化自适应流管理器"""
manager = get_adaptive_stream_manager()
await manager.start()
async def shutdown_adaptive_stream_manager():
"""关闭自适应流管理器"""
manager = get_adaptive_stream_manager()
await manager.stop()

View File

@@ -0,0 +1,348 @@
"""
异步批量数据库写入器
优化频繁的数据库写入操作减少I/O阻塞
"""
import asyncio
import time
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from collections import defaultdict
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("batch_database_writer")
@dataclass
class StreamUpdatePayload:
"""流更新数据结构"""
stream_id: str
update_data: Dict[str, Any]
priority: int = 0 # 优先级,数字越大优先级越高
timestamp: float = field(default_factory=time.time)
class BatchDatabaseWriter:
"""异步批量数据库写入器"""
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0, max_queue_size: int = 1000):
"""
初始化批量写入器
Args:
batch_size: 批量写入的大小
flush_interval: 刷新间隔(秒)
max_queue_size: 最大队列大小
"""
self.batch_size = batch_size
self.flush_interval = flush_interval
self.max_queue_size = max_queue_size
# 异步队列
self.write_queue: asyncio.Queue[StreamUpdatePayload] = asyncio.Queue(maxsize=max_queue_size)
# 运行状态
self.is_running = False
self.writer_task: Optional[asyncio.Task] = None
# 统计信息
self.stats = {
"total_writes": 0,
"batch_writes": 0,
"failed_writes": 0,
"queue_size": 0,
"avg_batch_size": 0,
"last_flush_time": 0,
}
# 按优先级分类的批次
self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list)
logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)")
async def start(self):
"""启动批量写入器"""
if self.is_running:
logger.warning("批量写入器已经在运行")
return
self.is_running = True
self.writer_task = asyncio.create_task(self._batch_writer_loop(), name="batch_database_writer")
logger.info("批量数据库写入器已启动")
async def stop(self):
"""停止批量写入器"""
if not self.is_running:
return
self.is_running = False
# 等待当前批次写入完成
if self.writer_task and not self.writer_task.done():
try:
# 先处理剩余的数据
await self._flush_all_batches()
# 取消任务
self.writer_task.cancel()
await asyncio.wait_for(self.writer_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("批量写入器停止超时")
except Exception as e:
logger.error(f"停止批量写入器时出错: {e}")
logger.info("批量数据库写入器已停止")
async def schedule_stream_update(
self,
stream_id: str,
update_data: Dict[str, Any],
priority: int = 0
) -> bool:
"""
调度流更新
Args:
stream_id: 流ID
update_data: 更新数据
priority: 优先级
Returns:
bool: 是否成功加入队列
"""
try:
if not self.is_running:
logger.warning("批量写入器未运行,直接写入数据库")
await self._direct_write(stream_id, update_data)
return True
# 创建更新载荷
payload = StreamUpdatePayload(
stream_id=stream_id,
update_data=update_data,
priority=priority
)
# 非阻塞方式加入队列
try:
self.write_queue.put_nowait(payload)
self.stats["total_writes"] += 1
self.stats["queue_size"] = self.write_queue.qsize()
return True
except asyncio.QueueFull:
logger.warning(f"写入队列已满,丢弃低优先级更新: stream_id={stream_id}")
return False
except Exception as e:
logger.error(f"调度流更新失败: {e}")
return False
async def _batch_writer_loop(self):
"""批量写入主循环"""
logger.info("批量写入循环启动")
while self.is_running:
try:
# 等待批次填满或超时
batch = await self._collect_batch()
if batch:
await self._write_batch(batch)
# 更新统计信息
self.stats["queue_size"] = self.write_queue.qsize()
except asyncio.CancelledError:
logger.info("批量写入循环被取消")
break
except Exception as e:
logger.error(f"批量写入循环出错: {e}")
# 短暂等待后继续
await asyncio.sleep(1.0)
# 循环结束前处理剩余数据
await self._flush_all_batches()
logger.info("批量写入循环结束")
async def _collect_batch(self) -> List[StreamUpdatePayload]:
"""收集一个批次的数据"""
batch = []
deadline = time.time() + self.flush_interval
while len(batch) < self.batch_size and time.time() < deadline:
try:
# 计算剩余等待时间
remaining_time = max(0, deadline - time.time())
if remaining_time == 0:
break
payload = await asyncio.wait_for(
self.write_queue.get(),
timeout=remaining_time
)
batch.append(payload)
except asyncio.TimeoutError:
break
return batch
async def _write_batch(self, batch: List[StreamUpdatePayload]):
"""批量写入数据库"""
if not batch:
return
start_time = time.time()
try:
# 按优先级排序
batch.sort(key=lambda x: (-x.priority, x.timestamp))
# 合并同一流ID的更新保留最新的
merged_updates = {}
for payload in batch:
if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp:
merged_updates[payload.stream_id] = payload
# 批量写入
await self._batch_write_to_database(list(merged_updates.values()))
# 更新统计
self.stats["batch_writes"] += 1
self.stats["avg_batch_size"] = (
self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1
) # 滑动平均
self.stats["last_flush_time"] = start_time
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
except Exception as e:
self.stats["failed_writes"] += 1
logger.error(f"批量写入失败: {e}")
# 降级到单个写入
for payload in batch:
try:
await self._direct_write(payload.stream_id, payload.update_data)
except Exception as single_e:
logger.error(f"单个写入也失败: {single_e}")
async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]):
"""批量写入数据库"""
async with get_db_session() as session:
for payload in payloads:
stream_id = payload.stream_id
update_data = payload.update_data
# 根据数据库类型选择不同的插入/更新策略
if global_config.database.database_type == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_=update_data
)
elif global_config.database.database_type == "mysql":
from sqlalchemy.dialects.mysql import insert as mysql_insert
stmt = mysql_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in update_data.items() if key != "stream_id"}
)
else:
# 默认使用SQLite语法
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_=update_data
)
await session.execute(stmt)
await session.commit()
async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]):
"""直接写入数据库(降级方案)"""
async with get_db_session() as session:
if global_config.database.database_type == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_=update_data
)
elif global_config.database.database_type == "mysql":
from sqlalchemy.dialects.mysql import insert as mysql_insert
stmt = mysql_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in update_data.items() if key != "stream_id"}
)
else:
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_=update_data
)
await session.execute(stmt)
await session.commit()
async def _flush_all_batches(self):
"""刷新所有剩余批次"""
# 收集所有剩余数据
remaining_batch = []
while not self.write_queue.empty():
try:
payload = self.write_queue.get_nowait()
remaining_batch.append(payload)
except asyncio.QueueEmpty:
break
if remaining_batch:
await self._write_batch(remaining_batch)
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
stats = self.stats.copy()
stats["is_running"] = self.is_running
stats["current_queue_size"] = self.write_queue.qsize() if self.is_running else 0
return stats
# 全局批量写入器实例
_batch_writer: Optional[BatchDatabaseWriter] = None
def get_batch_writer() -> BatchDatabaseWriter:
"""获取批量写入器实例"""
global _batch_writer
if _batch_writer is None:
_batch_writer = BatchDatabaseWriter()
return _batch_writer
async def init_batch_writer():
"""初始化批量写入器"""
writer = get_batch_writer()
await writer.start()
async def shutdown_batch_writer():
"""关闭批量写入器"""
writer = get_batch_writer()
await writer.stop()

View File

@@ -23,6 +23,8 @@ class StreamLoopManager:
def __init__(self, max_concurrent_streams: int | None = None):
# 流循环任务管理
self.stream_loops: dict[str, asyncio.Task] = {}
# 跟踪流使用的管理器类型
self.stream_management_type: dict[str, str] = {} # stream_id -> "adaptive" or "fallback"
# 统计信息
self.stats: dict[str, Any] = {
@@ -99,7 +101,7 @@ class StreamLoopManager:
logger.info("流循环管理器已停止")
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
"""启动指定流的循环任务
"""启动指定流的循环任务 - 优化版本使用自适应管理器
Args:
stream_id: 流ID
@@ -113,6 +115,71 @@ class StreamLoopManager:
logger.debug(f"{stream_id} 循环已在运行")
return True
# 使用自适应流管理器获取槽位
use_adaptive = False
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority
adaptive_manager = get_adaptive_stream_manager()
if adaptive_manager.is_running:
# 确定流优先级
priority = self._determine_stream_priority(stream_id)
# 获取处理槽位
slot_acquired = await adaptive_manager.acquire_stream_slot(
stream_id=stream_id,
priority=priority,
force=force
)
if slot_acquired:
use_adaptive = True
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
else:
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
else:
logger.debug(f"自适应管理器未运行,使用原始方法")
except Exception as e:
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")
# 如果自适应管理器失败或未运行,使用回退方案
if not use_adaptive:
if not await self._fallback_acquire_slot(stream_id, force):
logger.debug(f"回退方案也失败: {stream_id}")
return False
# 创建流循环任务
try:
loop_task = asyncio.create_task(
self._stream_loop_worker(stream_id),
name=f"stream_loop_{stream_id}"
)
self.stream_loops[stream_id] = loop_task
# 记录管理器类型
self.stream_management_type[stream_id] = "adaptive" if use_adaptive else "fallback"
# 更新统计信息
self.stats["active_streams"] += 1
self.stats["total_loops"] += 1
logger.info(f"启动流循环任务: {stream_id} (管理器: {'adaptive' if use_adaptive else 'fallback'})")
return True
except Exception as e:
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
# 释放槽位
if use_adaptive:
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.release_stream_slot(stream_id)
except:
pass
return False
async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool:
"""回退方案:获取槽位(原始方法)"""
# 判断是否需要强制分发
should_force = force or self._should_force_dispatch_for_stream(stream_id)
@@ -149,6 +216,28 @@ class StreamLoopManager:
del self.stream_loops[stream_id]
current_streams -= 1 # 更新当前流数量
return True
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
"""确定流优先级"""
try:
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
# 这里可以基于流的历史数据、用户身份等确定优先级
# 简化版本基于流ID的哈希值分配优先级
hash_value = hash(stream_id) % 10
if hash_value >= 8: # 20% 高优先级
return StreamPriority.HIGH
elif hash_value >= 5: # 30% 中等优先级
return StreamPriority.NORMAL
else: # 50% 低优先级
return StreamPriority.LOW
except Exception:
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
return StreamPriority.NORMAL
# 创建流循环任务
try:
task = asyncio.create_task(
@@ -201,13 +290,13 @@ class StreamLoopManager:
logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})")
return True
async def _stream_loop(self, stream_id: str) -> None:
"""单个流的无限循环
async def _stream_loop_worker(self, stream_id: str) -> None:
"""单个流的工作循环 - 优化版本
Args:
stream_id: 流ID
"""
logger.info(f"流循环开始: {stream_id}")
logger.info(f"流循环工作器启动: {stream_id}")
try:
while self.is_running:
@@ -223,6 +312,18 @@ class StreamLoopManager:
unread_count = self._get_unread_count(context)
force_dispatch = self._needs_force_dispatch_for_context(context, unread_count)
# 3. 更新自适应管理器指标
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.update_stream_metrics(
stream_id,
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
last_activity=time.time()
)
except Exception as e:
logger.debug(f"更新流指标失败: {e}")
has_messages = force_dispatch or await self._has_messages_to_process(context)
if has_messages:
@@ -278,6 +379,24 @@ class StreamLoopManager:
del self.stream_loops[stream_id]
logger.debug(f"清理流循环标记: {stream_id}")
# 根据管理器类型释放相应的槽位
management_type = self.stream_management_type.get(stream_id, "fallback")
if management_type == "adaptive":
# 释放自适应管理器的槽位
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.release_stream_slot(stream_id)
logger.debug(f"释放自适应流处理槽位: {stream_id}")
except Exception as e:
logger.debug(f"释放自适应流处理槽位失败: {e}")
else:
logger.debug(f"{stream_id} 使用回退方案,无需释放自适应槽位")
# 清理管理器类型记录
if stream_id in self.stream_management_type:
del self.stream_management_type[stream_id]
logger.info(f"流循环结束: {stream_id}")
async def _get_stream_context(self, stream_id: str) -> Any | None:

View File

@@ -56,6 +56,30 @@ class MessageManager:
self.is_running = True
# 启动批量数据库写入器
try:
from src.chat.message_manager.batch_database_writer import init_batch_writer
await init_batch_writer()
logger.info("📦 批量数据库写入器已启动")
except Exception as e:
logger.error(f"启动批量数据库写入器失败: {e}")
# 启动流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
await init_stream_cache_manager()
logger.info("🗄️ 流缓存管理器已启动")
except Exception as e:
logger.error(f"启动流缓存管理器失败: {e}")
# 启动自适应流管理器
try:
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
await init_adaptive_stream_manager()
logger.info("🎯 自适应流管理器已启动")
except Exception as e:
logger.error(f"启动自适应流管理器失败: {e}")
# 启动睡眠和唤醒管理器
await self.wakeup_manager.start()
@@ -72,6 +96,30 @@ class MessageManager:
self.is_running = False
# 停止批量数据库写入器
try:
from src.chat.message_manager.batch_database_writer import shutdown_batch_writer
await shutdown_batch_writer()
logger.info("📦 批量数据库写入器已停止")
except Exception as e:
logger.error(f"停止批量数据库写入器失败: {e}")
# 停止流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
await shutdown_stream_cache_manager()
logger.info("🗄️ 流缓存管理器已停止")
except Exception as e:
logger.error(f"停止流缓存管理器失败: {e}")
# 停止自适应流管理器
try:
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
await shutdown_adaptive_stream_manager()
logger.info("🎯 自适应流管理器已停止")
except Exception as e:
logger.error(f"停止自适应流管理器失败: {e}")
# 停止睡眠和唤醒管理器
await self.wakeup_manager.stop()

View File

@@ -0,0 +1,381 @@
"""
流缓存管理器 - 使用优化版聊天流和智能缓存策略
提供分层缓存和自动清理功能
"""
import asyncio
import time
from typing import Dict, List, Optional, Set
from dataclasses import dataclass
from collections import OrderedDict
from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
logger = get_logger("stream_cache_manager")
@dataclass
class StreamCacheStats:
"""缓存统计信息"""
hot_cache_size: int = 0
warm_storage_size: int = 0
cold_storage_size: int = 0
total_memory_usage: int = 0 # 估算的内存使用(字节)
cache_hits: int = 0
cache_misses: int = 0
evictions: int = 0
last_cleanup_time: float = 0
class TieredStreamCache:
"""分层流缓存管理器"""
def __init__(
self,
max_hot_size: int = 100,
max_warm_size: int = 500,
max_cold_size: int = 2000,
cleanup_interval: float = 300.0, # 5分钟清理一次
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
cold_timeout: float = 86400.0, # 24小时未访问删除
):
self.max_hot_size = max_hot_size
self.max_warm_size = max_warm_size
self.max_cold_size = max_cold_size
self.cleanup_interval = cleanup_interval
self.hot_timeout = hot_timeout
self.warm_timeout = warm_timeout
self.cold_timeout = cold_timeout
# 三层缓存存储
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据LRU
self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
# 统计信息
self.stats = StreamCacheStats()
# 清理任务
self.cleanup_task: Optional[asyncio.Task] = None
self.is_running = False
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
async def start(self):
"""启动缓存管理器"""
if self.is_running:
logger.warning("缓存管理器已经在运行")
return
self.is_running = True
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
logger.info("分层流缓存管理器已启动")
async def stop(self):
"""停止缓存管理器"""
if not self.is_running:
return
self.is_running = False
if self.cleanup_task and not self.cleanup_task.done():
self.cleanup_task.cancel()
try:
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("缓存清理任务停止超时")
except Exception as e:
logger.error(f"停止缓存清理任务时出错: {e}")
logger.info("分层流缓存管理器已停止")
async def get_or_create_stream(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
) -> OptimizedChatStream:
"""获取或创建流 - 优化版本"""
current_time = time.time()
# 1. 检查热缓存
if stream_id in self.hot_cache:
stream = self.hot_cache[stream_id]
# 移动到末尾LRU更新
self.hot_cache.move_to_end(stream_id)
self.stats.cache_hits += 1
logger.debug(f"热缓存命中: {stream_id}")
return stream.create_snapshot()
# 2. 检查温存储
if stream_id in self.warm_storage:
stream, last_access = self.warm_storage[stream_id]
self.warm_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"温缓存命中: {stream_id}")
# 提升到热缓存
await self._promote_to_hot(stream_id, stream)
return stream.create_snapshot()
# 3. 检查冷存储
if stream_id in self.cold_storage:
stream, last_access = self.cold_storage[stream_id]
self.cold_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"冷缓存命中: {stream_id}")
# 提升到温缓存
await self._promote_to_warm(stream_id, stream)
return stream.create_snapshot()
# 4. 缓存未命中,创建新流
self.stats.cache_misses += 1
stream = create_optimized_chat_stream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
data=data
)
logger.debug(f"缓存未命中,创建新流: {stream_id}")
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
return stream
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""添加到热缓存"""
# 检查是否需要驱逐
if len(self.hot_cache) >= self.max_hot_size:
await self._evict_from_hot()
self.hot_cache[stream_id] = stream
self.stats.hot_cache_size = len(self.hot_cache)
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""提升到热缓存"""
# 从温存储中移除
if stream_id in self.warm_storage:
del self.warm_storage[stream_id]
self.stats.warm_storage_size = len(self.warm_storage)
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
logger.debug(f"{stream_id} 提升到热缓存")
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
"""提升到温缓存"""
# 从冷存储中移除
if stream_id in self.cold_storage:
del self.cold_storage[stream_id]
self.stats.cold_storage_size = len(self.cold_storage)
# 添加到温存储
if len(self.warm_storage) >= self.max_warm_size:
await self._evict_from_warm()
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
logger.debug(f"{stream_id} 提升到温缓存")
async def _evict_from_hot(self):
"""从热缓存驱逐最久未使用的流"""
if not self.hot_cache:
return
# LRU驱逐
stream_id, stream = self.hot_cache.popitem(last=False)
self.stats.evictions += 1
logger.debug(f"从热缓存驱逐: {stream_id}")
# 移动到温存储
if len(self.warm_storage) < self.max_warm_size:
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
else:
# 温存储也满了,直接删除
logger.debug(f"温存储已满,删除流: {stream_id}")
self.stats.hot_cache_size = len(self.hot_cache)
async def _evict_from_warm(self):
"""从温存储驱逐最久未使用的流"""
if not self.warm_storage:
return
# 找到最久未访问的流
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
stream, last_access = self.warm_storage.pop(oldest_stream_id)
self.stats.evictions += 1
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
# 移动到冷存储
if len(self.cold_storage) < self.max_cold_size:
current_time = time.time()
self.cold_storage[oldest_stream_id] = (stream, current_time)
self.stats.cold_storage_size = len(self.cold_storage)
else:
# 冷存储也满了,直接删除
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
self.stats.warm_storage_size = len(self.warm_storage)
async def _cleanup_loop(self):
"""清理循环"""
logger.info("流缓存清理循环启动")
while self.is_running:
try:
await asyncio.sleep(self.cleanup_interval)
await self._perform_cleanup()
except asyncio.CancelledError:
logger.info("流缓存清理循环被取消")
break
except Exception as e:
logger.error(f"流缓存清理出错: {e}")
logger.info("流缓存清理循环结束")
async def _perform_cleanup(self):
"""执行清理操作"""
current_time = time.time()
cleanup_stats = {
"hot_to_warm": 0,
"warm_to_cold": 0,
"cold_removed": 0,
}
# 1. 检查热缓存超时
hot_to_demote = []
for stream_id, stream in self.hot_cache.items():
# 获取最后访问时间(简化:使用创建时间作为近似)
last_access = getattr(stream, 'last_active_time', stream.create_time)
if current_time - last_access > self.hot_timeout:
hot_to_demote.append(stream_id)
for stream_id in hot_to_demote:
stream = self.hot_cache.pop(stream_id)
current_time_local = time.time()
self.warm_storage[stream_id] = (stream, current_time_local)
cleanup_stats["hot_to_warm"] += 1
# 2. 检查温存储超时
warm_to_demote = []
for stream_id, (stream, last_access) in self.warm_storage.items():
if current_time - last_access > self.warm_timeout:
warm_to_demote.append(stream_id)
for stream_id in warm_to_demote:
stream, last_access = self.warm_storage.pop(stream_id)
self.cold_storage[stream_id] = (stream, last_access)
cleanup_stats["warm_to_cold"] += 1
# 3. 检查冷存储超时
cold_to_remove = []
for stream_id, (stream, last_access) in self.cold_storage.items():
if current_time - last_access > self.cold_timeout:
cold_to_remove.append(stream_id)
for stream_id in cold_to_remove:
self.cold_storage.pop(stream_id)
cleanup_stats["cold_removed"] += 1
# 更新统计信息
self.stats.hot_cache_size = len(self.hot_cache)
self.stats.warm_storage_size = len(self.warm_storage)
self.stats.cold_storage_size = len(self.cold_storage)
self.stats.last_cleanup_time = current_time
# 估算内存使用(粗略估计)
self.stats.total_memory_usage = (
len(self.hot_cache) * 1024 + # 每个热流约1KB
len(self.warm_storage) * 512 + # 每个温流约512B
len(self.cold_storage) * 256 # 每个冷流约256B
)
if sum(cleanup_stats.values()) > 0:
logger.info(
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
f"{cleanup_stats['warm_to_cold']}温→冷, "
f"{cleanup_stats['cold_removed']}冷删除"
)
def get_stats(self) -> StreamCacheStats:
"""获取缓存统计信息"""
# 计算命中率
total_requests = self.stats.cache_hits + self.stats.cache_misses
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
stats_copy = StreamCacheStats(
hot_cache_size=self.stats.hot_cache_size,
warm_storage_size=self.stats.warm_storage_size,
cold_storage_size=self.stats.cold_storage_size,
total_memory_usage=self.stats.total_memory_usage,
cache_hits=self.stats.cache_hits,
cache_misses=self.stats.cache_misses,
evictions=self.stats.evictions,
last_cleanup_time=self.stats.last_cleanup_time,
)
# 添加命中率信息
stats_copy.hit_rate = hit_rate
return stats_copy
def clear_cache(self):
"""清空所有缓存"""
self.hot_cache.clear()
self.warm_storage.clear()
self.cold_storage.clear()
self.stats.hot_cache_size = 0
self.stats.warm_storage_size = 0
self.stats.cold_storage_size = 0
self.stats.total_memory_usage = 0
logger.info("所有缓存已清空")
async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]:
"""获取流的快照(不修改缓存状态)"""
if stream_id in self.hot_cache:
return self.hot_cache[stream_id].create_snapshot()
elif stream_id in self.warm_storage:
return self.warm_storage[stream_id][0].create_snapshot()
elif stream_id in self.cold_storage:
return self.cold_storage[stream_id][0].create_snapshot()
return None
def get_cached_stream_ids(self) -> Set[str]:
"""获取所有缓存的流ID"""
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
# 全局缓存管理器实例
_cache_manager: Optional[TieredStreamCache] = None
def get_stream_cache_manager() -> TieredStreamCache:
"""获取流缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
_cache_manager = TieredStreamCache()
return _cache_manager
async def init_stream_cache_manager():
"""初始化流缓存管理器"""
manager = get_stream_cache_manager()
await manager.start()
async def shutdown_stream_cache_manager():
"""关闭流缓存管理器"""
manager = get_stream_cache_manager()
await manager.stop()

View File

@@ -464,7 +464,7 @@ class ChatManager:
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
) -> ChatStream:
"""获取或创建聊天流
"""获取或创建聊天流 - 优化版本使用缓存管理器
Args:
platform: 平台标识
@@ -478,6 +478,31 @@ class ChatManager:
try:
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 优先使用缓存管理器(优化版本)
try:
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
cache_manager = get_stream_cache_manager()
if cache_manager.is_running:
optimized_stream = await cache_manager.get_or_create_stream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
# 设置消息上下文
from .message import MessageRecv
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
optimized_stream.set_context(self.last_messages[stream_id])
# 转换为原始ChatStream以保持兼容性
return self._convert_to_original_stream(optimized_stream)
except Exception as e:
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
# 回退到原始方法
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
@@ -634,12 +659,35 @@ class ChatManager:
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库"""
"""保存聊天流到数据库 - 优化版本使用异步批量写入"""
if stream.saved:
return
stream_data_dict = stream.to_dict()
# 尝试使用数据库批量调度
# 优先使用新的批量写入
try:
from src.chat.message_manager.batch_database_writer import get_batch_writer
batch_writer = get_batch_writer()
if batch_writer.is_running:
success = await batch_writer.schedule_stream_update(
stream_id=stream_data_dict["stream_id"],
update_data=ChatManager._prepare_stream_data(stream_data_dict),
priority=1 # 流更新的优先级
)
if success:
stream.saved = True
logger.debug(f"聊天流 {stream.stream_id} 通过批量写入器调度成功")
return
else:
logger.warning(f"批量写入器队列已满,使用原始方法: {stream.stream_id}")
else:
logger.debug(f"批量写入器未运行,使用原始方法: {stream.stream_id}")
except Exception as e:
logger.debug(f"批量写入器保存聊天流失败,使用原始方法: {e}")
# 尝试使用数据库批量调度器回退方案1
try:
from src.common.database.db_batch_scheduler import batch_update, get_batch_session
@@ -657,7 +705,7 @@ class ChatManager:
except (ImportError, Exception) as e:
logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}")
# 回退到原始方法
# 回退到原始方法(最终方案)
async def _db_save_stream_async(s_data_dict: dict):
async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info")
@@ -782,6 +830,46 @@ class ChatManager:
chat_manager = None
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
try:
# 创建原始ChatStream实例
original_stream = ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info()
)
# 复制状态
original_stream.create_time = optimized_stream.create_time
original_stream.last_active_time = optimized_stream.last_active_time
original_stream.sleep_pressure = optimized_stream.sleep_pressure
original_stream.base_interest_energy = optimized_stream.base_interest_energy
original_stream._focus_energy = optimized_stream._focus_energy
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
original_stream.saved = optimized_stream.saved
# 复制上下文信息(如果存在)
if hasattr(optimized_stream, '_stream_context') and optimized_stream._stream_context:
original_stream.stream_context = optimized_stream._stream_context
if hasattr(optimized_stream, '_context_manager') and optimized_stream._context_manager:
original_stream.context_manager = optimized_stream._context_manager
return original_stream
except Exception as e:
logger.error(f"转换OptimizedChatStream失败: {e}")
# 如果转换失败,创建一个新的原始流
return ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info()
)
def get_chat_manager():
global chat_manager
if chat_manager is None:

View File

@@ -0,0 +1,494 @@
"""
优化版聊天流 - 实现写时复制机制
避免不必要的深拷贝开销,提升多流并发性能
"""
import asyncio
import copy
import hashlib
import time
from typing import TYPE_CHECKING, Any, Dict, Optional
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from .message import MessageRecv
install(extra_lines=3)
logger = get_logger("optimized_chat_stream")
class SharedContext:
"""共享上下文数据 - 只读数据结构"""
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = time.time()
self._frozen = True
def __setattr__(self, name, value):
if hasattr(self, '_frozen') and self._frozen and name not in ['_frozen']:
raise AttributeError(f"SharedContext is frozen, cannot modify {name}")
super().__setattr__(name, value)
class LocalChanges:
"""本地修改跟踪器"""
def __init__(self):
self._changes: Dict[str, Any] = {}
self._dirty = False
def set_change(self, key: str, value: Any):
"""设置修改项"""
self._changes[key] = value
self._dirty = True
def get_change(self, key: str, default: Any = None) -> Any:
"""获取修改项"""
return self._changes.get(key, default)
def has_changes(self) -> bool:
"""是否有修改"""
return self._dirty
def get_changes(self) -> Dict[str, Any]:
"""获取所有修改"""
return self._changes.copy()
def clear_changes(self):
"""清除修改记录"""
self._changes.clear()
self._dirty = False
class OptimizedChatStream:
"""优化版聊天流 - 使用写时复制机制"""
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
):
# 共享的只读数据
self._shared_context = SharedContext(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
# 本地修改数据
self._local_changes = LocalChanges()
# 写时复制标志
self._copy_on_write = False
# 基础参数
self.base_interest_energy = data.get("base_interest_energy", 0.5) if data else 0.5
self._focus_energy = data.get("focus_energy", 0.5) if data else 0.5
self.no_reply_consecutive = 0
# 创建StreamContext延迟创建
self._stream_context = None
self._context_manager = None
# 更新活跃时间
self.update_active_time()
# 保存标志
self.saved = False
@property
def stream_id(self) -> str:
return self._shared_context.stream_id
@property
def platform(self) -> str:
return self._shared_context.platform
@property
def user_info(self) -> UserInfo:
return self._shared_context.user_info
@user_info.setter
def user_info(self, value: UserInfo):
"""修改用户信息时触发写时复制"""
self._ensure_copy_on_write()
# 由于SharedContext是frozen的我们需要在本地修改中记录
self._local_changes.set_change('user_info', value)
@property
def group_info(self) -> Optional[GroupInfo]:
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
return self._local_changes.get_change('group_info')
return self._shared_context.group_info
@group_info.setter
def group_info(self, value: Optional[GroupInfo]):
"""修改群组信息时触发写时复制"""
self._ensure_copy_on_write()
self._local_changes.set_change('group_info', value)
@property
def create_time(self) -> float:
if self._local_changes.has_changes() and 'create_time' in self._local_changes._changes:
return self._local_changes.get_change('create_time')
return self._shared_context.create_time
@property
def last_active_time(self) -> float:
return self._local_changes.get_change('last_active_time', self.create_time)
@last_active_time.setter
def last_active_time(self, value: float):
self._local_changes.set_change('last_active_time', value)
self.saved = False
@property
def sleep_pressure(self) -> float:
return self._local_changes.get_change('sleep_pressure', 0.0)
@sleep_pressure.setter
def sleep_pressure(self, value: float):
self._local_changes.set_change('sleep_pressure', value)
self.saved = False
def _ensure_copy_on_write(self):
"""确保写时复制机制生效"""
if not self._copy_on_write:
self._copy_on_write = True
# 深拷贝共享上下文到本地
logger.debug(f"触发写时复制: {self.stream_id}")
def _get_effective_user_info(self) -> UserInfo:
"""获取有效的用户信息"""
if self._local_changes.has_changes() and 'user_info' in self._local_changes._changes:
return self._local_changes.get_change('user_info')
return self._shared_context.user_info
def _get_effective_group_info(self) -> Optional[GroupInfo]:
"""获取有效的群组信息"""
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
return self._local_changes.get_change('group_info')
return self._shared_context.group_info
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = time.time()
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
# 确保stream_context存在
if self._stream_context is None:
self._ensure_copy_on_write()
self._create_stream_context()
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json
from src.common.data_models.database_data_model import DatabaseMessages
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
group_info = getattr(message_info, "group_info", {})
reply_to = None
if hasattr(message, "message_segment") and message.message_segment:
reply_to = self._extract_reply_from_segment(message.message_segment)
db_message = DatabaseMessages(
message_id=getattr(message, "message_id", ""),
time=getattr(message, "time", time.time()),
chat_id=self._generate_chat_id(message_info),
reply_to=reply_to,
interest_value=getattr(message, "interest_value", 0.0),
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
if getattr(message, "key_words", None)
else None,
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
if getattr(message, "key_words_lite", None)
else None,
is_mentioned=getattr(message, "is_mentioned", None),
is_at=getattr(message, "is_at", False),
is_emoji=getattr(message, "is_emoji", False),
is_picid=getattr(message, "is_picid", False),
is_voice=getattr(message, "is_voice", False),
is_video=getattr(message, "is_video", False),
is_command=getattr(message, "is_command", False),
is_notify=getattr(message, "is_notify", False),
processed_plain_text=getattr(message, "processed_plain_text", ""),
display_message=getattr(message, "processed_plain_text", ""),
priority_mode=getattr(message, "priority_mode", None),
priority_info=json.dumps(getattr(message, "priority_info", None))
if getattr(message, "priority_info", None)
else None,
additional_config=getattr(message_info, "additional_config", None),
user_id=str(getattr(user_info, "user_id", "")),
user_nickname=getattr(user_info, "user_nickname", ""),
user_cardname=getattr(user_info, "user_cardname", None),
user_platform=getattr(user_info, "platform", ""),
chat_info_group_id=getattr(group_info, "group_id", None),
chat_info_group_name=getattr(group_info, "group_name", None),
chat_info_group_platform=getattr(group_info, "platform", None),
chat_info_user_id=str(getattr(user_info, "user_id", "")),
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
chat_info_user_platform=getattr(user_info, "platform", ""),
chat_info_stream_id=self.stream_id,
chat_info_platform=self.platform,
chat_info_create_time=self.create_time,
chat_info_last_active_time=self.last_active_time,
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
)
self._stream_context.set_current_message(db_message)
self._stream_context.priority_mode = getattr(message, "priority_mode", None)
self._stream_context.priority_info = getattr(message, "priority_info", None)
logger.debug(
f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, "
f"interest_value: {db_message.interest_value}"
)
def _create_stream_context(self):
"""创建StreamContext"""
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
self._stream_context = StreamContext(
stream_id=self.stream_id,
chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL
)
# 创建单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
self._context_manager = SingleStreamContextManager(
stream_id=self.stream_id, context=self._stream_context
)
@property
def stream_context(self):
"""获取StreamContext"""
if self._stream_context is None:
self._ensure_copy_on_write()
self._create_stream_context()
return self._stream_context
@property
def context_manager(self):
"""获取ContextManager"""
if self._context_manager is None:
self._ensure_copy_on_write()
self._create_stream_context()
return self._context_manager
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式 - 考虑本地修改"""
user_info = self._get_effective_user_info()
group_info = self._get_effective_group_info()
return {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": user_info.to_dict() if user_info else None,
"group_info": group_info.to_dict() if group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
"sleep_pressure": self.sleep_pressure,
"focus_energy": self.focus_energy,
"base_interest_energy": self.base_interest_energy,
"stream_context_chat_type": self.stream_context.chat_type.value,
"stream_context_chat_mode": self.stream_context.chat_mode.value,
"interruption_count": self.stream_context.interruption_count,
}
@classmethod
def from_dict(cls, data: Dict) -> "OptimizedChatStream":
"""从字典创建实例"""
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
instance = cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info, # type: ignore
group_info=group_info,
data=data,
)
# 恢复stream_context信息
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
# 恢复interruption_count信息
if "interruption_count" in data:
instance.stream_context.interruption_count = data["interruption_count"]
return instance
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
"""安全获取消息的actions字段"""
try:
actions = getattr(message, "actions", None)
if actions is None:
return None
if isinstance(actions, str):
try:
import json
actions = json.loads(actions)
except json.JSONDecodeError:
logger.warning(f"无法解析actions JSON字符串: {actions}")
return None
if isinstance(actions, list):
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
return filtered_actions if filtered_actions else None
else:
logger.warning(f"actions字段类型不支持: {type(actions)}")
return None
except Exception as e:
logger.warning(f"获取actions字段失败: {e}")
return None
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = self._extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
def _generate_chat_id(self, message_info) -> str:
"""生成chat_id基于群组或用户信息"""
try:
group_info = getattr(message_info, "group_info", None)
user_info = getattr(message_info, "user_info", None)
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
return f"{self.platform}_{group_info.group_id}"
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
return f"{self.platform}_{user_info.user_id}_private"
else:
return self.stream_id
except Exception as e:
logger.warning(f"生成chat_id失败: {e}")
return self.stream_id
@property
def focus_energy(self) -> float:
"""获取缓存的focus_energy值"""
return self._focus_energy
async def calculate_focus_energy(self) -> float:
"""异步计算focus_energy"""
try:
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
user_id = None
effective_user_info = self._get_effective_user_info()
if effective_user_info and hasattr(effective_user_info, "user_id"):
user_id = str(effective_user_info.user_id)
from src.chat.energy_system import energy_manager
energy = await energy_manager.calculate_focus_energy(
stream_id=self.stream_id, messages=all_messages, user_id=user_id
)
self._focus_energy = energy
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
return energy
except Exception as e:
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
return self._focus_energy
@focus_energy.setter
def focus_energy(self, value: float):
"""设置focus_energy值"""
self._focus_energy = max(0.0, min(1.0, value))
async def _get_user_relationship_score(self) -> float:
"""获取用户关系分"""
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
effective_user_info = self._get_effective_user_info()
if effective_user_info and hasattr(effective_user_info, "user_id"):
user_id = str(effective_user_info.user_id)
relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
logger.debug(f"OptimizedChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
return max(0.0, min(1.0, relationship_score))
except Exception as e:
logger.warning(f"OptimizedChatStream {self.stream_id}: 插件内部关系分计算失败: {e}")
return 0.3
def create_snapshot(self) -> "OptimizedChatStream":
"""创建当前状态的快照(用于缓存)"""
# 创建一个新的实例,共享相同的上下文
snapshot = OptimizedChatStream(
stream_id=self.stream_id,
platform=self.platform,
user_info=self._get_effective_user_info(),
group_info=self._get_effective_group_info()
)
# 复制本地修改(但不触发写时复制)
snapshot._local_changes._changes = self._local_changes.get_changes()
snapshot._local_changes._dirty = self._local_changes._dirty
snapshot._focus_energy = self._focus_energy
snapshot.base_interest_energy = self.base_interest_energy
snapshot.no_reply_consecutive = self.no_reply_consecutive
snapshot.saved = self.saved
return snapshot
# 为了向后兼容,创建一个工厂函数
def create_optimized_chat_stream(
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
) -> OptimizedChatStream:
"""创建优化版聊天流实例"""
return OptimizedChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
data=data
)