Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -373,7 +373,11 @@ class VirtualLogDisplay:
|
||||
|
||||
# 为每个部分应用正确的标签
|
||||
current_len = 0
|
||||
for part, tag_name in zip(parts, tags, strict=False):
|
||||
# Python 3.9 兼容性:不使用 strict=False 参数
|
||||
min_len = min(len(parts), len(tags))
|
||||
for i in range(min_len):
|
||||
part = parts[i]
|
||||
tag_name = tags[i]
|
||||
start_index = f"{start_pos}+{current_len}c"
|
||||
end_index = f"{start_pos}+{current_len + len(part)}c"
|
||||
self.text_widget.tag_add(tag_name, start_index, end_index)
|
||||
|
||||
@@ -1,363 +0,0 @@
|
||||
"""
|
||||
增强记忆系统适配器
|
||||
将增强记忆系统集成到现有MoFox Bot架构中
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
MEMORY_TYPE_LABELS = {
|
||||
MemoryType.PERSONAL_FACT: "个人事实",
|
||||
MemoryType.EVENT: "事件",
|
||||
MemoryType.PREFERENCE: "偏好",
|
||||
MemoryType.OPINION: "观点",
|
||||
MemoryType.RELATIONSHIP: "关系",
|
||||
MemoryType.EMOTION: "情感",
|
||||
MemoryType.KNOWLEDGE: "知识",
|
||||
MemoryType.SKILL: "技能",
|
||||
MemoryType.GOAL: "目标",
|
||||
MemoryType.EXPERIENCE: "经验",
|
||||
MemoryType.CONTEXTUAL: "上下文",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig:
|
||||
"""适配器配置"""
|
||||
|
||||
enable_enhanced_memory: bool = True
|
||||
integration_mode: str = "enhanced_only" # replace, enhanced_only
|
||||
auto_migration: bool = True
|
||||
memory_value_threshold: float = 0.6
|
||||
fusion_threshold: float = 0.85
|
||||
max_retrieval_results: int = 10
|
||||
|
||||
|
||||
class EnhancedMemoryAdapter:
|
||||
"""增强记忆系统适配器"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or AdapterConfig()
|
||||
self.integration_layer: MemoryIntegrationLayer | None = None
|
||||
self._initialized = False
|
||||
|
||||
# 统计信息
|
||||
self.adapter_stats = {
|
||||
"total_processed": 0,
|
||||
"enhanced_used": 0,
|
||||
"legacy_used": 0,
|
||||
"hybrid_used": 0,
|
||||
"memories_created": 0,
|
||||
"memories_retrieved": 0,
|
||||
"average_processing_time": 0.0,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化适配器"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🚀 初始化增强记忆系统适配器...")
|
||||
|
||||
# 转换配置格式
|
||||
integration_config = IntegrationConfig(
|
||||
mode=IntegrationMode(self.config.integration_mode),
|
||||
enable_enhanced_memory=self.config.enable_enhanced_memory,
|
||||
memory_value_threshold=self.config.memory_value_threshold,
|
||||
fusion_threshold=self.config.fusion_threshold,
|
||||
max_retrieval_results=self.config.max_retrieval_results,
|
||||
enable_learning=True, # 启用学习功能
|
||||
)
|
||||
|
||||
# 创建集成层
|
||||
self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config)
|
||||
|
||||
# 初始化集成层
|
||||
await self.integration_layer.initialize()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ 增强记忆系统适配器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True)
|
||||
# 如果初始化失败,禁用增强记忆功能
|
||||
self.config.enable_enhanced_memory = False
|
||||
|
||||
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""处理对话记忆,以上下文为唯一输入"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"success": False, "error": "Enhanced memory not available"}
|
||||
|
||||
start_time = time.time()
|
||||
self.adapter_stats["total_processed"] += 1
|
||||
|
||||
try:
|
||||
payload_context: dict[str, Any] = dict(context or {})
|
||||
|
||||
conversation_text = payload_context.get("conversation_text")
|
||||
if not conversation_text:
|
||||
conversation_candidate = (
|
||||
payload_context.get("message_content")
|
||||
or payload_context.get("latest_message")
|
||||
or payload_context.get("raw_text")
|
||||
)
|
||||
if conversation_candidate is not None:
|
||||
conversation_text = str(conversation_candidate)
|
||||
payload_context["conversation_text"] = conversation_text
|
||||
else:
|
||||
conversation_text = ""
|
||||
else:
|
||||
conversation_text = str(conversation_text)
|
||||
|
||||
if "timestamp" not in payload_context:
|
||||
payload_context["timestamp"] = time.time()
|
||||
|
||||
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
|
||||
|
||||
# 使用集成层处理对话
|
||||
result = await self.integration_layer.process_conversation(payload_context)
|
||||
|
||||
# 更新统计
|
||||
processing_time = time.time() - start_time
|
||||
self._update_processing_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
created_count = len(result.get("created_memories", []))
|
||||
self.adapter_stats["memories_created"] += created_count
|
||||
logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return []
|
||||
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit)
|
||||
|
||||
self.adapter_stats["memories_retrieved"] += len(memories)
|
||||
logger.debug(f"检索到 {len(memories)} 条相关记忆")
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
|
||||
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
# 使用新的记忆格式化器
|
||||
formatter_config = FormatterConfig(
|
||||
include_timestamps=True,
|
||||
include_memory_types=True,
|
||||
include_confidence=False,
|
||||
use_emoji_icons=True,
|
||||
group_by_type=False,
|
||||
max_display_length=150,
|
||||
)
|
||||
|
||||
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
|
||||
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
|
||||
"""获取增强记忆系统摘要"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"available": False, "reason": "Not initialized or disabled"}
|
||||
|
||||
try:
|
||||
# 获取系统状态
|
||||
status = await self.integration_layer.get_system_status()
|
||||
|
||||
# 获取适配器统计
|
||||
adapter_stats = self.adapter_stats.copy()
|
||||
|
||||
# 获取集成统计
|
||||
integration_stats = self.integration_layer.get_integration_stats()
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"system_status": status,
|
||||
"adapter_stats": adapter_stats,
|
||||
"integration_stats": integration_stats,
|
||||
"total_memories_created": adapter_stats["memories_created"],
|
||||
"total_memories_retrieved": adapter_stats["memories_retrieved"],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True)
|
||||
return {"available": False, "error": str(e)}
|
||||
|
||||
def _update_processing_stats(self, processing_time: float):
|
||||
"""更新处理统计"""
|
||||
total_processed = self.adapter_stats["total_processed"]
|
||||
if total_processed > 0:
|
||||
current_avg = self.adapter_stats["average_processing_time"]
|
||||
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
|
||||
self.adapter_stats["average_processing_time"] = new_avg
|
||||
|
||||
def get_adapter_stats(self) -> dict[str, Any]:
|
||||
"""获取适配器统计信息"""
|
||||
return self.adapter_stats.copy()
|
||||
|
||||
async def maintenance(self):
|
||||
"""维护操作"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔧 增强记忆系统适配器维护...")
|
||||
await self.integration_layer.maintenance()
|
||||
logger.info("✅ 增强记忆系统适配器维护完成")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭适配器"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔄 关闭增强记忆系统适配器...")
|
||||
await self.integration_layer.shutdown()
|
||||
self._initialized = False
|
||||
logger.info("✅ 增强记忆系统适配器已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 全局适配器实例
|
||||
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
|
||||
|
||||
|
||||
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
|
||||
"""获取全局增强记忆适配器实例"""
|
||||
global _enhanced_memory_adapter
|
||||
|
||||
if _enhanced_memory_adapter is None:
|
||||
# 从配置中获取适配器配置
|
||||
from src.config.config import global_config
|
||||
|
||||
adapter_config = AdapterConfig(
|
||||
enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True),
|
||||
integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"),
|
||||
auto_migration=getattr(global_config.memory, "enable_memory_migration", True),
|
||||
memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6),
|
||||
fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85),
|
||||
max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10),
|
||||
)
|
||||
|
||||
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
|
||||
await _enhanced_memory_adapter.initialize()
|
||||
|
||||
return _enhanced_memory_adapter
|
||||
|
||||
|
||||
async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
||||
"""初始化增强记忆系统"""
|
||||
try:
|
||||
logger.info("🚀 初始化增强记忆系统...")
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
logger.info("✅ 增强记忆系统初始化完成")
|
||||
return adapter
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
async def process_conversation_with_enhanced_memory(
|
||||
context: dict[str, Any], llm_model: LLMRequest | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
payload_context = dict(context or {})
|
||||
|
||||
if "conversation_text" not in payload_context:
|
||||
conversation_candidate = (
|
||||
payload_context.get("message_content")
|
||||
or payload_context.get("latest_message")
|
||||
or payload_context.get("raw_text")
|
||||
)
|
||||
if conversation_candidate is not None:
|
||||
payload_context["conversation_text"] = str(conversation_candidate)
|
||||
|
||||
return await adapter.process_conversation_memory(payload_context)
|
||||
except Exception as e:
|
||||
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def retrieve_memories_with_enhanced_system(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""使用增强记忆系统检索记忆"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
return await adapter.retrieve_relevant_memories(query, user_id, context, limit)
|
||||
except Exception as e:
|
||||
logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
max_memories: int = 5,
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
|
||||
return ""
|
||||
@@ -1,194 +0,0 @@
|
||||
"""
|
||||
增强记忆系统钩子
|
||||
用于在消息处理过程中自动构建和检索记忆
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EnhancedMemoryHooks:
|
||||
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory
|
||||
self.processed_messages = set() # 避免重复处理
|
||||
|
||||
async def process_message_for_memory(
|
||||
self,
|
||||
message_content: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理消息并构建记忆
|
||||
|
||||
Args:
|
||||
message_content: 消息内容
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
message_id: 消息ID
|
||||
context: 上下文信息
|
||||
|
||||
Returns:
|
||||
bool: 是否成功处理
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
if message_id in self.processed_messages:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 确保增强记忆管理器已初始化
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
# 注入机器人基础人设,帮助记忆构建时避免记录自身信息
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
personality_config = getattr(global_config, "personality", None)
|
||||
bot_context = {}
|
||||
if bot_config is not None:
|
||||
bot_context["bot_name"] = getattr(bot_config, "nickname", None)
|
||||
bot_context["bot_aliases"] = list(getattr(bot_config, "alias_names", []) or [])
|
||||
bot_context["bot_account"] = getattr(bot_config, "qq_account", None)
|
||||
|
||||
if personality_config is not None:
|
||||
bot_context["bot_identity"] = getattr(personality_config, "identity", None)
|
||||
bot_context["bot_personality"] = getattr(personality_config, "personality_core", None)
|
||||
bot_context["bot_personality_side"] = getattr(personality_config, "personality_side", None)
|
||||
|
||||
# 构建上下文
|
||||
memory_context = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"timestamp": datetime.now().timestamp(),
|
||||
"message_type": "user_message",
|
||||
**bot_context,
|
||||
**(context or {}),
|
||||
}
|
||||
|
||||
# 处理对话并构建记忆
|
||||
memory_chunks = await enhanced_memory_manager.process_conversation(
|
||||
conversation_text=message_content,
|
||||
context=memory_context,
|
||||
user_id=user_id,
|
||||
timestamp=memory_context["timestamp"],
|
||||
)
|
||||
|
||||
# 标记消息已处理
|
||||
self.processed_messages.add(message_id)
|
||||
|
||||
# 限制处理历史大小
|
||||
if len(self.processed_messages) > 1000:
|
||||
# 移除最旧的500个记录
|
||||
self.processed_messages = set(list(self.processed_messages)[-500:])
|
||||
|
||||
logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆")
|
||||
return len(memory_chunks) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记忆失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_memory_for_response(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5,
|
||||
extra_context: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回记忆数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 相关记忆列表
|
||||
"""
|
||||
if not self.enabled:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 确保增强记忆管理器已初始化
|
||||
if not enhanced_memory_manager.is_initialized:
|
||||
await enhanced_memory_manager.initialize()
|
||||
|
||||
# 构建查询上下文
|
||||
context = {
|
||||
"chat_id": chat_id,
|
||||
"query_intent": "response_generation",
|
||||
"expected_memory_types": ["personal_fact", "event", "preference", "opinion"],
|
||||
}
|
||||
|
||||
if extra_context:
|
||||
context.update(extra_context)
|
||||
|
||||
# 获取相关记忆
|
||||
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
|
||||
query_text=query_text, user_id=user_id, context=context, limit=limit
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
results = []
|
||||
for result in enhanced_results:
|
||||
memory_dict = {
|
||||
"content": result.content,
|
||||
"type": result.memory_type,
|
||||
"confidence": result.confidence,
|
||||
"importance": result.importance,
|
||||
"timestamp": result.timestamp,
|
||||
"source": result.source,
|
||||
"relevance": result.relevance_score,
|
||||
"structure": result.structure,
|
||||
}
|
||||
results.append(memory_dict)
|
||||
|
||||
logger.debug(f"为回复查询到 {len(results)} 条相关记忆")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取回复记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def cleanup_old_memories(self):
|
||||
"""清理旧记忆"""
|
||||
try:
|
||||
if enhanced_memory_manager.is_initialized:
|
||||
# 调用增强记忆系统的维护功能
|
||||
await enhanced_memory_manager.enhanced_system.maintenance()
|
||||
logger.debug("增强记忆系统维护完成")
|
||||
except Exception as e:
|
||||
logger.error(f"清理旧记忆失败: {e}")
|
||||
|
||||
def clear_processed_cache(self):
|
||||
"""清除已处理消息的缓存"""
|
||||
self.processed_messages.clear()
|
||||
logger.debug("已清除消息处理缓存")
|
||||
|
||||
def enable(self):
|
||||
"""启用记忆钩子"""
|
||||
self.enabled = True
|
||||
logger.info("增强记忆钩子已启用")
|
||||
|
||||
def disable(self):
|
||||
"""禁用记忆钩子"""
|
||||
self.enabled = False
|
||||
logger.info("增强记忆钩子已禁用")
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
enhanced_memory_hooks = EnhancedMemoryHooks()
|
||||
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
增强记忆系统集成脚本
|
||||
用于在现有系统中无缝集成增强记忆功能
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def process_user_message_memory(
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
处理用户消息并构建记忆
|
||||
|
||||
Args:
|
||||
message_content: 消息内容
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
message_id: 消息ID
|
||||
context: 额外的上下文信息
|
||||
|
||||
Returns:
|
||||
bool: 是否成功构建记忆
|
||||
"""
|
||||
try:
|
||||
success = await enhanced_memory_hooks.process_message_for_memory(
|
||||
message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"成功为消息 {message_id} 构建记忆")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户消息记忆失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_relevant_memories_for_response(
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
Args:
|
||||
query_text: 查询文本(通常是用户的当前消息)
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回记忆数量限制
|
||||
extra_context: 额外上下文信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含记忆信息的字典
|
||||
"""
|
||||
try:
|
||||
memories = await enhanced_memory_hooks.get_memory_for_response(
|
||||
query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context
|
||||
)
|
||||
|
||||
result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)}
|
||||
|
||||
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取回复记忆失败: {e}")
|
||||
return {"has_memories": False, "memories": [], "memory_count": 0}
|
||||
|
||||
|
||||
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化记忆信息用于Prompt
|
||||
|
||||
Args:
|
||||
memories: 记忆信息字典
|
||||
|
||||
Returns:
|
||||
str: 格式化后的记忆文本
|
||||
"""
|
||||
if not memories["has_memories"]:
|
||||
return ""
|
||||
|
||||
memory_lines = ["以下是相关的记忆信息:"]
|
||||
|
||||
for memory in memories["memories"]:
|
||||
content = memory["content"]
|
||||
memory_type = memory["type"]
|
||||
confidence = memory["confidence"]
|
||||
importance = memory["importance"]
|
||||
|
||||
# 根据重要性添加不同的标记
|
||||
importance_marker = "🔥" if importance >= 3 else "⭐" if importance >= 2 else "📝"
|
||||
confidence_marker = "✅" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭"
|
||||
|
||||
memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)"
|
||||
memory_lines.append(memory_line)
|
||||
|
||||
return "\n".join(memory_lines)
|
||||
|
||||
|
||||
async def cleanup_memory_system():
|
||||
"""清理记忆系统"""
|
||||
try:
|
||||
await enhanced_memory_hooks.cleanup_old_memories()
|
||||
logger.info("记忆系统清理完成")
|
||||
except Exception as e:
|
||||
logger.error(f"记忆系统清理失败: {e}")
|
||||
|
||||
|
||||
def get_memory_system_status() -> dict[str, Any]:
|
||||
"""
|
||||
获取记忆系统状态
|
||||
|
||||
Returns:
|
||||
Dict: 系统状态信息
|
||||
"""
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
return {
|
||||
"enabled": enhanced_memory_hooks.enabled,
|
||||
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
|
||||
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
|
||||
"system_type": "enhanced_memory_system",
|
||||
}
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def remember_message(
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
便捷的记忆构建函数
|
||||
|
||||
Args:
|
||||
message: 要记住的消息
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
import uuid
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
return await process_user_message_memory(
|
||||
message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
|
||||
)
|
||||
|
||||
|
||||
async def recall_memories(
|
||||
query: str,
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
limit: int = 5,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
便捷的记忆检索函数
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
user_id: 用户ID
|
||||
chat_id: 聊天ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
Dict: 记忆信息
|
||||
"""
|
||||
return await get_relevant_memories_for_response(
|
||||
query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context
|
||||
)
|
||||
@@ -1,356 +0,0 @@
|
||||
"""
|
||||
增强重排序器
|
||||
实现文档设计的多维度评分模型
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class IntentType(Enum):
|
||||
"""对话意图类型"""
|
||||
|
||||
FACT_QUERY = "fact_query" # 事实查询
|
||||
EVENT_RECALL = "event_recall" # 事件回忆
|
||||
PREFERENCE_CHECK = "preference_check" # 偏好检查
|
||||
GENERAL_CHAT = "general_chat" # 一般对话
|
||||
UNKNOWN = "unknown" # 未知意图
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReRankingConfig:
|
||||
"""重排序配置"""
|
||||
|
||||
# 权重配置 (w1 + w2 + w3 + w4 = 1.0)
|
||||
semantic_weight: float = 0.5 # 语义相似度权重
|
||||
recency_weight: float = 0.2 # 时效性权重
|
||||
usage_freq_weight: float = 0.2 # 使用频率权重
|
||||
type_match_weight: float = 0.1 # 类型匹配权重
|
||||
|
||||
# 时效性衰减参数
|
||||
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
|
||||
|
||||
# 使用频率计算参数
|
||||
freq_log_base: float = 2.0 # 对数底数
|
||||
freq_max_score: float = 5.0 # 最大频率得分
|
||||
|
||||
# 类型匹配权重映射
|
||||
type_match_weights: dict[str, dict[str, float]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化类型匹配权重"""
|
||||
if self.type_match_weights is None:
|
||||
self.type_match_weights = {
|
||||
IntentType.FACT_QUERY.value: {
|
||||
MemoryType.PERSONAL_FACT.value: 1.0,
|
||||
MemoryType.KNOWLEDGE.value: 0.8,
|
||||
MemoryType.PREFERENCE.value: 0.5,
|
||||
MemoryType.EVENT.value: 0.3,
|
||||
"default": 0.3,
|
||||
},
|
||||
IntentType.EVENT_RECALL.value: {
|
||||
MemoryType.EVENT.value: 1.0,
|
||||
MemoryType.EXPERIENCE.value: 0.8,
|
||||
MemoryType.EMOTION.value: 0.6,
|
||||
MemoryType.PERSONAL_FACT.value: 0.5,
|
||||
"default": 0.5,
|
||||
},
|
||||
IntentType.PREFERENCE_CHECK.value: {
|
||||
MemoryType.PREFERENCE.value: 1.0,
|
||||
MemoryType.OPINION.value: 0.8,
|
||||
MemoryType.GOAL.value: 0.6,
|
||||
MemoryType.PERSONAL_FACT.value: 0.4,
|
||||
"default": 0.4,
|
||||
},
|
||||
IntentType.GENERAL_CHAT.value: {"default": 0.8},
|
||||
IntentType.UNKNOWN.value: {"default": 0.8},
|
||||
}
|
||||
|
||||
|
||||
class IntentClassifier:
|
||||
"""轻量级意图识别器"""
|
||||
|
||||
def __init__(self):
|
||||
# 关键词模式匹配规则
|
||||
self.patterns = {
|
||||
IntentType.FACT_QUERY: [
|
||||
# 中文模式
|
||||
"我是",
|
||||
"我的",
|
||||
"我叫",
|
||||
"我在",
|
||||
"我住在",
|
||||
"我的职业",
|
||||
"我的工作",
|
||||
"什么时候",
|
||||
"在哪里",
|
||||
"是什么",
|
||||
"多少",
|
||||
"几岁",
|
||||
"年龄",
|
||||
# 英文模式
|
||||
"what is",
|
||||
"where is",
|
||||
"when is",
|
||||
"how old",
|
||||
"my name",
|
||||
"i am",
|
||||
"i live",
|
||||
],
|
||||
IntentType.EVENT_RECALL: [
|
||||
# 中文模式
|
||||
"记得",
|
||||
"想起",
|
||||
"还记得",
|
||||
"那次",
|
||||
"上次",
|
||||
"之前",
|
||||
"以前",
|
||||
"曾经",
|
||||
"发生过",
|
||||
"经历",
|
||||
"做过",
|
||||
"去过",
|
||||
"见过",
|
||||
# 英文模式
|
||||
"remember",
|
||||
"recall",
|
||||
"last time",
|
||||
"before",
|
||||
"previously",
|
||||
"happened",
|
||||
"experience",
|
||||
],
|
||||
IntentType.PREFERENCE_CHECK: [
|
||||
# 中文模式
|
||||
"喜欢",
|
||||
"不喜欢",
|
||||
"偏好",
|
||||
"爱好",
|
||||
"兴趣",
|
||||
"讨厌",
|
||||
"最爱",
|
||||
"最喜欢",
|
||||
"习惯",
|
||||
"通常",
|
||||
"一般",
|
||||
"倾向于",
|
||||
"更喜欢",
|
||||
# 英文模式
|
||||
"like",
|
||||
"love",
|
||||
"hate",
|
||||
"prefer",
|
||||
"favorite",
|
||||
"usually",
|
||||
"tend to",
|
||||
"interest",
|
||||
],
|
||||
}
|
||||
|
||||
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
|
||||
"""识别对话意图"""
|
||||
if not query:
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
query_lower = query.lower()
|
||||
|
||||
# 统计各意图的匹配分数
|
||||
intent_scores = dict.fromkeys(IntentType, 0)
|
||||
|
||||
for intent, patterns in self.patterns.items():
|
||||
for pattern in patterns:
|
||||
if pattern in query_lower:
|
||||
intent_scores[intent] += 1
|
||||
|
||||
# 返回得分最高的意图
|
||||
max_score = max(intent_scores.values())
|
||||
if max_score == 0:
|
||||
return IntentType.GENERAL_CHAT
|
||||
|
||||
for intent, score in intent_scores.items():
|
||||
if score == max_score:
|
||||
return intent
|
||||
|
||||
return IntentType.GENERAL_CHAT
|
||||
|
||||
|
||||
class EnhancedReRanker:
|
||||
"""增强重排序器 - 实现文档设计的多维度评分模型"""
|
||||
|
||||
def __init__(self, config: ReRankingConfig | None = None):
|
||||
self.config = config or ReRankingConfig()
|
||||
self.intent_classifier = IntentClassifier()
|
||||
|
||||
# 验证权重和为1.0
|
||||
total_weight = (
|
||||
self.config.semantic_weight
|
||||
+ self.config.recency_weight
|
||||
+ self.config.usage_freq_weight
|
||||
+ self.config.type_match_weight
|
||||
)
|
||||
|
||||
if abs(total_weight - 1.0) > 0.01:
|
||||
logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化")
|
||||
# 归一化权重
|
||||
self.config.semantic_weight /= total_weight
|
||||
self.config.recency_weight /= total_weight
|
||||
self.config.usage_freq_weight /= total_weight
|
||||
self.config.type_match_weight /= total_weight
|
||||
|
||||
def rerank_memories(
|
||||
self,
|
||||
query: str,
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
对候选记忆进行重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)]
|
||||
context: 上下文信息
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
重排序后的记忆列表 [(memory_id, memory, final_score)]
|
||||
"""
|
||||
if not candidate_memories:
|
||||
return []
|
||||
|
||||
# 识别查询意图
|
||||
intent = self.intent_classifier.classify_intent(query, context)
|
||||
logger.debug(f"识别到查询意图: {intent.value}")
|
||||
|
||||
# 计算每个候选记忆的最终得分
|
||||
scored_memories = []
|
||||
current_time = time.time()
|
||||
|
||||
for memory_id, memory, vector_sim in candidate_memories:
|
||||
try:
|
||||
# 1. 语义相似度得分 (已归一化到[0,1])
|
||||
semantic_score = self._normalize_similarity(vector_sim)
|
||||
|
||||
# 2. 时效性得分
|
||||
recency_score = self._calculate_recency_score(memory, current_time)
|
||||
|
||||
# 3. 使用频率得分
|
||||
usage_freq_score = self._calculate_usage_frequency_score(memory)
|
||||
|
||||
# 4. 类型匹配得分
|
||||
type_match_score = self._calculate_type_match_score(memory, intent)
|
||||
|
||||
# 计算最终得分
|
||||
final_score = (
|
||||
self.config.semantic_weight * semantic_score
|
||||
+ self.config.recency_weight * recency_score
|
||||
+ self.config.usage_freq_weight * usage_freq_score
|
||||
+ self.config.type_match_weight * type_match_score
|
||||
)
|
||||
|
||||
scored_memories.append((memory_id, memory, final_score))
|
||||
|
||||
# 记录调试信息
|
||||
logger.debug(
|
||||
f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, "
|
||||
f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, "
|
||||
f"type={type_match_score:.3f}, final={final_score:.3f}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算记忆 {memory_id} 得分时出错: {e}")
|
||||
# 使用向量相似度作为后备得分
|
||||
scored_memories.append((memory_id, memory, vector_sim))
|
||||
|
||||
# 按最终得分降序排序
|
||||
scored_memories.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# 返回前N个结果
|
||||
result = scored_memories[:limit]
|
||||
|
||||
highest_score = result[0][2] if result else 0.0
|
||||
logger.info(
|
||||
f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, "
|
||||
f"意图={intent.value}, 最高分={highest_score:.3f}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _normalize_similarity(self, raw_similarity: float) -> float:
|
||||
"""归一化相似度到[0,1]区间"""
|
||||
# 假设原始相似度已经在[-1,1]或[0,1]区间
|
||||
if raw_similarity < 0:
|
||||
return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1]
|
||||
return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间
|
||||
|
||||
def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float:
|
||||
"""
|
||||
计算时效性得分
|
||||
公式: Recency = 1 / (1 + decay_rate * days_old)
|
||||
"""
|
||||
last_accessed = memory.metadata.last_accessed or memory.metadata.created_at
|
||||
days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数
|
||||
|
||||
if days_old < 0:
|
||||
days_old = 0 # 处理时间异常
|
||||
|
||||
score = 1 / (1 + self.config.recency_decay_rate * days_old)
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float:
|
||||
"""
|
||||
计算使用频率得分
|
||||
公式: Usage_Freq = min(1.0, log2(access_count + 1) / max_score)
|
||||
"""
|
||||
access_count = memory.metadata.access_count
|
||||
if access_count <= 0:
|
||||
return 0.0
|
||||
|
||||
log_count = math.log2(access_count + 1)
|
||||
score = log_count / self.config.freq_max_score
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float:
|
||||
"""计算类型匹配得分"""
|
||||
memory_type = memory.memory_type.value
|
||||
intent_value = intent.value
|
||||
|
||||
# 获取对应意图的类型权重映射
|
||||
type_weights = self.config.type_match_weights.get(intent_value, {})
|
||||
|
||||
# 查找具体类型的权重,如果没有则使用默认权重
|
||||
score = type_weights.get(memory_type, type_weights.get("default", 0.8))
|
||||
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
|
||||
# 创建默认的重排序器实例
|
||||
default_reranker = EnhancedReRanker()
|
||||
|
||||
|
||||
def rerank_candidate_memories(
|
||||
query: str,
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]],
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
config: ReRankingConfig | None = None,
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
便捷函数:对候选记忆进行重排序
|
||||
"""
|
||||
if config:
|
||||
reranker = EnhancedReRanker(config)
|
||||
else:
|
||||
reranker = default_reranker
|
||||
|
||||
return reranker.rerank_memories(query, candidate_memories, context, limit)
|
||||
@@ -1,245 +0,0 @@
|
||||
"""
|
||||
增强记忆系统集成层
|
||||
现在只管理新的增强记忆系统,旧系统已被完全移除
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class IntegrationMode(Enum):
|
||||
"""集成模式"""
|
||||
|
||||
REPLACE = "replace" # 完全替换现有记忆系统
|
||||
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntegrationConfig:
|
||||
"""集成配置"""
|
||||
|
||||
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
|
||||
enable_enhanced_memory: bool = True
|
||||
memory_value_threshold: float = 0.6
|
||||
fusion_threshold: float = 0.85
|
||||
max_retrieval_results: int = 10
|
||||
enable_learning: bool = True
|
||||
|
||||
|
||||
class MemoryIntegrationLayer:
|
||||
"""记忆系统集成层 - 现在只管理增强记忆系统"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or IntegrationConfig()
|
||||
|
||||
# 只初始化增强记忆系统
|
||||
self.enhanced_memory: EnhancedMemorySystem | None = None
|
||||
|
||||
# 集成统计
|
||||
self.integration_stats = {
|
||||
"total_queries": 0,
|
||||
"enhanced_queries": 0,
|
||||
"memory_creations": 0,
|
||||
"average_response_time": 0.0,
|
||||
"success_rate": 0.0,
|
||||
}
|
||||
|
||||
# 初始化锁
|
||||
self._initialization_lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化集成层"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._initialization_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("🚀 开始初始化增强记忆系统集成层...")
|
||||
|
||||
try:
|
||||
# 初始化增强记忆系统
|
||||
if self.config.enable_enhanced_memory:
|
||||
await self._initialize_enhanced_memory()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("✅ 增强记忆系统集成层初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _initialize_enhanced_memory(self):
|
||||
"""初始化增强记忆系统"""
|
||||
try:
|
||||
logger.debug("初始化增强记忆系统...")
|
||||
|
||||
# 创建增强记忆系统配置
|
||||
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
|
||||
|
||||
memory_config = MemorySystemConfig.from_global_config()
|
||||
|
||||
# 使用集成配置覆盖部分值
|
||||
memory_config.memory_value_threshold = self.config.memory_value_threshold
|
||||
memory_config.fusion_similarity_threshold = self.config.fusion_threshold
|
||||
memory_config.final_recall_limit = self.config.max_retrieval_results
|
||||
|
||||
# 创建增强记忆系统
|
||||
self.enhanced_memory = EnhancedMemorySystem(config=memory_config)
|
||||
|
||||
# 如果外部提供了LLM模型,注入到系统中
|
||||
if self.llm_model is not None:
|
||||
self.enhanced_memory.llm_model = self.llm_model
|
||||
|
||||
# 初始化系统
|
||||
await self.enhanced_memory.initialize()
|
||||
logger.info("✅ 增强记忆系统初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""处理对话记忆,仅使用上下文信息"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return {"success": False, "error": "Memory system not available"}
|
||||
|
||||
start_time = time.time()
|
||||
self.integration_stats["total_queries"] += 1
|
||||
self.integration_stats["enhanced_queries"] += 1
|
||||
|
||||
try:
|
||||
payload_context = dict(context or {})
|
||||
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
|
||||
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
|
||||
|
||||
# 直接使用增强记忆系统处理
|
||||
result = await self.enhanced_memory.process_conversation_memory(payload_context)
|
||||
|
||||
# 更新统计
|
||||
processing_time = time.time() - start_time
|
||||
self._update_response_stats(processing_time, result.get("success", False))
|
||||
|
||||
if result.get("success"):
|
||||
created_count = len(result.get("created_memories", []))
|
||||
self.integration_stats["memory_creations"] += created_count
|
||||
logger.debug(f"对话处理完成,创建 {created_count} 条记忆")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
self._update_response_stats(processing_time, False)
|
||||
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return []
|
||||
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.enhanced_memory.retrieve_relevant_memories(
|
||||
query=query, user_id=None, context=context or {}, limit=limit
|
||||
)
|
||||
|
||||
memory_count = len(memories)
|
||||
logger.debug(f"检索到 {memory_count} 条相关记忆")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_system_status(self) -> dict[str, Any]:
|
||||
"""获取系统状态"""
|
||||
if not self._initialized:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
try:
|
||||
enhanced_status = {}
|
||||
if self.enhanced_memory:
|
||||
enhanced_status = await self.enhanced_memory.get_system_status()
|
||||
|
||||
return {
|
||||
"status": "initialized",
|
||||
"mode": self.config.mode.value,
|
||||
"enhanced_memory": enhanced_status,
|
||||
"integration_stats": self.integration_stats.copy(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取系统状态失败: {e}", exc_info=True)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_integration_stats(self) -> dict[str, Any]:
|
||||
"""获取集成统计信息"""
|
||||
return self.integration_stats.copy()
|
||||
|
||||
def _update_response_stats(self, processing_time: float, success: bool):
|
||||
"""更新响应统计"""
|
||||
total_queries = self.integration_stats["total_queries"]
|
||||
if total_queries > 0:
|
||||
# 更新平均响应时间
|
||||
current_avg = self.integration_stats["average_response_time"]
|
||||
new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries
|
||||
self.integration_stats["average_response_time"] = new_avg
|
||||
|
||||
# 更新成功率
|
||||
if success:
|
||||
current_success_rate = self.integration_stats["success_rate"]
|
||||
new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries
|
||||
self.integration_stats["success_rate"] = new_success_rate
|
||||
|
||||
async def maintenance(self):
|
||||
"""执行维护操作"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔧 执行记忆系统集成层维护...")
|
||||
|
||||
if self.enhanced_memory:
|
||||
await self.enhanced_memory.maintenance()
|
||||
|
||||
logger.info("✅ 记忆系统集成层维护完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 集成层维护失败: {e}", exc_info=True)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭集成层"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔄 关闭记忆系统集成层...")
|
||||
|
||||
if self.enhanced_memory:
|
||||
await self.enhanced_memory.shutdown()
|
||||
|
||||
self._initialized = False
|
||||
logger.info("✅ 记忆系统集成层已关闭")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)
|
||||
@@ -1,526 +0,0 @@
|
||||
"""
|
||||
记忆系统集成钩子
|
||||
提供与现有MoFox Bot系统的无缝集成点
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_adapter import (
|
||||
get_memory_context_for_prompt,
|
||||
process_conversation_with_enhanced_memory,
|
||||
retrieve_memories_with_enhanced_system,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""钩子执行结果"""
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: str | None = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
|
||||
class MemoryIntegrationHooks:
|
||||
"""记忆系统集成钩子"""
|
||||
|
||||
def __init__(self):
|
||||
self.hooks_registered = False
|
||||
self.hook_stats = {
|
||||
"message_processing_hooks": 0,
|
||||
"memory_retrieval_hooks": 0,
|
||||
"prompt_enhancement_hooks": 0,
|
||||
"total_hook_executions": 0,
|
||||
"average_hook_time": 0.0,
|
||||
}
|
||||
|
||||
async def register_hooks(self):
|
||||
"""注册所有集成钩子"""
|
||||
if self.hooks_registered:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🔗 注册记忆系统集成钩子...")
|
||||
|
||||
# 注册消息处理钩子
|
||||
await self._register_message_processing_hooks()
|
||||
|
||||
# 注册记忆检索钩子
|
||||
await self._register_memory_retrieval_hooks()
|
||||
|
||||
# 注册提示词增强钩子
|
||||
await self._register_prompt_enhancement_hooks()
|
||||
|
||||
# 注册系统维护钩子
|
||||
await self._register_maintenance_hooks()
|
||||
|
||||
self.hooks_registered = True
|
||||
logger.info("✅ 记忆系统集成钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True)
|
||||
|
||||
async def _register_message_processing_hooks(self):
|
||||
"""注册消息处理钩子"""
|
||||
try:
|
||||
# 钩子1: 在消息处理后创建记忆
|
||||
await self._register_post_message_hook()
|
||||
|
||||
# 钩子2: 在聊天流保存时处理记忆
|
||||
await self._register_chat_stream_hook()
|
||||
|
||||
logger.debug("消息处理钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册消息处理钩子失败: {e}")
|
||||
|
||||
async def _register_memory_retrieval_hooks(self):
|
||||
"""注册记忆检索钩子"""
|
||||
try:
|
||||
# 钩子1: 在生成回复前检索相关记忆
|
||||
await self._register_pre_response_hook()
|
||||
|
||||
# 钩子2: 在知识库查询前增强上下文
|
||||
await self._register_knowledge_query_hook()
|
||||
|
||||
logger.debug("记忆检索钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册记忆检索钩子失败: {e}")
|
||||
|
||||
async def _register_prompt_enhancement_hooks(self):
|
||||
"""注册提示词增强钩子"""
|
||||
try:
|
||||
# 钩子1: 增强提示词构建
|
||||
await self._register_prompt_building_hook()
|
||||
|
||||
logger.debug("提示词增强钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册提示词增强钩子失败: {e}")
|
||||
|
||||
async def _register_maintenance_hooks(self):
|
||||
"""注册系统维护钩子"""
|
||||
try:
|
||||
# 钩子1: 系统维护时的记忆系统维护
|
||||
await self._register_system_maintenance_hook()
|
||||
|
||||
logger.debug("系统维护钩子注册完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册系统维护钩子失败: {e}")
|
||||
|
||||
async def _register_post_message_hook(self):
|
||||
"""注册消息后处理钩子"""
|
||||
try:
|
||||
# 这里需要根据实际的系统架构来注册钩子
|
||||
# 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整
|
||||
|
||||
# 尝试注册到事件系统
|
||||
try:
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
# 注册消息后处理事件
|
||||
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
|
||||
logger.debug("已注册到事件系统的消息处理钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("事件系统不可用,跳过事件钩子注册")
|
||||
|
||||
# 尝试注册到消息管理器
|
||||
try:
|
||||
from src.chat.message_manager import message_manager
|
||||
|
||||
# 如果消息管理器支持钩子注册
|
||||
if hasattr(message_manager, "register_post_process_hook"):
|
||||
message_manager.register_post_process_hook(self._on_message_processed_hook)
|
||||
logger.debug("已注册到消息管理器的处理钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("消息管理器不可用,跳过消息管理器钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册消息后处理钩子失败: {e}")
|
||||
|
||||
async def _register_chat_stream_hook(self):
|
||||
"""注册聊天流钩子"""
|
||||
try:
|
||||
# 尝试注册到聊天流管理器
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if hasattr(chat_manager, "register_save_hook"):
|
||||
chat_manager.register_save_hook(self._on_chat_stream_save_hook)
|
||||
logger.debug("已注册到聊天流管理器的保存钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("聊天流管理器不可用,跳过聊天流钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册聊天流钩子失败: {e}")
|
||||
|
||||
async def _register_pre_response_hook(self):
|
||||
"""注册回复前钩子"""
|
||||
try:
|
||||
# 尝试注册到回复生成器
|
||||
try:
|
||||
from src.chat.replyer.default_generator import default_generator
|
||||
|
||||
if hasattr(default_generator, "register_pre_generation_hook"):
|
||||
default_generator.register_pre_generation_hook(self._on_pre_response_hook)
|
||||
logger.debug("已注册到回复生成器的前置钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("回复生成器不可用,跳过回复前钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册回复前钩子失败: {e}")
|
||||
|
||||
async def _register_knowledge_query_hook(self):
|
||||
"""注册知识库查询钩子"""
|
||||
try:
|
||||
# 尝试注册到知识库系统
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import knowledge_manager
|
||||
|
||||
if hasattr(knowledge_manager, "register_query_enhancer"):
|
||||
knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook)
|
||||
logger.debug("已注册到知识库的查询增强钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("知识库系统不可用,跳过知识库钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册知识库查询钩子失败: {e}")
|
||||
|
||||
async def _register_prompt_building_hook(self):
|
||||
"""注册提示词构建钩子"""
|
||||
try:
|
||||
# 尝试注册到提示词系统
|
||||
try:
|
||||
from src.chat.utils.prompt import prompt_manager
|
||||
|
||||
if hasattr(prompt_manager, "register_enhancer"):
|
||||
prompt_manager.register_enhancer(self._on_prompt_building_hook)
|
||||
logger.debug("已注册到提示词管理器的增强钩子")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("提示词系统不可用,跳过提示词钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册提示词构建钩子失败: {e}")
|
||||
|
||||
async def _register_system_maintenance_hook(self):
|
||||
"""注册系统维护钩子"""
|
||||
try:
|
||||
# 尝试注册到系统维护器
|
||||
try:
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
|
||||
# 注册定期维护任务
|
||||
async_task_manager.add_task(MemoryMaintenanceTask())
|
||||
logger.debug("已注册到系统维护器的定期任务")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("异步任务管理器不可用,跳过系统维护钩子注册")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"注册系统维护钩子失败: {e}")
|
||||
|
||||
# 钩子处理器方法
|
||||
|
||||
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
|
||||
"""事件系统的消息处理处理器"""
|
||||
return await self._on_message_processed_hook(event_data)
|
||||
|
||||
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
|
||||
"""消息后处理钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["message_processing_hooks"] += 1
|
||||
|
||||
# 提取必要的信息
|
||||
message_info = message_data.get("message_info", {})
|
||||
user_info = message_info.get("user_info", {})
|
||||
conversation_text = message_data.get("processed_plain_text", "")
|
||||
|
||||
if not conversation_text:
|
||||
return HookResult(success=True, data="No conversation text")
|
||||
|
||||
user_id = str(user_info.get("user_id", "unknown"))
|
||||
context = {
|
||||
"chat_id": message_data.get("chat_id"),
|
||||
"message_type": message_data.get("message_type", "normal"),
|
||||
"platform": message_info.get("platform", "unknown"),
|
||||
"interest_value": message_data.get("interest_value", 0.0),
|
||||
"keywords": message_data.get("key_words", []),
|
||||
"timestamp": message_data.get("time", time.time()),
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
memory_context = dict(context)
|
||||
memory_context["conversation_text"] = conversation_text
|
||||
memory_context["user_id"] = user_id
|
||||
|
||||
result = await process_conversation_with_enhanced_memory(memory_context)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
|
||||
return HookResult(success=True, data=result, processing_time=processing_time)
|
||||
else:
|
||||
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
|
||||
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
|
||||
"""聊天流保存钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["message_processing_hooks"] += 1
|
||||
|
||||
# 从聊天流数据中提取对话信息
|
||||
stream_context = chat_stream_data.get("stream_context", {})
|
||||
user_id = stream_context.get("user_id", "unknown")
|
||||
messages = stream_context.get("messages", [])
|
||||
|
||||
if not messages:
|
||||
return HookResult(success=True, data="No messages to process")
|
||||
|
||||
# 构建对话文本
|
||||
conversation_parts = []
|
||||
for msg in messages[-10:]: # 只处理最近10条消息
|
||||
text = msg.get("processed_plain_text", "")
|
||||
if text:
|
||||
conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}")
|
||||
|
||||
conversation_text = "\n".join(conversation_parts)
|
||||
if not conversation_text:
|
||||
return HookResult(success=True, data="No conversation text")
|
||||
|
||||
context = {
|
||||
"chat_id": chat_stream_data.get("chat_id"),
|
||||
"stream_id": chat_stream_data.get("stream_id"),
|
||||
"platform": chat_stream_data.get("platform", "unknown"),
|
||||
"message_count": len(messages),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
memory_context = dict(context)
|
||||
memory_context["conversation_text"] = conversation_text
|
||||
memory_context["user_id"] = user_id
|
||||
|
||||
result = await process_conversation_with_enhanced_memory(memory_context)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
if result["success"]:
|
||||
logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
|
||||
return HookResult(success=True, data=result, processing_time=processing_time)
|
||||
else:
|
||||
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
|
||||
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
|
||||
"""回复前钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["memory_retrieval_hooks"] += 1
|
||||
|
||||
# 提取查询信息
|
||||
query = response_data.get("query", "")
|
||||
user_id = response_data.get("user_id", "unknown")
|
||||
context = response_data.get("context", {})
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 检索相关记忆
|
||||
memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 将记忆添加到响应数据中
|
||||
response_data["enhanced_memories"] = memories
|
||||
response_data["enhanced_memory_context"] = await get_memory_context_for_prompt(
|
||||
query, user_id, context, max_memories=5
|
||||
)
|
||||
|
||||
logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆")
|
||||
return HookResult(success=True, data=memories, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
|
||||
"""知识库查询钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["memory_retrieval_hooks"] += 1
|
||||
|
||||
query = query_data.get("query", "")
|
||||
user_id = query_data.get("user_id", "unknown")
|
||||
context = query_data.get("context", {})
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 获取记忆上下文并增强查询
|
||||
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 将记忆上下文添加到查询数据中
|
||||
query_data["enhanced_memory_context"] = memory_context
|
||||
|
||||
logger.debug("知识库查询钩子执行成功")
|
||||
return HookResult(success=True, data=memory_context, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
|
||||
"""提示词构建钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.hook_stats["prompt_enhancement_hooks"] += 1
|
||||
|
||||
query = prompt_data.get("query", "")
|
||||
user_id = prompt_data.get("user_id", "unknown")
|
||||
context = prompt_data.get("context", {})
|
||||
base_prompt = prompt_data.get("base_prompt", "")
|
||||
|
||||
if not query:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 获取记忆上下文
|
||||
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
# 构建增强的提示词
|
||||
enhanced_prompt = base_prompt
|
||||
if memory_context:
|
||||
enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n"
|
||||
|
||||
# 将增强的提示词添加到数据中
|
||||
prompt_data["enhanced_prompt"] = enhanced_prompt
|
||||
prompt_data["memory_context"] = memory_context
|
||||
|
||||
logger.debug("提示词构建钩子执行成功")
|
||||
return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
def _update_hook_stats(self, processing_time: float):
|
||||
"""更新钩子统计"""
|
||||
self.hook_stats["total_hook_executions"] += 1
|
||||
|
||||
total_executions = self.hook_stats["total_hook_executions"]
|
||||
if total_executions > 0:
|
||||
current_avg = self.hook_stats["average_hook_time"]
|
||||
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
|
||||
self.hook_stats["average_hook_time"] = new_avg
|
||||
|
||||
def get_hook_stats(self) -> dict[str, Any]:
|
||||
"""获取钩子统计信息"""
|
||||
return self.hook_stats.copy()
|
||||
|
||||
|
||||
class MemoryMaintenanceTask:
|
||||
"""记忆系统维护任务"""
|
||||
|
||||
def __init__(self):
|
||||
self.task_name = "enhanced_memory_maintenance"
|
||||
self.interval = 3600 # 1小时执行一次
|
||||
|
||||
async def execute(self):
|
||||
"""执行维护任务"""
|
||||
try:
|
||||
logger.info("🔧 执行增强记忆系统维护任务...")
|
||||
|
||||
# 获取适配器实例
|
||||
try:
|
||||
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
|
||||
|
||||
if _enhanced_memory_adapter:
|
||||
await _enhanced_memory_adapter.maintenance()
|
||||
logger.info("✅ 增强记忆系统维护任务完成")
|
||||
else:
|
||||
logger.debug("增强记忆适配器未初始化,跳过维护")
|
||||
except Exception as e:
|
||||
logger.error(f"增强记忆系统维护失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行维护任务时发生异常: {e}", exc_info=True)
|
||||
|
||||
def get_interval(self) -> int:
|
||||
"""获取执行间隔"""
|
||||
return self.interval
|
||||
|
||||
def get_task_name(self) -> str:
|
||||
"""获取任务名称"""
|
||||
return self.task_name
|
||||
|
||||
|
||||
# 全局钩子实例
|
||||
_memory_hooks: MemoryIntegrationHooks | None = None
|
||||
|
||||
|
||||
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
|
||||
"""获取全局记忆集成钩子实例"""
|
||||
global _memory_hooks
|
||||
|
||||
if _memory_hooks is None:
|
||||
_memory_hooks = MemoryIntegrationHooks()
|
||||
await _memory_hooks.register_hooks()
|
||||
|
||||
return _memory_hooks
|
||||
|
||||
|
||||
async def initialize_memory_integration_hooks():
|
||||
"""初始化记忆集成钩子"""
|
||||
try:
|
||||
logger.info("🚀 初始化记忆集成钩子...")
|
||||
hooks = await get_memory_integration_hooks()
|
||||
logger.info("✅ 记忆集成钩子初始化完成")
|
||||
return hooks
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,875 +0,0 @@
|
||||
"""
|
||||
向量数据库存储接口
|
||||
为记忆系统提供高效的向量存储和语义搜索能力
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 尝试导入FAISS,如果不可用则使用简单替代
|
||||
try:
|
||||
import faiss
|
||||
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
logger.warning("FAISS not available, using simple vector storage")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStorageConfig:
|
||||
"""向量存储配置"""
|
||||
|
||||
dimension: int = 1024
|
||||
similarity_threshold: float = 0.8
|
||||
index_type: str = "flat" # flat, ivf, hnsw
|
||||
max_index_size: int = 100000
|
||||
storage_path: str = "data/memory_vectors"
|
||||
auto_save_interval: int = 10 # 每N次操作自动保存
|
||||
enable_compression: bool = True
|
||||
|
||||
|
||||
class VectorStorageManager:
|
||||
"""向量存储管理器"""
|
||||
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
|
||||
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
|
||||
if resolved_dimension and resolved_dimension != self.config.dimension:
|
||||
logger.info(
|
||||
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
|
||||
resolved_dimension,
|
||||
self.config.dimension,
|
||||
)
|
||||
self.config.dimension = resolved_dimension
|
||||
self.storage_path = Path(self.config.storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 向量索引
|
||||
self.vector_index = None
|
||||
self.memory_id_to_index = {} # memory_id -> vector index
|
||||
self.index_to_memory_id = {} # vector index -> memory_id
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: dict[str, list[float]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.storage_stats = {
|
||||
"total_vectors": 0,
|
||||
"index_build_time": 0.0,
|
||||
"average_search_time": 0.0,
|
||||
"cache_hit_rate": 0.0,
|
||||
"total_searches": 0,
|
||||
"cache_hits": 0,
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
self._lock = threading.RLock()
|
||||
self._operation_count = 0
|
||||
|
||||
# 初始化索引
|
||||
self._initialize_index()
|
||||
|
||||
# 嵌入模型
|
||||
self.embedding_model: LLMRequest = None
|
||||
|
||||
def _initialize_index(self):
|
||||
"""初始化向量索引"""
|
||||
try:
|
||||
if FAISS_AVAILABLE:
|
||||
if self.config.index_type == "flat":
|
||||
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
|
||||
elif self.config.index_type == "ivf":
|
||||
quantizer = faiss.IndexFlatIP(self.config.dimension)
|
||||
nlist = min(100, max(1, self.config.max_index_size // 1000))
|
||||
self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist)
|
||||
elif self.config.index_type == "hnsw":
|
||||
self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32)
|
||||
self.vector_index.hnsw.efConstruction = 40
|
||||
else:
|
||||
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
|
||||
else:
|
||||
# 简单的向量存储实现
|
||||
self.vector_index = SimpleVectorIndex(self.config.dimension)
|
||||
|
||||
logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量索引初始化失败: {e}")
|
||||
# 回退到简单实现
|
||||
self.vector_index = SimpleVectorIndex(self.config.dimension)
|
||||
|
||||
async def initialize_embedding_model(self):
|
||||
"""初始化嵌入模型"""
|
||||
if self.embedding_model is None:
|
||||
self.embedding_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding, request_type="memory.embedding"
|
||||
)
|
||||
logger.info("✅ 嵌入模型初始化完成")
|
||||
|
||||
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
|
||||
"""生成查询向量,用于记忆召回"""
|
||||
if not query_text:
|
||||
logger.warning("查询文本为空,无法生成向量")
|
||||
return None
|
||||
|
||||
try:
|
||||
await self.initialize_embedding_model()
|
||||
|
||||
logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'")
|
||||
|
||||
embedding, _ = await self.embedding_model.get_embedding(query_text)
|
||||
if not embedding:
|
||||
logger.warning("嵌入模型返回空向量")
|
||||
return None
|
||||
|
||||
logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}")
|
||||
|
||||
if len(embedding) != self.config.dimension:
|
||||
logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding))
|
||||
return None
|
||||
|
||||
normalized_vector = self._normalize_vector(embedding)
|
||||
logger.debug(f"查询向量生成成功,向量范围: [{min(normalized_vector):.4f}, {max(normalized_vector):.4f}]")
|
||||
return normalized_vector
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def store_memories(self, memories: list[MemoryChunk]):
|
||||
"""存储记忆向量"""
|
||||
if not memories:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保嵌入模型已初始化
|
||||
await self.initialize_embedding_model()
|
||||
|
||||
# 批量获取嵌入向量
|
||||
memory_texts = []
|
||||
|
||||
for memory in memories:
|
||||
# 预先缓存记忆,确保后续流程可访问
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
if memory.embedding is None:
|
||||
# 如果没有嵌入向量,需要生成
|
||||
text = self._prepare_embedding_text(memory)
|
||||
memory_texts.append((memory.memory_id, text))
|
||||
else:
|
||||
# 已有嵌入向量,直接使用
|
||||
await self._add_single_memory(memory, memory.embedding)
|
||||
|
||||
# 批量生成缺失的嵌入向量
|
||||
if memory_texts:
|
||||
await self._batch_generate_and_store_embeddings(memory_texts)
|
||||
|
||||
# 自动保存检查
|
||||
self._operation_count += len(memories)
|
||||
if self._operation_count >= self.config.auto_save_interval:
|
||||
await self.save_storage()
|
||||
self._operation_count = 0
|
||||
|
||||
storage_time = time.time() - start_time
|
||||
logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}秒")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量存储失败: {e}", exc_info=True)
|
||||
|
||||
def _prepare_embedding_text(self, memory: MemoryChunk) -> str:
|
||||
"""准备用于嵌入的文本,仅使用自然语言展示内容"""
|
||||
display_text = (memory.display or "").strip()
|
||||
if display_text:
|
||||
return display_text
|
||||
|
||||
fallback_text = (memory.text_content or "").strip()
|
||||
if fallback_text:
|
||||
return fallback_text
|
||||
|
||||
subjects = "、".join(s.strip() for s in memory.subjects if s and isinstance(s, str))
|
||||
predicate = (memory.content.predicate or "").strip()
|
||||
|
||||
obj = memory.content.object
|
||||
if isinstance(obj, dict):
|
||||
object_parts = []
|
||||
for key, value in obj.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, (list, tuple)):
|
||||
preview = "、".join(str(item) for item in value[:3])
|
||||
object_parts.append(f"{key}:{preview}")
|
||||
else:
|
||||
object_parts.append(f"{key}:{value}")
|
||||
object_text = ", ".join(object_parts)
|
||||
else:
|
||||
object_text = str(obj or "").strip()
|
||||
|
||||
composite_parts = [part for part in [subjects, predicate, object_text] if part]
|
||||
if composite_parts:
|
||||
return " ".join(composite_parts)
|
||||
|
||||
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
|
||||
return memory.memory_id
|
||||
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
|
||||
"""批量生成和存储嵌入向量"""
|
||||
if not memory_texts:
|
||||
return
|
||||
|
||||
try:
|
||||
texts = [text for _, text in memory_texts]
|
||||
memory_ids = [memory_id for memory_id, _ in memory_texts]
|
||||
|
||||
# 批量生成嵌入向量
|
||||
embeddings = await self._batch_generate_embeddings(memory_ids, texts)
|
||||
|
||||
# 存储向量和记忆
|
||||
for memory_id, embedding in embeddings.items():
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory:
|
||||
await self._add_single_memory(memory, embedding)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
|
||||
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
|
||||
"""批量生成嵌入向量"""
|
||||
if not texts:
|
||||
return {}
|
||||
|
||||
results: dict[str, list[float]] = {}
|
||||
|
||||
try:
|
||||
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
|
||||
|
||||
async def generate_embedding(memory_id: str, text: str) -> None:
|
||||
async with semaphore:
|
||||
try:
|
||||
embedding, _ = await self.embedding_model.get_embedding(text)
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
results[memory_id] = embedding
|
||||
else:
|
||||
logger.warning(
|
||||
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
|
||||
self.config.dimension,
|
||||
len(embedding) if embedding else 0,
|
||||
memory_id,
|
||||
)
|
||||
results[memory_id] = []
|
||||
except Exception as exc:
|
||||
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
|
||||
results[memory_id] = []
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
for memory_id in memory_ids:
|
||||
results.setdefault(memory_id, [])
|
||||
|
||||
return results
|
||||
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
|
||||
"""添加单个记忆到向量存储"""
|
||||
with self._lock:
|
||||
try:
|
||||
# 规范化向量
|
||||
if embedding:
|
||||
embedding = self._normalize_vector(embedding)
|
||||
|
||||
# 添加到缓存
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
self.vector_cache[memory.memory_id] = embedding
|
||||
|
||||
# 更新记忆的嵌入向量
|
||||
memory.set_embedding(embedding)
|
||||
|
||||
# 添加到向量索引
|
||||
if hasattr(self.vector_index, "add"):
|
||||
# FAISS索引
|
||||
if isinstance(embedding, np.ndarray):
|
||||
vector_array = embedding.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
vector_array = np.array([embedding], dtype="float32")
|
||||
|
||||
# 特殊处理IVF索引
|
||||
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
|
||||
# IVF索引需要先训练
|
||||
logger.debug("训练IVF索引...")
|
||||
self.vector_index.train(vector_array)
|
||||
|
||||
self.vector_index.add(vector_array)
|
||||
index_id = self.vector_index.ntotal - 1
|
||||
|
||||
else:
|
||||
# 简单索引
|
||||
index_id = self.vector_index.add_vector(embedding)
|
||||
|
||||
# 更新映射关系
|
||||
self.memory_id_to_index[memory.memory_id] = index_id
|
||||
self.index_to_memory_id[index_id] = memory.memory_id
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
|
||||
|
||||
def _normalize_vector(self, vector: list[float]) -> list[float]:
|
||||
"""L2归一化向量"""
|
||||
if not vector:
|
||||
return vector
|
||||
|
||||
try:
|
||||
vector_array = np.array(vector, dtype=np.float32)
|
||||
norm = np.linalg.norm(vector_array)
|
||||
if norm == 0:
|
||||
return vector
|
||||
|
||||
normalized = vector_array / norm
|
||||
return normalized.tolist()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"向量归一化失败: {e}")
|
||||
return vector
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_vector: list[float] | None = None,
|
||||
*,
|
||||
query_text: str | None = None,
|
||||
limit: int = 10,
|
||||
scope_id: str | None = None,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}")
|
||||
|
||||
if query_vector is None:
|
||||
if not query_text:
|
||||
logger.warning("查询向量和查询文本都为空")
|
||||
return []
|
||||
|
||||
query_vector = await self.generate_query_embedding(query_text)
|
||||
if not query_vector:
|
||||
logger.warning("查询向量生成失败")
|
||||
return []
|
||||
|
||||
scope_filter: str | None = None
|
||||
if isinstance(scope_id, str):
|
||||
normalized_scope = scope_id.strip().lower()
|
||||
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
|
||||
scope_filter = scope_id
|
||||
elif scope_id:
|
||||
scope_filter = str(scope_id)
|
||||
|
||||
# 规范化查询向量
|
||||
query_vector = self._normalize_vector(query_vector)
|
||||
|
||||
logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}")
|
||||
|
||||
# 检查向量索引状态
|
||||
if not self.vector_index:
|
||||
logger.error("向量索引未初始化")
|
||||
return []
|
||||
|
||||
total_vectors = 0
|
||||
if hasattr(self.vector_index, "ntotal"):
|
||||
total_vectors = self.vector_index.ntotal
|
||||
elif hasattr(self.vector_index, "vectors"):
|
||||
total_vectors = len(self.vector_index.vectors)
|
||||
|
||||
logger.debug(f"向量索引中实际向量数: {total_vectors}")
|
||||
|
||||
if total_vectors == 0:
|
||||
logger.warning("向量索引为空,无法执行搜索")
|
||||
return []
|
||||
|
||||
# 执行向量搜索
|
||||
with self._lock:
|
||||
if hasattr(self.vector_index, "search"):
|
||||
# FAISS索引
|
||||
if isinstance(query_vector, np.ndarray):
|
||||
query_array = query_vector.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
query_array = np.array([query_vector], dtype="float32")
|
||||
|
||||
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
|
||||
# 设置IVF搜索参数
|
||||
nprobe = min(self.vector_index.nlist, 10)
|
||||
self.vector_index.nprobe = nprobe
|
||||
logger.debug(f"IVF搜索参数: nprobe={nprobe}")
|
||||
|
||||
search_limit = min(limit, total_vectors)
|
||||
logger.debug(f"执行FAISS搜索,搜索限制: {search_limit}")
|
||||
|
||||
distances, indices = self.vector_index.search(query_array, search_limit)
|
||||
distances = distances.flatten().tolist()
|
||||
indices = indices.flatten().tolist()
|
||||
|
||||
logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引")
|
||||
else:
|
||||
# 简单索引
|
||||
logger.debug("使用简单向量索引执行搜索")
|
||||
results = self.vector_index.search(query_vector, limit)
|
||||
distances = [score for _, score in results]
|
||||
indices = [idx for idx, _ in results]
|
||||
logger.debug(f"简单索引搜索结果: {len(results)} 个结果")
|
||||
|
||||
# 处理搜索结果
|
||||
results = []
|
||||
valid_results = 0
|
||||
invalid_indices = 0
|
||||
filtered_by_scope = 0
|
||||
|
||||
for distance, index in zip(distances, indices, strict=False):
|
||||
if index == -1: # FAISS的无效索引标记
|
||||
invalid_indices += 1
|
||||
continue
|
||||
|
||||
memory_id = self.index_to_memory_id.get(index)
|
||||
if not memory_id:
|
||||
logger.debug(f"索引 {index} 没有对应的记忆ID")
|
||||
invalid_indices += 1
|
||||
continue
|
||||
|
||||
if scope_filter:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory and str(memory.user_id) != scope_filter:
|
||||
filtered_by_scope += 1
|
||||
continue
|
||||
|
||||
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
|
||||
results.append((memory_id, similarity))
|
||||
valid_results += 1
|
||||
|
||||
logger.debug(
|
||||
f"搜索结果处理: 总距离={len(distances)}, 有效结果={valid_results}, "
|
||||
f"无效索引={invalid_indices}, 作用域过滤={filtered_by_scope}"
|
||||
)
|
||||
|
||||
# 更新统计
|
||||
search_time = time.time() - start_time
|
||||
self.storage_stats["total_searches"] += 1
|
||||
self.storage_stats["average_search_time"] = (
|
||||
self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time
|
||||
) / self.storage_stats["total_searches"]
|
||||
|
||||
final_results = results[:limit]
|
||||
logger.info(
|
||||
f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' "
|
||||
f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果"
|
||||
)
|
||||
|
||||
return final_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""根据ID获取记忆"""
|
||||
# 先检查缓存
|
||||
if memory_id in self.memory_cache:
|
||||
self.storage_stats["cache_hits"] += 1
|
||||
return self.memory_cache[memory_id]
|
||||
|
||||
self.storage_stats["total_searches"] += 1
|
||||
return None
|
||||
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
|
||||
"""更新记忆的嵌入向量"""
|
||||
with self._lock:
|
||||
try:
|
||||
if memory_id not in self.memory_id_to_index:
|
||||
logger.warning(f"记忆 {memory_id} 不存在于向量索引中")
|
||||
return
|
||||
|
||||
# 获取旧索引
|
||||
old_index = self.memory_id_to_index[memory_id]
|
||||
|
||||
# 删除旧向量(如果支持)
|
||||
if hasattr(self.vector_index, "remove_ids"):
|
||||
try:
|
||||
self.vector_index.remove_ids(np.array([old_index]))
|
||||
except:
|
||||
logger.warning("无法删除旧向量,将直接添加新向量")
|
||||
|
||||
# 规范化新向量
|
||||
new_embedding = self._normalize_vector(new_embedding)
|
||||
|
||||
# 添加新向量
|
||||
if hasattr(self.vector_index, "add"):
|
||||
if isinstance(new_embedding, np.ndarray):
|
||||
vector_array = new_embedding.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
vector_array = np.array([new_embedding], dtype="float32")
|
||||
|
||||
self.vector_index.add(vector_array)
|
||||
new_index = self.vector_index.ntotal - 1
|
||||
else:
|
||||
new_index = self.vector_index.add_vector(new_embedding)
|
||||
|
||||
# 更新映射关系
|
||||
self.memory_id_to_index[memory_id] = new_index
|
||||
self.index_to_memory_id[new_index] = memory_id
|
||||
|
||||
# 更新缓存
|
||||
self.vector_cache[memory_id] = new_embedding
|
||||
|
||||
# 更新记忆对象
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory:
|
||||
memory.set_embedding(new_embedding)
|
||||
|
||||
logger.debug(f"更新记忆 {memory_id} 的嵌入向量")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 更新记忆嵌入向量失败: {e}")
|
||||
|
||||
async def delete_memory(self, memory_id: str):
|
||||
"""删除记忆"""
|
||||
with self._lock:
|
||||
try:
|
||||
if memory_id not in self.memory_id_to_index:
|
||||
return
|
||||
|
||||
# 获取索引
|
||||
index = self.memory_id_to_index[memory_id]
|
||||
|
||||
# 从向量索引中删除(如果支持)
|
||||
if hasattr(self.vector_index, "remove_ids"):
|
||||
try:
|
||||
self.vector_index.remove_ids(np.array([index]))
|
||||
except:
|
||||
logger.warning("无法从向量索引中删除,仅从缓存中移除")
|
||||
|
||||
# 删除映射关系
|
||||
del self.memory_id_to_index[memory_id]
|
||||
if index in self.index_to_memory_id:
|
||||
del self.index_to_memory_id[index]
|
||||
|
||||
# 从缓存中删除
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.vector_cache.pop(memory_id, None)
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1)
|
||||
|
||||
logger.debug(f"删除记忆 {memory_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 删除记忆失败: {e}")
|
||||
|
||||
async def save_storage(self):
|
||||
"""保存向量存储到文件"""
|
||||
try:
|
||||
logger.info("正在保存向量存储...")
|
||||
|
||||
# 保存记忆缓存
|
||||
cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()}
|
||||
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
with open(vector_cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
mapping_data = {
|
||||
"memory_id_to_index": {
|
||||
str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items()
|
||||
},
|
||||
"index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()},
|
||||
}
|
||||
with open(mapping_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存FAISS索引(如果可用)
|
||||
if FAISS_AVAILABLE and hasattr(self.vector_index, "save"):
|
||||
index_file = self.storage_path / "vector_index.faiss"
|
||||
faiss.write_index(self.vector_index, str(index_file))
|
||||
|
||||
# 保存统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
with open(stats_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
logger.info("✅ 向量存储保存完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 保存向量存储失败: {e}")
|
||||
|
||||
async def load_storage(self):
|
||||
"""从文件加载向量存储"""
|
||||
try:
|
||||
logger.info("正在加载向量存储...")
|
||||
|
||||
# 加载记忆缓存
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
self.memory_cache = {
|
||||
memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items()
|
||||
}
|
||||
|
||||
# 加载向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
if vector_cache_file.exists():
|
||||
with open(vector_cache_file, encoding="utf-8") as f:
|
||||
self.vector_cache = orjson.loads(f.read())
|
||||
|
||||
# 加载映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
if mapping_file.exists():
|
||||
with open(mapping_file, encoding="utf-8") as f:
|
||||
mapping_data = orjson.loads(f.read())
|
||||
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.memory_id_to_index = {
|
||||
str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items()
|
||||
}
|
||||
|
||||
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
|
||||
self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()}
|
||||
|
||||
# 加载FAISS索引(如果可用)
|
||||
index_loaded = False
|
||||
if FAISS_AVAILABLE:
|
||||
index_file = self.storage_path / "vector_index.faiss"
|
||||
if index_file.exists():
|
||||
try:
|
||||
loaded_index = faiss.read_index(str(index_file))
|
||||
# 如果索引类型匹配,则替换
|
||||
if type(loaded_index) == type(self.vector_index):
|
||||
self.vector_index = loaded_index
|
||||
index_loaded = True
|
||||
logger.info("✅ FAISS索引文件加载完成")
|
||||
else:
|
||||
logger.warning("索引类型不匹配,重新构建索引")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载FAISS索引失败: {e},重新构建")
|
||||
else:
|
||||
logger.info("FAISS索引文件不存在,将重新构建")
|
||||
|
||||
# 如果索引没有成功加载且有向量数据,则重建索引
|
||||
if not index_loaded and self.vector_cache:
|
||||
logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引")
|
||||
await self._rebuild_index()
|
||||
|
||||
# 加载统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, encoding="utf-8") as f:
|
||||
self.storage_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新向量计数
|
||||
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
|
||||
|
||||
logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载向量存储失败: {e}")
|
||||
|
||||
async def _rebuild_index(self):
|
||||
"""重建向量索引"""
|
||||
try:
|
||||
logger.info(f"正在重建向量索引...向量数量: {len(self.vector_cache)}")
|
||||
|
||||
# 重新初始化索引
|
||||
self._initialize_index()
|
||||
|
||||
# 清空映射关系
|
||||
self.memory_id_to_index.clear()
|
||||
self.index_to_memory_id.clear()
|
||||
|
||||
if not self.vector_cache:
|
||||
logger.warning("没有向量缓存数据,跳过重建")
|
||||
return
|
||||
|
||||
# 准备向量数据
|
||||
memory_ids = []
|
||||
vectors = []
|
||||
|
||||
for memory_id, embedding in self.vector_cache.items():
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
memory_ids.append(memory_id)
|
||||
vectors.append(self._normalize_vector(embedding))
|
||||
else:
|
||||
logger.debug(f"跳过无效向量: {memory_id}, 维度: {len(embedding) if embedding else 0}")
|
||||
|
||||
if not vectors:
|
||||
logger.warning("没有有效的向量数据")
|
||||
return
|
||||
|
||||
logger.info(f"准备重建 {len(vectors)} 个向量到索引")
|
||||
|
||||
# 批量添加向量到FAISS索引
|
||||
if hasattr(self.vector_index, "add"):
|
||||
# FAISS索引
|
||||
vector_array = np.array(vectors, dtype="float32")
|
||||
|
||||
# 特殊处理IVF索引
|
||||
if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"):
|
||||
logger.info("训练IVF索引...")
|
||||
self.vector_index.train(vector_array)
|
||||
|
||||
# 添加向量
|
||||
self.vector_index.add(vector_array)
|
||||
|
||||
# 重建映射关系
|
||||
for i, memory_id in enumerate(memory_ids):
|
||||
self.memory_id_to_index[memory_id] = i
|
||||
self.index_to_memory_id[i] = memory_id
|
||||
|
||||
else:
|
||||
# 简单索引
|
||||
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)):
|
||||
index_id = self.vector_index.add_vector(vector)
|
||||
self.memory_id_to_index[memory_id] = index_id
|
||||
self.index_to_memory_id[index_id] = memory_id
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
|
||||
|
||||
final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index))
|
||||
logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重建向量索引失败: {e}", exc_info=True)
|
||||
|
||||
async def optimize_storage(self):
|
||||
"""优化存储"""
|
||||
try:
|
||||
logger.info("开始向量存储优化...")
|
||||
|
||||
# 清理无效引用
|
||||
self._cleanup_invalid_references()
|
||||
|
||||
# 重新构建索引(如果碎片化严重)
|
||||
if self.storage_stats["total_vectors"] > 1000:
|
||||
await self._rebuild_index()
|
||||
|
||||
# 更新缓存命中率
|
||||
if self.storage_stats["total_searches"] > 0:
|
||||
self.storage_stats["cache_hit_rate"] = (
|
||||
self.storage_stats["cache_hits"] / self.storage_stats["total_searches"]
|
||||
)
|
||||
|
||||
logger.info("✅ 向量存储优化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量存储优化失败: {e}")
|
||||
|
||||
def _cleanup_invalid_references(self):
|
||||
"""清理无效引用"""
|
||||
with self._lock:
|
||||
# 清理无效的memory_id到index的映射
|
||||
valid_memory_ids = set(self.memory_cache.keys())
|
||||
invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids
|
||||
|
||||
for memory_id in invalid_memory_ids:
|
||||
index = self.memory_id_to_index[memory_id]
|
||||
del self.memory_id_to_index[memory_id]
|
||||
if index in self.index_to_memory_id:
|
||||
del self.index_to_memory_id[index]
|
||||
|
||||
if invalid_memory_ids:
|
||||
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
|
||||
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
stats = self.storage_stats.copy()
|
||||
if stats["total_searches"] > 0:
|
||||
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"]
|
||||
else:
|
||||
stats["cache_hit_rate"] = 0.0
|
||||
return stats
|
||||
|
||||
|
||||
class SimpleVectorIndex:
|
||||
"""简单的向量索引实现(当FAISS不可用时的替代方案)"""
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
self.dimension = dimension
|
||||
self.vectors: list[list[float]] = []
|
||||
self.vector_ids: list[int] = []
|
||||
self.next_id = 0
|
||||
|
||||
def add_vector(self, vector: list[float]) -> int:
|
||||
"""添加向量"""
|
||||
if len(vector) != self.dimension:
|
||||
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
|
||||
|
||||
vector_id = self.next_id
|
||||
self.vectors.append(vector.copy())
|
||||
self.vector_ids.append(vector_id)
|
||||
self.next_id += 1
|
||||
|
||||
return vector_id
|
||||
|
||||
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
|
||||
"""搜索相似向量"""
|
||||
if len(query_vector) != self.dimension:
|
||||
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
|
||||
|
||||
results = []
|
||||
|
||||
for i, vector in enumerate(self.vectors):
|
||||
similarity = self._calculate_cosine_similarity(query_vector, vector)
|
||||
results.append((self.vector_ids[i], similarity))
|
||||
|
||||
# 按相似度排序
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
|
||||
norm1 = sum(x * x for x in v1) ** 0.5
|
||||
norm2 = sum(x * x for x in v2) ** 0.5
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def ntotal(self) -> int:
|
||||
"""向量总数"""
|
||||
return len(self.vectors)
|
||||
731
src/chat/memory_system/hippocampus_sampler.py
Normal file
731
src/chat/memory_system/hippocampus_sampler.py
Normal file
@@ -0,0 +1,731 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
海马体双峰分布采样器
|
||||
基于旧版海马体的采样策略,适配新版记忆系统
|
||||
实现低消耗、高效率的记忆采样模式
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HippocampusSampleConfig:
|
||||
"""海马体采样配置"""
|
||||
# 双峰分布参数
|
||||
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
|
||||
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
||||
recent_weight: float = 0.7 # 近期分布权重
|
||||
|
||||
distant_mean_hours: float = 48.0 # 远期分布均值(小时)
|
||||
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
|
||||
distant_weight: float = 0.3 # 远期分布权重
|
||||
|
||||
# 采样参数
|
||||
total_samples: int = 50 # 总采样数
|
||||
sample_interval: int = 1800 # 采样间隔(秒)
|
||||
max_sample_length: int = 30 # 每次采样的最大消息数量
|
||||
batch_size: int = 5 # 批处理大小
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls) -> 'HippocampusSampleConfig':
|
||||
"""从全局配置创建海马体采样配置"""
|
||||
config = global_config.memory.hippocampus_distribution_config
|
||||
return cls(
|
||||
recent_mean_hours=config[0],
|
||||
recent_std_hours=config[1],
|
||||
recent_weight=config[2],
|
||||
distant_mean_hours=config[3],
|
||||
distant_std_hours=config[4],
|
||||
distant_weight=config[5],
|
||||
total_samples=global_config.memory.hippocampus_sample_size,
|
||||
sample_interval=global_config.memory.hippocampus_sample_interval,
|
||||
max_sample_length=global_config.memory.hippocampus_batch_size,
|
||||
batch_size=global_config.memory.hippocampus_batch_size,
|
||||
)
|
||||
|
||||
|
||||
class HippocampusSampler:
|
||||
"""海马体双峰分布采样器"""
|
||||
|
||||
def __init__(self, memory_system=None):
|
||||
self.memory_system = memory_system
|
||||
self.config = HippocampusSampleConfig.from_global_config()
|
||||
self.last_sample_time = 0
|
||||
self.is_running = False
|
||||
|
||||
# 记忆构建模型
|
||||
self.memory_builder_model: Optional[LLMRequest] = None
|
||||
|
||||
# 统计信息
|
||||
self.sample_count = 0
|
||||
self.success_count = 0
|
||||
self.last_sample_results: List[Dict[str, Any]] = []
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化采样器"""
|
||||
try:
|
||||
# 初始化LLM模型
|
||||
from src.config.config import model_config
|
||||
task_config = getattr(model_config.model_task_config, "utils", None)
|
||||
if task_config:
|
||||
self.memory_builder_model = LLMRequest(
|
||||
model_set=task_config,
|
||||
request_type="memory.hippocampus_build"
|
||||
)
|
||||
asyncio.create_task(self.start_background_sampling())
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
else:
|
||||
raise RuntimeError("未找到记忆构建模型配置")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 海马体采样器初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def generate_time_samples(self) -> List[datetime]:
|
||||
"""生成双峰分布的时间采样点"""
|
||||
# 计算每个分布的样本数
|
||||
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
|
||||
distant_samples = max(1, self.config.total_samples - recent_samples)
|
||||
|
||||
# 生成两个正态分布的小时偏移
|
||||
recent_offsets = np.random.normal(
|
||||
loc=self.config.recent_mean_hours,
|
||||
scale=self.config.recent_std_hours,
|
||||
size=recent_samples
|
||||
)
|
||||
distant_offsets = np.random.normal(
|
||||
loc=self.config.distant_mean_hours,
|
||||
scale=self.config.distant_std_hours,
|
||||
size=distant_samples
|
||||
)
|
||||
|
||||
# 合并两个分布的偏移
|
||||
all_offsets = np.concatenate([recent_offsets, distant_offsets])
|
||||
|
||||
# 转换为时间戳(使用绝对值确保时间点在过去)
|
||||
base_time = datetime.now()
|
||||
timestamps = [
|
||||
base_time - timedelta(hours=abs(offset))
|
||||
for offset in all_offsets
|
||||
]
|
||||
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
|
||||
async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]:
|
||||
"""收集指定时间戳附近的消息样本"""
|
||||
try:
|
||||
# 随机时间窗口:5-30分钟
|
||||
time_window_seconds = random.randint(300, 1800)
|
||||
|
||||
# 尝试3次获取消息
|
||||
for attempt in range(3):
|
||||
timestamp_start = target_timestamp
|
||||
timestamp_end = target_timestamp + time_window_seconds
|
||||
|
||||
# 获取单条消息作为锚点
|
||||
anchor_messages = await get_raw_msg_by_timestamp(
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
limit=1,
|
||||
limit_mode="earliest",
|
||||
)
|
||||
|
||||
if not anchor_messages:
|
||||
target_timestamp -= 120 # 向前调整2分钟
|
||||
continue
|
||||
|
||||
anchor_message = anchor_messages[0]
|
||||
chat_id = anchor_message.get("chat_id")
|
||||
|
||||
if not chat_id:
|
||||
continue
|
||||
|
||||
# 获取同聊天的多条消息
|
||||
messages = await get_raw_msg_by_timestamp_with_chat(
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
limit=self.config.max_sample_length,
|
||||
limit_mode="earliest",
|
||||
chat_id=chat_id,
|
||||
)
|
||||
|
||||
if messages and len(messages) >= 2: # 至少需要2条消息
|
||||
# 过滤掉已经记忆过的消息
|
||||
filtered_messages = [
|
||||
msg for msg in messages
|
||||
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
|
||||
]
|
||||
|
||||
if filtered_messages:
|
||||
logger.debug(f"成功收集 {len(filtered_messages)} 条消息样本")
|
||||
return filtered_messages
|
||||
|
||||
target_timestamp -= 120 # 向前调整再试
|
||||
|
||||
logger.debug(f"时间戳 {target_timestamp} 附近未找到有效消息样本")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"收集消息样本失败: {e}")
|
||||
return None
|
||||
|
||||
async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]:
|
||||
"""从消息样本构建记忆"""
|
||||
if not messages or not self.memory_system or not self.memory_builder_model:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 构建可读消息文本
|
||||
readable_text = await build_readable_messages(
|
||||
messages,
|
||||
merge_messages=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
replace_bot_name=False,
|
||||
)
|
||||
|
||||
if not readable_text:
|
||||
logger.warning("无法从消息样本生成可读文本")
|
||||
return None
|
||||
|
||||
# 添加当前日期信息
|
||||
current_date = f"当前日期: {datetime.now().isoformat()}"
|
||||
input_text = f"{current_date}\n{readable_text}"
|
||||
|
||||
logger.debug(f"开始构建记忆,文本长度: {len(input_text)}")
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"user_id": "hippocampus_sampler",
|
||||
"timestamp": time.time(),
|
||||
"source": "hippocampus_sampling",
|
||||
"message_count": len(messages),
|
||||
"sample_mode": "bimodal_distribution",
|
||||
"is_hippocampus_sample": True, # 标识为海马体样本
|
||||
"bypass_value_threshold": True, # 绕过价值阈值检查
|
||||
"hippocampus_sample_time": target_timestamp, # 记录样本时间
|
||||
}
|
||||
|
||||
# 使用记忆系统构建记忆(绕过构建间隔检查)
|
||||
memories = await self.memory_system.build_memory_from_conversation(
|
||||
conversation_text=input_text,
|
||||
context=context,
|
||||
timestamp=time.time(),
|
||||
bypass_interval=True # 海马体采样器绕过构建间隔限制
|
||||
)
|
||||
|
||||
if memories:
|
||||
memory_count = len(memories)
|
||||
self.success_count += 1
|
||||
|
||||
# 记录采样结果
|
||||
result = {
|
||||
"timestamp": time.time(),
|
||||
"memory_count": memory_count,
|
||||
"message_count": len(messages),
|
||||
"text_preview": readable_text[:100] + "..." if len(readable_text) > 100 else readable_text,
|
||||
"memory_types": [m.memory_type.value for m in memories],
|
||||
}
|
||||
self.last_sample_results.append(result)
|
||||
|
||||
# 限制结果历史长度
|
||||
if len(self.last_sample_results) > 10:
|
||||
self.last_sample_results.pop(0)
|
||||
|
||||
logger.info(f"✅ 海马体采样成功构建 {memory_count} 条记忆")
|
||||
return f"构建{memory_count}条记忆"
|
||||
else:
|
||||
logger.debug("海马体采样未生成有效记忆")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"海马体采样构建记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def perform_sampling_cycle(self) -> Dict[str, Any]:
|
||||
"""执行一次完整的采样周期(优化版:批量融合构建)"""
|
||||
if not self.should_sample():
|
||||
return {"status": "skipped", "reason": "interval_not_met"}
|
||||
|
||||
start_time = time.time()
|
||||
self.sample_count += 1
|
||||
|
||||
try:
|
||||
# 生成时间采样点
|
||||
time_samples = self.generate_time_samples()
|
||||
logger.debug(f"生成 {len(time_samples)} 个时间采样点")
|
||||
|
||||
# 记录时间采样点(调试用)
|
||||
readable_timestamps = [
|
||||
translate_timestamp_to_human_readable(int(ts.timestamp()), mode="normal")
|
||||
for ts in time_samples[:5] # 只显示前5个
|
||||
]
|
||||
logger.debug(f"时间采样点示例: {readable_timestamps}")
|
||||
|
||||
# 第一步:批量收集所有消息样本
|
||||
logger.debug("开始批量收集消息样本...")
|
||||
collected_messages = await self._collect_all_message_samples(time_samples)
|
||||
|
||||
if not collected_messages:
|
||||
logger.info("未收集到有效消息样本,跳过本次采样")
|
||||
self.last_sample_time = time.time()
|
||||
return {
|
||||
"status": "success",
|
||||
"sample_count": self.sample_count,
|
||||
"success_count": self.success_count,
|
||||
"processed_samples": len(time_samples),
|
||||
"successful_builds": 0,
|
||||
"duration": time.time() - start_time,
|
||||
"samples_generated": len(time_samples),
|
||||
"message": "未收集到有效消息样本",
|
||||
}
|
||||
|
||||
logger.info(f"收集到 {len(collected_messages)} 组消息样本")
|
||||
|
||||
# 第二步:融合和去重消息
|
||||
logger.debug("开始融合和去重消息...")
|
||||
fused_messages = await self._fuse_and_deduplicate_messages(collected_messages)
|
||||
|
||||
if not fused_messages:
|
||||
logger.info("消息融合后为空,跳过记忆构建")
|
||||
self.last_sample_time = time.time()
|
||||
return {
|
||||
"status": "success",
|
||||
"sample_count": self.sample_count,
|
||||
"success_count": self.success_count,
|
||||
"processed_samples": len(time_samples),
|
||||
"successful_builds": 0,
|
||||
"duration": time.time() - start_time,
|
||||
"samples_generated": len(time_samples),
|
||||
"message": "消息融合后为空",
|
||||
}
|
||||
|
||||
logger.info(f"融合后得到 {len(fused_messages)} 组有效消息")
|
||||
|
||||
# 第三步:一次性构建记忆
|
||||
logger.debug("开始批量构建记忆...")
|
||||
build_result = await self._build_batch_memory(fused_messages, time_samples)
|
||||
|
||||
# 更新最后采样时间
|
||||
self.last_sample_time = time.time()
|
||||
|
||||
duration = time.time() - start_time
|
||||
result = {
|
||||
"status": "success",
|
||||
"sample_count": self.sample_count,
|
||||
"success_count": self.success_count,
|
||||
"processed_samples": len(time_samples),
|
||||
"successful_builds": build_result.get("memory_count", 0),
|
||||
"duration": duration,
|
||||
"samples_generated": len(time_samples),
|
||||
"messages_collected": len(collected_messages),
|
||||
"messages_fused": len(fused_messages),
|
||||
"optimization_mode": "batch_fusion",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"✅ 海马体采样周期完成(批量融合模式) | "
|
||||
f"采样点: {len(time_samples)} | "
|
||||
f"收集消息: {len(collected_messages)} | "
|
||||
f"融合消息: {len(fused_messages)} | "
|
||||
f"构建记忆: {build_result.get('memory_count', 0)} | "
|
||||
f"耗时: {duration:.2f}s"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 海马体采样周期失败: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"sample_count": self.sample_count,
|
||||
"duration": time.time() - start_time,
|
||||
}
|
||||
|
||||
async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]:
|
||||
"""批量收集所有时间点的消息样本"""
|
||||
collected_messages = []
|
||||
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
||||
|
||||
for i in range(0, len(time_samples), max_concurrent):
|
||||
batch = time_samples[i:i + max_concurrent]
|
||||
tasks = []
|
||||
|
||||
# 创建并发收集任务
|
||||
for timestamp in batch:
|
||||
target_ts = timestamp.timestamp()
|
||||
task = self.collect_message_samples(target_ts)
|
||||
tasks.append(task)
|
||||
|
||||
# 执行并发收集
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理收集结果
|
||||
for result in results:
|
||||
if isinstance(result, list) and result:
|
||||
collected_messages.append(result)
|
||||
elif isinstance(result, Exception):
|
||||
logger.debug(f"消息收集异常: {result}")
|
||||
|
||||
# 批次间短暂延迟
|
||||
if i + max_concurrent < len(time_samples):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return collected_messages
|
||||
|
||||
async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
|
||||
"""融合和去重消息样本"""
|
||||
if not collected_messages:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 展平所有消息
|
||||
all_messages = []
|
||||
for message_group in collected_messages:
|
||||
all_messages.extend(message_group)
|
||||
|
||||
logger.debug(f"展开后总消息数: {len(all_messages)}")
|
||||
|
||||
# 去重逻辑:基于消息内容和时间戳
|
||||
unique_messages = []
|
||||
seen_hashes = set()
|
||||
|
||||
for message in all_messages:
|
||||
# 创建消息哈希用于去重
|
||||
content = message.get("processed_plain_text", "") or message.get("display_message", "")
|
||||
timestamp = message.get("time", 0)
|
||||
chat_id = message.get("chat_id", "")
|
||||
|
||||
# 简单哈希:内容前50字符 + 时间戳(精确到分钟) + 聊天ID
|
||||
hash_key = f"{content[:50]}_{int(timestamp//60)}_{chat_id}"
|
||||
|
||||
if hash_key not in seen_hashes and len(content.strip()) > 10:
|
||||
seen_hashes.add(hash_key)
|
||||
unique_messages.append(message)
|
||||
|
||||
logger.debug(f"去重后消息数: {len(unique_messages)}")
|
||||
|
||||
# 按时间排序
|
||||
unique_messages.sort(key=lambda x: x.get("time", 0))
|
||||
|
||||
# 按聊天ID分组重新组织
|
||||
chat_groups = {}
|
||||
for message in unique_messages:
|
||||
chat_id = message.get("chat_id", "unknown")
|
||||
if chat_id not in chat_groups:
|
||||
chat_groups[chat_id] = []
|
||||
chat_groups[chat_id].append(message)
|
||||
|
||||
# 合并相邻时间范围内的消息
|
||||
fused_groups = []
|
||||
for chat_id, messages in chat_groups.items():
|
||||
fused_groups.extend(self._merge_adjacent_messages(messages))
|
||||
|
||||
logger.debug(f"融合后消息组数: {len(fused_groups)}")
|
||||
return fused_groups
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息融合失败: {e}")
|
||||
# 返回原始消息组作为备选
|
||||
return collected_messages[:5] # 限制返回数量
|
||||
|
||||
def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]:
|
||||
"""合并时间间隔内的消息"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged_groups = []
|
||||
current_group = [messages[0]]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
current_time = messages[i].get("time", 0)
|
||||
prev_time = current_group[-1].get("time", 0)
|
||||
|
||||
# 如果时间间隔小于阈值,合并到当前组
|
||||
if current_time - prev_time <= time_gap:
|
||||
current_group.append(messages[i])
|
||||
else:
|
||||
# 否则开始新组
|
||||
merged_groups.append(current_group)
|
||||
current_group = [messages[i]]
|
||||
|
||||
# 添加最后一组
|
||||
merged_groups.append(current_group)
|
||||
|
||||
# 过滤掉只有一条消息的组(除非内容较长)
|
||||
result_groups = []
|
||||
for group in merged_groups:
|
||||
if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group):
|
||||
result_groups.append(group)
|
||||
|
||||
return result_groups
|
||||
|
||||
async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]:
|
||||
"""批量构建记忆"""
|
||||
if not fused_messages:
|
||||
return {"memory_count": 0, "memories": []}
|
||||
|
||||
try:
|
||||
total_memories = []
|
||||
total_memory_count = 0
|
||||
|
||||
# 构建融合后的文本
|
||||
batch_input_text = await self._build_fused_conversation_text(fused_messages)
|
||||
|
||||
if not batch_input_text:
|
||||
logger.warning("无法构建融合文本,尝试单独构建")
|
||||
# 备选方案:分别构建
|
||||
return await self._fallback_individual_build(fused_messages)
|
||||
|
||||
# 创建批量上下文
|
||||
batch_context = {
|
||||
"user_id": "hippocampus_batch_sampler",
|
||||
"timestamp": time.time(),
|
||||
"source": "hippocampus_batch_sampling",
|
||||
"message_groups_count": len(fused_messages),
|
||||
"total_messages": sum(len(group) for group in fused_messages),
|
||||
"sample_count": len(time_samples),
|
||||
"is_hippocampus_sample": True,
|
||||
"bypass_value_threshold": True,
|
||||
"optimization_mode": "batch_fusion",
|
||||
}
|
||||
|
||||
logger.debug(f"批量构建记忆,文本长度: {len(batch_input_text)}")
|
||||
|
||||
# 一次性构建记忆
|
||||
memories = await self.memory_system.build_memory_from_conversation(
|
||||
conversation_text=batch_input_text,
|
||||
context=batch_context,
|
||||
timestamp=time.time(),
|
||||
bypass_interval=True
|
||||
)
|
||||
|
||||
if memories:
|
||||
memory_count = len(memories)
|
||||
self.success_count += 1
|
||||
total_memory_count += memory_count
|
||||
total_memories.extend(memories)
|
||||
|
||||
logger.info(f"✅ 批量海马体采样成功构建 {memory_count} 条记忆")
|
||||
else:
|
||||
logger.debug("批量海马体采样未生成有效记忆")
|
||||
|
||||
# 记录采样结果
|
||||
result = {
|
||||
"timestamp": time.time(),
|
||||
"memory_count": total_memory_count,
|
||||
"message_groups_count": len(fused_messages),
|
||||
"total_messages": sum(len(group) for group in fused_messages),
|
||||
"text_preview": batch_input_text[:200] + "..." if len(batch_input_text) > 200 else batch_input_text,
|
||||
"memory_types": [m.memory_type.value for m in total_memories],
|
||||
}
|
||||
|
||||
self.last_sample_results.append(result)
|
||||
|
||||
# 限制结果历史长度
|
||||
if len(self.last_sample_results) > 10:
|
||||
self.last_sample_results.pop(0)
|
||||
|
||||
return {
|
||||
"memory_count": total_memory_count,
|
||||
"memories": total_memories,
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量构建记忆失败: {e}")
|
||||
return {"memory_count": 0, "error": str(e)}
|
||||
|
||||
async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str:
|
||||
"""构建融合后的对话文本"""
|
||||
try:
|
||||
# 添加批次标识
|
||||
current_date = f"海马体批量采样 - {datetime.now().isoformat()}\n"
|
||||
conversation_parts = [current_date]
|
||||
|
||||
for group_idx, message_group in enumerate(fused_messages):
|
||||
if not message_group:
|
||||
continue
|
||||
|
||||
# 为每个消息组添加分隔符
|
||||
group_header = f"\n=== 对话片段 {group_idx + 1} ==="
|
||||
conversation_parts.append(group_header)
|
||||
|
||||
# 构建可读消息
|
||||
group_text = await build_readable_messages(
|
||||
message_group,
|
||||
merge_messages=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
replace_bot_name=False,
|
||||
)
|
||||
|
||||
if group_text and len(group_text.strip()) > 10:
|
||||
conversation_parts.append(group_text.strip())
|
||||
|
||||
return "\n".join(conversation_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建融合文本失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
"""备选方案:单独构建每个消息组"""
|
||||
total_memories = []
|
||||
total_count = 0
|
||||
|
||||
for group in fused_messages[:5]: # 限制最多5组
|
||||
try:
|
||||
memories = await self.build_memory_from_samples(group, time.time())
|
||||
if memories:
|
||||
total_memories.extend(memories)
|
||||
total_count += len(memories)
|
||||
except Exception as e:
|
||||
logger.debug(f"单独构建失败: {e}")
|
||||
|
||||
return {
|
||||
"memory_count": total_count,
|
||||
"memories": total_memories,
|
||||
"fallback_mode": True
|
||||
}
|
||||
|
||||
async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]:
|
||||
"""处理单个时间戳采样(保留作为备选方法)"""
|
||||
try:
|
||||
# 收集消息样本
|
||||
messages = await self.collect_message_samples(target_timestamp)
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 构建记忆
|
||||
result = await self.build_memory_from_samples(messages, target_timestamp)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"处理时间戳采样失败 {target_timestamp}: {e}")
|
||||
return None
|
||||
|
||||
def should_sample(self) -> bool:
|
||||
"""检查是否应该进行采样"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查时间间隔
|
||||
if current_time - self.last_sample_time < self.config.sample_interval:
|
||||
return False
|
||||
|
||||
# 检查是否已初始化
|
||||
if not self.memory_builder_model:
|
||||
logger.warning("海马体采样器未初始化")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def start_background_sampling(self):
|
||||
"""启动后台采样"""
|
||||
if self.is_running:
|
||||
logger.warning("海马体后台采样已在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
logger.info("🚀 启动海马体后台采样任务")
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
try:
|
||||
# 执行采样周期
|
||||
result = await self.perform_sampling_cycle()
|
||||
|
||||
# 如果是跳过状态,短暂睡眠
|
||||
if result.get("status") == "skipped":
|
||||
await asyncio.sleep(60) # 1分钟后重试
|
||||
else:
|
||||
# 正常等待下一个采样间隔
|
||||
await asyncio.sleep(self.config.sample_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"海马体后台采样异常: {e}")
|
||||
await asyncio.sleep(300) # 异常时等待5分钟
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("海马体后台采样任务被取消")
|
||||
finally:
|
||||
self.is_running = False
|
||||
|
||||
def stop_background_sampling(self):
|
||||
"""停止后台采样"""
|
||||
self.is_running = False
|
||||
logger.info("🛑 停止海马体后台采样任务")
|
||||
|
||||
def get_sampling_stats(self) -> Dict[str, Any]:
|
||||
"""获取采样统计信息"""
|
||||
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
|
||||
|
||||
# 计算最近的平均数据
|
||||
recent_avg_messages = 0
|
||||
recent_avg_memory_count = 0
|
||||
if self.last_sample_results:
|
||||
recent_results = self.last_sample_results[-5:] # 最近5次
|
||||
recent_avg_messages = sum(r.get("total_messages", 0) for r in recent_results) / len(recent_results)
|
||||
recent_avg_memory_count = sum(r.get("memory_count", 0) for r in recent_results) / len(recent_results)
|
||||
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"sample_count": self.sample_count,
|
||||
"success_count": self.success_count,
|
||||
"success_rate": f"{success_rate:.1f}%",
|
||||
"last_sample_time": self.last_sample_time,
|
||||
"optimization_mode": "batch_fusion", # 显示优化模式
|
||||
"performance_metrics": {
|
||||
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
|
||||
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
|
||||
"fusion_efficiency": f"{(recent_avg_messages/max(recent_avg_memory_count, 1)):.1f}x" if recent_avg_messages > 0 else "N/A"
|
||||
},
|
||||
"config": {
|
||||
"sample_interval": self.config.sample_interval,
|
||||
"total_samples": self.config.total_samples,
|
||||
"recent_weight": f"{self.config.recent_weight:.1%}",
|
||||
"distant_weight": f"{self.config.distant_weight:.1%}",
|
||||
"max_concurrent": 5, # 批量模式并发数
|
||||
"fusion_time_gap": "30分钟", # 消息融合时间间隔
|
||||
},
|
||||
"recent_results": self.last_sample_results[-5:], # 最近5次结果
|
||||
}
|
||||
|
||||
|
||||
# 全局海马体采样器实例
|
||||
_hippocampus_sampler: Optional[HippocampusSampler] = None
|
||||
|
||||
|
||||
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
|
||||
"""获取全局海马体采样器实例"""
|
||||
global _hippocampus_sampler
|
||||
if _hippocampus_sampler is None:
|
||||
_hippocampus_sampler = HippocampusSampler(memory_system)
|
||||
return _hippocampus_sampler
|
||||
|
||||
|
||||
async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
|
||||
"""初始化全局海马体采样器"""
|
||||
sampler = get_hippocampus_sampler(memory_system)
|
||||
await sampler.initialize()
|
||||
return sampler
|
||||
@@ -19,6 +19,12 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
# 简化的记忆采样模式枚举
|
||||
class MemorySamplingMode(Enum):
|
||||
"""记忆采样模式"""
|
||||
HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样
|
||||
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
||||
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -148,6 +154,9 @@ class MemorySystem:
|
||||
# 记忆指纹缓存,用于快速检测重复记忆
|
||||
self._memory_fingerprints: dict[str, str] = {}
|
||||
|
||||
# 海马体采样器
|
||||
self.hippocampus_sampler = None
|
||||
|
||||
logger.info("MemorySystem 初始化开始")
|
||||
|
||||
async def initialize(self):
|
||||
@@ -249,6 +258,16 @@ class MemorySystem:
|
||||
|
||||
self.query_planner = MemoryQueryPlanner(planner_model, default_limit=self.config.final_recall_limit)
|
||||
|
||||
# 初始化海马体采样器
|
||||
if global_config.memory.enable_hippocampus_sampling:
|
||||
try:
|
||||
from .hippocampus_sampler import initialize_hippocampus_sampler
|
||||
self.hippocampus_sampler = await initialize_hippocampus_sampler(self)
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"海马体采样器初始化失败: {e}")
|
||||
self.hippocampus_sampler = None
|
||||
|
||||
# 统一存储已经自动加载数据,无需额外加载
|
||||
logger.info("✅ 简化版记忆系统初始化完成")
|
||||
|
||||
@@ -283,14 +302,14 @@ class MemorySystem:
|
||||
|
||||
try:
|
||||
# 使用统一存储检索相似记忆
|
||||
filters = {"user_id": user_id} if user_id else None
|
||||
search_results = await self.unified_storage.search_similar_memories(
|
||||
query_text=query_text, limit=limit, scope_id=user_id
|
||||
query_text=query_text, limit=limit, filters=filters
|
||||
)
|
||||
|
||||
# 转换为记忆对象
|
||||
memories = []
|
||||
for memory_id, similarity_score in search_results:
|
||||
memory = self.unified_storage.get_memory_by_id(memory_id)
|
||||
for memory, similarity_score in search_results:
|
||||
if memory:
|
||||
memory.update_access() # 更新访问信息
|
||||
memories.append(memory)
|
||||
@@ -302,7 +321,7 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None
|
||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None, bypass_interval: bool = False
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
@@ -310,6 +329,7 @@ class MemorySystem:
|
||||
conversation_text: 对话文本
|
||||
context: 上下文信息
|
||||
timestamp: 时间戳,默认为当前时间
|
||||
bypass_interval: 是否绕过构建间隔检查(海马体采样器专用)
|
||||
|
||||
Returns:
|
||||
构建的记忆块列表
|
||||
@@ -328,7 +348,8 @@ class MemorySystem:
|
||||
min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0))
|
||||
current_time = time.time()
|
||||
|
||||
if build_scope_key and min_interval > 0:
|
||||
# 构建间隔检查(海马体采样器可以绕过)
|
||||
if build_scope_key and min_interval > 0 and not bypass_interval:
|
||||
last_time = self._last_memory_build_times.get(build_scope_key)
|
||||
if last_time and (current_time - last_time) < min_interval:
|
||||
remaining = min_interval - (current_time - last_time)
|
||||
@@ -340,18 +361,35 @@ class MemorySystem:
|
||||
|
||||
build_marker_time = current_time
|
||||
self._last_memory_build_times[build_scope_key] = current_time
|
||||
elif bypass_interval:
|
||||
# 海马体采样模式:不更新构建时间记录,避免影响即时模式
|
||||
logger.debug("海马体采样模式:绕过构建间隔检查")
|
||||
|
||||
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
|
||||
|
||||
logger.debug("开始构建记忆,文本长度: %d", len(conversation_text))
|
||||
|
||||
# 1. 信息价值评估
|
||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||
# 1. 信息价值评估(海马体采样器可以绕过)
|
||||
if not bypass_interval and not context.get("bypass_value_threshold", False):
|
||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||
|
||||
if value_score < self.config.memory_value_threshold:
|
||||
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
|
||||
self.status = original_status
|
||||
return []
|
||||
if value_score < self.config.memory_value_threshold:
|
||||
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
|
||||
self.status = original_status
|
||||
return []
|
||||
else:
|
||||
# 海马体采样器:使用默认价值分数或简单评估
|
||||
value_score = 0.6 # 默认中等价值
|
||||
if context.get("is_hippocampus_sample", False):
|
||||
# 对海马体样本进行简单价值评估
|
||||
if len(conversation_text) > 100: # 长文本可能有更多信息
|
||||
value_score = 0.7
|
||||
elif len(conversation_text) > 50:
|
||||
value_score = 0.6
|
||||
else:
|
||||
value_score = 0.5
|
||||
|
||||
logger.debug(f"海马体采样模式:使用价值评分 {value_score:.2f}")
|
||||
|
||||
# 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享)
|
||||
memory_chunks = await self.memory_builder.build_memories(
|
||||
@@ -469,7 +507,7 @@ class MemorySystem:
|
||||
continue
|
||||
search_tasks.append(
|
||||
self.unified_storage.search_similar_memories(
|
||||
query_text=display_text, limit=8, scope_id=GLOBAL_MEMORY_SCOPE
|
||||
query_text=display_text, limit=8, filters={"user_id": GLOBAL_MEMORY_SCOPE}
|
||||
)
|
||||
)
|
||||
|
||||
@@ -512,12 +550,70 @@ class MemorySystem:
|
||||
return existing_candidates
|
||||
|
||||
async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
|
||||
"""对外暴露的对话记忆处理接口,支持海马体、即时、所有三种采样模式"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
context = dict(context or {})
|
||||
|
||||
# 获取配置的采样模式
|
||||
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate')
|
||||
current_mode = MemorySamplingMode(sampling_mode)
|
||||
|
||||
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
||||
|
||||
# 根据采样模式处理记忆
|
||||
if current_mode == MemorySamplingMode.HIPPOCAMPUS:
|
||||
# 海马体模式:仅后台定时采样,不立即处理
|
||||
return {
|
||||
"success": True,
|
||||
"created_memories": [],
|
||||
"memory_count": 0,
|
||||
"processing_time": time.time() - start_time,
|
||||
"status": self.status.value,
|
||||
"processing_mode": "hippocampus",
|
||||
"message": "海马体模式:记忆将由后台定时任务采样处理",
|
||||
}
|
||||
|
||||
elif current_mode == MemorySamplingMode.IMMEDIATE:
|
||||
# 即时模式:立即处理记忆构建
|
||||
return await self._process_immediate_memory(context, start_time)
|
||||
|
||||
elif current_mode == MemorySamplingMode.ALL:
|
||||
# 所有模式:同时进行即时处理和海马体采样
|
||||
immediate_result = await self._process_immediate_memory(context, start_time)
|
||||
|
||||
# 海马体采样器会在后台继续处理,这里只是记录
|
||||
if self.hippocampus_sampler:
|
||||
immediate_result["processing_mode"] = "all_modes"
|
||||
immediate_result["hippocampus_status"] = "background_sampling_enabled"
|
||||
immediate_result["message"] = "所有模式:即时处理已完成,海马体采样将在后台继续"
|
||||
else:
|
||||
immediate_result["processing_mode"] = "immediate_fallback"
|
||||
immediate_result["hippocampus_status"] = "not_available"
|
||||
immediate_result["message"] = "海马体采样器不可用,回退到即时模式"
|
||||
|
||||
return immediate_result
|
||||
|
||||
else:
|
||||
# 默认回退到即时模式
|
||||
logger.warning(f"未知的采样模式 {sampling_mode},回退到即时模式")
|
||||
return await self._process_immediate_memory(context, start_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"对话记忆处理失败: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"processing_time": processing_time,
|
||||
"status": self.status.value,
|
||||
"processing_mode": "error",
|
||||
}
|
||||
|
||||
async def _process_immediate_memory(self, context: dict[str, Any], start_time: float) -> dict[str, Any]:
|
||||
"""即时记忆处理的辅助方法"""
|
||||
try:
|
||||
conversation_candidate = (
|
||||
context.get("conversation_text")
|
||||
or context.get("message_content")
|
||||
@@ -537,6 +633,23 @@ class MemorySystem:
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
|
||||
normalized_context.setdefault("conversation_text", conversation_text)
|
||||
|
||||
# 检查信息价值阈值
|
||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||
threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5)
|
||||
|
||||
if value_score < threshold:
|
||||
logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建")
|
||||
return {
|
||||
"success": True,
|
||||
"created_memories": [],
|
||||
"memory_count": 0,
|
||||
"processing_time": time.time() - start_time,
|
||||
"status": self.status.value,
|
||||
"processing_mode": "immediate",
|
||||
"skip_reason": f"value_score_{value_score:.2f}_below_threshold_{threshold}",
|
||||
"value_score": value_score,
|
||||
}
|
||||
|
||||
memories = await self.build_memory_from_conversation(
|
||||
conversation_text=conversation_text, context=normalized_context, timestamp=timestamp
|
||||
)
|
||||
@@ -550,12 +663,20 @@ class MemorySystem:
|
||||
"memory_count": memory_count,
|
||||
"processing_time": processing_time,
|
||||
"status": self.status.value,
|
||||
"processing_mode": "immediate",
|
||||
"value_score": value_score,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"对话记忆处理失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e), "processing_time": processing_time, "status": self.status.value}
|
||||
logger.error(f"即时记忆处理失败: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"processing_time": processing_time,
|
||||
"status": self.status.value,
|
||||
"processing_mode": "immediate_error",
|
||||
}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
@@ -1372,11 +1493,53 @@ class MemorySystem:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆系统维护失败: {e}", exc_info=True)
|
||||
|
||||
def start_hippocampus_sampling(self):
|
||||
"""启动海马体采样"""
|
||||
if self.hippocampus_sampler:
|
||||
asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
|
||||
logger.info("🚀 海马体后台采样已启动")
|
||||
else:
|
||||
logger.warning("海马体采样器未初始化,无法启动采样")
|
||||
|
||||
def stop_hippocampus_sampling(self):
|
||||
"""停止海马体采样"""
|
||||
if self.hippocampus_sampler:
|
||||
self.hippocampus_sampler.stop_background_sampling()
|
||||
logger.info("🛑 海马体后台采样已停止")
|
||||
|
||||
def get_system_stats(self) -> dict[str, Any]:
|
||||
"""获取系统统计信息"""
|
||||
base_stats = {
|
||||
"status": self.status.value,
|
||||
"total_memories": self.total_memories,
|
||||
"last_build_time": self.last_build_time,
|
||||
"last_retrieval_time": self.last_retrieval_time,
|
||||
"config": asdict(self.config),
|
||||
}
|
||||
|
||||
# 添加海马体采样器统计
|
||||
if self.hippocampus_sampler:
|
||||
base_stats["hippocampus_sampler"] = self.hippocampus_sampler.get_sampling_stats()
|
||||
|
||||
# 添加存储统计
|
||||
if self.unified_storage:
|
||||
try:
|
||||
storage_stats = self.unified_storage.get_storage_stats()
|
||||
base_stats["storage_stats"] = storage_stats
|
||||
except Exception as e:
|
||||
logger.debug(f"获取存储统计失败: {e}")
|
||||
|
||||
return base_stats
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭系统(简化版)"""
|
||||
try:
|
||||
logger.info("正在关闭简化记忆系统...")
|
||||
|
||||
# 停止海马体采样
|
||||
if self.hippocampus_sampler:
|
||||
self.hippocampus_sampler.stop_background_sampling()
|
||||
|
||||
# 保存统一存储数据
|
||||
if self.unified_storage:
|
||||
await self.unified_storage.cleanup()
|
||||
@@ -1456,4 +1619,10 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None):
|
||||
if memory_system is None:
|
||||
memory_system = MemorySystem(llm_model=llm_model)
|
||||
await memory_system.initialize()
|
||||
|
||||
# 根据配置启动海马体采样
|
||||
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate')
|
||||
if sampling_mode in ['hippocampus', 'all']:
|
||||
memory_system.start_hippocampus_sampling()
|
||||
|
||||
return memory_system
|
||||
|
||||
489
src/chat/message_manager/adaptive_stream_manager.py
Normal file
489
src/chat/message_manager/adaptive_stream_manager.py
Normal file
@@ -0,0 +1,489 @@
|
||||
"""
|
||||
自适应流管理器 - 动态并发限制和异步流池管理
|
||||
根据系统负载和流优先级动态调整并发限制
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import psutil
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("adaptive_stream_manager")
|
||||
|
||||
|
||||
class StreamPriority(Enum):
|
||||
"""流优先级"""
|
||||
LOW = 1
|
||||
NORMAL = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMetrics:
|
||||
"""系统指标"""
|
||||
cpu_usage: float = 0.0
|
||||
memory_usage: float = 0.0
|
||||
active_coroutines: int = 0
|
||||
event_loop_lag: float = 0.0
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMetrics:
|
||||
"""流指标"""
|
||||
stream_id: str
|
||||
priority: StreamPriority
|
||||
message_rate: float = 0.0 # 消息速率(消息/分钟)
|
||||
response_time: float = 0.0 # 平均响应时间
|
||||
last_activity: float = field(default_factory=time.time)
|
||||
consecutive_failures: int = 0
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class AdaptiveStreamManager:
|
||||
"""自适应流管理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_concurrent_limit: int = 50,
|
||||
max_concurrent_limit: int = 200,
|
||||
min_concurrent_limit: int = 10,
|
||||
metrics_window: float = 60.0, # 指标窗口时间
|
||||
adjustment_interval: float = 30.0, # 调整间隔
|
||||
cpu_threshold_high: float = 0.8, # CPU高负载阈值
|
||||
cpu_threshold_low: float = 0.3, # CPU低负载阈值
|
||||
memory_threshold_high: float = 0.85, # 内存高负载阈值
|
||||
):
|
||||
self.base_concurrent_limit = base_concurrent_limit
|
||||
self.max_concurrent_limit = max_concurrent_limit
|
||||
self.min_concurrent_limit = min_concurrent_limit
|
||||
self.metrics_window = metrics_window
|
||||
self.adjustment_interval = adjustment_interval
|
||||
self.cpu_threshold_high = cpu_threshold_high
|
||||
self.cpu_threshold_low = cpu_threshold_low
|
||||
self.memory_threshold_high = memory_threshold_high
|
||||
|
||||
# 当前状态
|
||||
self.current_limit = base_concurrent_limit
|
||||
self.active_streams: Set[str] = set()
|
||||
self.pending_streams: Set[str] = set()
|
||||
self.stream_metrics: Dict[str, StreamMetrics] = {}
|
||||
|
||||
# 异步信号量
|
||||
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
|
||||
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
|
||||
|
||||
# 系统监控
|
||||
self.system_metrics: List[SystemMetrics] = []
|
||||
self.last_adjustment_time = 0.0
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_requests": 0,
|
||||
"accepted_requests": 0,
|
||||
"rejected_requests": 0,
|
||||
"priority_accepts": 0,
|
||||
"limit_adjustments": 0,
|
||||
"avg_concurrent_streams": 0,
|
||||
"peak_concurrent_streams": 0,
|
||||
}
|
||||
|
||||
# 监控任务
|
||||
self.monitor_task: Optional[asyncio.Task] = None
|
||||
self.adjustment_task: Optional[asyncio.Task] = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
|
||||
|
||||
async def start(self):
|
||||
"""启动自适应管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("自适应流管理器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.monitor_task = asyncio.create_task(self._system_monitor_loop(), name="system_monitor")
|
||||
self.adjustment_task = asyncio.create_task(self._adjustment_loop(), name="limit_adjustment")
|
||||
logger.info("自适应流管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止自适应管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# 停止监控任务
|
||||
if self.monitor_task and not self.monitor_task.done():
|
||||
self.monitor_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.monitor_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("系统监控任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止系统监控任务时出错: {e}")
|
||||
|
||||
if self.adjustment_task and not self.adjustment_task.done():
|
||||
self.adjustment_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.adjustment_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("限制调整任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止限制调整任务时出错: {e}")
|
||||
|
||||
logger.info("自适应流管理器已停止")
|
||||
|
||||
async def acquire_stream_slot(
|
||||
self,
|
||||
stream_id: str,
|
||||
priority: StreamPriority = StreamPriority.NORMAL,
|
||||
force: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
获取流处理槽位
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
priority: 优先级
|
||||
force: 是否强制获取(突破限制)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功获取槽位
|
||||
"""
|
||||
# 检查管理器是否已启动
|
||||
if not self.is_running:
|
||||
logger.warning(f"自适应流管理器未运行,直接允许流 {stream_id}")
|
||||
return True
|
||||
|
||||
self.stats["total_requests"] += 1
|
||||
current_time = time.time()
|
||||
|
||||
# 更新流指标
|
||||
if stream_id not in self.stream_metrics:
|
||||
self.stream_metrics[stream_id] = StreamMetrics(
|
||||
stream_id=stream_id,
|
||||
priority=priority
|
||||
)
|
||||
self.stream_metrics[stream_id].last_activity = current_time
|
||||
|
||||
# 检查是否已经活跃
|
||||
if stream_id in self.active_streams:
|
||||
logger.debug(f"流 {stream_id} 已经在活跃列表中")
|
||||
return True
|
||||
|
||||
# 优先级处理
|
||||
if priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
|
||||
return await self._acquire_priority_slot(stream_id, priority, force)
|
||||
|
||||
# 检查是否需要强制分发(消息积压)
|
||||
if not force and self._should_force_dispatch(stream_id):
|
||||
force = True
|
||||
logger.info(f"流 {stream_id} 消息积压严重,强制分发")
|
||||
|
||||
# 尝试获取常规信号量
|
||||
try:
|
||||
# 使用wait_for实现非阻塞获取
|
||||
acquired = await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
|
||||
if acquired:
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取常规槽位成功 (当前活跃: {len(self.active_streams)})")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f"常规信号量已满: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取常规槽位时出错: {e}")
|
||||
|
||||
# 如果强制分发,尝试突破限制
|
||||
if force:
|
||||
return await self._force_acquire_slot(stream_id)
|
||||
|
||||
# 无法获取槽位
|
||||
self.stats["rejected_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取槽位失败,当前限制: {self.current_limit}, 活跃流: {len(self.active_streams)}")
|
||||
return False
|
||||
|
||||
async def _acquire_priority_slot(self, stream_id: str, priority: StreamPriority, force: bool) -> bool:
|
||||
"""获取优先级槽位"""
|
||||
try:
|
||||
# 优先级信号量有少量槽位
|
||||
acquired = await asyncio.wait_for(self.priority_semaphore.acquire(), timeout=0.001)
|
||||
if acquired:
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["priority_accepts"] += 1
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取优先级槽位成功 (优先级: {priority.name})")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f"优先级信号量已满: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取优先级槽位时出错: {e}")
|
||||
|
||||
# 如果优先级槽位也满了,检查是否强制
|
||||
if force or priority == StreamPriority.CRITICAL:
|
||||
return await self._force_acquire_slot(stream_id)
|
||||
|
||||
return False
|
||||
|
||||
async def _force_acquire_slot(self, stream_id: str) -> bool:
|
||||
"""强制获取槽位(突破限制)"""
|
||||
# 检查是否超过最大限制
|
||||
if len(self.active_streams) >= self.max_concurrent_limit:
|
||||
logger.warning(f"达到最大并发限制 {self.max_concurrent_limit},无法为流 {stream_id} 强制分发")
|
||||
return False
|
||||
|
||||
# 强制添加到活跃列表
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.warning(f"流 {stream_id} 突破并发限制强制分发 (当前活跃: {len(self.active_streams)})")
|
||||
return True
|
||||
|
||||
def release_stream_slot(self, stream_id: str):
|
||||
"""释放流处理槽位"""
|
||||
if stream_id in self.active_streams:
|
||||
self.active_streams.remove(stream_id)
|
||||
|
||||
# 释放相应的信号量
|
||||
metrics = self.stream_metrics.get(stream_id)
|
||||
if metrics and metrics.priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
|
||||
self.priority_semaphore.release()
|
||||
else:
|
||||
self.semaphore.release()
|
||||
|
||||
logger.debug(f"流 {stream_id} 释放槽位 (当前活跃: {len(self.active_streams)})")
|
||||
|
||||
def _should_force_dispatch(self, stream_id: str) -> bool:
|
||||
"""判断是否应该强制分发"""
|
||||
# 这里可以实现基于消息积压的判断逻辑
|
||||
# 简化版本:基于流的历史活跃度和优先级
|
||||
metrics = self.stream_metrics.get(stream_id)
|
||||
if not metrics:
|
||||
return False
|
||||
|
||||
# 如果是高优先级流,更容易强制分发
|
||||
if metrics.priority == StreamPriority.HIGH:
|
||||
return True
|
||||
|
||||
# 如果最近有活跃且响应时间较长,可能需要强制分发
|
||||
current_time = time.time()
|
||||
if (current_time - metrics.last_activity < 300 and # 5分钟内有活动
|
||||
metrics.response_time > 5.0): # 响应时间超过5秒
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _system_monitor_loop(self):
|
||||
"""系统监控循环"""
|
||||
logger.info("系统监控循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(5.0) # 每5秒监控一次
|
||||
await self._collect_system_metrics()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("系统监控循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"系统监控出错: {e}")
|
||||
|
||||
logger.info("系统监控循环结束")
|
||||
|
||||
async def _collect_system_metrics(self):
|
||||
"""收集系统指标"""
|
||||
try:
|
||||
# CPU使用率
|
||||
cpu_usage = psutil.cpu_percent(interval=None) / 100.0
|
||||
|
||||
# 内存使用率
|
||||
memory = psutil.virtual_memory()
|
||||
memory_usage = memory.percent / 100.0
|
||||
|
||||
# 活跃协程数量
|
||||
try:
|
||||
active_coroutines = len(asyncio.all_tasks())
|
||||
except:
|
||||
active_coroutines = 0
|
||||
|
||||
# 事件循环延迟
|
||||
event_loop_lag = 0.0
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
start_time = time.time()
|
||||
await asyncio.sleep(0)
|
||||
event_loop_lag = time.time() - start_time
|
||||
except:
|
||||
pass
|
||||
|
||||
metrics = SystemMetrics(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
active_coroutines=active_coroutines,
|
||||
event_loop_lag=event_loop_lag,
|
||||
timestamp=time.time()
|
||||
)
|
||||
|
||||
self.system_metrics.append(metrics)
|
||||
|
||||
# 保持指标窗口大小
|
||||
cutoff_time = time.time() - self.metrics_window
|
||||
self.system_metrics = [
|
||||
m for m in self.system_metrics
|
||||
if m.timestamp > cutoff_time
|
||||
]
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["avg_concurrent_streams"] = (
|
||||
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
|
||||
)
|
||||
self.stats["peak_concurrent_streams"] = max(
|
||||
self.stats["peak_concurrent_streams"],
|
||||
len(self.active_streams)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"收集系统指标失败: {e}")
|
||||
|
||||
async def _adjustment_loop(self):
|
||||
"""限制调整循环"""
|
||||
logger.info("限制调整循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(self.adjustment_interval)
|
||||
await self._adjust_concurrent_limit()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("限制调整循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"限制调整出错: {e}")
|
||||
|
||||
logger.info("限制调整循环结束")
|
||||
|
||||
async def _adjust_concurrent_limit(self):
|
||||
"""调整并发限制"""
|
||||
if not self.system_metrics:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self.last_adjustment_time < self.adjustment_interval:
|
||||
return
|
||||
|
||||
# 计算平均系统指标
|
||||
recent_metrics = self.system_metrics[-10:] if len(self.system_metrics) >= 10 else self.system_metrics
|
||||
if not recent_metrics:
|
||||
return
|
||||
|
||||
avg_cpu = sum(m.cpu_usage for m in recent_metrics) / len(recent_metrics)
|
||||
avg_memory = sum(m.memory_usage for m in recent_metrics) / len(recent_metrics)
|
||||
avg_coroutines = sum(m.active_coroutines for m in recent_metrics) / len(recent_metrics)
|
||||
|
||||
# 调整策略
|
||||
old_limit = self.current_limit
|
||||
adjustment_factor = 1.0
|
||||
|
||||
# CPU负载调整
|
||||
if avg_cpu > self.cpu_threshold_high:
|
||||
adjustment_factor *= 0.8 # 减少20%
|
||||
elif avg_cpu < self.cpu_threshold_low:
|
||||
adjustment_factor *= 1.2 # 增加20%
|
||||
|
||||
# 内存负载调整
|
||||
if avg_memory > self.memory_threshold_high:
|
||||
adjustment_factor *= 0.7 # 减少30%
|
||||
|
||||
# 协程数量调整
|
||||
if avg_coroutines > 1000:
|
||||
adjustment_factor *= 0.9 # 减少10%
|
||||
|
||||
# 应用调整
|
||||
new_limit = int(self.current_limit * adjustment_factor)
|
||||
new_limit = max(self.min_concurrent_limit, min(self.max_concurrent_limit, new_limit))
|
||||
|
||||
# 检查是否需要调整信号量
|
||||
if new_limit != self.current_limit:
|
||||
await self._adjust_semaphore(self.current_limit, new_limit)
|
||||
self.current_limit = new_limit
|
||||
self.stats["limit_adjustments"] += 1
|
||||
self.last_adjustment_time = current_time
|
||||
|
||||
logger.info(
|
||||
f"并发限制调整: {old_limit} -> {new_limit} "
|
||||
f"(CPU: {avg_cpu:.2f}, 内存: {avg_memory:.2f}, 协程: {avg_coroutines:.0f})"
|
||||
)
|
||||
|
||||
async def _adjust_semaphore(self, old_limit: int, new_limit: int):
|
||||
"""调整信号量大小"""
|
||||
if new_limit > old_limit:
|
||||
# 增加信号量槽位
|
||||
for _ in range(new_limit - old_limit):
|
||||
self.semaphore.release()
|
||||
elif new_limit < old_limit:
|
||||
# 减少信号量槽位(通过等待槽位被释放)
|
||||
reduction = old_limit - new_limit
|
||||
for _ in range(reduction):
|
||||
try:
|
||||
await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
|
||||
except:
|
||||
# 如果无法立即获取,说明当前使用量接近限制
|
||||
break
|
||||
|
||||
def update_stream_metrics(self, stream_id: str, **kwargs):
|
||||
"""更新流指标"""
|
||||
if stream_id not in self.stream_metrics:
|
||||
return
|
||||
|
||||
metrics = self.stream_metrics[stream_id]
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(metrics, key):
|
||||
setattr(metrics, key, value)
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats.update({
|
||||
"current_limit": self.current_limit,
|
||||
"active_streams": len(self.active_streams),
|
||||
"pending_streams": len(self.pending_streams),
|
||||
"is_running": self.is_running,
|
||||
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
||||
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
||||
})
|
||||
|
||||
# 计算接受率
|
||||
if stats["total_requests"] > 0:
|
||||
stats["acceptance_rate"] = stats["accepted_requests"] / stats["total_requests"]
|
||||
else:
|
||||
stats["acceptance_rate"] = 0
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# 全局自适应管理器实例
|
||||
_adaptive_manager: Optional[AdaptiveStreamManager] = None
|
||||
|
||||
|
||||
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
|
||||
"""获取自适应流管理器实例"""
|
||||
global _adaptive_manager
|
||||
if _adaptive_manager is None:
|
||||
_adaptive_manager = AdaptiveStreamManager()
|
||||
return _adaptive_manager
|
||||
|
||||
|
||||
async def init_adaptive_stream_manager():
|
||||
"""初始化自适应流管理器"""
|
||||
manager = get_adaptive_stream_manager()
|
||||
await manager.start()
|
||||
|
||||
|
||||
async def shutdown_adaptive_stream_manager():
|
||||
"""关闭自适应流管理器"""
|
||||
manager = get_adaptive_stream_manager()
|
||||
await manager.stop()
|
||||
348
src/chat/message_manager/batch_database_writer.py
Normal file
348
src/chat/message_manager/batch_database_writer.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
异步批量数据库写入器
|
||||
优化频繁的数据库写入操作,减少I/O阻塞
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("batch_database_writer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamUpdatePayload:
|
||||
"""流更新数据结构"""
|
||||
stream_id: str
|
||||
update_data: Dict[str, Any]
|
||||
priority: int = 0 # 优先级,数字越大优先级越高
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class BatchDatabaseWriter:
|
||||
"""异步批量数据库写入器"""
|
||||
|
||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0, max_queue_size: int = 1000):
|
||||
"""
|
||||
初始化批量写入器
|
||||
|
||||
Args:
|
||||
batch_size: 批量写入的大小
|
||||
flush_interval: 刷新间隔(秒)
|
||||
max_queue_size: 最大队列大小
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self.max_queue_size = max_queue_size
|
||||
|
||||
# 异步队列
|
||||
self.write_queue: asyncio.Queue[StreamUpdatePayload] = asyncio.Queue(maxsize=max_queue_size)
|
||||
|
||||
# 运行状态
|
||||
self.is_running = False
|
||||
self.writer_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_writes": 0,
|
||||
"batch_writes": 0,
|
||||
"failed_writes": 0,
|
||||
"queue_size": 0,
|
||||
"avg_batch_size": 0,
|
||||
"last_flush_time": 0,
|
||||
}
|
||||
|
||||
# 按优先级分类的批次
|
||||
self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list)
|
||||
|
||||
logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)")
|
||||
|
||||
async def start(self):
|
||||
"""启动批量写入器"""
|
||||
if self.is_running:
|
||||
logger.warning("批量写入器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.writer_task = asyncio.create_task(self._batch_writer_loop(), name="batch_database_writer")
|
||||
logger.info("批量数据库写入器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止批量写入器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# 等待当前批次写入完成
|
||||
if self.writer_task and not self.writer_task.done():
|
||||
try:
|
||||
# 先处理剩余的数据
|
||||
await self._flush_all_batches()
|
||||
# 取消任务
|
||||
self.writer_task.cancel()
|
||||
await asyncio.wait_for(self.writer_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("批量写入器停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止批量写入器时出错: {e}")
|
||||
|
||||
logger.info("批量数据库写入器已停止")
|
||||
|
||||
async def schedule_stream_update(
|
||||
self,
|
||||
stream_id: str,
|
||||
update_data: Dict[str, Any],
|
||||
priority: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
调度流更新
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
update_data: 更新数据
|
||||
priority: 优先级
|
||||
|
||||
Returns:
|
||||
bool: 是否成功加入队列
|
||||
"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
logger.warning("批量写入器未运行,直接写入数据库")
|
||||
await self._direct_write(stream_id, update_data)
|
||||
return True
|
||||
|
||||
# 创建更新载荷
|
||||
payload = StreamUpdatePayload(
|
||||
stream_id=stream_id,
|
||||
update_data=update_data,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
# 非阻塞方式加入队列
|
||||
try:
|
||||
self.write_queue.put_nowait(payload)
|
||||
self.stats["total_writes"] += 1
|
||||
self.stats["queue_size"] = self.write_queue.qsize()
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"写入队列已满,丢弃低优先级更新: stream_id={stream_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调度流更新失败: {e}")
|
||||
return False
|
||||
|
||||
async def _batch_writer_loop(self):
|
||||
"""批量写入主循环"""
|
||||
logger.info("批量写入循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# 等待批次填满或超时
|
||||
batch = await self._collect_batch()
|
||||
|
||||
if batch:
|
||||
await self._write_batch(batch)
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["queue_size"] = self.write_queue.qsize()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("批量写入循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"批量写入循环出错: {e}")
|
||||
# 短暂等待后继续
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# 循环结束前处理剩余数据
|
||||
await self._flush_all_batches()
|
||||
logger.info("批量写入循环结束")
|
||||
|
||||
async def _collect_batch(self) -> List[StreamUpdatePayload]:
|
||||
"""收集一个批次的数据"""
|
||||
batch = []
|
||||
deadline = time.time() + self.flush_interval
|
||||
|
||||
while len(batch) < self.batch_size and time.time() < deadline:
|
||||
try:
|
||||
# 计算剩余等待时间
|
||||
remaining_time = max(0, deadline - time.time())
|
||||
if remaining_time == 0:
|
||||
break
|
||||
|
||||
payload = await asyncio.wait_for(
|
||||
self.write_queue.get(),
|
||||
timeout=remaining_time
|
||||
)
|
||||
batch.append(payload)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
return batch
|
||||
|
||||
async def _write_batch(self, batch: List[StreamUpdatePayload]):
|
||||
"""批量写入数据库"""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 按优先级排序
|
||||
batch.sort(key=lambda x: (-x.priority, x.timestamp))
|
||||
|
||||
# 合并同一流ID的更新(保留最新的)
|
||||
merged_updates = {}
|
||||
for payload in batch:
|
||||
if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp:
|
||||
merged_updates[payload.stream_id] = payload
|
||||
|
||||
# 批量写入
|
||||
await self._batch_write_to_database(list(merged_updates.values()))
|
||||
|
||||
# 更新统计
|
||||
self.stats["batch_writes"] += 1
|
||||
self.stats["avg_batch_size"] = (
|
||||
self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1
|
||||
) # 滑动平均
|
||||
self.stats["last_flush_time"] = start_time
|
||||
|
||||
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
self.stats["failed_writes"] += 1
|
||||
logger.error(f"批量写入失败: {e}")
|
||||
# 降级到单个写入
|
||||
for payload in batch:
|
||||
try:
|
||||
await self._direct_write(payload.stream_id, payload.update_data)
|
||||
except Exception as single_e:
|
||||
logger.error(f"单个写入也失败: {single_e}")
|
||||
|
||||
async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]):
|
||||
"""批量写入数据库"""
|
||||
async with get_db_session() as session:
|
||||
for payload in payloads:
|
||||
stream_id = payload.stream_id
|
||||
update_data = payload.update_data
|
||||
|
||||
# 根据数据库类型选择不同的插入/更新策略
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
stmt = mysql_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用SQLite语法
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
await session.execute(stmt)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]):
|
||||
"""直接写入数据库(降级方案)"""
|
||||
async with get_db_session() as session:
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
stmt = mysql_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def _flush_all_batches(self):
|
||||
"""刷新所有剩余批次"""
|
||||
# 收集所有剩余数据
|
||||
remaining_batch = []
|
||||
while not self.write_queue.empty():
|
||||
try:
|
||||
payload = self.write_queue.get_nowait()
|
||||
remaining_batch.append(payload)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if remaining_batch:
|
||||
await self._write_batch(remaining_batch)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats["is_running"] = self.is_running
|
||||
stats["current_queue_size"] = self.write_queue.qsize() if self.is_running else 0
|
||||
return stats
|
||||
|
||||
|
||||
# 全局批量写入器实例
|
||||
_batch_writer: Optional[BatchDatabaseWriter] = None
|
||||
|
||||
|
||||
def get_batch_writer() -> BatchDatabaseWriter:
|
||||
"""获取批量写入器实例"""
|
||||
global _batch_writer
|
||||
if _batch_writer is None:
|
||||
_batch_writer = BatchDatabaseWriter()
|
||||
return _batch_writer
|
||||
|
||||
|
||||
async def init_batch_writer():
|
||||
"""初始化批量写入器"""
|
||||
writer = get_batch_writer()
|
||||
await writer.start()
|
||||
|
||||
|
||||
async def shutdown_batch_writer():
|
||||
"""关闭批量写入器"""
|
||||
writer = get_batch_writer()
|
||||
await writer.stop()
|
||||
@@ -23,6 +23,8 @@ class StreamLoopManager:
|
||||
def __init__(self, max_concurrent_streams: int | None = None):
|
||||
# 流循环任务管理
|
||||
self.stream_loops: dict[str, asyncio.Task] = {}
|
||||
# 跟踪流使用的管理器类型
|
||||
self.stream_management_type: dict[str, str] = {} # stream_id -> "adaptive" or "fallback"
|
||||
|
||||
# 统计信息
|
||||
self.stats: dict[str, Any] = {
|
||||
@@ -99,7 +101,7 @@ class StreamLoopManager:
|
||||
logger.info("流循环管理器已停止")
|
||||
|
||||
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
|
||||
"""启动指定流的循环任务
|
||||
"""启动指定流的循环任务 - 优化版本使用自适应管理器
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
@@ -113,6 +115,71 @@ class StreamLoopManager:
|
||||
logger.debug(f"流 {stream_id} 循环已在运行")
|
||||
return True
|
||||
|
||||
# 使用自适应流管理器获取槽位
|
||||
use_adaptive = False
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
|
||||
if adaptive_manager.is_running:
|
||||
# 确定流优先级
|
||||
priority = self._determine_stream_priority(stream_id)
|
||||
|
||||
# 获取处理槽位
|
||||
slot_acquired = await adaptive_manager.acquire_stream_slot(
|
||||
stream_id=stream_id,
|
||||
priority=priority,
|
||||
force=force
|
||||
)
|
||||
|
||||
if slot_acquired:
|
||||
use_adaptive = True
|
||||
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
|
||||
else:
|
||||
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
|
||||
else:
|
||||
logger.debug(f"自适应管理器未运行,使用原始方法")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")
|
||||
|
||||
# 如果自适应管理器失败或未运行,使用回退方案
|
||||
if not use_adaptive:
|
||||
if not await self._fallback_acquire_slot(stream_id, force):
|
||||
logger.debug(f"回退方案也失败: {stream_id}")
|
||||
return False
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
loop_task = asyncio.create_task(
|
||||
self._stream_loop_worker(stream_id),
|
||||
name=f"stream_loop_{stream_id}"
|
||||
)
|
||||
self.stream_loops[stream_id] = loop_task
|
||||
# 记录管理器类型
|
||||
self.stream_management_type[stream_id] = "adaptive" if use_adaptive else "fallback"
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["total_loops"] += 1
|
||||
|
||||
logger.info(f"启动流循环任务: {stream_id} (管理器: {'adaptive' if use_adaptive else 'fallback'})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
||||
# 释放槽位
|
||||
if use_adaptive:
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool:
|
||||
"""回退方案:获取槽位(原始方法)"""
|
||||
# 判断是否需要强制分发
|
||||
should_force = force or self._should_force_dispatch_for_stream(stream_id)
|
||||
|
||||
@@ -149,6 +216,28 @@ class StreamLoopManager:
|
||||
del self.stream_loops[stream_id]
|
||||
current_streams -= 1 # 更新当前流数量
|
||||
|
||||
return True
|
||||
|
||||
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
|
||||
"""确定流优先级"""
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
|
||||
# 这里可以基于流的历史数据、用户身份等确定优先级
|
||||
# 简化版本:基于流ID的哈希值分配优先级
|
||||
hash_value = hash(stream_id) % 10
|
||||
|
||||
if hash_value >= 8: # 20% 高优先级
|
||||
return StreamPriority.HIGH
|
||||
elif hash_value >= 5: # 30% 中等优先级
|
||||
return StreamPriority.NORMAL
|
||||
else: # 50% 低优先级
|
||||
return StreamPriority.LOW
|
||||
|
||||
except Exception:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
return StreamPriority.NORMAL
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
@@ -201,13 +290,13 @@ class StreamLoopManager:
|
||||
logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})")
|
||||
return True
|
||||
|
||||
async def _stream_loop(self, stream_id: str) -> None:
|
||||
"""单个流的无限循环
|
||||
async def _stream_loop_worker(self, stream_id: str) -> None:
|
||||
"""单个流的工作循环 - 优化版本
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
"""
|
||||
logger.info(f"流循环开始: {stream_id}")
|
||||
logger.info(f"流循环工作器启动: {stream_id}")
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
@@ -223,6 +312,18 @@ class StreamLoopManager:
|
||||
unread_count = self._get_unread_count(context)
|
||||
force_dispatch = self._needs_force_dispatch_for_context(context, unread_count)
|
||||
|
||||
# 3. 更新自适应管理器指标
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.update_stream_metrics(
|
||||
stream_id,
|
||||
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
|
||||
last_activity=time.time()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"更新流指标失败: {e}")
|
||||
|
||||
has_messages = force_dispatch or await self._has_messages_to_process(context)
|
||||
|
||||
if has_messages:
|
||||
@@ -278,6 +379,24 @@ class StreamLoopManager:
|
||||
del self.stream_loops[stream_id]
|
||||
logger.debug(f"清理流循环标记: {stream_id}")
|
||||
|
||||
# 根据管理器类型释放相应的槽位
|
||||
management_type = self.stream_management_type.get(stream_id, "fallback")
|
||||
if management_type == "adaptive":
|
||||
# 释放自适应管理器的槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
||||
else:
|
||||
logger.debug(f"流 {stream_id} 使用回退方案,无需释放自适应槽位")
|
||||
|
||||
# 清理管理器类型记录
|
||||
if stream_id in self.stream_management_type:
|
||||
del self.stream_management_type[stream_id]
|
||||
|
||||
logger.info(f"流循环结束: {stream_id}")
|
||||
|
||||
async def _get_stream_context(self, stream_id: str) -> Any | None:
|
||||
|
||||
@@ -56,6 +56,30 @@ class MessageManager:
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# 启动批量数据库写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import init_batch_writer
|
||||
await init_batch_writer()
|
||||
logger.info("📦 批量数据库写入器已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动批量数据库写入器失败: {e}")
|
||||
|
||||
# 启动流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
|
||||
await init_stream_cache_manager()
|
||||
logger.info("🗄️ 流缓存管理器已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动流缓存管理器失败: {e}")
|
||||
|
||||
# 启动自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
|
||||
await init_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动自适应流管理器失败: {e}")
|
||||
|
||||
# 启动睡眠和唤醒管理器
|
||||
await self.wakeup_manager.start()
|
||||
|
||||
@@ -72,6 +96,30 @@ class MessageManager:
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# 停止批量数据库写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import shutdown_batch_writer
|
||||
await shutdown_batch_writer()
|
||||
logger.info("📦 批量数据库写入器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止批量数据库写入器失败: {e}")
|
||||
|
||||
# 停止流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
|
||||
await shutdown_stream_cache_manager()
|
||||
logger.info("🗄️ 流缓存管理器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止流缓存管理器失败: {e}")
|
||||
|
||||
# 停止自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
|
||||
await shutdown_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止自适应流管理器失败: {e}")
|
||||
|
||||
# 停止睡眠和唤醒管理器
|
||||
await self.wakeup_manager.stop()
|
||||
|
||||
|
||||
381
src/chat/message_manager/stream_cache_manager.py
Normal file
381
src/chat/message_manager/stream_cache_manager.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
流缓存管理器 - 使用优化版聊天流和智能缓存策略
|
||||
提供分层缓存和自动清理功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
|
||||
|
||||
logger = get_logger("stream_cache_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamCacheStats:
|
||||
"""缓存统计信息"""
|
||||
hot_cache_size: int = 0
|
||||
warm_storage_size: int = 0
|
||||
cold_storage_size: int = 0
|
||||
total_memory_usage: int = 0 # 估算的内存使用(字节)
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
evictions: int = 0
|
||||
last_cleanup_time: float = 0
|
||||
|
||||
|
||||
class TieredStreamCache:
|
||||
"""分层流缓存管理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_hot_size: int = 100,
|
||||
max_warm_size: int = 500,
|
||||
max_cold_size: int = 2000,
|
||||
cleanup_interval: float = 300.0, # 5分钟清理一次
|
||||
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
|
||||
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
|
||||
cold_timeout: float = 86400.0, # 24小时未访问删除
|
||||
):
|
||||
self.max_hot_size = max_hot_size
|
||||
self.max_warm_size = max_warm_size
|
||||
self.max_cold_size = max_cold_size
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.hot_timeout = hot_timeout
|
||||
self.warm_timeout = warm_timeout
|
||||
self.cold_timeout = cold_timeout
|
||||
|
||||
# 三层缓存存储
|
||||
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
|
||||
self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
||||
self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
||||
|
||||
# 统计信息
|
||||
self.stats = StreamCacheStats()
|
||||
|
||||
# 清理任务
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
|
||||
|
||||
async def start(self):
|
||||
"""启动缓存管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("缓存管理器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
|
||||
logger.info("分层流缓存管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止缓存管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
if self.cleanup_task and not self.cleanup_task.done():
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("缓存清理任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止缓存清理任务时出错: {e}")
|
||||
|
||||
logger.info("分层流缓存管理器已停止")
|
||||
|
||||
async def get_or_create_stream(
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[Dict] = None,
|
||||
) -> OptimizedChatStream:
|
||||
"""获取或创建流 - 优化版本"""
|
||||
current_time = time.time()
|
||||
|
||||
# 1. 检查热缓存
|
||||
if stream_id in self.hot_cache:
|
||||
stream = self.hot_cache[stream_id]
|
||||
# 移动到末尾(LRU更新)
|
||||
self.hot_cache.move_to_end(stream_id)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"热缓存命中: {stream_id}")
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 2. 检查温存储
|
||||
if stream_id in self.warm_storage:
|
||||
stream, last_access = self.warm_storage[stream_id]
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"温缓存命中: {stream_id}")
|
||||
# 提升到热缓存
|
||||
await self._promote_to_hot(stream_id, stream)
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 3. 检查冷存储
|
||||
if stream_id in self.cold_storage:
|
||||
stream, last_access = self.cold_storage[stream_id]
|
||||
self.cold_storage[stream_id] = (stream, current_time)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"冷缓存命中: {stream_id}")
|
||||
# 提升到温缓存
|
||||
await self._promote_to_warm(stream_id, stream)
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 4. 缓存未命中,创建新流
|
||||
self.stats.cache_misses += 1
|
||||
stream = create_optimized_chat_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data
|
||||
)
|
||||
logger.debug(f"缓存未命中,创建新流: {stream_id}")
|
||||
|
||||
# 添加到热缓存
|
||||
await self._add_to_hot(stream_id, stream)
|
||||
|
||||
return stream
|
||||
|
||||
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""添加到热缓存"""
|
||||
# 检查是否需要驱逐
|
||||
if len(self.hot_cache) >= self.max_hot_size:
|
||||
await self._evict_from_hot()
|
||||
|
||||
self.hot_cache[stream_id] = stream
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
|
||||
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""提升到热缓存"""
|
||||
# 从温存储中移除
|
||||
if stream_id in self.warm_storage:
|
||||
del self.warm_storage[stream_id]
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
|
||||
# 添加到热缓存
|
||||
await self._add_to_hot(stream_id, stream)
|
||||
logger.debug(f"流 {stream_id} 提升到热缓存")
|
||||
|
||||
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""提升到温缓存"""
|
||||
# 从冷存储中移除
|
||||
if stream_id in self.cold_storage:
|
||||
del self.cold_storage[stream_id]
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
|
||||
# 添加到温存储
|
||||
if len(self.warm_storage) >= self.max_warm_size:
|
||||
await self._evict_from_warm()
|
||||
|
||||
current_time = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
logger.debug(f"流 {stream_id} 提升到温缓存")
|
||||
|
||||
async def _evict_from_hot(self):
|
||||
"""从热缓存驱逐最久未使用的流"""
|
||||
if not self.hot_cache:
|
||||
return
|
||||
|
||||
# LRU驱逐
|
||||
stream_id, stream = self.hot_cache.popitem(last=False)
|
||||
self.stats.evictions += 1
|
||||
logger.debug(f"从热缓存驱逐: {stream_id}")
|
||||
|
||||
# 移动到温存储
|
||||
if len(self.warm_storage) < self.max_warm_size:
|
||||
current_time = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
else:
|
||||
# 温存储也满了,直接删除
|
||||
logger.debug(f"温存储已满,删除流: {stream_id}")
|
||||
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
|
||||
async def _evict_from_warm(self):
|
||||
"""从温存储驱逐最久未使用的流"""
|
||||
if not self.warm_storage:
|
||||
return
|
||||
|
||||
# 找到最久未访问的流
|
||||
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
|
||||
stream, last_access = self.warm_storage.pop(oldest_stream_id)
|
||||
self.stats.evictions += 1
|
||||
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
|
||||
|
||||
# 移动到冷存储
|
||||
if len(self.cold_storage) < self.max_cold_size:
|
||||
current_time = time.time()
|
||||
self.cold_storage[oldest_stream_id] = (stream, current_time)
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
else:
|
||||
# 冷存储也满了,直接删除
|
||||
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
|
||||
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""清理循环"""
|
||||
logger.info("流缓存清理循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self._perform_cleanup()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("流缓存清理循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"流缓存清理出错: {e}")
|
||||
|
||||
logger.info("流缓存清理循环结束")
|
||||
|
||||
async def _perform_cleanup(self):
|
||||
"""执行清理操作"""
|
||||
current_time = time.time()
|
||||
cleanup_stats = {
|
||||
"hot_to_warm": 0,
|
||||
"warm_to_cold": 0,
|
||||
"cold_removed": 0,
|
||||
}
|
||||
|
||||
# 1. 检查热缓存超时
|
||||
hot_to_demote = []
|
||||
for stream_id, stream in self.hot_cache.items():
|
||||
# 获取最后访问时间(简化:使用创建时间作为近似)
|
||||
last_access = getattr(stream, 'last_active_time', stream.create_time)
|
||||
if current_time - last_access > self.hot_timeout:
|
||||
hot_to_demote.append(stream_id)
|
||||
|
||||
for stream_id in hot_to_demote:
|
||||
stream = self.hot_cache.pop(stream_id)
|
||||
current_time_local = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time_local)
|
||||
cleanup_stats["hot_to_warm"] += 1
|
||||
|
||||
# 2. 检查温存储超时
|
||||
warm_to_demote = []
|
||||
for stream_id, (stream, last_access) in self.warm_storage.items():
|
||||
if current_time - last_access > self.warm_timeout:
|
||||
warm_to_demote.append(stream_id)
|
||||
|
||||
for stream_id in warm_to_demote:
|
||||
stream, last_access = self.warm_storage.pop(stream_id)
|
||||
self.cold_storage[stream_id] = (stream, last_access)
|
||||
cleanup_stats["warm_to_cold"] += 1
|
||||
|
||||
# 3. 检查冷存储超时
|
||||
cold_to_remove = []
|
||||
for stream_id, (stream, last_access) in self.cold_storage.items():
|
||||
if current_time - last_access > self.cold_timeout:
|
||||
cold_to_remove.append(stream_id)
|
||||
|
||||
for stream_id in cold_to_remove:
|
||||
self.cold_storage.pop(stream_id)
|
||||
cleanup_stats["cold_removed"] += 1
|
||||
|
||||
# 更新统计信息
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
self.stats.last_cleanup_time = current_time
|
||||
|
||||
# 估算内存使用(粗略估计)
|
||||
self.stats.total_memory_usage = (
|
||||
len(self.hot_cache) * 1024 + # 每个热流约1KB
|
||||
len(self.warm_storage) * 512 + # 每个温流约512B
|
||||
len(self.cold_storage) * 256 # 每个冷流约256B
|
||||
)
|
||||
|
||||
if sum(cleanup_stats.values()) > 0:
|
||||
logger.info(
|
||||
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
|
||||
f"{cleanup_stats['warm_to_cold']}温→冷, "
|
||||
f"{cleanup_stats['cold_removed']}冷删除"
|
||||
)
|
||||
|
||||
def get_stats(self) -> StreamCacheStats:
|
||||
"""获取缓存统计信息"""
|
||||
# 计算命中率
|
||||
total_requests = self.stats.cache_hits + self.stats.cache_misses
|
||||
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
|
||||
|
||||
stats_copy = StreamCacheStats(
|
||||
hot_cache_size=self.stats.hot_cache_size,
|
||||
warm_storage_size=self.stats.warm_storage_size,
|
||||
cold_storage_size=self.stats.cold_storage_size,
|
||||
total_memory_usage=self.stats.total_memory_usage,
|
||||
cache_hits=self.stats.cache_hits,
|
||||
cache_misses=self.stats.cache_misses,
|
||||
evictions=self.stats.evictions,
|
||||
last_cleanup_time=self.stats.last_cleanup_time,
|
||||
)
|
||||
|
||||
# 添加命中率信息
|
||||
stats_copy.hit_rate = hit_rate
|
||||
|
||||
return stats_copy
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空所有缓存"""
|
||||
self.hot_cache.clear()
|
||||
self.warm_storage.clear()
|
||||
self.cold_storage.clear()
|
||||
|
||||
self.stats.hot_cache_size = 0
|
||||
self.stats.warm_storage_size = 0
|
||||
self.stats.cold_storage_size = 0
|
||||
self.stats.total_memory_usage = 0
|
||||
|
||||
logger.info("所有缓存已清空")
|
||||
|
||||
async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]:
|
||||
"""获取流的快照(不修改缓存状态)"""
|
||||
if stream_id in self.hot_cache:
|
||||
return self.hot_cache[stream_id].create_snapshot()
|
||||
elif stream_id in self.warm_storage:
|
||||
return self.warm_storage[stream_id][0].create_snapshot()
|
||||
elif stream_id in self.cold_storage:
|
||||
return self.cold_storage[stream_id][0].create_snapshot()
|
||||
return None
|
||||
|
||||
def get_cached_stream_ids(self) -> Set[str]:
|
||||
"""获取所有缓存的流ID"""
|
||||
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
|
||||
|
||||
|
||||
# 全局缓存管理器实例
|
||||
_cache_manager: Optional[TieredStreamCache] = None
|
||||
|
||||
|
||||
def get_stream_cache_manager() -> TieredStreamCache:
|
||||
"""获取流缓存管理器实例"""
|
||||
global _cache_manager
|
||||
if _cache_manager is None:
|
||||
_cache_manager = TieredStreamCache()
|
||||
return _cache_manager
|
||||
|
||||
|
||||
async def init_stream_cache_manager():
|
||||
"""初始化流缓存管理器"""
|
||||
manager = get_stream_cache_manager()
|
||||
await manager.start()
|
||||
|
||||
|
||||
async def shutdown_stream_cache_manager():
|
||||
"""关闭流缓存管理器"""
|
||||
manager = get_stream_cache_manager()
|
||||
await manager.stop()
|
||||
@@ -464,7 +464,7 @@ class ChatManager:
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流
|
||||
"""获取或创建聊天流 - 优化版本使用缓存管理器
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
@@ -478,6 +478,31 @@ class ChatManager:
|
||||
try:
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
|
||||
# 优先使用缓存管理器(优化版本)
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
|
||||
cache_manager = get_stream_cache_manager()
|
||||
|
||||
if cache_manager.is_running:
|
||||
optimized_stream = await cache_manager.get_or_create_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
)
|
||||
|
||||
# 设置消息上下文
|
||||
from .message import MessageRecv
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||
optimized_stream.set_context(self.last_messages[stream_id])
|
||||
|
||||
# 转换为原始ChatStream以保持兼容性
|
||||
return self._convert_to_original_stream(optimized_stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
|
||||
|
||||
# 回退到原始方法
|
||||
# 检查内存中是否存在
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
@@ -634,12 +659,35 @@ class ChatManager:
|
||||
|
||||
@staticmethod
|
||||
async def _save_stream(stream: ChatStream):
|
||||
"""保存聊天流到数据库"""
|
||||
"""保存聊天流到数据库 - 优化版本使用异步批量写入"""
|
||||
if stream.saved:
|
||||
return
|
||||
stream_data_dict = stream.to_dict()
|
||||
|
||||
# 尝试使用数据库批量调度器
|
||||
# 优先使用新的批量写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import get_batch_writer
|
||||
|
||||
batch_writer = get_batch_writer()
|
||||
if batch_writer.is_running:
|
||||
success = await batch_writer.schedule_stream_update(
|
||||
stream_id=stream_data_dict["stream_id"],
|
||||
update_data=ChatManager._prepare_stream_data(stream_data_dict),
|
||||
priority=1 # 流更新的优先级
|
||||
)
|
||||
if success:
|
||||
stream.saved = True
|
||||
logger.debug(f"聊天流 {stream.stream_id} 通过批量写入器调度成功")
|
||||
return
|
||||
else:
|
||||
logger.warning(f"批量写入器队列已满,使用原始方法: {stream.stream_id}")
|
||||
else:
|
||||
logger.debug(f"批量写入器未运行,使用原始方法: {stream.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"批量写入器保存聊天流失败,使用原始方法: {e}")
|
||||
|
||||
# 尝试使用数据库批量调度器(回退方案1)
|
||||
try:
|
||||
from src.common.database.db_batch_scheduler import batch_update, get_batch_session
|
||||
|
||||
@@ -657,7 +705,7 @@ class ChatManager:
|
||||
except (ImportError, Exception) as e:
|
||||
logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}")
|
||||
|
||||
# 回退到原始方法
|
||||
# 回退到原始方法(最终方案)
|
||||
async def _db_save_stream_async(s_data_dict: dict):
|
||||
async with get_db_session() as session:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
@@ -782,6 +830,46 @@ class ChatManager:
|
||||
chat_manager = None
|
||||
|
||||
|
||||
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
||||
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
|
||||
try:
|
||||
# 创建原始ChatStream实例
|
||||
original_stream = ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info()
|
||||
)
|
||||
|
||||
# 复制状态
|
||||
original_stream.create_time = optimized_stream.create_time
|
||||
original_stream.last_active_time = optimized_stream.last_active_time
|
||||
original_stream.sleep_pressure = optimized_stream.sleep_pressure
|
||||
original_stream.base_interest_energy = optimized_stream.base_interest_energy
|
||||
original_stream._focus_energy = optimized_stream._focus_energy
|
||||
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
|
||||
original_stream.saved = optimized_stream.saved
|
||||
|
||||
# 复制上下文信息(如果存在)
|
||||
if hasattr(optimized_stream, '_stream_context') and optimized_stream._stream_context:
|
||||
original_stream.stream_context = optimized_stream._stream_context
|
||||
|
||||
if hasattr(optimized_stream, '_context_manager') and optimized_stream._context_manager:
|
||||
original_stream.context_manager = optimized_stream._context_manager
|
||||
|
||||
return original_stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换OptimizedChatStream失败: {e}")
|
||||
# 如果转换失败,创建一个新的原始流
|
||||
return ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info()
|
||||
)
|
||||
|
||||
|
||||
def get_chat_manager():
|
||||
global chat_manager
|
||||
if chat_manager is None:
|
||||
|
||||
494
src/chat/message_receive/optimized_chat_stream.py
Normal file
494
src/chat/message_receive/optimized_chat_stream.py
Normal file
@@ -0,0 +1,494 @@
|
||||
"""
|
||||
优化版聊天流 - 实现写时复制机制
|
||||
避免不必要的深拷贝开销,提升多流并发性能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .message import MessageRecv
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("optimized_chat_stream")
|
||||
|
||||
|
||||
class SharedContext:
|
||||
"""共享上下文数据 - 只读数据结构"""
|
||||
|
||||
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
self.user_info = user_info
|
||||
self.group_info = group_info
|
||||
self.create_time = time.time()
|
||||
self._frozen = True
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if hasattr(self, '_frozen') and self._frozen and name not in ['_frozen']:
|
||||
raise AttributeError(f"SharedContext is frozen, cannot modify {name}")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
class LocalChanges:
|
||||
"""本地修改跟踪器"""
|
||||
|
||||
def __init__(self):
|
||||
self._changes: Dict[str, Any] = {}
|
||||
self._dirty = False
|
||||
|
||||
def set_change(self, key: str, value: Any):
|
||||
"""设置修改项"""
|
||||
self._changes[key] = value
|
||||
self._dirty = True
|
||||
|
||||
def get_change(self, key: str, default: Any = None) -> Any:
|
||||
"""获取修改项"""
|
||||
return self._changes.get(key, default)
|
||||
|
||||
def has_changes(self) -> bool:
|
||||
"""是否有修改"""
|
||||
return self._dirty
|
||||
|
||||
def get_changes(self) -> Dict[str, Any]:
|
||||
"""获取所有修改"""
|
||||
return self._changes.copy()
|
||||
|
||||
def clear_changes(self):
|
||||
"""清除修改记录"""
|
||||
self._changes.clear()
|
||||
self._dirty = False
|
||||
|
||||
|
||||
class OptimizedChatStream:
|
||||
"""优化版聊天流 - 使用写时复制机制"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[Dict] = None,
|
||||
):
|
||||
# 共享的只读数据
|
||||
self._shared_context = SharedContext(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
)
|
||||
|
||||
# 本地修改数据
|
||||
self._local_changes = LocalChanges()
|
||||
|
||||
# 写时复制标志
|
||||
self._copy_on_write = False
|
||||
|
||||
# 基础参数
|
||||
self.base_interest_energy = data.get("base_interest_energy", 0.5) if data else 0.5
|
||||
self._focus_energy = data.get("focus_energy", 0.5) if data else 0.5
|
||||
self.no_reply_consecutive = 0
|
||||
|
||||
# 创建StreamContext(延迟创建)
|
||||
self._stream_context = None
|
||||
self._context_manager = None
|
||||
|
||||
# 更新活跃时间
|
||||
self.update_active_time()
|
||||
|
||||
# 保存标志
|
||||
self.saved = False
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str:
|
||||
return self._shared_context.stream_id
|
||||
|
||||
@property
|
||||
def platform(self) -> str:
|
||||
return self._shared_context.platform
|
||||
|
||||
@property
|
||||
def user_info(self) -> UserInfo:
|
||||
return self._shared_context.user_info
|
||||
|
||||
@user_info.setter
|
||||
def user_info(self, value: UserInfo):
|
||||
"""修改用户信息时触发写时复制"""
|
||||
self._ensure_copy_on_write()
|
||||
# 由于SharedContext是frozen的,我们需要在本地修改中记录
|
||||
self._local_changes.set_change('user_info', value)
|
||||
|
||||
@property
|
||||
def group_info(self) -> Optional[GroupInfo]:
|
||||
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
|
||||
return self._local_changes.get_change('group_info')
|
||||
return self._shared_context.group_info
|
||||
|
||||
@group_info.setter
|
||||
def group_info(self, value: Optional[GroupInfo]):
|
||||
"""修改群组信息时触发写时复制"""
|
||||
self._ensure_copy_on_write()
|
||||
self._local_changes.set_change('group_info', value)
|
||||
|
||||
@property
|
||||
def create_time(self) -> float:
|
||||
if self._local_changes.has_changes() and 'create_time' in self._local_changes._changes:
|
||||
return self._local_changes.get_change('create_time')
|
||||
return self._shared_context.create_time
|
||||
|
||||
@property
|
||||
def last_active_time(self) -> float:
|
||||
return self._local_changes.get_change('last_active_time', self.create_time)
|
||||
|
||||
@last_active_time.setter
|
||||
def last_active_time(self, value: float):
|
||||
self._local_changes.set_change('last_active_time', value)
|
||||
self.saved = False
|
||||
|
||||
@property
|
||||
def sleep_pressure(self) -> float:
|
||||
return self._local_changes.get_change('sleep_pressure', 0.0)
|
||||
|
||||
@sleep_pressure.setter
|
||||
def sleep_pressure(self, value: float):
|
||||
self._local_changes.set_change('sleep_pressure', value)
|
||||
self.saved = False
|
||||
|
||||
def _ensure_copy_on_write(self):
|
||||
"""确保写时复制机制生效"""
|
||||
if not self._copy_on_write:
|
||||
self._copy_on_write = True
|
||||
# 深拷贝共享上下文到本地
|
||||
logger.debug(f"触发写时复制: {self.stream_id}")
|
||||
|
||||
def _get_effective_user_info(self) -> UserInfo:
|
||||
"""获取有效的用户信息"""
|
||||
if self._local_changes.has_changes() and 'user_info' in self._local_changes._changes:
|
||||
return self._local_changes.get_change('user_info')
|
||||
return self._shared_context.user_info
|
||||
|
||||
def _get_effective_group_info(self) -> Optional[GroupInfo]:
|
||||
"""获取有效的群组信息"""
|
||||
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
|
||||
return self._local_changes.get_change('group_info')
|
||||
return self._shared_context.group_info
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_time = time.time()
|
||||
|
||||
def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
# 确保stream_context存在
|
||||
if self._stream_context is None:
|
||||
self._ensure_copy_on_write()
|
||||
self._create_stream_context()
|
||||
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
import json
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
group_info = getattr(message_info, "group_info", {})
|
||||
|
||||
reply_to = None
|
||||
if hasattr(message, "message_segment") and message.message_segment:
|
||||
reply_to = self._extract_reply_from_segment(message.message_segment)
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
time=getattr(message, "time", time.time()),
|
||||
chat_id=self._generate_chat_id(message_info),
|
||||
reply_to=reply_to,
|
||||
interest_value=getattr(message, "interest_value", 0.0),
|
||||
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words", None)
|
||||
else None,
|
||||
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words_lite", None)
|
||||
else None,
|
||||
is_mentioned=getattr(message, "is_mentioned", None),
|
||||
is_at=getattr(message, "is_at", False),
|
||||
is_emoji=getattr(message, "is_emoji", False),
|
||||
is_picid=getattr(message, "is_picid", False),
|
||||
is_voice=getattr(message, "is_voice", False),
|
||||
is_video=getattr(message, "is_video", False),
|
||||
is_command=getattr(message, "is_command", False),
|
||||
is_notify=getattr(message, "is_notify", False),
|
||||
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
||||
display_message=getattr(message, "processed_plain_text", ""),
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||||
if getattr(message, "priority_info", None)
|
||||
else None,
|
||||
additional_config=getattr(message_info, "additional_config", None),
|
||||
user_id=str(getattr(user_info, "user_id", "")),
|
||||
user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
user_cardname=getattr(user_info, "user_cardname", None),
|
||||
user_platform=getattr(user_info, "platform", ""),
|
||||
chat_info_group_id=getattr(group_info, "group_id", None),
|
||||
chat_info_group_name=getattr(group_info, "group_name", None),
|
||||
chat_info_group_platform=getattr(group_info, "platform", None),
|
||||
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
||||
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
||||
chat_info_user_platform=getattr(user_info, "platform", ""),
|
||||
chat_info_stream_id=self.stream_id,
|
||||
chat_info_platform=self.platform,
|
||||
chat_info_create_time=self.create_time,
|
||||
chat_info_last_active_time=self.last_active_time,
|
||||
actions=self._safe_get_actions(message),
|
||||
should_reply=getattr(message, "should_reply", False),
|
||||
)
|
||||
|
||||
self._stream_context.set_current_message(db_message)
|
||||
self._stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||||
self._stream_context.priority_info = getattr(message, "priority_info", None)
|
||||
|
||||
logger.debug(
|
||||
f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||
f"chat_id: {db_message.chat_id}, "
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _create_stream_context(self):
|
||||
"""创建StreamContext"""
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
self._stream_context = StreamContext(
|
||||
stream_id=self.stream_id,
|
||||
chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.NORMAL
|
||||
)
|
||||
|
||||
# 创建单流上下文管理器
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
self._context_manager = SingleStreamContextManager(
|
||||
stream_id=self.stream_id, context=self._stream_context
|
||||
)
|
||||
|
||||
@property
|
||||
def stream_context(self):
|
||||
"""获取StreamContext"""
|
||||
if self._stream_context is None:
|
||||
self._ensure_copy_on_write()
|
||||
self._create_stream_context()
|
||||
return self._stream_context
|
||||
|
||||
@property
|
||||
def context_manager(self):
|
||||
"""获取ContextManager"""
|
||||
if self._context_manager is None:
|
||||
self._ensure_copy_on_write()
|
||||
self._create_stream_context()
|
||||
return self._context_manager
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式 - 考虑本地修改"""
|
||||
user_info = self._get_effective_user_info()
|
||||
group_info = self._get_effective_group_info()
|
||||
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"platform": self.platform,
|
||||
"user_info": user_info.to_dict() if user_info else None,
|
||||
"group_info": group_info.to_dict() if group_info else None,
|
||||
"create_time": self.create_time,
|
||||
"last_active_time": self.last_active_time,
|
||||
"sleep_pressure": self.sleep_pressure,
|
||||
"focus_energy": self.focus_energy,
|
||||
"base_interest_energy": self.base_interest_energy,
|
||||
"stream_context_chat_type": self.stream_context.chat_type.value,
|
||||
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
||||
"interruption_count": self.stream_context.interruption_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "OptimizedChatStream":
|
||||
"""从字典创建实例"""
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||
|
||||
instance = cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
data=data,
|
||||
)
|
||||
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
# 恢复interruption_count信息
|
||||
if "interruption_count" in data:
|
||||
instance.stream_context.interruption_count = data["interruption_count"]
|
||||
|
||||
return instance
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
if actions is None:
|
||||
return None
|
||||
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
return None
|
||||
|
||||
if isinstance(actions, list):
|
||||
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
|
||||
return filtered_actions if filtered_actions else None
|
||||
else:
|
||||
logger.warning(f"actions字段类型不支持: {type(actions)}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
if hasattr(segment, "data") and segment.data:
|
||||
for seg in segment.data:
|
||||
reply_id = self._extract_reply_from_segment(seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
elif hasattr(segment, "type") and segment.type == "reply":
|
||||
return str(segment.data) if segment.data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"提取reply_to信息失败: {e}")
|
||||
return None
|
||||
|
||||
def _generate_chat_id(self, message_info) -> str:
|
||||
"""生成chat_id,基于群组或用户信息"""
|
||||
try:
|
||||
group_info = getattr(message_info, "group_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
|
||||
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
|
||||
return f"{self.platform}_{group_info.group_id}"
|
||||
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
|
||||
return f"{self.platform}_{user_info.user_id}_private"
|
||||
else:
|
||||
return self.stream_id
|
||||
except Exception as e:
|
||||
logger.warning(f"生成chat_id失败: {e}")
|
||||
return self.stream_id
|
||||
|
||||
@property
|
||||
def focus_energy(self) -> float:
|
||||
"""获取缓存的focus_energy值"""
|
||||
return self._focus_energy
|
||||
|
||||
async def calculate_focus_energy(self) -> float:
|
||||
"""异步计算focus_energy"""
|
||||
try:
|
||||
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
|
||||
|
||||
user_id = None
|
||||
effective_user_info = self._get_effective_user_info()
|
||||
if effective_user_info and hasattr(effective_user_info, "user_id"):
|
||||
user_id = str(effective_user_info.user_id)
|
||||
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
energy = await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id, messages=all_messages, user_id=user_id
|
||||
)
|
||||
|
||||
self._focus_energy = energy
|
||||
|
||||
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
|
||||
return energy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
|
||||
return self._focus_energy
|
||||
|
||||
@focus_energy.setter
|
||||
def focus_energy(self, value: float):
|
||||
"""设置focus_energy值"""
|
||||
self._focus_energy = max(0.0, min(1.0, value))
|
||||
|
||||
async def _get_user_relationship_score(self) -> float:
|
||||
"""获取用户关系分"""
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
effective_user_info = self._get_effective_user_info()
|
||||
if effective_user_info and hasattr(effective_user_info, "user_id"):
|
||||
user_id = str(effective_user_info.user_id)
|
||||
relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
logger.debug(f"OptimizedChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"OptimizedChatStream {self.stream_id}: 插件内部关系分计算失败: {e}")
|
||||
|
||||
return 0.3
|
||||
|
||||
def create_snapshot(self) -> "OptimizedChatStream":
|
||||
"""创建当前状态的快照(用于缓存)"""
|
||||
# 创建一个新的实例,共享相同的上下文
|
||||
snapshot = OptimizedChatStream(
|
||||
stream_id=self.stream_id,
|
||||
platform=self.platform,
|
||||
user_info=self._get_effective_user_info(),
|
||||
group_info=self._get_effective_group_info()
|
||||
)
|
||||
|
||||
# 复制本地修改(但不触发写时复制)
|
||||
snapshot._local_changes._changes = self._local_changes.get_changes()
|
||||
snapshot._local_changes._dirty = self._local_changes._dirty
|
||||
snapshot._focus_energy = self._focus_energy
|
||||
snapshot.base_interest_energy = self.base_interest_energy
|
||||
snapshot.no_reply_consecutive = self.no_reply_consecutive
|
||||
snapshot.saved = self.saved
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
# 为了向后兼容,创建一个工厂函数
|
||||
def create_optimized_chat_stream(
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[Dict] = None,
|
||||
) -> OptimizedChatStream:
|
||||
"""创建优化版聊天流实例"""
|
||||
return OptimizedChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data
|
||||
)
|
||||
@@ -337,6 +337,41 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
# 休眠机制
|
||||
dormant_threshold_days: int = Field(default=90, description="休眠状态判定天数")
|
||||
|
||||
# === 混合记忆系统配置 ===
|
||||
# 采样模式配置
|
||||
memory_sampling_mode: Literal["adaptive", "hippocampus", "precision"] = Field(
|
||||
default="adaptive", description="记忆采样模式:adaptive(自适应),hippocampus(海马体双峰采样),precision(精准记忆)"
|
||||
)
|
||||
|
||||
# 海马体双峰采样配置
|
||||
enable_hippocampus_sampling: bool = Field(default=True, description="启用海马体双峰采样策略")
|
||||
hippocampus_sample_interval: int = Field(default=1800, description="海马体采样间隔(秒,默认30分钟)")
|
||||
hippocampus_sample_size: int = Field(default=30, description="海马体每次采样的消息数量")
|
||||
hippocampus_batch_size: int = Field(default=5, description="海马体每批处理的记忆数量")
|
||||
|
||||
# 双峰分布配置 [近期均值, 近期标准差, 近期权重, 远期均值, 远期标准差, 远期权重]
|
||||
hippocampus_distribution_config: list[float] = Field(
|
||||
default=[12.0, 8.0, 0.7, 48.0, 24.0, 0.3],
|
||||
description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]"
|
||||
)
|
||||
|
||||
# 自适应采样配置
|
||||
adaptive_sampling_enabled: bool = Field(default=True, description="启用自适应采样策略")
|
||||
adaptive_sampling_threshold: float = Field(default=0.8, description="自适应采样负载阈值(0-1)")
|
||||
adaptive_sampling_check_interval: int = Field(default=300, description="自适应采样检查间隔(秒)")
|
||||
adaptive_sampling_max_concurrent_builds: int = Field(default=3, description="自适应采样最大并发记忆构建数")
|
||||
|
||||
# 精准记忆配置(现有系统的增强版本)
|
||||
precision_memory_reply_threshold: float = Field(
|
||||
default=0.6, description="精准记忆回复触发阈值(对话价值评分超过此值时触发记忆构建)"
|
||||
)
|
||||
precision_memory_max_builds_per_hour: int = Field(default=10, description="精准记忆每小时最大构建数量")
|
||||
|
||||
# 混合系统优化配置
|
||||
memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡")
|
||||
memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流")
|
||||
memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列")
|
||||
|
||||
|
||||
class MoodConfig(ValidatedConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
@@ -21,76 +21,69 @@ logger = get_logger(__name__)
|
||||
|
||||
class ColdStartTask(AsyncTask):
|
||||
"""
|
||||
冷启动任务,专门用于处理那些在白名单里,但从未与机器人发生过交互的用户。
|
||||
它的核心职责是“破冰”,主动创建聊天流并发起第一次问候。
|
||||
“冷启动”任务,在机器人启动时执行一次。
|
||||
它的核心职责是“唤醒”那些因重启而“沉睡”的聊天流,确保它们能够接收主动思考。
|
||||
对于在白名单中但从未有过记录的全新用户,它也会发起第一次“破冰”问候。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, bot_start_time: float):
|
||||
super().__init__(task_name="ColdStartTask")
|
||||
self.chat_manager = get_chat_manager()
|
||||
self.executor = ProactiveThinkerExecutor()
|
||||
self.bot_start_time = bot_start_time
|
||||
|
||||
async def run(self):
|
||||
"""任务主循环,周期性地检查是否有需要“破冰”的新用户。"""
|
||||
logger.info("冷启动任务已启动,将周期性检查白名单中的新朋友。")
|
||||
# 初始等待一段时间,确保其他服务(如数据库)完全启动
|
||||
await asyncio.sleep(100)
|
||||
"""任务主逻辑,在启动后执行一次白名单扫描。"""
|
||||
logger.info("冷启动任务已启动,将在短暂延迟后开始唤醒沉睡的聊天流...")
|
||||
await asyncio.sleep(30) # 延迟以确保所有服务和聊天流已从数据库加载完毕
|
||||
|
||||
while True:
|
||||
try:
|
||||
#开始就先暂停一小时,等bot聊一会再说()
|
||||
await asyncio.sleep(3600)
|
||||
logger.info("【冷启动】开始扫描白名单,寻找从未聊过的用户...")
|
||||
try:
|
||||
logger.info("【冷启动】开始扫描白名单,唤醒沉睡的聊天流...")
|
||||
|
||||
# 从全局配置中获取私聊白名单
|
||||
enabled_private_chats = global_config.proactive_thinking.enabled_private_chats
|
||||
if not enabled_private_chats:
|
||||
logger.debug("【冷启动】私聊白名单为空,任务暂停一小时。")
|
||||
await asyncio.sleep(3600) # 白名单为空时,没必要频繁检查
|
||||
continue
|
||||
enabled_private_chats = global_config.proactive_thinking.enabled_private_chats
|
||||
if not enabled_private_chats:
|
||||
logger.debug("【冷启动】私聊白名单为空,任务结束。")
|
||||
return
|
||||
|
||||
# 遍历白名单中的每一个用户
|
||||
for chat_id in enabled_private_chats:
|
||||
try:
|
||||
platform, user_id_str = chat_id.split(":")
|
||||
user_id = int(user_id_str)
|
||||
for chat_id in enabled_private_chats:
|
||||
try:
|
||||
platform, user_id_str = chat_id.split(":")
|
||||
user_id = int(user_id_str)
|
||||
|
||||
# 【核心逻辑】使用 chat_api 检查该用户是否已经存在聊天流(ChatStream)
|
||||
# 如果返回了 ChatStream 对象,说明已经聊过天了,不是本次任务的目标
|
||||
if chat_api.get_stream_by_user_id(user_id_str, platform):
|
||||
continue # 跳过已存在的用户
|
||||
should_wake_up = False
|
||||
stream = chat_api.get_stream_by_user_id(user_id_str, platform)
|
||||
|
||||
logger.info(f"【冷启动】发现白名单新用户 {chat_id},准备发起第一次问候。")
|
||||
if not stream:
|
||||
should_wake_up = True
|
||||
logger.info(f"【冷启动】发现全新用户 {chat_id},准备发起第一次问候。")
|
||||
elif stream.last_active_time < self.bot_start_time:
|
||||
should_wake_up = True
|
||||
logger.info(f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。")
|
||||
|
||||
# 【增强体验】尝试从关系数据库中获取该用户的昵称
|
||||
# 这样打招呼时可以更亲切,而不是只知道一个冷冰冰的ID
|
||||
if should_wake_up:
|
||||
person_id = person_api.get_person_id(platform, user_id)
|
||||
nickname = await person_api.get_person_value(person_id, "nickname")
|
||||
|
||||
# 如果数据库里有昵称,就用数据库里的;如果没有,就用 "用户+ID" 作为备用
|
||||
user_nickname = nickname or f"用户{user_id}"
|
||||
|
||||
# 创建 UserInfo 对象,这是创建聊天流的必要信息
|
||||
user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname)
|
||||
|
||||
# 【关键步骤】主动创建聊天流。
|
||||
# 创建后,该用户就进入了机器人的“好友列表”,后续将由 ProactiveThinkingTask 接管
|
||||
# 使用 get_or_create_stream 来安全地获取或创建流
|
||||
stream = await self.chat_manager.get_or_create_stream(platform, user_info)
|
||||
|
||||
await self.executor.execute(stream_id=stream.stream_id, start_mode="cold_start")
|
||||
logger.info(f"【冷启动】已为新用户 {chat_id} (昵称: {user_nickname}) 创建聊天流并发送问候。")
|
||||
formatted_stream_id = f"{stream.user_info.platform}:{stream.user_info.user_id}:private"
|
||||
await self.executor.execute(stream_id=formatted_stream_id, start_mode="cold_start")
|
||||
logger.info(f"【冷启动】已为用户 {chat_id} (昵称: {user_nickname}) 发送唤醒/问候消息。")
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"【冷启动】白名单条目格式错误或用户ID无效,已跳过: {chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"【冷启动】处理用户 {chat_id} 时发生未知错误: {e}", exc_info=True)
|
||||
except ValueError:
|
||||
logger.warning(f"【冷启动】白名单条目格式错误或用户ID无效,已跳过: {chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"【冷启动】处理用户 {chat_id} 时发生未知错误: {e}", exc_info=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("冷启动任务被正常取消。")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"【冷启动】任务出现严重错误,将在5分钟后重试: {e}", exc_info=True)
|
||||
await asyncio.sleep(300)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("冷启动任务被正常取消。")
|
||||
except Exception as e:
|
||||
logger.error(f"【冷启动】任务出现严重错误: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.info("【冷启动】任务执行完毕。")
|
||||
|
||||
|
||||
class ProactiveThinkingTask(AsyncTask):
|
||||
@@ -222,13 +215,15 @@ class ProactiveThinkerEventHandler(BaseEventHandler):
|
||||
logger.info("检测到插件启动事件,正在初始化【主动思考】")
|
||||
# 检查总开关
|
||||
if global_config.proactive_thinking.enable:
|
||||
bot_start_time = time.time() # 记录“诞生时刻”
|
||||
|
||||
# 启动负责“日常唤醒”的核心任务
|
||||
proactive_task = ProactiveThinkingTask()
|
||||
await async_task_manager.add_task(proactive_task)
|
||||
|
||||
# 检查“冷启动”功能的独立开关
|
||||
if global_config.proactive_thinking.enable_cold_start:
|
||||
cold_start_task = ColdStartTask()
|
||||
cold_start_task = ColdStartTask(bot_start_time)
|
||||
await async_task_manager.add_task(cold_start_task)
|
||||
|
||||
else:
|
||||
|
||||
@@ -80,7 +80,7 @@ class ProactiveThinkerExecutor:
|
||||
plan_prompt = self._build_plan_prompt(context, start_mode, topic, reason)
|
||||
|
||||
is_success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt=plan_prompt, model_config=model_config.model_task_config.utils
|
||||
prompt=plan_prompt, model_config=model_config.model_task_config.replyer
|
||||
)
|
||||
|
||||
if is_success and response:
|
||||
@@ -140,7 +140,7 @@ class ProactiveThinkerExecutor:
|
||||
else "今天没有日程安排。"
|
||||
)
|
||||
|
||||
recent_messages = await message_api.get_recent_messages(stream_id, limit=10)
|
||||
recent_messages = await message_api.get_recent_messages(stream.stream_id)
|
||||
recent_chat_history = (
|
||||
await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无"
|
||||
)
|
||||
@@ -158,12 +158,12 @@ class ProactiveThinkerExecutor:
|
||||
)
|
||||
|
||||
# 2. 构建基础上下文
|
||||
mood_state = "暂时没有"
|
||||
if global_config.mood.enable_mood:
|
||||
try:
|
||||
mood_state = mood_manager.get_mood_by_chat_id(stream.stream_id).mood_state
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪失败,原因:{e}")
|
||||
mood_state = "暂时没有"
|
||||
base_context = {
|
||||
"schedule_context": schedule_context,
|
||||
"recent_chat_history": recent_chat_history,
|
||||
@@ -281,30 +281,48 @@ class ProactiveThinkerExecutor:
|
||||
# 构建通用尾部
|
||||
prompt += """
|
||||
# 决策指令
|
||||
请综合以上所有信息,做出决策。你的决策需要以JSON格式输出,包含以下字段:
|
||||
请综合以上所有信息,以稳定、真实、拟人的方式做出决策。你的决策需要以JSON格式输出,包含以下字段:
|
||||
- `should_reply`: bool, 是否应该发起对话。
|
||||
- `topic`: str, 如果 `should_reply` 为 true,你打算聊什么话题?(例如:问候一下今天的日程、关心一下昨天的某件事、分享一个你自己的趣事等)
|
||||
- `topic`: str, 如果 `should_reply` 为 true,你打算聊什么话题?
|
||||
- `reason`: str, 做出此决策的简要理由。
|
||||
|
||||
# 决策原则
|
||||
- **避免打扰**: 如果你最近(尤其是在最近的几次决策中)已经主动发起过对话,请倾向于选择“不回复”,除非有非常重要和紧急的事情。
|
||||
- **如果上下文中只有你的消息而没有别人的消息**:选择不回复,以防刷屏或者打扰到别人虽然第一
|
||||
- **谨慎对待未回复的对话**: 在发起新话题前,请检查【最近的聊天摘要】。如果最后一条消息是你自己发送的,请仔细评估等待的时间和上下文,判断再次主动发起对话是否礼貌和自然。如果等待时间很短(例如几分钟或半小时内),通常应该选择“不回复”。
|
||||
- **优先利用上下文**: 优先从【情境分析】中已有的信息(如最近的聊天摘要、你的日程、你对Ta的关系印象)寻找自然的话题切入点。
|
||||
- **简单问候作为备选**: 如果上下文中没有合适的话题,可以生成一个简单、真诚的日常问候(例如“在忙吗?”,“下午好呀~”)。
|
||||
- **避免抽象**: 避免创造过于复杂、抽象或需要对方思考很久才能明白的话题。目标是轻松、自然地开启对话。
|
||||
- **避免过于频繁**: 如果你最近(尤其是在最近的几次决策中)已经主动发起过对话,请倾向于选择“不回复”,除非有非常重要和紧急的事情。
|
||||
- **如果上下文中只有你的消息而没有别人的消息**:选择不回复,以防刷屏或者打扰到别人
|
||||
|
||||
|
||||
---
|
||||
示例1 (应该回复):
|
||||
示例1 (基于上下文):
|
||||
{{
|
||||
"should_reply": true,
|
||||
"topic": "提醒大家今天下午有'项目会议'的日程",
|
||||
"reason": "现在是上午,下午有个重要会议,我觉得应该主动提醒一下大家,这会显得我很贴心。"
|
||||
"topic": "关心一下Ta昨天提到的那个项目进展如何了",
|
||||
"reason": "用户昨天在聊天中提到了一个重要的项目,现在主动关心一下进展,会显得很体贴,也能自然地开启对话。"
|
||||
}}
|
||||
|
||||
示例2 (不应回复):
|
||||
示例2 (简单问候):
|
||||
{{
|
||||
"should_reply": true,
|
||||
"topic": "打个招呼,问问Ta现在在忙些什么",
|
||||
"reason": "最近没有聊天记录,日程也很常规,没有特别的切入点。一个简单的日常问候是最安全和自然的方式来重新连接。"
|
||||
}}
|
||||
|
||||
示例3 (不应回复 - 过于频繁):
|
||||
{{
|
||||
"should_reply": false,
|
||||
"topic": null,
|
||||
"reason": "虽然群里很活跃,但现在是深夜,而且最近的聊天话题我也不熟悉,没有合适的理由去打扰大家。"
|
||||
}}
|
||||
|
||||
示例4 (不应回复 - 等待回应):
|
||||
{{
|
||||
"should_reply": false,
|
||||
"topic": null,
|
||||
"reason": "我注意到上一条消息是我几分钟前主动发送的,对方可能正在忙。为了表现出耐心和体贴,我现在最好保持安静,等待对方的回应。"
|
||||
}}
|
||||
---
|
||||
|
||||
请输出你的决策:
|
||||
@@ -369,10 +387,18 @@ class ProactiveThinkerExecutor:
|
||||
|
||||
# 决策上下文
|
||||
- **决策理由**: {reason}
|
||||
- **你和Ta的关系**:
|
||||
|
||||
# 情境分析
|
||||
1. **你的日程**:
|
||||
{context["schedule_context"]}
|
||||
2. **你和Ta的关系**:
|
||||
- 简短印象: {relationship["short_impression"]}
|
||||
- 详细印象: {relationship["impression"]}
|
||||
- 好感度: {relationship["attitude"]}/100
|
||||
3. **最近的聊天摘要**:
|
||||
{context["recent_chat_history"]}
|
||||
4. **你最近的相关动作**:
|
||||
{context["action_history_context"]}
|
||||
|
||||
# 对话指引
|
||||
- 你的目标是“破冰”,让对话自然地开始。
|
||||
@@ -400,6 +426,7 @@ class ProactiveThinkerExecutor:
|
||||
|
||||
# 对话指引
|
||||
- 你决定和Ta聊聊关于“{topic}”的话题。
|
||||
- **重要**: 在开始你的话题前,必须先用一句通用的、礼貌的开场白进行问候(例如:“在吗?”、“上午好!”、“晚上好呀~”),然后再自然地衔接你的话题,确保整个回复在一条消息内流畅、自然、像人类的说话方式。
|
||||
- 请结合以上所有情境信息,自然地开启对话。
|
||||
- 你的语气应该符合你的人设({context["mood_state"]})以及你对Ta的好感度。
|
||||
"""
|
||||
@@ -437,6 +464,7 @@ class ProactiveThinkerExecutor:
|
||||
|
||||
# 对话指引
|
||||
- 你决定和大家聊聊关于“{topic}”的话题。
|
||||
- **重要**: 在开始你的话题前,必须先用一句通用的、礼貌的开场白进行问候(例如:“哈喽,大家好呀~”、“下午好!”),然后再自然地衔接你的话题,确保整个回复在一条消息内流畅、自然、像人类的说话方式。
|
||||
- 你的语气应该更活泼、更具包容性,以吸引更多群成员参与讨论。你的语气应该符合你的人设)。
|
||||
- 请结合以上所有情境信息,自然地开启对话。
|
||||
- 可以分享你的看法、提出相关问题,或者开个合适的玩笑。
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "7.1.5"
|
||||
version = "7.1.6"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -208,6 +208,19 @@ max_context_emojis = 30 # 每次随机传递给LLM的表情包详细描述的最
|
||||
enable_memory = true # 是否启用记忆系统
|
||||
memory_build_interval = 600 # 记忆构建间隔(秒)。间隔越低,学习越频繁,但可能产生更多冗余信息
|
||||
|
||||
# === 记忆采样系统配置 ===
|
||||
memory_sampling_mode = "immediate" # 记忆采样模式:hippocampus(海马体定时采样),immediate(即时采样),all(所有模式)
|
||||
|
||||
# 海马体双峰采样配置
|
||||
enable_hippocampus_sampling = true # 启用海马体双峰采样策略
|
||||
hippocampus_sample_interval = 1800 # 海马体采样间隔(秒,默认30分钟)
|
||||
hippocampus_sample_size = 30 # 海马体采样样本数量
|
||||
hippocampus_batch_size = 10 # 海马体批量处理大小
|
||||
hippocampus_distribution_config = [12.0, 8.0, 0.7, 48.0, 24.0, 0.3] # 海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]
|
||||
|
||||
# 即时采样配置
|
||||
precision_memory_reply_threshold = 0.5 # 精准记忆回复阈值(0-1),高于此值的对话将立即构建记忆
|
||||
|
||||
min_memory_length = 10 # 最小记忆长度
|
||||
max_memory_length = 500 # 最大记忆长度
|
||||
memory_value_threshold = 0.5 # 记忆价值阈值,低于该值的记忆会被丢弃
|
||||
|
||||
Reference in New Issue
Block a user