This commit is contained in:
minecraft1024a
2025-10-04 12:10:11 +08:00
23 changed files with 3030 additions and 5283 deletions

View File

@@ -373,7 +373,11 @@ class VirtualLogDisplay:
# 为每个部分应用正确的标签 # 为每个部分应用正确的标签
current_len = 0 current_len = 0
for part, tag_name in zip(parts, tags, strict=False): # Python 3.9 兼容性:不使用 strict=False 参数
min_len = min(len(parts), len(tags))
for i in range(min_len):
part = parts[i]
tag_name = tags[i]
start_index = f"{start_pos}+{current_len}c" start_index = f"{start_pos}+{current_len}c"
end_index = f"{start_pos}+{current_len + len(part)}c" end_index = f"{start_pos}+{current_len + len(part)}c"
self.text_widget.tag_add(tag_name, start_index, end_index) self.text_widget.tag_add(tag_name, start_index, end_index)

View File

@@ -1,363 +0,0 @@
"""
增强记忆系统适配器
将增强记忆系统集成到现有MoFox Bot架构中
"""
import time
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
MEMORY_TYPE_LABELS = {
MemoryType.PERSONAL_FACT: "个人事实",
MemoryType.EVENT: "事件",
MemoryType.PREFERENCE: "偏好",
MemoryType.OPINION: "观点",
MemoryType.RELATIONSHIP: "关系",
MemoryType.EMOTION: "情感",
MemoryType.KNOWLEDGE: "知识",
MemoryType.SKILL: "技能",
MemoryType.GOAL: "目标",
MemoryType.EXPERIENCE: "经验",
MemoryType.CONTEXTUAL: "上下文",
}
@dataclass
class AdapterConfig:
"""适配器配置"""
enable_enhanced_memory: bool = True
integration_mode: str = "enhanced_only" # replace, enhanced_only
auto_migration: bool = True
memory_value_threshold: float = 0.6
fusion_threshold: float = 0.85
max_retrieval_results: int = 10
class EnhancedMemoryAdapter:
"""增强记忆系统适配器"""
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
self.llm_model = llm_model
self.config = config or AdapterConfig()
self.integration_layer: MemoryIntegrationLayer | None = None
self._initialized = False
# 统计信息
self.adapter_stats = {
"total_processed": 0,
"enhanced_used": 0,
"legacy_used": 0,
"hybrid_used": 0,
"memories_created": 0,
"memories_retrieved": 0,
"average_processing_time": 0.0,
}
async def initialize(self):
"""初始化适配器"""
if self._initialized:
return
try:
logger.info("🚀 初始化增强记忆系统适配器...")
# 转换配置格式
integration_config = IntegrationConfig(
mode=IntegrationMode(self.config.integration_mode),
enable_enhanced_memory=self.config.enable_enhanced_memory,
memory_value_threshold=self.config.memory_value_threshold,
fusion_threshold=self.config.fusion_threshold,
max_retrieval_results=self.config.max_retrieval_results,
enable_learning=True, # 启用学习功能
)
# 创建集成层
self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config)
# 初始化集成层
await self.integration_layer.initialize()
self._initialized = True
logger.info("✅ 增强记忆系统适配器初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True)
# 如果初始化失败,禁用增强记忆功能
self.config.enable_enhanced_memory = False
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
start_time = time.time()
self.adapter_stats["total_processed"] += 1
try:
payload_context: dict[str, Any] = dict(context or {})
conversation_text = payload_context.get("conversation_text")
if not conversation_text:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
conversation_text = str(conversation_candidate)
payload_context["conversation_text"] = conversation_text
else:
conversation_text = ""
else:
conversation_text = str(conversation_text)
if "timestamp" not in payload_context:
payload_context["timestamp"] = time.time()
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
# 使用集成层处理对话
result = await self.integration_layer.process_conversation(payload_context)
# 更新统计
processing_time = time.time() - start_time
self._update_processing_stats(processing_time)
if result["success"]:
created_count = len(result.get("created_memories", []))
self.adapter_stats["memories_created"] += created_count
logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆")
return result
except Exception as e:
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.config.enable_enhanced_memory:
return []
try:
limit = limit or self.config.max_retrieval_results
memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit)
self.adapter_stats["memories_retrieved"] += len(memories)
logger.debug(f"检索到 {len(memories)} 条相关记忆")
return memories
except Exception as e:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_memory_context_for_prompt(
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
) -> str:
"""获取用于提示词的记忆上下文"""
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
if not memories:
return ""
# 使用新的记忆格式化器
formatter_config = FormatterConfig(
include_timestamps=True,
include_memory_types=True,
include_confidence=False,
use_emoji_icons=True,
group_by_type=False,
max_display_length=150,
)
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
"""获取增强记忆系统摘要"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"available": False, "reason": "Not initialized or disabled"}
try:
# 获取系统状态
status = await self.integration_layer.get_system_status()
# 获取适配器统计
adapter_stats = self.adapter_stats.copy()
# 获取集成统计
integration_stats = self.integration_layer.get_integration_stats()
return {
"available": True,
"system_status": status,
"adapter_stats": adapter_stats,
"integration_stats": integration_stats,
"total_memories_created": adapter_stats["memories_created"],
"total_memories_retrieved": adapter_stats["memories_retrieved"],
}
except Exception as e:
logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True)
return {"available": False, "error": str(e)}
def _update_processing_stats(self, processing_time: float):
"""更新处理统计"""
total_processed = self.adapter_stats["total_processed"]
if total_processed > 0:
current_avg = self.adapter_stats["average_processing_time"]
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
self.adapter_stats["average_processing_time"] = new_avg
def get_adapter_stats(self) -> dict[str, Any]:
"""获取适配器统计信息"""
return self.adapter_stats.copy()
async def maintenance(self):
"""维护操作"""
if not self._initialized:
return
try:
logger.info("🔧 增强记忆系统适配器维护...")
await self.integration_layer.maintenance()
logger.info("✅ 增强记忆系统适配器维护完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭适配器"""
if not self._initialized:
return
try:
logger.info("🔄 关闭增强记忆系统适配器...")
await self.integration_layer.shutdown()
self._initialized = False
logger.info("✅ 增强记忆系统适配器已关闭")
except Exception as e:
logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True)
# 全局适配器实例
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
"""获取全局增强记忆适配器实例"""
global _enhanced_memory_adapter
if _enhanced_memory_adapter is None:
# 从配置中获取适配器配置
from src.config.config import global_config
adapter_config = AdapterConfig(
enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True),
integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"),
auto_migration=getattr(global_config.memory, "enable_memory_migration", True),
memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6),
fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85),
max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10),
)
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
await _enhanced_memory_adapter.initialize()
return _enhanced_memory_adapter
async def initialize_enhanced_memory_system(llm_model: LLMRequest):
"""初始化增强记忆系统"""
try:
logger.info("🚀 初始化增强记忆系统...")
adapter = await get_enhanced_memory_adapter(llm_model)
logger.info("✅ 增强记忆系统初始化完成")
return adapter
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
return None
async def process_conversation_with_enhanced_memory(
context: dict[str, Any], llm_model: LLMRequest | None = None
) -> dict[str, Any]:
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
payload_context = dict(context or {})
if "conversation_text" not in payload_context:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
payload_context["conversation_text"] = str(conversation_candidate)
return await adapter.process_conversation_memory(payload_context)
except Exception as e:
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_memories_with_enhanced_system(
query: str,
user_id: str,
context: dict[str, Any] | None = None,
limit: int = 10,
llm_model: LLMRequest | None = None,
) -> list[MemoryChunk]:
"""使用增强记忆系统检索记忆"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.retrieve_relevant_memories(query, user_id, context, limit)
except Exception as e:
logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True)
return []
async def get_memory_context_for_prompt(
query: str,
user_id: str,
context: dict[str, Any] | None = None,
max_memories: int = 5,
llm_model: LLMRequest | None = None,
) -> str:
"""获取用于提示词的记忆上下文"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
llm_model = get_global_llm_model()
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
except Exception as e:
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
return ""

View File

@@ -1,194 +0,0 @@
"""
增强记忆系统钩子
用于在消息处理过程中自动构建和检索记忆
"""
from datetime import datetime
from typing import Any
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger(__name__)
class EnhancedMemoryHooks:
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
def __init__(self):
self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory
self.processed_messages = set() # 避免重复处理
async def process_message_for_memory(
self,
message_content: str,
user_id: str,
chat_id: str,
message_id: str,
context: dict[str, Any] | None = None,
) -> bool:
"""
处理消息并构建记忆
Args:
message_content: 消息内容
user_id: 用户ID
chat_id: 聊天ID
message_id: 消息ID
context: 上下文信息
Returns:
bool: 是否成功处理
"""
if not self.enabled:
return False
if message_id in self.processed_messages:
return False
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 注入机器人基础人设,帮助记忆构建时避免记录自身信息
bot_config = getattr(global_config, "bot", None)
personality_config = getattr(global_config, "personality", None)
bot_context = {}
if bot_config is not None:
bot_context["bot_name"] = getattr(bot_config, "nickname", None)
bot_context["bot_aliases"] = list(getattr(bot_config, "alias_names", []) or [])
bot_context["bot_account"] = getattr(bot_config, "qq_account", None)
if personality_config is not None:
bot_context["bot_identity"] = getattr(personality_config, "identity", None)
bot_context["bot_personality"] = getattr(personality_config, "personality_core", None)
bot_context["bot_personality_side"] = getattr(personality_config, "personality_side", None)
# 构建上下文
memory_context = {
"chat_id": chat_id,
"message_id": message_id,
"timestamp": datetime.now().timestamp(),
"message_type": "user_message",
**bot_context,
**(context or {}),
}
# 处理对话并构建记忆
memory_chunks = await enhanced_memory_manager.process_conversation(
conversation_text=message_content,
context=memory_context,
user_id=user_id,
timestamp=memory_context["timestamp"],
)
# 标记消息已处理
self.processed_messages.add(message_id)
# 限制处理历史大小
if len(self.processed_messages) > 1000:
# 移除最旧的500个记录
self.processed_messages = set(list(self.processed_messages)[-500:])
logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆")
return len(memory_chunks) > 0
except Exception as e:
logger.error(f"处理消息记忆失败: {e}")
return False
async def get_memory_for_response(
self,
query_text: str,
user_id: str,
chat_id: str,
limit: int = 5,
extra_context: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""
为回复获取相关记忆
Args:
query_text: 查询文本
user_id: 用户ID
chat_id: 聊天ID
limit: 返回记忆数量限制
Returns:
List[Dict]: 相关记忆列表
"""
if not self.enabled:
return []
try:
# 确保增强记忆管理器已初始化
if not enhanced_memory_manager.is_initialized:
await enhanced_memory_manager.initialize()
# 构建查询上下文
context = {
"chat_id": chat_id,
"query_intent": "response_generation",
"expected_memory_types": ["personal_fact", "event", "preference", "opinion"],
}
if extra_context:
context.update(extra_context)
# 获取相关记忆
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
query_text=query_text, user_id=user_id, context=context, limit=limit
)
# 转换为字典格式
results = []
for result in enhanced_results:
memory_dict = {
"content": result.content,
"type": result.memory_type,
"confidence": result.confidence,
"importance": result.importance,
"timestamp": result.timestamp,
"source": result.source,
"relevance": result.relevance_score,
"structure": result.structure,
}
results.append(memory_dict)
logger.debug(f"为回复查询到 {len(results)} 条相关记忆")
return results
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return []
async def cleanup_old_memories(self):
"""清理旧记忆"""
try:
if enhanced_memory_manager.is_initialized:
# 调用增强记忆系统的维护功能
await enhanced_memory_manager.enhanced_system.maintenance()
logger.debug("增强记忆系统维护完成")
except Exception as e:
logger.error(f"清理旧记忆失败: {e}")
def clear_processed_cache(self):
"""清除已处理消息的缓存"""
self.processed_messages.clear()
logger.debug("已清除消息处理缓存")
def enable(self):
"""启用记忆钩子"""
self.enabled = True
logger.info("增强记忆钩子已启用")
def disable(self):
"""禁用记忆钩子"""
self.enabled = False
logger.info("增强记忆钩子已禁用")
# 创建全局实例
enhanced_memory_hooks = EnhancedMemoryHooks()

View File

@@ -1,177 +0,0 @@
"""
增强记忆系统集成脚本
用于在现有系统中无缝集成增强记忆功能
"""
from typing import Any
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
from src.common.logger import get_logger
logger = get_logger(__name__)
async def process_user_message_memory(
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
) -> bool:
"""
处理用户消息并构建记忆
Args:
message_content: 消息内容
user_id: 用户ID
chat_id: 聊天ID
message_id: 消息ID
context: 额外的上下文信息
Returns:
bool: 是否成功构建记忆
"""
try:
success = await enhanced_memory_hooks.process_message_for_memory(
message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
)
if success:
logger.debug(f"成功为消息 {message_id} 构建记忆")
return success
except Exception as e:
logger.error(f"处理用户消息记忆失败: {e}")
return False
async def get_relevant_memories_for_response(
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
为回复获取相关记忆
Args:
query_text: 查询文本(通常是用户的当前消息)
user_id: 用户ID
chat_id: 聊天ID
limit: 返回记忆数量限制
extra_context: 额外上下文信息
Returns:
Dict: 包含记忆信息的字典
"""
try:
memories = await enhanced_memory_hooks.get_memory_for_response(
query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context
)
result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)}
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
return result
except Exception as e:
logger.error(f"获取回复记忆失败: {e}")
return {"has_memories": False, "memories": [], "memory_count": 0}
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
"""
格式化记忆信息用于Prompt
Args:
memories: 记忆信息字典
Returns:
str: 格式化后的记忆文本
"""
if not memories["has_memories"]:
return ""
memory_lines = ["以下是相关的记忆信息:"]
for memory in memories["memories"]:
content = memory["content"]
memory_type = memory["type"]
confidence = memory["confidence"]
importance = memory["importance"]
# 根据重要性添加不同的标记
importance_marker = "🔥" if importance >= 3 else "" if importance >= 2 else "📝"
confidence_marker = "" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭"
memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)"
memory_lines.append(memory_line)
return "\n".join(memory_lines)
async def cleanup_memory_system():
"""清理记忆系统"""
try:
await enhanced_memory_hooks.cleanup_old_memories()
logger.info("记忆系统清理完成")
except Exception as e:
logger.error(f"记忆系统清理失败: {e}")
def get_memory_system_status() -> dict[str, Any]:
"""
获取记忆系统状态
Returns:
Dict: 系统状态信息
"""
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
return {
"enabled": enhanced_memory_hooks.enabled,
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
"system_type": "enhanced_memory_system",
}
# 便捷函数
async def remember_message(
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
) -> bool:
"""
便捷的记忆构建函数
Args:
message: 要记住的消息
user_id: 用户ID
chat_id: 聊天ID
Returns:
bool: 是否成功
"""
import uuid
message_id = str(uuid.uuid4())
return await process_user_message_memory(
message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
)
async def recall_memories(
query: str,
user_id: str = "default_user",
chat_id: str = "default_chat",
limit: int = 5,
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
便捷的记忆检索函数
Args:
query: 查询文本
user_id: 用户ID
chat_id: 聊天ID
limit: 返回数量限制
Returns:
Dict: 记忆信息
"""
return await get_relevant_memories_for_response(
query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context
)

View File

@@ -1,356 +0,0 @@
"""
增强重排序器
实现文档设计的多维度评分模型
"""
import math
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
logger = get_logger(__name__)
class IntentType(Enum):
"""对话意图类型"""
FACT_QUERY = "fact_query" # 事实查询
EVENT_RECALL = "event_recall" # 事件回忆
PREFERENCE_CHECK = "preference_check" # 偏好检查
GENERAL_CHAT = "general_chat" # 一般对话
UNKNOWN = "unknown" # 未知意图
@dataclass
class ReRankingConfig:
"""重排序配置"""
# 权重配置 (w1 + w2 + w3 + w4 = 1.0)
semantic_weight: float = 0.5 # 语义相似度权重
recency_weight: float = 0.2 # 时效性权重
usage_freq_weight: float = 0.2 # 使用频率权重
type_match_weight: float = 0.1 # 类型匹配权重
# 时效性衰减参数
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
# 使用频率计算参数
freq_log_base: float = 2.0 # 对数底数
freq_max_score: float = 5.0 # 最大频率得分
# 类型匹配权重映射
type_match_weights: dict[str, dict[str, float]] = None
def __post_init__(self):
"""初始化类型匹配权重"""
if self.type_match_weights is None:
self.type_match_weights = {
IntentType.FACT_QUERY.value: {
MemoryType.PERSONAL_FACT.value: 1.0,
MemoryType.KNOWLEDGE.value: 0.8,
MemoryType.PREFERENCE.value: 0.5,
MemoryType.EVENT.value: 0.3,
"default": 0.3,
},
IntentType.EVENT_RECALL.value: {
MemoryType.EVENT.value: 1.0,
MemoryType.EXPERIENCE.value: 0.8,
MemoryType.EMOTION.value: 0.6,
MemoryType.PERSONAL_FACT.value: 0.5,
"default": 0.5,
},
IntentType.PREFERENCE_CHECK.value: {
MemoryType.PREFERENCE.value: 1.0,
MemoryType.OPINION.value: 0.8,
MemoryType.GOAL.value: 0.6,
MemoryType.PERSONAL_FACT.value: 0.4,
"default": 0.4,
},
IntentType.GENERAL_CHAT.value: {"default": 0.8},
IntentType.UNKNOWN.value: {"default": 0.8},
}
class IntentClassifier:
"""轻量级意图识别器"""
def __init__(self):
# 关键词模式匹配规则
self.patterns = {
IntentType.FACT_QUERY: [
# 中文模式
"我是",
"我的",
"我叫",
"我在",
"我住在",
"我的职业",
"我的工作",
"什么时候",
"在哪里",
"是什么",
"多少",
"几岁",
"年龄",
# 英文模式
"what is",
"where is",
"when is",
"how old",
"my name",
"i am",
"i live",
],
IntentType.EVENT_RECALL: [
# 中文模式
"记得",
"想起",
"还记得",
"那次",
"上次",
"之前",
"以前",
"曾经",
"发生过",
"经历",
"做过",
"去过",
"见过",
# 英文模式
"remember",
"recall",
"last time",
"before",
"previously",
"happened",
"experience",
],
IntentType.PREFERENCE_CHECK: [
# 中文模式
"喜欢",
"不喜欢",
"偏好",
"爱好",
"兴趣",
"讨厌",
"最爱",
"最喜欢",
"习惯",
"通常",
"一般",
"倾向于",
"更喜欢",
# 英文模式
"like",
"love",
"hate",
"prefer",
"favorite",
"usually",
"tend to",
"interest",
],
}
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
"""识别对话意图"""
if not query:
return IntentType.UNKNOWN
query_lower = query.lower()
# 统计各意图的匹配分数
intent_scores = dict.fromkeys(IntentType, 0)
for intent, patterns in self.patterns.items():
for pattern in patterns:
if pattern in query_lower:
intent_scores[intent] += 1
# 返回得分最高的意图
max_score = max(intent_scores.values())
if max_score == 0:
return IntentType.GENERAL_CHAT
for intent, score in intent_scores.items():
if score == max_score:
return intent
return IntentType.GENERAL_CHAT
class EnhancedReRanker:
"""增强重排序器 - 实现文档设计的多维度评分模型"""
def __init__(self, config: ReRankingConfig | None = None):
self.config = config or ReRankingConfig()
self.intent_classifier = IntentClassifier()
# 验证权重和为1.0
total_weight = (
self.config.semantic_weight
+ self.config.recency_weight
+ self.config.usage_freq_weight
+ self.config.type_match_weight
)
if abs(total_weight - 1.0) > 0.01:
logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化")
# 归一化权重
self.config.semantic_weight /= total_weight
self.config.recency_weight /= total_weight
self.config.usage_freq_weight /= total_weight
self.config.type_match_weight /= total_weight
def rerank_memories(
self,
query: str,
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
context: dict[str, Any],
limit: int = 10,
) -> list[tuple[str, MemoryChunk, float]]:
"""
对候选记忆进行重排序
Args:
query: 查询文本
candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)]
context: 上下文信息
limit: 返回数量限制
Returns:
重排序后的记忆列表 [(memory_id, memory, final_score)]
"""
if not candidate_memories:
return []
# 识别查询意图
intent = self.intent_classifier.classify_intent(query, context)
logger.debug(f"识别到查询意图: {intent.value}")
# 计算每个候选记忆的最终得分
scored_memories = []
current_time = time.time()
for memory_id, memory, vector_sim in candidate_memories:
try:
# 1. 语义相似度得分 (已归一化到[0,1])
semantic_score = self._normalize_similarity(vector_sim)
# 2. 时效性得分
recency_score = self._calculate_recency_score(memory, current_time)
# 3. 使用频率得分
usage_freq_score = self._calculate_usage_frequency_score(memory)
# 4. 类型匹配得分
type_match_score = self._calculate_type_match_score(memory, intent)
# 计算最终得分
final_score = (
self.config.semantic_weight * semantic_score
+ self.config.recency_weight * recency_score
+ self.config.usage_freq_weight * usage_freq_score
+ self.config.type_match_weight * type_match_score
)
scored_memories.append((memory_id, memory, final_score))
# 记录调试信息
logger.debug(
f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, "
f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, "
f"type={type_match_score:.3f}, final={final_score:.3f}"
)
except Exception as e:
logger.error(f"计算记忆 {memory_id} 得分时出错: {e}")
# 使用向量相似度作为后备得分
scored_memories.append((memory_id, memory, vector_sim))
# 按最终得分降序排序
scored_memories.sort(key=lambda x: x[2], reverse=True)
# 返回前N个结果
result = scored_memories[:limit]
highest_score = result[0][2] if result else 0.0
logger.info(
f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, "
f"意图={intent.value}, 最高分={highest_score:.3f}"
)
return result
def _normalize_similarity(self, raw_similarity: float) -> float:
"""归一化相似度到[0,1]区间"""
# 假设原始相似度已经在[-1,1]或[0,1]区间
if raw_similarity < 0:
return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1]
return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间
def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float:
"""
计算时效性得分
公式: Recency = 1 / (1 + decay_rate * days_old)
"""
last_accessed = memory.metadata.last_accessed or memory.metadata.created_at
days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数
if days_old < 0:
days_old = 0 # 处理时间异常
score = 1 / (1 + self.config.recency_decay_rate * days_old)
return min(1.0, max(0.0, score))
def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float:
"""
计算使用频率得分
公式: Usage_Freq = min(1.0, log2(access_count + 1) / max_score)
"""
access_count = memory.metadata.access_count
if access_count <= 0:
return 0.0
log_count = math.log2(access_count + 1)
score = log_count / self.config.freq_max_score
return min(1.0, max(0.0, score))
def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float:
"""计算类型匹配得分"""
memory_type = memory.memory_type.value
intent_value = intent.value
# 获取对应意图的类型权重映射
type_weights = self.config.type_match_weights.get(intent_value, {})
# 查找具体类型的权重,如果没有则使用默认权重
score = type_weights.get(memory_type, type_weights.get("default", 0.8))
return min(1.0, max(0.0, score))
# 创建默认的重排序器实例
default_reranker = EnhancedReRanker()
def rerank_candidate_memories(
query: str,
candidate_memories: list[tuple[str, MemoryChunk, float]],
context: dict[str, Any],
limit: int = 10,
config: ReRankingConfig | None = None,
) -> list[tuple[str, MemoryChunk, float]]:
"""
便捷函数:对候选记忆进行重排序
"""
if config:
reranker = EnhancedReRanker(config)
else:
reranker = default_reranker
return reranker.rerank_memories(query, candidate_memories, context, limit)

View File

@@ -1,245 +0,0 @@
"""
增强记忆系统集成层
现在只管理新的增强记忆系统,旧系统已被完全移除
"""
import asyncio
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
class IntegrationMode(Enum):
"""集成模式"""
REPLACE = "replace" # 完全替换现有记忆系统
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
@dataclass
class IntegrationConfig:
"""集成配置"""
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
enable_enhanced_memory: bool = True
memory_value_threshold: float = 0.6
fusion_threshold: float = 0.85
max_retrieval_results: int = 10
enable_learning: bool = True
class MemoryIntegrationLayer:
"""记忆系统集成层 - 现在只管理增强记忆系统"""
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
self.llm_model = llm_model
self.config = config or IntegrationConfig()
# 只初始化增强记忆系统
self.enhanced_memory: EnhancedMemorySystem | None = None
# 集成统计
self.integration_stats = {
"total_queries": 0,
"enhanced_queries": 0,
"memory_creations": 0,
"average_response_time": 0.0,
"success_rate": 0.0,
}
# 初始化锁
self._initialization_lock = asyncio.Lock()
self._initialized = False
async def initialize(self):
"""初始化集成层"""
if self._initialized:
return
async with self._initialization_lock:
if self._initialized:
return
logger.info("🚀 开始初始化增强记忆系统集成层...")
try:
# 初始化增强记忆系统
if self.config.enable_enhanced_memory:
await self._initialize_enhanced_memory()
self._initialized = True
logger.info("✅ 增强记忆系统集成层初始化完成")
except Exception as e:
logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True)
raise
async def _initialize_enhanced_memory(self):
"""初始化增强记忆系统"""
try:
logger.debug("初始化增强记忆系统...")
# 创建增强记忆系统配置
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
memory_config = MemorySystemConfig.from_global_config()
# 使用集成配置覆盖部分值
memory_config.memory_value_threshold = self.config.memory_value_threshold
memory_config.fusion_similarity_threshold = self.config.fusion_threshold
memory_config.final_recall_limit = self.config.max_retrieval_results
# 创建增强记忆系统
self.enhanced_memory = EnhancedMemorySystem(config=memory_config)
# 如果外部提供了LLM模型注入到系统中
if self.llm_model is not None:
self.enhanced_memory.llm_model = self.llm_model
# 初始化系统
await self.enhanced_memory.initialize()
logger.info("✅ 增强记忆系统初始化完成")
except Exception as e:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
raise
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
start_time = time.time()
self.integration_stats["total_queries"] += 1
self.integration_stats["enhanced_queries"] += 1
try:
payload_context = dict(context or {})
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
# 直接使用增强记忆系统处理
result = await self.enhanced_memory.process_conversation_memory(payload_context)
# 更新统计
processing_time = time.time() - start_time
self._update_response_stats(processing_time, result.get("success", False))
if result.get("success"):
created_count = len(result.get("created_memories", []))
self.integration_stats["memory_creations"] += created_count
logger.debug(f"对话处理完成,创建 {created_count} 条记忆")
return result
except Exception as e:
processing_time = time.time() - start_time
self._update_response_stats(processing_time, False)
logger.error(f"处理对话记忆失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self,
query: str,
user_id: str | None = None,
context: dict[str, Any] | None = None,
limit: int | None = None,
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.enhanced_memory:
return []
try:
limit = limit or self.config.max_retrieval_results
memories = await self.enhanced_memory.retrieve_relevant_memories(
query=query, user_id=None, context=context or {}, limit=limit
)
memory_count = len(memories)
logger.debug(f"检索到 {memory_count} 条相关记忆")
return memories
except Exception as e:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_system_status(self) -> dict[str, Any]:
"""获取系统状态"""
if not self._initialized:
return {"status": "not_initialized"}
try:
enhanced_status = {}
if self.enhanced_memory:
enhanced_status = await self.enhanced_memory.get_system_status()
return {
"status": "initialized",
"mode": self.config.mode.value,
"enhanced_memory": enhanced_status,
"integration_stats": self.integration_stats.copy(),
}
except Exception as e:
logger.error(f"获取系统状态失败: {e}", exc_info=True)
return {"status": "error", "error": str(e)}
def get_integration_stats(self) -> dict[str, Any]:
"""获取集成统计信息"""
return self.integration_stats.copy()
def _update_response_stats(self, processing_time: float, success: bool):
"""更新响应统计"""
total_queries = self.integration_stats["total_queries"]
if total_queries > 0:
# 更新平均响应时间
current_avg = self.integration_stats["average_response_time"]
new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries
self.integration_stats["average_response_time"] = new_avg
# 更新成功率
if success:
current_success_rate = self.integration_stats["success_rate"]
new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries
self.integration_stats["success_rate"] = new_success_rate
async def maintenance(self):
"""执行维护操作"""
if not self._initialized:
return
try:
logger.info("🔧 执行记忆系统集成层维护...")
if self.enhanced_memory:
await self.enhanced_memory.maintenance()
logger.info("✅ 记忆系统集成层维护完成")
except Exception as e:
logger.error(f"❌ 集成层维护失败: {e}", exc_info=True)
async def shutdown(self):
"""关闭集成层"""
if not self._initialized:
return
try:
logger.info("🔄 关闭记忆系统集成层...")
if self.enhanced_memory:
await self.enhanced_memory.shutdown()
self._initialized = False
logger.info("✅ 记忆系统集成层已关闭")
except Exception as e:
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)

View File

@@ -1,526 +0,0 @@
"""
记忆系统集成钩子
提供与现有MoFox Bot系统的无缝集成点
"""
import time
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.enhanced_memory_adapter import (
get_memory_context_for_prompt,
process_conversation_with_enhanced_memory,
retrieve_memories_with_enhanced_system,
)
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class HookResult:
"""钩子执行结果"""
success: bool
data: Any = None
error: str | None = None
processing_time: float = 0.0
class MemoryIntegrationHooks:
"""记忆系统集成钩子"""
def __init__(self):
self.hooks_registered = False
self.hook_stats = {
"message_processing_hooks": 0,
"memory_retrieval_hooks": 0,
"prompt_enhancement_hooks": 0,
"total_hook_executions": 0,
"average_hook_time": 0.0,
}
async def register_hooks(self):
"""注册所有集成钩子"""
if self.hooks_registered:
return
try:
logger.info("🔗 注册记忆系统集成钩子...")
# 注册消息处理钩子
await self._register_message_processing_hooks()
# 注册记忆检索钩子
await self._register_memory_retrieval_hooks()
# 注册提示词增强钩子
await self._register_prompt_enhancement_hooks()
# 注册系统维护钩子
await self._register_maintenance_hooks()
self.hooks_registered = True
logger.info("✅ 记忆系统集成钩子注册完成")
except Exception as e:
logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True)
async def _register_message_processing_hooks(self):
"""注册消息处理钩子"""
try:
# 钩子1: 在消息处理后创建记忆
await self._register_post_message_hook()
# 钩子2: 在聊天流保存时处理记忆
await self._register_chat_stream_hook()
logger.debug("消息处理钩子注册完成")
except Exception as e:
logger.error(f"注册消息处理钩子失败: {e}")
async def _register_memory_retrieval_hooks(self):
"""注册记忆检索钩子"""
try:
# 钩子1: 在生成回复前检索相关记忆
await self._register_pre_response_hook()
# 钩子2: 在知识库查询前增强上下文
await self._register_knowledge_query_hook()
logger.debug("记忆检索钩子注册完成")
except Exception as e:
logger.error(f"注册记忆检索钩子失败: {e}")
async def _register_prompt_enhancement_hooks(self):
"""注册提示词增强钩子"""
try:
# 钩子1: 增强提示词构建
await self._register_prompt_building_hook()
logger.debug("提示词增强钩子注册完成")
except Exception as e:
logger.error(f"注册提示词增强钩子失败: {e}")
async def _register_maintenance_hooks(self):
"""注册系统维护钩子"""
try:
# 钩子1: 系统维护时的记忆系统维护
await self._register_system_maintenance_hook()
logger.debug("系统维护钩子注册完成")
except Exception as e:
logger.error(f"注册系统维护钩子失败: {e}")
async def _register_post_message_hook(self):
"""注册消息后处理钩子"""
try:
# 这里需要根据实际的系统架构来注册钩子
# 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整
# 尝试注册到事件系统
try:
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
# 注册消息后处理事件
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
logger.debug("已注册到事件系统的消息处理钩子")
except ImportError:
logger.debug("事件系统不可用,跳过事件钩子注册")
# 尝试注册到消息管理器
try:
from src.chat.message_manager import message_manager
# 如果消息管理器支持钩子注册
if hasattr(message_manager, "register_post_process_hook"):
message_manager.register_post_process_hook(self._on_message_processed_hook)
logger.debug("已注册到消息管理器的处理钩子")
except ImportError:
logger.debug("消息管理器不可用,跳过消息管理器钩子注册")
except Exception as e:
logger.error(f"注册消息后处理钩子失败: {e}")
async def _register_chat_stream_hook(self):
"""注册聊天流钩子"""
try:
# 尝试注册到聊天流管理器
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if hasattr(chat_manager, "register_save_hook"):
chat_manager.register_save_hook(self._on_chat_stream_save_hook)
logger.debug("已注册到聊天流管理器的保存钩子")
except ImportError:
logger.debug("聊天流管理器不可用,跳过聊天流钩子注册")
except Exception as e:
logger.error(f"注册聊天流钩子失败: {e}")
async def _register_pre_response_hook(self):
"""注册回复前钩子"""
try:
# 尝试注册到回复生成器
try:
from src.chat.replyer.default_generator import default_generator
if hasattr(default_generator, "register_pre_generation_hook"):
default_generator.register_pre_generation_hook(self._on_pre_response_hook)
logger.debug("已注册到回复生成器的前置钩子")
except ImportError:
logger.debug("回复生成器不可用,跳过回复前钩子注册")
except Exception as e:
logger.error(f"注册回复前钩子失败: {e}")
async def _register_knowledge_query_hook(self):
"""注册知识库查询钩子"""
try:
# 尝试注册到知识库系统
try:
from src.chat.knowledge.knowledge_lib import knowledge_manager
if hasattr(knowledge_manager, "register_query_enhancer"):
knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook)
logger.debug("已注册到知识库的查询增强钩子")
except ImportError:
logger.debug("知识库系统不可用,跳过知识库钩子注册")
except Exception as e:
logger.error(f"注册知识库查询钩子失败: {e}")
async def _register_prompt_building_hook(self):
"""注册提示词构建钩子"""
try:
# 尝试注册到提示词系统
try:
from src.chat.utils.prompt import prompt_manager
if hasattr(prompt_manager, "register_enhancer"):
prompt_manager.register_enhancer(self._on_prompt_building_hook)
logger.debug("已注册到提示词管理器的增强钩子")
except ImportError:
logger.debug("提示词系统不可用,跳过提示词钩子注册")
except Exception as e:
logger.error(f"注册提示词构建钩子失败: {e}")
async def _register_system_maintenance_hook(self):
"""注册系统维护钩子"""
try:
# 尝试注册到系统维护器
try:
from src.manager.async_task_manager import async_task_manager
# 注册定期维护任务
async_task_manager.add_task(MemoryMaintenanceTask())
logger.debug("已注册到系统维护器的定期任务")
except ImportError:
logger.debug("异步任务管理器不可用,跳过系统维护钩子注册")
except Exception as e:
logger.error(f"注册系统维护钩子失败: {e}")
# 钩子处理器方法
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
"""事件系统的消息处理处理器"""
return await self._on_message_processed_hook(event_data)
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
"""消息后处理钩子"""
start_time = time.time()
try:
self.hook_stats["message_processing_hooks"] += 1
# 提取必要的信息
message_info = message_data.get("message_info", {})
user_info = message_info.get("user_info", {})
conversation_text = message_data.get("processed_plain_text", "")
if not conversation_text:
return HookResult(success=True, data="No conversation text")
user_id = str(user_info.get("user_id", "unknown"))
context = {
"chat_id": message_data.get("chat_id"),
"message_type": message_data.get("message_type", "normal"),
"platform": message_info.get("platform", "unknown"),
"interest_value": message_data.get("interest_value", 0.0),
"keywords": message_data.get("key_words", []),
"timestamp": message_data.get("time", time.time()),
}
# 使用增强记忆系统处理对话
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
if result["success"]:
logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
"""聊天流保存钩子"""
start_time = time.time()
try:
self.hook_stats["message_processing_hooks"] += 1
# 从聊天流数据中提取对话信息
stream_context = chat_stream_data.get("stream_context", {})
user_id = stream_context.get("user_id", "unknown")
messages = stream_context.get("messages", [])
if not messages:
return HookResult(success=True, data="No messages to process")
# 构建对话文本
conversation_parts = []
for msg in messages[-10:]: # 只处理最近10条消息
text = msg.get("processed_plain_text", "")
if text:
conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}")
conversation_text = "\n".join(conversation_parts)
if not conversation_text:
return HookResult(success=True, data="No conversation text")
context = {
"chat_id": chat_stream_data.get("chat_id"),
"stream_id": chat_stream_data.get("stream_id"),
"platform": chat_stream_data.get("platform", "unknown"),
"message_count": len(messages),
"timestamp": time.time(),
}
# 使用增强记忆系统处理对话
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
if result["success"]:
logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆")
return HookResult(success=True, data=result, processing_time=processing_time)
else:
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
"""回复前钩子"""
start_time = time.time()
try:
self.hook_stats["memory_retrieval_hooks"] += 1
# 提取查询信息
query = response_data.get("query", "")
user_id = response_data.get("user_id", "unknown")
context = response_data.get("context", {})
if not query:
return HookResult(success=True, data="No query provided")
# 检索相关记忆
memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 将记忆添加到响应数据中
response_data["enhanced_memories"] = memories
response_data["enhanced_memory_context"] = await get_memory_context_for_prompt(
query, user_id, context, max_memories=5
)
logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆")
return HookResult(success=True, data=memories, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
"""知识库查询钩子"""
start_time = time.time()
try:
self.hook_stats["memory_retrieval_hooks"] += 1
query = query_data.get("query", "")
user_id = query_data.get("user_id", "unknown")
context = query_data.get("context", {})
if not query:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文并增强查询
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 将记忆上下文添加到查询数据中
query_data["enhanced_memory_context"] = memory_context
logger.debug("知识库查询钩子执行成功")
return HookResult(success=True, data=memory_context, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
"""提示词构建钩子"""
start_time = time.time()
try:
self.hook_stats["prompt_enhancement_hooks"] += 1
query = prompt_data.get("query", "")
user_id = prompt_data.get("user_id", "unknown")
context = prompt_data.get("context", {})
base_prompt = prompt_data.get("base_prompt", "")
if not query:
return HookResult(success=True, data="No query provided")
# 获取记忆上下文
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
# 构建增强的提示词
enhanced_prompt = base_prompt
if memory_context:
enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n"
# 将增强的提示词添加到数据中
prompt_data["enhanced_prompt"] = enhanced_prompt
prompt_data["memory_context"] = memory_context
logger.debug("提示词构建钩子执行成功")
return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
def _update_hook_stats(self, processing_time: float):
"""更新钩子统计"""
self.hook_stats["total_hook_executions"] += 1
total_executions = self.hook_stats["total_hook_executions"]
if total_executions > 0:
current_avg = self.hook_stats["average_hook_time"]
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
self.hook_stats["average_hook_time"] = new_avg
def get_hook_stats(self) -> dict[str, Any]:
"""获取钩子统计信息"""
return self.hook_stats.copy()
class MemoryMaintenanceTask:
"""记忆系统维护任务"""
def __init__(self):
self.task_name = "enhanced_memory_maintenance"
self.interval = 3600 # 1小时执行一次
async def execute(self):
"""执行维护任务"""
try:
logger.info("🔧 执行增强记忆系统维护任务...")
# 获取适配器实例
try:
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
if _enhanced_memory_adapter:
await _enhanced_memory_adapter.maintenance()
logger.info("✅ 增强记忆系统维护任务完成")
else:
logger.debug("增强记忆适配器未初始化,跳过维护")
except Exception as e:
logger.error(f"增强记忆系统维护失败: {e}")
except Exception as e:
logger.error(f"执行维护任务时发生异常: {e}", exc_info=True)
def get_interval(self) -> int:
"""获取执行间隔"""
return self.interval
def get_task_name(self) -> str:
"""获取任务名称"""
return self.task_name
# 全局钩子实例
_memory_hooks: MemoryIntegrationHooks | None = None
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
"""获取全局记忆集成钩子实例"""
global _memory_hooks
if _memory_hooks is None:
_memory_hooks = MemoryIntegrationHooks()
await _memory_hooks.register_hooks()
return _memory_hooks
async def initialize_memory_integration_hooks():
"""初始化记忆集成钩子"""
try:
logger.info("🚀 初始化记忆集成钩子...")
hooks = await get_memory_integration_hooks()
logger.info("✅ 记忆集成钩子初始化完成")
return hooks
except Exception as e:
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
return None

File diff suppressed because it is too large Load Diff

View File

@@ -1,875 +0,0 @@
"""
向量数据库存储接口
为记忆系统提供高效的向量存储和语义搜索能力
"""
import asyncio
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import orjson
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.config_helpers import resolve_embedding_dimension
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
# 尝试导入FAISS如果不可用则使用简单替代
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger.warning("FAISS not available, using simple vector storage")
@dataclass
class VectorStorageConfig:
"""向量存储配置"""
dimension: int = 1024
similarity_threshold: float = 0.8
index_type: str = "flat" # flat, ivf, hnsw
max_index_size: int = 100000
storage_path: str = "data/memory_vectors"
auto_save_interval: int = 10 # 每N次操作自动保存
enable_compression: bool = True
class VectorStorageManager:
"""向量存储管理器"""
def __init__(self, config: VectorStorageConfig | None = None):
self.config = config or VectorStorageConfig()
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
if resolved_dimension and resolved_dimension != self.config.dimension:
logger.info(
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
resolved_dimension,
self.config.dimension,
)
self.config.dimension = resolved_dimension
self.storage_path = Path(self.config.storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
# 向量索引
self.vector_index = None
self.memory_id_to_index = {} # memory_id -> vector index
self.index_to_memory_id = {} # vector index -> memory_id
# 内存缓存
self.memory_cache: dict[str, MemoryChunk] = {}
self.vector_cache: dict[str, list[float]] = {}
# 统计信息
self.storage_stats = {
"total_vectors": 0,
"index_build_time": 0.0,
"average_search_time": 0.0,
"cache_hit_rate": 0.0,
"total_searches": 0,
"cache_hits": 0,
}
# 线程锁
self._lock = threading.RLock()
self._operation_count = 0
# 初始化索引
self._initialize_index()
# 嵌入模型
self.embedding_model: LLMRequest = None
def _initialize_index(self):
"""初始化向量索引"""
try:
if FAISS_AVAILABLE:
if self.config.index_type == "flat":
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
elif self.config.index_type == "ivf":
quantizer = faiss.IndexFlatIP(self.config.dimension)
nlist = min(100, max(1, self.config.max_index_size // 1000))
self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist)
elif self.config.index_type == "hnsw":
self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32)
self.vector_index.hnsw.efConstruction = 40
else:
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
else:
# 简单的向量存储实现
self.vector_index = SimpleVectorIndex(self.config.dimension)
logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}")
except Exception as e:
logger.error(f"❌ 向量索引初始化失败: {e}")
# 回退到简单实现
self.vector_index = SimpleVectorIndex(self.config.dimension)
async def initialize_embedding_model(self):
"""初始化嵌入模型"""
if self.embedding_model is None:
self.embedding_model = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="memory.embedding"
)
logger.info("✅ 嵌入模型初始化完成")
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
"""生成查询向量,用于记忆召回"""
if not query_text:
logger.warning("查询文本为空,无法生成向量")
return None
try:
await self.initialize_embedding_model()
logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'")
embedding, _ = await self.embedding_model.get_embedding(query_text)
if not embedding:
logger.warning("嵌入模型返回空向量")
return None
logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}")
if len(embedding) != self.config.dimension:
logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding))
return None
normalized_vector = self._normalize_vector(embedding)
logger.debug(f"查询向量生成成功,向量范围: [{min(normalized_vector):.4f}, {max(normalized_vector):.4f}]")
return normalized_vector
except Exception as exc:
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
return None
async def store_memories(self, memories: list[MemoryChunk]):
"""存储记忆向量"""
if not memories:
return
start_time = time.time()
try:
# 确保嵌入模型已初始化
await self.initialize_embedding_model()
# 批量获取嵌入向量
memory_texts = []
for memory in memories:
# 预先缓存记忆,确保后续流程可访问
self.memory_cache[memory.memory_id] = memory
if memory.embedding is None:
# 如果没有嵌入向量,需要生成
text = self._prepare_embedding_text(memory)
memory_texts.append((memory.memory_id, text))
else:
# 已有嵌入向量,直接使用
await self._add_single_memory(memory, memory.embedding)
# 批量生成缺失的嵌入向量
if memory_texts:
await self._batch_generate_and_store_embeddings(memory_texts)
# 自动保存检查
self._operation_count += len(memories)
if self._operation_count >= self.config.auto_save_interval:
await self.save_storage()
self._operation_count = 0
storage_time = time.time() - start_time
logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}")
except Exception as e:
logger.error(f"❌ 向量存储失败: {e}", exc_info=True)
def _prepare_embedding_text(self, memory: MemoryChunk) -> str:
"""准备用于嵌入的文本,仅使用自然语言展示内容"""
display_text = (memory.display or "").strip()
if display_text:
return display_text
fallback_text = (memory.text_content or "").strip()
if fallback_text:
return fallback_text
subjects = "".join(s.strip() for s in memory.subjects if s and isinstance(s, str))
predicate = (memory.content.predicate or "").strip()
obj = memory.content.object
if isinstance(obj, dict):
object_parts = []
for key, value in obj.items():
if value is None:
continue
if isinstance(value, (list, tuple)):
preview = "".join(str(item) for item in value[:3])
object_parts.append(f"{key}:{preview}")
else:
object_parts.append(f"{key}:{value}")
object_text = ", ".join(object_parts)
else:
object_text = str(obj or "").strip()
composite_parts = [part for part in [subjects, predicate, object_text] if part]
if composite_parts:
return " ".join(composite_parts)
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
return memory.memory_id
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
"""批量生成和存储嵌入向量"""
if not memory_texts:
return
try:
texts = [text for _, text in memory_texts]
memory_ids = [memory_id for memory_id, _ in memory_texts]
# 批量生成嵌入向量
embeddings = await self._batch_generate_embeddings(memory_ids, texts)
# 存储向量和记忆
for memory_id, embedding in embeddings.items():
if embedding and len(embedding) == self.config.dimension:
memory = self.memory_cache.get(memory_id)
if memory:
await self._add_single_memory(memory, embedding)
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
"""批量生成嵌入向量"""
if not texts:
return {}
results: dict[str, list[float]] = {}
try:
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
async def generate_embedding(memory_id: str, text: str) -> None:
async with semaphore:
try:
embedding, _ = await self.embedding_model.get_embedding(text)
if embedding and len(embedding) == self.config.dimension:
results[memory_id] = embedding
else:
logger.warning(
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
self.config.dimension,
len(embedding) if embedding else 0,
memory_id,
)
results[memory_id] = []
except Exception as exc:
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
results[memory_id] = []
tasks = [
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
]
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
for memory_id in memory_ids:
results.setdefault(memory_id, [])
return results
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
"""添加单个记忆到向量存储"""
with self._lock:
try:
# 规范化向量
if embedding:
embedding = self._normalize_vector(embedding)
# 添加到缓存
self.memory_cache[memory.memory_id] = memory
self.vector_cache[memory.memory_id] = embedding
# 更新记忆的嵌入向量
memory.set_embedding(embedding)
# 添加到向量索引
if hasattr(self.vector_index, "add"):
# FAISS索引
if isinstance(embedding, np.ndarray):
vector_array = embedding.reshape(1, -1).astype("float32")
else:
vector_array = np.array([embedding], dtype="float32")
# 特殊处理IVF索引
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
# IVF索引需要先训练
logger.debug("训练IVF索引...")
self.vector_index.train(vector_array)
self.vector_index.add(vector_array)
index_id = self.vector_index.ntotal - 1
else:
# 简单索引
index_id = self.vector_index.add_vector(embedding)
# 更新映射关系
self.memory_id_to_index[memory.memory_id] = index_id
self.index_to_memory_id[index_id] = memory.memory_id
# 更新统计
self.storage_stats["total_vectors"] += 1
except Exception as e:
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
def _normalize_vector(self, vector: list[float]) -> list[float]:
"""L2归一化向量"""
if not vector:
return vector
try:
vector_array = np.array(vector, dtype=np.float32)
norm = np.linalg.norm(vector_array)
if norm == 0:
return vector
normalized = vector_array / norm
return normalized.tolist()
except Exception as e:
logger.warning(f"向量归一化失败: {e}")
return vector
async def search_similar_memories(
self,
query_vector: list[float] | None = None,
*,
query_text: str | None = None,
limit: int = 10,
scope_id: str | None = None,
) -> list[tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
try:
logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}")
if query_vector is None:
if not query_text:
logger.warning("查询向量和查询文本都为空")
return []
query_vector = await self.generate_query_embedding(query_text)
if not query_vector:
logger.warning("查询向量生成失败")
return []
scope_filter: str | None = None
if isinstance(scope_id, str):
normalized_scope = scope_id.strip().lower()
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
scope_filter = scope_id
elif scope_id:
scope_filter = str(scope_id)
# 规范化查询向量
query_vector = self._normalize_vector(query_vector)
logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}")
# 检查向量索引状态
if not self.vector_index:
logger.error("向量索引未初始化")
return []
total_vectors = 0
if hasattr(self.vector_index, "ntotal"):
total_vectors = self.vector_index.ntotal
elif hasattr(self.vector_index, "vectors"):
total_vectors = len(self.vector_index.vectors)
logger.debug(f"向量索引中实际向量数: {total_vectors}")
if total_vectors == 0:
logger.warning("向量索引为空,无法执行搜索")
return []
# 执行向量搜索
with self._lock:
if hasattr(self.vector_index, "search"):
# FAISS索引
if isinstance(query_vector, np.ndarray):
query_array = query_vector.reshape(1, -1).astype("float32")
else:
query_array = np.array([query_vector], dtype="float32")
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
# 设置IVF搜索参数
nprobe = min(self.vector_index.nlist, 10)
self.vector_index.nprobe = nprobe
logger.debug(f"IVF搜索参数: nprobe={nprobe}")
search_limit = min(limit, total_vectors)
logger.debug(f"执行FAISS搜索搜索限制: {search_limit}")
distances, indices = self.vector_index.search(query_array, search_limit)
distances = distances.flatten().tolist()
indices = indices.flatten().tolist()
logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引")
else:
# 简单索引
logger.debug("使用简单向量索引执行搜索")
results = self.vector_index.search(query_vector, limit)
distances = [score for _, score in results]
indices = [idx for idx, _ in results]
logger.debug(f"简单索引搜索结果: {len(results)} 个结果")
# 处理搜索结果
results = []
valid_results = 0
invalid_indices = 0
filtered_by_scope = 0
for distance, index in zip(distances, indices, strict=False):
if index == -1: # FAISS的无效索引标记
invalid_indices += 1
continue
memory_id = self.index_to_memory_id.get(index)
if not memory_id:
logger.debug(f"索引 {index} 没有对应的记忆ID")
invalid_indices += 1
continue
if scope_filter:
memory = self.memory_cache.get(memory_id)
if memory and str(memory.user_id) != scope_filter:
filtered_by_scope += 1
continue
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
results.append((memory_id, similarity))
valid_results += 1
logger.debug(
f"搜索结果处理: 总距离={len(distances)}, 有效结果={valid_results}, "
f"无效索引={invalid_indices}, 作用域过滤={filtered_by_scope}"
)
# 更新统计
search_time = time.time() - start_time
self.storage_stats["total_searches"] += 1
self.storage_stats["average_search_time"] = (
self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time
) / self.storage_stats["total_searches"]
final_results = results[:limit]
logger.info(
f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' "
f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果"
)
return final_results
except Exception as e:
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
return []
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
"""根据ID获取记忆"""
# 先检查缓存
if memory_id in self.memory_cache:
self.storage_stats["cache_hits"] += 1
return self.memory_cache[memory_id]
self.storage_stats["total_searches"] += 1
return None
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
"""更新记忆的嵌入向量"""
with self._lock:
try:
if memory_id not in self.memory_id_to_index:
logger.warning(f"记忆 {memory_id} 不存在于向量索引中")
return
# 获取旧索引
old_index = self.memory_id_to_index[memory_id]
# 删除旧向量(如果支持)
if hasattr(self.vector_index, "remove_ids"):
try:
self.vector_index.remove_ids(np.array([old_index]))
except:
logger.warning("无法删除旧向量,将直接添加新向量")
# 规范化新向量
new_embedding = self._normalize_vector(new_embedding)
# 添加新向量
if hasattr(self.vector_index, "add"):
if isinstance(new_embedding, np.ndarray):
vector_array = new_embedding.reshape(1, -1).astype("float32")
else:
vector_array = np.array([new_embedding], dtype="float32")
self.vector_index.add(vector_array)
new_index = self.vector_index.ntotal - 1
else:
new_index = self.vector_index.add_vector(new_embedding)
# 更新映射关系
self.memory_id_to_index[memory_id] = new_index
self.index_to_memory_id[new_index] = memory_id
# 更新缓存
self.vector_cache[memory_id] = new_embedding
# 更新记忆对象
memory = self.memory_cache.get(memory_id)
if memory:
memory.set_embedding(new_embedding)
logger.debug(f"更新记忆 {memory_id} 的嵌入向量")
except Exception as e:
logger.error(f"❌ 更新记忆嵌入向量失败: {e}")
async def delete_memory(self, memory_id: str):
"""删除记忆"""
with self._lock:
try:
if memory_id not in self.memory_id_to_index:
return
# 获取索引
index = self.memory_id_to_index[memory_id]
# 从向量索引中删除(如果支持)
if hasattr(self.vector_index, "remove_ids"):
try:
self.vector_index.remove_ids(np.array([index]))
except:
logger.warning("无法从向量索引中删除,仅从缓存中移除")
# 删除映射关系
del self.memory_id_to_index[memory_id]
if index in self.index_to_memory_id:
del self.index_to_memory_id[index]
# 从缓存中删除
self.memory_cache.pop(memory_id, None)
self.vector_cache.pop(memory_id, None)
# 更新统计
self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1)
logger.debug(f"删除记忆 {memory_id}")
except Exception as e:
logger.error(f"❌ 删除记忆失败: {e}")
async def save_storage(self):
"""保存向量存储到文件"""
try:
logger.info("正在保存向量存储...")
# 保存记忆缓存
cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()}
cache_file = self.storage_path / "memory_cache.json"
with open(cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
with open(vector_cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存映射关系
mapping_file = self.storage_path / "id_mapping.json"
mapping_data = {
"memory_id_to_index": {
str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items()
},
"index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()},
}
with open(mapping_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
# 保存FAISS索引如果可用
if FAISS_AVAILABLE and hasattr(self.vector_index, "save"):
index_file = self.storage_path / "vector_index.faiss"
faiss.write_index(self.vector_index, str(index_file))
# 保存统计信息
stats_file = self.storage_path / "storage_stats.json"
with open(stats_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
logger.info("✅ 向量存储保存完成")
except Exception as e:
logger.error(f"❌ 保存向量存储失败: {e}")
async def load_storage(self):
"""从文件加载向量存储"""
try:
logger.info("正在加载向量存储...")
# 加载记忆缓存
cache_file = self.storage_path / "memory_cache.json"
if cache_file.exists():
with open(cache_file, encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
self.memory_cache = {
memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items()
}
# 加载向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
if vector_cache_file.exists():
with open(vector_cache_file, encoding="utf-8") as f:
self.vector_cache = orjson.loads(f.read())
# 加载映射关系
mapping_file = self.storage_path / "id_mapping.json"
if mapping_file.exists():
with open(mapping_file, encoding="utf-8") as f:
mapping_data = orjson.loads(f.read())
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items()
}
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()}
# 加载FAISS索引如果可用
index_loaded = False
if FAISS_AVAILABLE:
index_file = self.storage_path / "vector_index.faiss"
if index_file.exists():
try:
loaded_index = faiss.read_index(str(index_file))
# 如果索引类型匹配,则替换
if type(loaded_index) == type(self.vector_index):
self.vector_index = loaded_index
index_loaded = True
logger.info("✅ FAISS索引文件加载完成")
else:
logger.warning("索引类型不匹配,重新构建索引")
except Exception as e:
logger.warning(f"加载FAISS索引失败: {e},重新构建")
else:
logger.info("FAISS索引文件不存在将重新构建")
# 如果索引没有成功加载且有向量数据,则重建索引
if not index_loaded and self.vector_cache:
logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引")
await self._rebuild_index()
# 加载统计信息
stats_file = self.storage_path / "storage_stats.json"
if stats_file.exists():
with open(stats_file, encoding="utf-8") as f:
self.storage_stats = orjson.loads(f.read())
# 更新向量计数
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量")
except Exception as e:
logger.error(f"❌ 加载向量存储失败: {e}")
async def _rebuild_index(self):
"""重建向量索引"""
try:
logger.info(f"正在重建向量索引...向量数量: {len(self.vector_cache)}")
# 重新初始化索引
self._initialize_index()
# 清空映射关系
self.memory_id_to_index.clear()
self.index_to_memory_id.clear()
if not self.vector_cache:
logger.warning("没有向量缓存数据,跳过重建")
return
# 准备向量数据
memory_ids = []
vectors = []
for memory_id, embedding in self.vector_cache.items():
if embedding and len(embedding) == self.config.dimension:
memory_ids.append(memory_id)
vectors.append(self._normalize_vector(embedding))
else:
logger.debug(f"跳过无效向量: {memory_id}, 维度: {len(embedding) if embedding else 0}")
if not vectors:
logger.warning("没有有效的向量数据")
return
logger.info(f"准备重建 {len(vectors)} 个向量到索引")
# 批量添加向量到FAISS索引
if hasattr(self.vector_index, "add"):
# FAISS索引
vector_array = np.array(vectors, dtype="float32")
# 特殊处理IVF索引
if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"):
logger.info("训练IVF索引...")
self.vector_index.train(vector_array)
# 添加向量
self.vector_index.add(vector_array)
# 重建映射关系
for i, memory_id in enumerate(memory_ids):
self.memory_id_to_index[memory_id] = i
self.index_to_memory_id[i] = memory_id
else:
# 简单索引
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)):
index_id = self.vector_index.add_vector(vector)
self.memory_id_to_index[memory_id] = index_id
self.index_to_memory_id[index_id] = memory_id
# 更新统计
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index))
logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}")
except Exception as e:
logger.error(f"❌ 重建向量索引失败: {e}", exc_info=True)
async def optimize_storage(self):
"""优化存储"""
try:
logger.info("开始向量存储优化...")
# 清理无效引用
self._cleanup_invalid_references()
# 重新构建索引(如果碎片化严重)
if self.storage_stats["total_vectors"] > 1000:
await self._rebuild_index()
# 更新缓存命中率
if self.storage_stats["total_searches"] > 0:
self.storage_stats["cache_hit_rate"] = (
self.storage_stats["cache_hits"] / self.storage_stats["total_searches"]
)
logger.info("✅ 向量存储优化完成")
except Exception as e:
logger.error(f"❌ 向量存储优化失败: {e}")
def _cleanup_invalid_references(self):
"""清理无效引用"""
with self._lock:
# 清理无效的memory_id到index的映射
valid_memory_ids = set(self.memory_cache.keys())
invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids
for memory_id in invalid_memory_ids:
index = self.memory_id_to_index[memory_id]
del self.memory_id_to_index[memory_id]
if index in self.index_to_memory_id:
del self.index_to_memory_id[index]
if invalid_memory_ids:
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
def get_storage_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
stats = self.storage_stats.copy()
if stats["total_searches"] > 0:
stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"]
else:
stats["cache_hit_rate"] = 0.0
return stats
class SimpleVectorIndex:
"""简单的向量索引实现当FAISS不可用时的替代方案"""
def __init__(self, dimension: int):
self.dimension = dimension
self.vectors: list[list[float]] = []
self.vector_ids: list[int] = []
self.next_id = 0
def add_vector(self, vector: list[float]) -> int:
"""添加向量"""
if len(vector) != self.dimension:
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
vector_id = self.next_id
self.vectors.append(vector.copy())
self.vector_ids.append(vector_id)
self.next_id += 1
return vector_id
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
"""搜索相似向量"""
if len(query_vector) != self.dimension:
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
results = []
for i, vector in enumerate(self.vectors):
similarity = self._calculate_cosine_similarity(query_vector, vector)
results.append((self.vector_ids[i], similarity))
# 按相似度排序
results.sort(key=lambda x: x[1], reverse=True)
return results[:limit]
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
"""计算余弦相似度"""
try:
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
norm1 = sum(x * x for x in v1) ** 0.5
norm2 = sum(x * x for x in v2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
except Exception:
return 0.0
@property
def ntotal(self) -> int:
"""向量总数"""
return len(self.vectors)

View File

@@ -0,0 +1,731 @@
# -*- coding: utf-8 -*-
"""
海马体双峰分布采样器
基于旧版海马体的采样策略,适配新版记忆系统
实现低消耗、高效率的记忆采样模式
"""
import asyncio
import random
import time
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Dict, Any
from dataclasses import dataclass
import numpy as np
import orjson
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
build_readable_messages,
get_raw_msg_by_timestamp_with_chat,
)
from src.chat.utils.utils import translate_timestamp_to_human_readable
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@dataclass
class HippocampusSampleConfig:
"""海马体采样配置"""
# 双峰分布参数
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
recent_weight: float = 0.7 # 近期分布权重
distant_mean_hours: float = 48.0 # 远期分布均值(小时)
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
distant_weight: float = 0.3 # 远期分布权重
# 采样参数
total_samples: int = 50 # 总采样数
sample_interval: int = 1800 # 采样间隔(秒)
max_sample_length: int = 30 # 每次采样的最大消息数量
batch_size: int = 5 # 批处理大小
@classmethod
def from_global_config(cls) -> 'HippocampusSampleConfig':
"""从全局配置创建海马体采样配置"""
config = global_config.memory.hippocampus_distribution_config
return cls(
recent_mean_hours=config[0],
recent_std_hours=config[1],
recent_weight=config[2],
distant_mean_hours=config[3],
distant_std_hours=config[4],
distant_weight=config[5],
total_samples=global_config.memory.hippocampus_sample_size,
sample_interval=global_config.memory.hippocampus_sample_interval,
max_sample_length=global_config.memory.hippocampus_batch_size,
batch_size=global_config.memory.hippocampus_batch_size,
)
class HippocampusSampler:
"""海马体双峰分布采样器"""
def __init__(self, memory_system=None):
self.memory_system = memory_system
self.config = HippocampusSampleConfig.from_global_config()
self.last_sample_time = 0
self.is_running = False
# 记忆构建模型
self.memory_builder_model: Optional[LLMRequest] = None
# 统计信息
self.sample_count = 0
self.success_count = 0
self.last_sample_results: List[Dict[str, Any]] = []
async def initialize(self):
"""初始化采样器"""
try:
# 初始化LLM模型
from src.config.config import model_config
task_config = getattr(model_config.model_task_config, "utils", None)
if task_config:
self.memory_builder_model = LLMRequest(
model_set=task_config,
request_type="memory.hippocampus_build"
)
asyncio.create_task(self.start_background_sampling())
logger.info("✅ 海马体采样器初始化成功")
else:
raise RuntimeError("未找到记忆构建模型配置")
except Exception as e:
logger.error(f"❌ 海马体采样器初始化失败: {e}")
raise
def generate_time_samples(self) -> List[datetime]:
"""生成双峰分布的时间采样点"""
# 计算每个分布的样本数
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
distant_samples = max(1, self.config.total_samples - recent_samples)
# 生成两个正态分布的小时偏移
recent_offsets = np.random.normal(
loc=self.config.recent_mean_hours,
scale=self.config.recent_std_hours,
size=recent_samples
)
distant_offsets = np.random.normal(
loc=self.config.distant_mean_hours,
scale=self.config.distant_std_hours,
size=distant_samples
)
# 合并两个分布的偏移
all_offsets = np.concatenate([recent_offsets, distant_offsets])
# 转换为时间戳(使用绝对值确保时间点在过去)
base_time = datetime.now()
timestamps = [
base_time - timedelta(hours=abs(offset))
for offset in all_offsets
]
# 按时间排序(从最早到最近)
return sorted(timestamps)
async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]:
"""收集指定时间戳附近的消息样本"""
try:
# 随机时间窗口5-30分钟
time_window_seconds = random.randint(300, 1800)
# 尝试3次获取消息
for attempt in range(3):
timestamp_start = target_timestamp
timestamp_end = target_timestamp + time_window_seconds
# 获取单条消息作为锚点
anchor_messages = await get_raw_msg_by_timestamp(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=1,
limit_mode="earliest",
)
if not anchor_messages:
target_timestamp -= 120 # 向前调整2分钟
continue
anchor_message = anchor_messages[0]
chat_id = anchor_message.get("chat_id")
if not chat_id:
continue
# 获取同聊天的多条消息
messages = await get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=self.config.max_sample_length,
limit_mode="earliest",
chat_id=chat_id,
)
if messages and len(messages) >= 2: # 至少需要2条消息
# 过滤掉已经记忆过的消息
filtered_messages = [
msg for msg in messages
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
]
if filtered_messages:
logger.debug(f"成功收集 {len(filtered_messages)} 条消息样本")
return filtered_messages
target_timestamp -= 120 # 向前调整再试
logger.debug(f"时间戳 {target_timestamp} 附近未找到有效消息样本")
return None
except Exception as e:
logger.error(f"收集消息样本失败: {e}")
return None
async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]:
"""从消息样本构建记忆"""
if not messages or not self.memory_system or not self.memory_builder_model:
return None
try:
# 构建可读消息文本
readable_text = await build_readable_messages(
messages,
merge_messages=True,
timestamp_mode="normal_no_YMD",
replace_bot_name=False,
)
if not readable_text:
logger.warning("无法从消息样本生成可读文本")
return None
# 添加当前日期信息
current_date = f"当前日期: {datetime.now().isoformat()}"
input_text = f"{current_date}\n{readable_text}"
logger.debug(f"开始构建记忆,文本长度: {len(input_text)}")
# 构建上下文
context = {
"user_id": "hippocampus_sampler",
"timestamp": time.time(),
"source": "hippocampus_sampling",
"message_count": len(messages),
"sample_mode": "bimodal_distribution",
"is_hippocampus_sample": True, # 标识为海马体样本
"bypass_value_threshold": True, # 绕过价值阈值检查
"hippocampus_sample_time": target_timestamp, # 记录样本时间
}
# 使用记忆系统构建记忆(绕过构建间隔检查)
memories = await self.memory_system.build_memory_from_conversation(
conversation_text=input_text,
context=context,
timestamp=time.time(),
bypass_interval=True # 海马体采样器绕过构建间隔限制
)
if memories:
memory_count = len(memories)
self.success_count += 1
# 记录采样结果
result = {
"timestamp": time.time(),
"memory_count": memory_count,
"message_count": len(messages),
"text_preview": readable_text[:100] + "..." if len(readable_text) > 100 else readable_text,
"memory_types": [m.memory_type.value for m in memories],
}
self.last_sample_results.append(result)
# 限制结果历史长度
if len(self.last_sample_results) > 10:
self.last_sample_results.pop(0)
logger.info(f"✅ 海马体采样成功构建 {memory_count} 条记忆")
return f"构建{memory_count}条记忆"
else:
logger.debug("海马体采样未生成有效记忆")
return None
except Exception as e:
logger.error(f"海马体采样构建记忆失败: {e}")
return None
async def perform_sampling_cycle(self) -> Dict[str, Any]:
"""执行一次完整的采样周期(优化版:批量融合构建)"""
if not self.should_sample():
return {"status": "skipped", "reason": "interval_not_met"}
start_time = time.time()
self.sample_count += 1
try:
# 生成时间采样点
time_samples = self.generate_time_samples()
logger.debug(f"生成 {len(time_samples)} 个时间采样点")
# 记录时间采样点(调试用)
readable_timestamps = [
translate_timestamp_to_human_readable(int(ts.timestamp()), mode="normal")
for ts in time_samples[:5] # 只显示前5个
]
logger.debug(f"时间采样点示例: {readable_timestamps}")
# 第一步:批量收集所有消息样本
logger.debug("开始批量收集消息样本...")
collected_messages = await self._collect_all_message_samples(time_samples)
if not collected_messages:
logger.info("未收集到有效消息样本,跳过本次采样")
self.last_sample_time = time.time()
return {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": 0,
"duration": time.time() - start_time,
"samples_generated": len(time_samples),
"message": "未收集到有效消息样本",
}
logger.info(f"收集到 {len(collected_messages)} 组消息样本")
# 第二步:融合和去重消息
logger.debug("开始融合和去重消息...")
fused_messages = await self._fuse_and_deduplicate_messages(collected_messages)
if not fused_messages:
logger.info("消息融合后为空,跳过记忆构建")
self.last_sample_time = time.time()
return {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": 0,
"duration": time.time() - start_time,
"samples_generated": len(time_samples),
"message": "消息融合后为空",
}
logger.info(f"融合后得到 {len(fused_messages)} 组有效消息")
# 第三步:一次性构建记忆
logger.debug("开始批量构建记忆...")
build_result = await self._build_batch_memory(fused_messages, time_samples)
# 更新最后采样时间
self.last_sample_time = time.time()
duration = time.time() - start_time
result = {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": build_result.get("memory_count", 0),
"duration": duration,
"samples_generated": len(time_samples),
"messages_collected": len(collected_messages),
"messages_fused": len(fused_messages),
"optimization_mode": "batch_fusion",
}
logger.info(
f"✅ 海马体采样周期完成(批量融合模式) | "
f"采样点: {len(time_samples)} | "
f"收集消息: {len(collected_messages)} | "
f"融合消息: {len(fused_messages)} | "
f"构建记忆: {build_result.get('memory_count', 0)} | "
f"耗时: {duration:.2f}s"
)
return result
except Exception as e:
logger.error(f"❌ 海马体采样周期失败: {e}")
return {
"status": "error",
"error": str(e),
"sample_count": self.sample_count,
"duration": time.time() - start_time,
}
async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]:
"""批量收集所有时间点的消息样本"""
collected_messages = []
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
for i in range(0, len(time_samples), max_concurrent):
batch = time_samples[i:i + max_concurrent]
tasks = []
# 创建并发收集任务
for timestamp in batch:
target_ts = timestamp.timestamp()
task = self.collect_message_samples(target_ts)
tasks.append(task)
# 执行并发收集
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理收集结果
for result in results:
if isinstance(result, list) and result:
collected_messages.append(result)
elif isinstance(result, Exception):
logger.debug(f"消息收集异常: {result}")
# 批次间短暂延迟
if i + max_concurrent < len(time_samples):
await asyncio.sleep(0.5)
return collected_messages
async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
"""融合和去重消息样本"""
if not collected_messages:
return []
try:
# 展平所有消息
all_messages = []
for message_group in collected_messages:
all_messages.extend(message_group)
logger.debug(f"展开后总消息数: {len(all_messages)}")
# 去重逻辑:基于消息内容和时间戳
unique_messages = []
seen_hashes = set()
for message in all_messages:
# 创建消息哈希用于去重
content = message.get("processed_plain_text", "") or message.get("display_message", "")
timestamp = message.get("time", 0)
chat_id = message.get("chat_id", "")
# 简单哈希内容前50字符 + 时间戳(精确到分钟) + 聊天ID
hash_key = f"{content[:50]}_{int(timestamp//60)}_{chat_id}"
if hash_key not in seen_hashes and len(content.strip()) > 10:
seen_hashes.add(hash_key)
unique_messages.append(message)
logger.debug(f"去重后消息数: {len(unique_messages)}")
# 按时间排序
unique_messages.sort(key=lambda x: x.get("time", 0))
# 按聊天ID分组重新组织
chat_groups = {}
for message in unique_messages:
chat_id = message.get("chat_id", "unknown")
if chat_id not in chat_groups:
chat_groups[chat_id] = []
chat_groups[chat_id].append(message)
# 合并相邻时间范围内的消息
fused_groups = []
for chat_id, messages in chat_groups.items():
fused_groups.extend(self._merge_adjacent_messages(messages))
logger.debug(f"融合后消息组数: {len(fused_groups)}")
return fused_groups
except Exception as e:
logger.error(f"消息融合失败: {e}")
# 返回原始消息组作为备选
return collected_messages[:5] # 限制返回数量
def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]:
"""合并时间间隔内的消息"""
if not messages:
return []
merged_groups = []
current_group = [messages[0]]
for i in range(1, len(messages)):
current_time = messages[i].get("time", 0)
prev_time = current_group[-1].get("time", 0)
# 如果时间间隔小于阈值,合并到当前组
if current_time - prev_time <= time_gap:
current_group.append(messages[i])
else:
# 否则开始新组
merged_groups.append(current_group)
current_group = [messages[i]]
# 添加最后一组
merged_groups.append(current_group)
# 过滤掉只有一条消息的组(除非内容较长)
result_groups = []
for group in merged_groups:
if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group):
result_groups.append(group)
return result_groups
async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]:
"""批量构建记忆"""
if not fused_messages:
return {"memory_count": 0, "memories": []}
try:
total_memories = []
total_memory_count = 0
# 构建融合后的文本
batch_input_text = await self._build_fused_conversation_text(fused_messages)
if not batch_input_text:
logger.warning("无法构建融合文本,尝试单独构建")
# 备选方案:分别构建
return await self._fallback_individual_build(fused_messages)
# 创建批量上下文
batch_context = {
"user_id": "hippocampus_batch_sampler",
"timestamp": time.time(),
"source": "hippocampus_batch_sampling",
"message_groups_count": len(fused_messages),
"total_messages": sum(len(group) for group in fused_messages),
"sample_count": len(time_samples),
"is_hippocampus_sample": True,
"bypass_value_threshold": True,
"optimization_mode": "batch_fusion",
}
logger.debug(f"批量构建记忆,文本长度: {len(batch_input_text)}")
# 一次性构建记忆
memories = await self.memory_system.build_memory_from_conversation(
conversation_text=batch_input_text,
context=batch_context,
timestamp=time.time(),
bypass_interval=True
)
if memories:
memory_count = len(memories)
self.success_count += 1
total_memory_count += memory_count
total_memories.extend(memories)
logger.info(f"✅ 批量海马体采样成功构建 {memory_count} 条记忆")
else:
logger.debug("批量海马体采样未生成有效记忆")
# 记录采样结果
result = {
"timestamp": time.time(),
"memory_count": total_memory_count,
"message_groups_count": len(fused_messages),
"total_messages": sum(len(group) for group in fused_messages),
"text_preview": batch_input_text[:200] + "..." if len(batch_input_text) > 200 else batch_input_text,
"memory_types": [m.memory_type.value for m in total_memories],
}
self.last_sample_results.append(result)
# 限制结果历史长度
if len(self.last_sample_results) > 10:
self.last_sample_results.pop(0)
return {
"memory_count": total_memory_count,
"memories": total_memories,
"result": result
}
except Exception as e:
logger.error(f"批量构建记忆失败: {e}")
return {"memory_count": 0, "error": str(e)}
async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str:
"""构建融合后的对话文本"""
try:
# 添加批次标识
current_date = f"海马体批量采样 - {datetime.now().isoformat()}\n"
conversation_parts = [current_date]
for group_idx, message_group in enumerate(fused_messages):
if not message_group:
continue
# 为每个消息组添加分隔符
group_header = f"\n=== 对话片段 {group_idx + 1} ==="
conversation_parts.append(group_header)
# 构建可读消息
group_text = await build_readable_messages(
message_group,
merge_messages=True,
timestamp_mode="normal_no_YMD",
replace_bot_name=False,
)
if group_text and len(group_text.strip()) > 10:
conversation_parts.append(group_text.strip())
return "\n".join(conversation_parts)
except Exception as e:
logger.error(f"构建融合文本失败: {e}")
return ""
async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
"""备选方案:单独构建每个消息组"""
total_memories = []
total_count = 0
for group in fused_messages[:5]: # 限制最多5组
try:
memories = await self.build_memory_from_samples(group, time.time())
if memories:
total_memories.extend(memories)
total_count += len(memories)
except Exception as e:
logger.debug(f"单独构建失败: {e}")
return {
"memory_count": total_count,
"memories": total_memories,
"fallback_mode": True
}
async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]:
"""处理单个时间戳采样(保留作为备选方法)"""
try:
# 收集消息样本
messages = await self.collect_message_samples(target_timestamp)
if not messages:
return None
# 构建记忆
result = await self.build_memory_from_samples(messages, target_timestamp)
return result
except Exception as e:
logger.debug(f"处理时间戳采样失败 {target_timestamp}: {e}")
return None
def should_sample(self) -> bool:
"""检查是否应该进行采样"""
current_time = time.time()
# 检查时间间隔
if current_time - self.last_sample_time < self.config.sample_interval:
return False
# 检查是否已初始化
if not self.memory_builder_model:
logger.warning("海马体采样器未初始化")
return False
return True
async def start_background_sampling(self):
"""启动后台采样"""
if self.is_running:
logger.warning("海马体后台采样已在运行")
return
self.is_running = True
logger.info("🚀 启动海马体后台采样任务")
try:
while self.is_running:
try:
# 执行采样周期
result = await self.perform_sampling_cycle()
# 如果是跳过状态,短暂睡眠
if result.get("status") == "skipped":
await asyncio.sleep(60) # 1分钟后重试
else:
# 正常等待下一个采样间隔
await asyncio.sleep(self.config.sample_interval)
except Exception as e:
logger.error(f"海马体后台采样异常: {e}")
await asyncio.sleep(300) # 异常时等待5分钟
except asyncio.CancelledError:
logger.info("海马体后台采样任务被取消")
finally:
self.is_running = False
def stop_background_sampling(self):
"""停止后台采样"""
self.is_running = False
logger.info("🛑 停止海马体后台采样任务")
def get_sampling_stats(self) -> Dict[str, Any]:
"""获取采样统计信息"""
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
# 计算最近的平均数据
recent_avg_messages = 0
recent_avg_memory_count = 0
if self.last_sample_results:
recent_results = self.last_sample_results[-5:] # 最近5次
recent_avg_messages = sum(r.get("total_messages", 0) for r in recent_results) / len(recent_results)
recent_avg_memory_count = sum(r.get("memory_count", 0) for r in recent_results) / len(recent_results)
return {
"is_running": self.is_running,
"sample_count": self.sample_count,
"success_count": self.success_count,
"success_rate": f"{success_rate:.1f}%",
"last_sample_time": self.last_sample_time,
"optimization_mode": "batch_fusion", # 显示优化模式
"performance_metrics": {
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
"fusion_efficiency": f"{(recent_avg_messages/max(recent_avg_memory_count, 1)):.1f}x" if recent_avg_messages > 0 else "N/A"
},
"config": {
"sample_interval": self.config.sample_interval,
"total_samples": self.config.total_samples,
"recent_weight": f"{self.config.recent_weight:.1%}",
"distant_weight": f"{self.config.distant_weight:.1%}",
"max_concurrent": 5, # 批量模式并发数
"fusion_time_gap": "30分钟", # 消息融合时间间隔
},
"recent_results": self.last_sample_results[-5:], # 最近5次结果
}
# 全局海马体采样器实例
_hippocampus_sampler: Optional[HippocampusSampler] = None
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
"""获取全局海马体采样器实例"""
global _hippocampus_sampler
if _hippocampus_sampler is None:
_hippocampus_sampler = HippocampusSampler(memory_system)
return _hippocampus_sampler
async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
"""初始化全局海马体采样器"""
sampler = get_hippocampus_sampler(memory_system)
await sampler.initialize()
return sampler

View File

@@ -19,6 +19,12 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio
from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_fusion import MemoryFusionEngine from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
# 简化的记忆采样模式枚举
class MemorySamplingMode(Enum):
"""记忆采样模式"""
HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
ALL = "all" # 所有模式:同时使用海马体和即时采样
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@@ -148,6 +154,9 @@ class MemorySystem:
# 记忆指纹缓存,用于快速检测重复记忆 # 记忆指纹缓存,用于快速检测重复记忆
self._memory_fingerprints: dict[str, str] = {} self._memory_fingerprints: dict[str, str] = {}
# 海马体采样器
self.hippocampus_sampler = None
logger.info("MemorySystem 初始化开始") logger.info("MemorySystem 初始化开始")
async def initialize(self): async def initialize(self):
@@ -249,6 +258,16 @@ class MemorySystem:
self.query_planner = MemoryQueryPlanner(planner_model, default_limit=self.config.final_recall_limit) self.query_planner = MemoryQueryPlanner(planner_model, default_limit=self.config.final_recall_limit)
# 初始化海马体采样器
if global_config.memory.enable_hippocampus_sampling:
try:
from .hippocampus_sampler import initialize_hippocampus_sampler
self.hippocampus_sampler = await initialize_hippocampus_sampler(self)
logger.info("✅ 海马体采样器初始化成功")
except Exception as e:
logger.warning(f"海马体采样器初始化失败: {e}")
self.hippocampus_sampler = None
# 统一存储已经自动加载数据,无需额外加载 # 统一存储已经自动加载数据,无需额外加载
logger.info("✅ 简化版记忆系统初始化完成") logger.info("✅ 简化版记忆系统初始化完成")
@@ -283,14 +302,14 @@ class MemorySystem:
try: try:
# 使用统一存储检索相似记忆 # 使用统一存储检索相似记忆
filters = {"user_id": user_id} if user_id else None
search_results = await self.unified_storage.search_similar_memories( search_results = await self.unified_storage.search_similar_memories(
query_text=query_text, limit=limit, scope_id=user_id query_text=query_text, limit=limit, filters=filters
) )
# 转换为记忆对象 # 转换为记忆对象
memories = [] memories = []
for memory_id, similarity_score in search_results: for memory, similarity_score in search_results:
memory = self.unified_storage.get_memory_by_id(memory_id)
if memory: if memory:
memory.update_access() # 更新访问信息 memory.update_access() # 更新访问信息
memories.append(memory) memories.append(memory)
@@ -302,7 +321,7 @@ class MemorySystem:
return [] return []
async def build_memory_from_conversation( async def build_memory_from_conversation(
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None, bypass_interval: bool = False
) -> list[MemoryChunk]: ) -> list[MemoryChunk]:
"""从对话中构建记忆 """从对话中构建记忆
@@ -310,6 +329,7 @@ class MemorySystem:
conversation_text: 对话文本 conversation_text: 对话文本
context: 上下文信息 context: 上下文信息
timestamp: 时间戳,默认为当前时间 timestamp: 时间戳,默认为当前时间
bypass_interval: 是否绕过构建间隔检查(海马体采样器专用)
Returns: Returns:
构建的记忆块列表 构建的记忆块列表
@@ -328,7 +348,8 @@ class MemorySystem:
min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0)) min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0))
current_time = time.time() current_time = time.time()
if build_scope_key and min_interval > 0: # 构建间隔检查(海马体采样器可以绕过)
if build_scope_key and min_interval > 0 and not bypass_interval:
last_time = self._last_memory_build_times.get(build_scope_key) last_time = self._last_memory_build_times.get(build_scope_key)
if last_time and (current_time - last_time) < min_interval: if last_time and (current_time - last_time) < min_interval:
remaining = min_interval - (current_time - last_time) remaining = min_interval - (current_time - last_time)
@@ -340,18 +361,35 @@ class MemorySystem:
build_marker_time = current_time build_marker_time = current_time
self._last_memory_build_times[build_scope_key] = current_time self._last_memory_build_times[build_scope_key] = current_time
elif bypass_interval:
# 海马体采样模式:不更新构建时间记录,避免影响即时模式
logger.debug("海马体采样模式:绕过构建间隔检查")
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context) conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
logger.debug("开始构建记忆,文本长度: %d", len(conversation_text)) logger.debug("开始构建记忆,文本长度: %d", len(conversation_text))
# 1. 信息价值评估 # 1. 信息价值评估(海马体采样器可以绕过)
value_score = await self._assess_information_value(conversation_text, normalized_context) if not bypass_interval and not context.get("bypass_value_threshold", False):
value_score = await self._assess_information_value(conversation_text, normalized_context)
if value_score < self.config.memory_value_threshold: if value_score < self.config.memory_value_threshold:
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建") logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
self.status = original_status self.status = original_status
return [] return []
else:
# 海马体采样器:使用默认价值分数或简单评估
value_score = 0.6 # 默认中等价值
if context.get("is_hippocampus_sample", False):
# 对海马体样本进行简单价值评估
if len(conversation_text) > 100: # 长文本可能有更多信息
value_score = 0.7
elif len(conversation_text) > 50:
value_score = 0.6
else:
value_score = 0.5
logger.debug(f"海马体采样模式:使用价值评分 {value_score:.2f}")
# 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享) # 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享)
memory_chunks = await self.memory_builder.build_memories( memory_chunks = await self.memory_builder.build_memories(
@@ -469,7 +507,7 @@ class MemorySystem:
continue continue
search_tasks.append( search_tasks.append(
self.unified_storage.search_similar_memories( self.unified_storage.search_similar_memories(
query_text=display_text, limit=8, scope_id=GLOBAL_MEMORY_SCOPE query_text=display_text, limit=8, filters={"user_id": GLOBAL_MEMORY_SCOPE}
) )
) )
@@ -512,12 +550,70 @@ class MemorySystem:
return existing_candidates return existing_candidates
async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]: async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]:
"""对外暴露的对话记忆处理接口,仅依赖上下文信息""" """对外暴露的对话记忆处理接口,支持海马体、即时、所有三种采样模式"""
start_time = time.time() start_time = time.time()
try: try:
context = dict(context or {}) context = dict(context or {})
# 获取配置的采样模式
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate')
current_mode = MemorySamplingMode(sampling_mode)
logger.debug(f"使用记忆采样模式: {current_mode.value}")
# 根据采样模式处理记忆
if current_mode == MemorySamplingMode.HIPPOCAMPUS:
# 海马体模式:仅后台定时采样,不立即处理
return {
"success": True,
"created_memories": [],
"memory_count": 0,
"processing_time": time.time() - start_time,
"status": self.status.value,
"processing_mode": "hippocampus",
"message": "海马体模式:记忆将由后台定时任务采样处理",
}
elif current_mode == MemorySamplingMode.IMMEDIATE:
# 即时模式:立即处理记忆构建
return await self._process_immediate_memory(context, start_time)
elif current_mode == MemorySamplingMode.ALL:
# 所有模式:同时进行即时处理和海马体采样
immediate_result = await self._process_immediate_memory(context, start_time)
# 海马体采样器会在后台继续处理,这里只是记录
if self.hippocampus_sampler:
immediate_result["processing_mode"] = "all_modes"
immediate_result["hippocampus_status"] = "background_sampling_enabled"
immediate_result["message"] = "所有模式:即时处理已完成,海马体采样将在后台继续"
else:
immediate_result["processing_mode"] = "immediate_fallback"
immediate_result["hippocampus_status"] = "not_available"
immediate_result["message"] = "海马体采样器不可用,回退到即时模式"
return immediate_result
else:
# 默认回退到即时模式
logger.warning(f"未知的采样模式 {sampling_mode},回退到即时模式")
return await self._process_immediate_memory(context, start_time)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"对话记忆处理失败: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"processing_time": processing_time,
"status": self.status.value,
"processing_mode": "error",
}
async def _process_immediate_memory(self, context: dict[str, Any], start_time: float) -> dict[str, Any]:
"""即时记忆处理的辅助方法"""
try:
conversation_candidate = ( conversation_candidate = (
context.get("conversation_text") context.get("conversation_text")
or context.get("message_content") or context.get("message_content")
@@ -537,6 +633,23 @@ class MemorySystem:
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp) normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
normalized_context.setdefault("conversation_text", conversation_text) normalized_context.setdefault("conversation_text", conversation_text)
# 检查信息价值阈值
value_score = await self._assess_information_value(conversation_text, normalized_context)
threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5)
if value_score < threshold:
logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建")
return {
"success": True,
"created_memories": [],
"memory_count": 0,
"processing_time": time.time() - start_time,
"status": self.status.value,
"processing_mode": "immediate",
"skip_reason": f"value_score_{value_score:.2f}_below_threshold_{threshold}",
"value_score": value_score,
}
memories = await self.build_memory_from_conversation( memories = await self.build_memory_from_conversation(
conversation_text=conversation_text, context=normalized_context, timestamp=timestamp conversation_text=conversation_text, context=normalized_context, timestamp=timestamp
) )
@@ -550,12 +663,20 @@ class MemorySystem:
"memory_count": memory_count, "memory_count": memory_count,
"processing_time": processing_time, "processing_time": processing_time,
"status": self.status.value, "status": self.status.value,
"processing_mode": "immediate",
"value_score": value_score,
} }
except Exception as e: except Exception as e:
processing_time = time.time() - start_time processing_time = time.time() - start_time
logger.error(f"对话记忆处理失败: {e}", exc_info=True) logger.error(f"即时记忆处理失败: {e}", exc_info=True)
return {"success": False, "error": str(e), "processing_time": processing_time, "status": self.status.value} return {
"success": False,
"error": str(e),
"processing_time": processing_time,
"status": self.status.value,
"processing_mode": "immediate_error",
}
async def retrieve_relevant_memories( async def retrieve_relevant_memories(
self, self,
@@ -1372,11 +1493,53 @@ class MemorySystem:
except Exception as e: except Exception as e:
logger.error(f"❌ 记忆系统维护失败: {e}", exc_info=True) logger.error(f"❌ 记忆系统维护失败: {e}", exc_info=True)
def start_hippocampus_sampling(self):
"""启动海马体采样"""
if self.hippocampus_sampler:
asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
logger.info("🚀 海马体后台采样已启动")
else:
logger.warning("海马体采样器未初始化,无法启动采样")
def stop_hippocampus_sampling(self):
"""停止海马体采样"""
if self.hippocampus_sampler:
self.hippocampus_sampler.stop_background_sampling()
logger.info("🛑 海马体后台采样已停止")
def get_system_stats(self) -> dict[str, Any]:
"""获取系统统计信息"""
base_stats = {
"status": self.status.value,
"total_memories": self.total_memories,
"last_build_time": self.last_build_time,
"last_retrieval_time": self.last_retrieval_time,
"config": asdict(self.config),
}
# 添加海马体采样器统计
if self.hippocampus_sampler:
base_stats["hippocampus_sampler"] = self.hippocampus_sampler.get_sampling_stats()
# 添加存储统计
if self.unified_storage:
try:
storage_stats = self.unified_storage.get_storage_stats()
base_stats["storage_stats"] = storage_stats
except Exception as e:
logger.debug(f"获取存储统计失败: {e}")
return base_stats
async def shutdown(self): async def shutdown(self):
"""关闭系统(简化版)""" """关闭系统(简化版)"""
try: try:
logger.info("正在关闭简化记忆系统...") logger.info("正在关闭简化记忆系统...")
# 停止海马体采样
if self.hippocampus_sampler:
self.hippocampus_sampler.stop_background_sampling()
# 保存统一存储数据 # 保存统一存储数据
if self.unified_storage: if self.unified_storage:
await self.unified_storage.cleanup() await self.unified_storage.cleanup()
@@ -1456,4 +1619,10 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None):
if memory_system is None: if memory_system is None:
memory_system = MemorySystem(llm_model=llm_model) memory_system = MemorySystem(llm_model=llm_model)
await memory_system.initialize() await memory_system.initialize()
# 根据配置启动海马体采样
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate')
if sampling_mode in ['hippocampus', 'all']:
memory_system.start_hippocampus_sampling()
return memory_system return memory_system

View File

@@ -0,0 +1,489 @@
"""
自适应流管理器 - 动态并发限制和异步流池管理
根据系统负载和流优先级动态调整并发限制
"""
import asyncio
import psutil
import time
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
from enum import Enum
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("adaptive_stream_manager")
class StreamPriority(Enum):
"""流优先级"""
LOW = 1
NORMAL = 2
HIGH = 3
CRITICAL = 4
@dataclass
class SystemMetrics:
"""系统指标"""
cpu_usage: float = 0.0
memory_usage: float = 0.0
active_coroutines: int = 0
event_loop_lag: float = 0.0
timestamp: float = field(default_factory=time.time)
@dataclass
class StreamMetrics:
"""流指标"""
stream_id: str
priority: StreamPriority
message_rate: float = 0.0 # 消息速率(消息/分钟)
response_time: float = 0.0 # 平均响应时间
last_activity: float = field(default_factory=time.time)
consecutive_failures: int = 0
is_active: bool = True
class AdaptiveStreamManager:
"""自适应流管理器"""
def __init__(
self,
base_concurrent_limit: int = 50,
max_concurrent_limit: int = 200,
min_concurrent_limit: int = 10,
metrics_window: float = 60.0, # 指标窗口时间
adjustment_interval: float = 30.0, # 调整间隔
cpu_threshold_high: float = 0.8, # CPU高负载阈值
cpu_threshold_low: float = 0.3, # CPU低负载阈值
memory_threshold_high: float = 0.85, # 内存高负载阈值
):
self.base_concurrent_limit = base_concurrent_limit
self.max_concurrent_limit = max_concurrent_limit
self.min_concurrent_limit = min_concurrent_limit
self.metrics_window = metrics_window
self.adjustment_interval = adjustment_interval
self.cpu_threshold_high = cpu_threshold_high
self.cpu_threshold_low = cpu_threshold_low
self.memory_threshold_high = memory_threshold_high
# 当前状态
self.current_limit = base_concurrent_limit
self.active_streams: Set[str] = set()
self.pending_streams: Set[str] = set()
self.stream_metrics: Dict[str, StreamMetrics] = {}
# 异步信号量
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
# 系统监控
self.system_metrics: List[SystemMetrics] = []
self.last_adjustment_time = 0.0
# 统计信息
self.stats = {
"total_requests": 0,
"accepted_requests": 0,
"rejected_requests": 0,
"priority_accepts": 0,
"limit_adjustments": 0,
"avg_concurrent_streams": 0,
"peak_concurrent_streams": 0,
}
# 监控任务
self.monitor_task: Optional[asyncio.Task] = None
self.adjustment_task: Optional[asyncio.Task] = None
self.is_running = False
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
async def start(self):
"""启动自适应管理器"""
if self.is_running:
logger.warning("自适应流管理器已经在运行")
return
self.is_running = True
self.monitor_task = asyncio.create_task(self._system_monitor_loop(), name="system_monitor")
self.adjustment_task = asyncio.create_task(self._adjustment_loop(), name="limit_adjustment")
logger.info("自适应流管理器已启动")
async def stop(self):
"""停止自适应管理器"""
if not self.is_running:
return
self.is_running = False
# 停止监控任务
if self.monitor_task and not self.monitor_task.done():
self.monitor_task.cancel()
try:
await asyncio.wait_for(self.monitor_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("系统监控任务停止超时")
except Exception as e:
logger.error(f"停止系统监控任务时出错: {e}")
if self.adjustment_task and not self.adjustment_task.done():
self.adjustment_task.cancel()
try:
await asyncio.wait_for(self.adjustment_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("限制调整任务停止超时")
except Exception as e:
logger.error(f"停止限制调整任务时出错: {e}")
logger.info("自适应流管理器已停止")
async def acquire_stream_slot(
self,
stream_id: str,
priority: StreamPriority = StreamPriority.NORMAL,
force: bool = False
) -> bool:
"""
获取流处理槽位
Args:
stream_id: 流ID
priority: 优先级
force: 是否强制获取(突破限制)
Returns:
bool: 是否成功获取槽位
"""
# 检查管理器是否已启动
if not self.is_running:
logger.warning(f"自适应流管理器未运行,直接允许流 {stream_id}")
return True
self.stats["total_requests"] += 1
current_time = time.time()
# 更新流指标
if stream_id not in self.stream_metrics:
self.stream_metrics[stream_id] = StreamMetrics(
stream_id=stream_id,
priority=priority
)
self.stream_metrics[stream_id].last_activity = current_time
# 检查是否已经活跃
if stream_id in self.active_streams:
logger.debug(f"{stream_id} 已经在活跃列表中")
return True
# 优先级处理
if priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
return await self._acquire_priority_slot(stream_id, priority, force)
# 检查是否需要强制分发(消息积压)
if not force and self._should_force_dispatch(stream_id):
force = True
logger.info(f"{stream_id} 消息积压严重,强制分发")
# 尝试获取常规信号量
try:
# 使用wait_for实现非阻塞获取
acquired = await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
if acquired:
self.active_streams.add(stream_id)
self.stats["accepted_requests"] += 1
logger.debug(f"{stream_id} 获取常规槽位成功 (当前活跃: {len(self.active_streams)})")
return True
except asyncio.TimeoutError:
logger.debug(f"常规信号量已满: {stream_id}")
except Exception as e:
logger.warning(f"获取常规槽位时出错: {e}")
# 如果强制分发,尝试突破限制
if force:
return await self._force_acquire_slot(stream_id)
# 无法获取槽位
self.stats["rejected_requests"] += 1
logger.debug(f"{stream_id} 获取槽位失败,当前限制: {self.current_limit}, 活跃流: {len(self.active_streams)}")
return False
async def _acquire_priority_slot(self, stream_id: str, priority: StreamPriority, force: bool) -> bool:
"""获取优先级槽位"""
try:
# 优先级信号量有少量槽位
acquired = await asyncio.wait_for(self.priority_semaphore.acquire(), timeout=0.001)
if acquired:
self.active_streams.add(stream_id)
self.stats["priority_accepts"] += 1
self.stats["accepted_requests"] += 1
logger.debug(f"{stream_id} 获取优先级槽位成功 (优先级: {priority.name})")
return True
except asyncio.TimeoutError:
logger.debug(f"优先级信号量已满: {stream_id}")
except Exception as e:
logger.warning(f"获取优先级槽位时出错: {e}")
# 如果优先级槽位也满了,检查是否强制
if force or priority == StreamPriority.CRITICAL:
return await self._force_acquire_slot(stream_id)
return False
async def _force_acquire_slot(self, stream_id: str) -> bool:
"""强制获取槽位(突破限制)"""
# 检查是否超过最大限制
if len(self.active_streams) >= self.max_concurrent_limit:
logger.warning(f"达到最大并发限制 {self.max_concurrent_limit},无法为流 {stream_id} 强制分发")
return False
# 强制添加到活跃列表
self.active_streams.add(stream_id)
self.stats["accepted_requests"] += 1
logger.warning(f"{stream_id} 突破并发限制强制分发 (当前活跃: {len(self.active_streams)})")
return True
def release_stream_slot(self, stream_id: str):
"""释放流处理槽位"""
if stream_id in self.active_streams:
self.active_streams.remove(stream_id)
# 释放相应的信号量
metrics = self.stream_metrics.get(stream_id)
if metrics and metrics.priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
self.priority_semaphore.release()
else:
self.semaphore.release()
logger.debug(f"{stream_id} 释放槽位 (当前活跃: {len(self.active_streams)})")
def _should_force_dispatch(self, stream_id: str) -> bool:
"""判断是否应该强制分发"""
# 这里可以实现基于消息积压的判断逻辑
# 简化版本:基于流的历史活跃度和优先级
metrics = self.stream_metrics.get(stream_id)
if not metrics:
return False
# 如果是高优先级流,更容易强制分发
if metrics.priority == StreamPriority.HIGH:
return True
# 如果最近有活跃且响应时间较长,可能需要强制分发
current_time = time.time()
if (current_time - metrics.last_activity < 300 and # 5分钟内有活动
metrics.response_time > 5.0): # 响应时间超过5秒
return True
return False
async def _system_monitor_loop(self):
"""系统监控循环"""
logger.info("系统监控循环启动")
while self.is_running:
try:
await asyncio.sleep(5.0) # 每5秒监控一次
await self._collect_system_metrics()
except asyncio.CancelledError:
logger.info("系统监控循环被取消")
break
except Exception as e:
logger.error(f"系统监控出错: {e}")
logger.info("系统监控循环结束")
async def _collect_system_metrics(self):
"""收集系统指标"""
try:
# CPU使用率
cpu_usage = psutil.cpu_percent(interval=None) / 100.0
# 内存使用率
memory = psutil.virtual_memory()
memory_usage = memory.percent / 100.0
# 活跃协程数量
try:
active_coroutines = len(asyncio.all_tasks())
except:
active_coroutines = 0
# 事件循环延迟
event_loop_lag = 0.0
try:
loop = asyncio.get_running_loop()
start_time = time.time()
await asyncio.sleep(0)
event_loop_lag = time.time() - start_time
except:
pass
metrics = SystemMetrics(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
active_coroutines=active_coroutines,
event_loop_lag=event_loop_lag,
timestamp=time.time()
)
self.system_metrics.append(metrics)
# 保持指标窗口大小
cutoff_time = time.time() - self.metrics_window
self.system_metrics = [
m for m in self.system_metrics
if m.timestamp > cutoff_time
]
# 更新统计信息
self.stats["avg_concurrent_streams"] = (
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
)
self.stats["peak_concurrent_streams"] = max(
self.stats["peak_concurrent_streams"],
len(self.active_streams)
)
except Exception as e:
logger.error(f"收集系统指标失败: {e}")
async def _adjustment_loop(self):
"""限制调整循环"""
logger.info("限制调整循环启动")
while self.is_running:
try:
await asyncio.sleep(self.adjustment_interval)
await self._adjust_concurrent_limit()
except asyncio.CancelledError:
logger.info("限制调整循环被取消")
break
except Exception as e:
logger.error(f"限制调整出错: {e}")
logger.info("限制调整循环结束")
async def _adjust_concurrent_limit(self):
"""调整并发限制"""
if not self.system_metrics:
return
current_time = time.time()
if current_time - self.last_adjustment_time < self.adjustment_interval:
return
# 计算平均系统指标
recent_metrics = self.system_metrics[-10:] if len(self.system_metrics) >= 10 else self.system_metrics
if not recent_metrics:
return
avg_cpu = sum(m.cpu_usage for m in recent_metrics) / len(recent_metrics)
avg_memory = sum(m.memory_usage for m in recent_metrics) / len(recent_metrics)
avg_coroutines = sum(m.active_coroutines for m in recent_metrics) / len(recent_metrics)
# 调整策略
old_limit = self.current_limit
adjustment_factor = 1.0
# CPU负载调整
if avg_cpu > self.cpu_threshold_high:
adjustment_factor *= 0.8 # 减少20%
elif avg_cpu < self.cpu_threshold_low:
adjustment_factor *= 1.2 # 增加20%
# 内存负载调整
if avg_memory > self.memory_threshold_high:
adjustment_factor *= 0.7 # 减少30%
# 协程数量调整
if avg_coroutines > 1000:
adjustment_factor *= 0.9 # 减少10%
# 应用调整
new_limit = int(self.current_limit * adjustment_factor)
new_limit = max(self.min_concurrent_limit, min(self.max_concurrent_limit, new_limit))
# 检查是否需要调整信号量
if new_limit != self.current_limit:
await self._adjust_semaphore(self.current_limit, new_limit)
self.current_limit = new_limit
self.stats["limit_adjustments"] += 1
self.last_adjustment_time = current_time
logger.info(
f"并发限制调整: {old_limit} -> {new_limit} "
f"(CPU: {avg_cpu:.2f}, 内存: {avg_memory:.2f}, 协程: {avg_coroutines:.0f})"
)
async def _adjust_semaphore(self, old_limit: int, new_limit: int):
"""调整信号量大小"""
if new_limit > old_limit:
# 增加信号量槽位
for _ in range(new_limit - old_limit):
self.semaphore.release()
elif new_limit < old_limit:
# 减少信号量槽位(通过等待槽位被释放)
reduction = old_limit - new_limit
for _ in range(reduction):
try:
await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
except:
# 如果无法立即获取,说明当前使用量接近限制
break
def update_stream_metrics(self, stream_id: str, **kwargs):
"""更新流指标"""
if stream_id not in self.stream_metrics:
return
metrics = self.stream_metrics[stream_id]
for key, value in kwargs.items():
if hasattr(metrics, key):
setattr(metrics, key, value)
def get_stats(self) -> Dict:
"""获取统计信息"""
stats = self.stats.copy()
stats.update({
"current_limit": self.current_limit,
"active_streams": len(self.active_streams),
"pending_streams": len(self.pending_streams),
"is_running": self.is_running,
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
})
# 计算接受率
if stats["total_requests"] > 0:
stats["acceptance_rate"] = stats["accepted_requests"] / stats["total_requests"]
else:
stats["acceptance_rate"] = 0
return stats
# 全局自适应管理器实例
_adaptive_manager: Optional[AdaptiveStreamManager] = None
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
"""获取自适应流管理器实例"""
global _adaptive_manager
if _adaptive_manager is None:
_adaptive_manager = AdaptiveStreamManager()
return _adaptive_manager
async def init_adaptive_stream_manager():
"""初始化自适应流管理器"""
manager = get_adaptive_stream_manager()
await manager.start()
async def shutdown_adaptive_stream_manager():
"""关闭自适应流管理器"""
manager = get_adaptive_stream_manager()
await manager.stop()

View File

@@ -0,0 +1,348 @@
"""
异步批量数据库写入器
优化频繁的数据库写入操作减少I/O阻塞
"""
import asyncio
import time
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from collections import defaultdict
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("batch_database_writer")
@dataclass
class StreamUpdatePayload:
"""流更新数据结构"""
stream_id: str
update_data: Dict[str, Any]
priority: int = 0 # 优先级,数字越大优先级越高
timestamp: float = field(default_factory=time.time)
class BatchDatabaseWriter:
"""异步批量数据库写入器"""
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0, max_queue_size: int = 1000):
"""
初始化批量写入器
Args:
batch_size: 批量写入的大小
flush_interval: 刷新间隔(秒)
max_queue_size: 最大队列大小
"""
self.batch_size = batch_size
self.flush_interval = flush_interval
self.max_queue_size = max_queue_size
# 异步队列
self.write_queue: asyncio.Queue[StreamUpdatePayload] = asyncio.Queue(maxsize=max_queue_size)
# 运行状态
self.is_running = False
self.writer_task: Optional[asyncio.Task] = None
# 统计信息
self.stats = {
"total_writes": 0,
"batch_writes": 0,
"failed_writes": 0,
"queue_size": 0,
"avg_batch_size": 0,
"last_flush_time": 0,
}
# 按优先级分类的批次
self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list)
logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)")
async def start(self):
"""启动批量写入器"""
if self.is_running:
logger.warning("批量写入器已经在运行")
return
self.is_running = True
self.writer_task = asyncio.create_task(self._batch_writer_loop(), name="batch_database_writer")
logger.info("批量数据库写入器已启动")
async def stop(self):
"""停止批量写入器"""
if not self.is_running:
return
self.is_running = False
# 等待当前批次写入完成
if self.writer_task and not self.writer_task.done():
try:
# 先处理剩余的数据
await self._flush_all_batches()
# 取消任务
self.writer_task.cancel()
await asyncio.wait_for(self.writer_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("批量写入器停止超时")
except Exception as e:
logger.error(f"停止批量写入器时出错: {e}")
logger.info("批量数据库写入器已停止")
async def schedule_stream_update(
self,
stream_id: str,
update_data: Dict[str, Any],
priority: int = 0
) -> bool:
"""
调度流更新
Args:
stream_id: 流ID
update_data: 更新数据
priority: 优先级
Returns:
bool: 是否成功加入队列
"""
try:
if not self.is_running:
logger.warning("批量写入器未运行,直接写入数据库")
await self._direct_write(stream_id, update_data)
return True
# 创建更新载荷
payload = StreamUpdatePayload(
stream_id=stream_id,
update_data=update_data,
priority=priority
)
# 非阻塞方式加入队列
try:
self.write_queue.put_nowait(payload)
self.stats["total_writes"] += 1
self.stats["queue_size"] = self.write_queue.qsize()
return True
except asyncio.QueueFull:
logger.warning(f"写入队列已满,丢弃低优先级更新: stream_id={stream_id}")
return False
except Exception as e:
logger.error(f"调度流更新失败: {e}")
return False
async def _batch_writer_loop(self):
"""批量写入主循环"""
logger.info("批量写入循环启动")
while self.is_running:
try:
# 等待批次填满或超时
batch = await self._collect_batch()
if batch:
await self._write_batch(batch)
# 更新统计信息
self.stats["queue_size"] = self.write_queue.qsize()
except asyncio.CancelledError:
logger.info("批量写入循环被取消")
break
except Exception as e:
logger.error(f"批量写入循环出错: {e}")
# 短暂等待后继续
await asyncio.sleep(1.0)
# 循环结束前处理剩余数据
await self._flush_all_batches()
logger.info("批量写入循环结束")
async def _collect_batch(self) -> List[StreamUpdatePayload]:
"""收集一个批次的数据"""
batch = []
deadline = time.time() + self.flush_interval
while len(batch) < self.batch_size and time.time() < deadline:
try:
# 计算剩余等待时间
remaining_time = max(0, deadline - time.time())
if remaining_time == 0:
break
payload = await asyncio.wait_for(
self.write_queue.get(),
timeout=remaining_time
)
batch.append(payload)
except asyncio.TimeoutError:
break
return batch
async def _write_batch(self, batch: List[StreamUpdatePayload]):
"""批量写入数据库"""
if not batch:
return
start_time = time.time()
try:
# 按优先级排序
batch.sort(key=lambda x: (-x.priority, x.timestamp))
# 合并同一流ID的更新保留最新的
merged_updates = {}
for payload in batch:
if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp:
merged_updates[payload.stream_id] = payload
# 批量写入
await self._batch_write_to_database(list(merged_updates.values()))
# 更新统计
self.stats["batch_writes"] += 1
self.stats["avg_batch_size"] = (
self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1
) # 滑动平均
self.stats["last_flush_time"] = start_time
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
except Exception as e:
self.stats["failed_writes"] += 1
logger.error(f"批量写入失败: {e}")
# 降级到单个写入
for payload in batch:
try:
await self._direct_write(payload.stream_id, payload.update_data)
except Exception as single_e:
logger.error(f"单个写入也失败: {single_e}")
async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]):
"""批量写入数据库"""
async with get_db_session() as session:
for payload in payloads:
stream_id = payload.stream_id
update_data = payload.update_data
# 根据数据库类型选择不同的插入/更新策略
if global_config.database.database_type == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_=update_data
)
elif global_config.database.database_type == "mysql":
from sqlalchemy.dialects.mysql import insert as mysql_insert
stmt = mysql_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in update_data.items() if key != "stream_id"}
)
else:
# 默认使用SQLite语法
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_=update_data
)
await session.execute(stmt)
await session.commit()
async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]):
"""直接写入数据库(降级方案)"""
async with get_db_session() as session:
if global_config.database.database_type == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_=update_data
)
elif global_config.database.database_type == "mysql":
from sqlalchemy.dialects.mysql import insert as mysql_insert
stmt = mysql_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in update_data.items() if key != "stream_id"}
)
else:
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(
stream_id=stream_id, **update_data
)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_=update_data
)
await session.execute(stmt)
await session.commit()
async def _flush_all_batches(self):
"""刷新所有剩余批次"""
# 收集所有剩余数据
remaining_batch = []
while not self.write_queue.empty():
try:
payload = self.write_queue.get_nowait()
remaining_batch.append(payload)
except asyncio.QueueEmpty:
break
if remaining_batch:
await self._write_batch(remaining_batch)
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
stats = self.stats.copy()
stats["is_running"] = self.is_running
stats["current_queue_size"] = self.write_queue.qsize() if self.is_running else 0
return stats
# 全局批量写入器实例
_batch_writer: Optional[BatchDatabaseWriter] = None
def get_batch_writer() -> BatchDatabaseWriter:
"""获取批量写入器实例"""
global _batch_writer
if _batch_writer is None:
_batch_writer = BatchDatabaseWriter()
return _batch_writer
async def init_batch_writer():
"""初始化批量写入器"""
writer = get_batch_writer()
await writer.start()
async def shutdown_batch_writer():
"""关闭批量写入器"""
writer = get_batch_writer()
await writer.stop()

View File

@@ -23,6 +23,8 @@ class StreamLoopManager:
def __init__(self, max_concurrent_streams: int | None = None): def __init__(self, max_concurrent_streams: int | None = None):
# 流循环任务管理 # 流循环任务管理
self.stream_loops: dict[str, asyncio.Task] = {} self.stream_loops: dict[str, asyncio.Task] = {}
# 跟踪流使用的管理器类型
self.stream_management_type: dict[str, str] = {} # stream_id -> "adaptive" or "fallback"
# 统计信息 # 统计信息
self.stats: dict[str, Any] = { self.stats: dict[str, Any] = {
@@ -99,7 +101,7 @@ class StreamLoopManager:
logger.info("流循环管理器已停止") logger.info("流循环管理器已停止")
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool: async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
"""启动指定流的循环任务 """启动指定流的循环任务 - 优化版本使用自适应管理器
Args: Args:
stream_id: 流ID stream_id: 流ID
@@ -113,6 +115,71 @@ class StreamLoopManager:
logger.debug(f"{stream_id} 循环已在运行") logger.debug(f"{stream_id} 循环已在运行")
return True return True
# 使用自适应流管理器获取槽位
use_adaptive = False
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority
adaptive_manager = get_adaptive_stream_manager()
if adaptive_manager.is_running:
# 确定流优先级
priority = self._determine_stream_priority(stream_id)
# 获取处理槽位
slot_acquired = await adaptive_manager.acquire_stream_slot(
stream_id=stream_id,
priority=priority,
force=force
)
if slot_acquired:
use_adaptive = True
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
else:
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
else:
logger.debug(f"自适应管理器未运行,使用原始方法")
except Exception as e:
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")
# 如果自适应管理器失败或未运行,使用回退方案
if not use_adaptive:
if not await self._fallback_acquire_slot(stream_id, force):
logger.debug(f"回退方案也失败: {stream_id}")
return False
# 创建流循环任务
try:
loop_task = asyncio.create_task(
self._stream_loop_worker(stream_id),
name=f"stream_loop_{stream_id}"
)
self.stream_loops[stream_id] = loop_task
# 记录管理器类型
self.stream_management_type[stream_id] = "adaptive" if use_adaptive else "fallback"
# 更新统计信息
self.stats["active_streams"] += 1
self.stats["total_loops"] += 1
logger.info(f"启动流循环任务: {stream_id} (管理器: {'adaptive' if use_adaptive else 'fallback'})")
return True
except Exception as e:
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
# 释放槽位
if use_adaptive:
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.release_stream_slot(stream_id)
except:
pass
return False
async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool:
"""回退方案:获取槽位(原始方法)"""
# 判断是否需要强制分发 # 判断是否需要强制分发
should_force = force or self._should_force_dispatch_for_stream(stream_id) should_force = force or self._should_force_dispatch_for_stream(stream_id)
@@ -149,6 +216,28 @@ class StreamLoopManager:
del self.stream_loops[stream_id] del self.stream_loops[stream_id]
current_streams -= 1 # 更新当前流数量 current_streams -= 1 # 更新当前流数量
return True
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
"""确定流优先级"""
try:
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
# 这里可以基于流的历史数据、用户身份等确定优先级
# 简化版本基于流ID的哈希值分配优先级
hash_value = hash(stream_id) % 10
if hash_value >= 8: # 20% 高优先级
return StreamPriority.HIGH
elif hash_value >= 5: # 30% 中等优先级
return StreamPriority.NORMAL
else: # 50% 低优先级
return StreamPriority.LOW
except Exception:
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
return StreamPriority.NORMAL
# 创建流循环任务 # 创建流循环任务
try: try:
task = asyncio.create_task( task = asyncio.create_task(
@@ -201,13 +290,13 @@ class StreamLoopManager:
logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})") logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})")
return True return True
async def _stream_loop(self, stream_id: str) -> None: async def _stream_loop_worker(self, stream_id: str) -> None:
"""单个流的无限循环 """单个流的工作循环 - 优化版本
Args: Args:
stream_id: 流ID stream_id: 流ID
""" """
logger.info(f"流循环开始: {stream_id}") logger.info(f"流循环工作器启动: {stream_id}")
try: try:
while self.is_running: while self.is_running:
@@ -223,6 +312,18 @@ class StreamLoopManager:
unread_count = self._get_unread_count(context) unread_count = self._get_unread_count(context)
force_dispatch = self._needs_force_dispatch_for_context(context, unread_count) force_dispatch = self._needs_force_dispatch_for_context(context, unread_count)
# 3. 更新自适应管理器指标
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.update_stream_metrics(
stream_id,
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
last_activity=time.time()
)
except Exception as e:
logger.debug(f"更新流指标失败: {e}")
has_messages = force_dispatch or await self._has_messages_to_process(context) has_messages = force_dispatch or await self._has_messages_to_process(context)
if has_messages: if has_messages:
@@ -278,6 +379,24 @@ class StreamLoopManager:
del self.stream_loops[stream_id] del self.stream_loops[stream_id]
logger.debug(f"清理流循环标记: {stream_id}") logger.debug(f"清理流循环标记: {stream_id}")
# 根据管理器类型释放相应的槽位
management_type = self.stream_management_type.get(stream_id, "fallback")
if management_type == "adaptive":
# 释放自适应管理器的槽位
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
adaptive_manager.release_stream_slot(stream_id)
logger.debug(f"释放自适应流处理槽位: {stream_id}")
except Exception as e:
logger.debug(f"释放自适应流处理槽位失败: {e}")
else:
logger.debug(f"{stream_id} 使用回退方案,无需释放自适应槽位")
# 清理管理器类型记录
if stream_id in self.stream_management_type:
del self.stream_management_type[stream_id]
logger.info(f"流循环结束: {stream_id}") logger.info(f"流循环结束: {stream_id}")
async def _get_stream_context(self, stream_id: str) -> Any | None: async def _get_stream_context(self, stream_id: str) -> Any | None:

View File

@@ -56,6 +56,30 @@ class MessageManager:
self.is_running = True self.is_running = True
# 启动批量数据库写入器
try:
from src.chat.message_manager.batch_database_writer import init_batch_writer
await init_batch_writer()
logger.info("📦 批量数据库写入器已启动")
except Exception as e:
logger.error(f"启动批量数据库写入器失败: {e}")
# 启动流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
await init_stream_cache_manager()
logger.info("🗄️ 流缓存管理器已启动")
except Exception as e:
logger.error(f"启动流缓存管理器失败: {e}")
# 启动自适应流管理器
try:
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
await init_adaptive_stream_manager()
logger.info("🎯 自适应流管理器已启动")
except Exception as e:
logger.error(f"启动自适应流管理器失败: {e}")
# 启动睡眠和唤醒管理器 # 启动睡眠和唤醒管理器
await self.wakeup_manager.start() await self.wakeup_manager.start()
@@ -72,6 +96,30 @@ class MessageManager:
self.is_running = False self.is_running = False
# 停止批量数据库写入器
try:
from src.chat.message_manager.batch_database_writer import shutdown_batch_writer
await shutdown_batch_writer()
logger.info("📦 批量数据库写入器已停止")
except Exception as e:
logger.error(f"停止批量数据库写入器失败: {e}")
# 停止流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
await shutdown_stream_cache_manager()
logger.info("🗄️ 流缓存管理器已停止")
except Exception as e:
logger.error(f"停止流缓存管理器失败: {e}")
# 停止自适应流管理器
try:
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
await shutdown_adaptive_stream_manager()
logger.info("🎯 自适应流管理器已停止")
except Exception as e:
logger.error(f"停止自适应流管理器失败: {e}")
# 停止睡眠和唤醒管理器 # 停止睡眠和唤醒管理器
await self.wakeup_manager.stop() await self.wakeup_manager.stop()

View File

@@ -0,0 +1,381 @@
"""
流缓存管理器 - 使用优化版聊天流和智能缓存策略
提供分层缓存和自动清理功能
"""
import asyncio
import time
from typing import Dict, List, Optional, Set
from dataclasses import dataclass
from collections import OrderedDict
from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
logger = get_logger("stream_cache_manager")
@dataclass
class StreamCacheStats:
"""缓存统计信息"""
hot_cache_size: int = 0
warm_storage_size: int = 0
cold_storage_size: int = 0
total_memory_usage: int = 0 # 估算的内存使用(字节)
cache_hits: int = 0
cache_misses: int = 0
evictions: int = 0
last_cleanup_time: float = 0
class TieredStreamCache:
"""分层流缓存管理器"""
def __init__(
self,
max_hot_size: int = 100,
max_warm_size: int = 500,
max_cold_size: int = 2000,
cleanup_interval: float = 300.0, # 5分钟清理一次
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
cold_timeout: float = 86400.0, # 24小时未访问删除
):
self.max_hot_size = max_hot_size
self.max_warm_size = max_warm_size
self.max_cold_size = max_cold_size
self.cleanup_interval = cleanup_interval
self.hot_timeout = hot_timeout
self.warm_timeout = warm_timeout
self.cold_timeout = cold_timeout
# 三层缓存存储
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据LRU
self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
# 统计信息
self.stats = StreamCacheStats()
# 清理任务
self.cleanup_task: Optional[asyncio.Task] = None
self.is_running = False
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
async def start(self):
"""启动缓存管理器"""
if self.is_running:
logger.warning("缓存管理器已经在运行")
return
self.is_running = True
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
logger.info("分层流缓存管理器已启动")
async def stop(self):
"""停止缓存管理器"""
if not self.is_running:
return
self.is_running = False
if self.cleanup_task and not self.cleanup_task.done():
self.cleanup_task.cancel()
try:
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("缓存清理任务停止超时")
except Exception as e:
logger.error(f"停止缓存清理任务时出错: {e}")
logger.info("分层流缓存管理器已停止")
async def get_or_create_stream(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
) -> OptimizedChatStream:
"""获取或创建流 - 优化版本"""
current_time = time.time()
# 1. 检查热缓存
if stream_id in self.hot_cache:
stream = self.hot_cache[stream_id]
# 移动到末尾LRU更新
self.hot_cache.move_to_end(stream_id)
self.stats.cache_hits += 1
logger.debug(f"热缓存命中: {stream_id}")
return stream.create_snapshot()
# 2. 检查温存储
if stream_id in self.warm_storage:
stream, last_access = self.warm_storage[stream_id]
self.warm_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"温缓存命中: {stream_id}")
# 提升到热缓存
await self._promote_to_hot(stream_id, stream)
return stream.create_snapshot()
# 3. 检查冷存储
if stream_id in self.cold_storage:
stream, last_access = self.cold_storage[stream_id]
self.cold_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"冷缓存命中: {stream_id}")
# 提升到温缓存
await self._promote_to_warm(stream_id, stream)
return stream.create_snapshot()
# 4. 缓存未命中,创建新流
self.stats.cache_misses += 1
stream = create_optimized_chat_stream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
data=data
)
logger.debug(f"缓存未命中,创建新流: {stream_id}")
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
return stream
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""添加到热缓存"""
# 检查是否需要驱逐
if len(self.hot_cache) >= self.max_hot_size:
await self._evict_from_hot()
self.hot_cache[stream_id] = stream
self.stats.hot_cache_size = len(self.hot_cache)
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""提升到热缓存"""
# 从温存储中移除
if stream_id in self.warm_storage:
del self.warm_storage[stream_id]
self.stats.warm_storage_size = len(self.warm_storage)
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
logger.debug(f"{stream_id} 提升到热缓存")
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
"""提升到温缓存"""
# 从冷存储中移除
if stream_id in self.cold_storage:
del self.cold_storage[stream_id]
self.stats.cold_storage_size = len(self.cold_storage)
# 添加到温存储
if len(self.warm_storage) >= self.max_warm_size:
await self._evict_from_warm()
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
logger.debug(f"{stream_id} 提升到温缓存")
async def _evict_from_hot(self):
"""从热缓存驱逐最久未使用的流"""
if not self.hot_cache:
return
# LRU驱逐
stream_id, stream = self.hot_cache.popitem(last=False)
self.stats.evictions += 1
logger.debug(f"从热缓存驱逐: {stream_id}")
# 移动到温存储
if len(self.warm_storage) < self.max_warm_size:
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
else:
# 温存储也满了,直接删除
logger.debug(f"温存储已满,删除流: {stream_id}")
self.stats.hot_cache_size = len(self.hot_cache)
async def _evict_from_warm(self):
"""从温存储驱逐最久未使用的流"""
if not self.warm_storage:
return
# 找到最久未访问的流
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
stream, last_access = self.warm_storage.pop(oldest_stream_id)
self.stats.evictions += 1
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
# 移动到冷存储
if len(self.cold_storage) < self.max_cold_size:
current_time = time.time()
self.cold_storage[oldest_stream_id] = (stream, current_time)
self.stats.cold_storage_size = len(self.cold_storage)
else:
# 冷存储也满了,直接删除
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
self.stats.warm_storage_size = len(self.warm_storage)
async def _cleanup_loop(self):
"""清理循环"""
logger.info("流缓存清理循环启动")
while self.is_running:
try:
await asyncio.sleep(self.cleanup_interval)
await self._perform_cleanup()
except asyncio.CancelledError:
logger.info("流缓存清理循环被取消")
break
except Exception as e:
logger.error(f"流缓存清理出错: {e}")
logger.info("流缓存清理循环结束")
async def _perform_cleanup(self):
"""执行清理操作"""
current_time = time.time()
cleanup_stats = {
"hot_to_warm": 0,
"warm_to_cold": 0,
"cold_removed": 0,
}
# 1. 检查热缓存超时
hot_to_demote = []
for stream_id, stream in self.hot_cache.items():
# 获取最后访问时间(简化:使用创建时间作为近似)
last_access = getattr(stream, 'last_active_time', stream.create_time)
if current_time - last_access > self.hot_timeout:
hot_to_demote.append(stream_id)
for stream_id in hot_to_demote:
stream = self.hot_cache.pop(stream_id)
current_time_local = time.time()
self.warm_storage[stream_id] = (stream, current_time_local)
cleanup_stats["hot_to_warm"] += 1
# 2. 检查温存储超时
warm_to_demote = []
for stream_id, (stream, last_access) in self.warm_storage.items():
if current_time - last_access > self.warm_timeout:
warm_to_demote.append(stream_id)
for stream_id in warm_to_demote:
stream, last_access = self.warm_storage.pop(stream_id)
self.cold_storage[stream_id] = (stream, last_access)
cleanup_stats["warm_to_cold"] += 1
# 3. 检查冷存储超时
cold_to_remove = []
for stream_id, (stream, last_access) in self.cold_storage.items():
if current_time - last_access > self.cold_timeout:
cold_to_remove.append(stream_id)
for stream_id in cold_to_remove:
self.cold_storage.pop(stream_id)
cleanup_stats["cold_removed"] += 1
# 更新统计信息
self.stats.hot_cache_size = len(self.hot_cache)
self.stats.warm_storage_size = len(self.warm_storage)
self.stats.cold_storage_size = len(self.cold_storage)
self.stats.last_cleanup_time = current_time
# 估算内存使用(粗略估计)
self.stats.total_memory_usage = (
len(self.hot_cache) * 1024 + # 每个热流约1KB
len(self.warm_storage) * 512 + # 每个温流约512B
len(self.cold_storage) * 256 # 每个冷流约256B
)
if sum(cleanup_stats.values()) > 0:
logger.info(
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
f"{cleanup_stats['warm_to_cold']}温→冷, "
f"{cleanup_stats['cold_removed']}冷删除"
)
def get_stats(self) -> StreamCacheStats:
"""获取缓存统计信息"""
# 计算命中率
total_requests = self.stats.cache_hits + self.stats.cache_misses
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
stats_copy = StreamCacheStats(
hot_cache_size=self.stats.hot_cache_size,
warm_storage_size=self.stats.warm_storage_size,
cold_storage_size=self.stats.cold_storage_size,
total_memory_usage=self.stats.total_memory_usage,
cache_hits=self.stats.cache_hits,
cache_misses=self.stats.cache_misses,
evictions=self.stats.evictions,
last_cleanup_time=self.stats.last_cleanup_time,
)
# 添加命中率信息
stats_copy.hit_rate = hit_rate
return stats_copy
def clear_cache(self):
"""清空所有缓存"""
self.hot_cache.clear()
self.warm_storage.clear()
self.cold_storage.clear()
self.stats.hot_cache_size = 0
self.stats.warm_storage_size = 0
self.stats.cold_storage_size = 0
self.stats.total_memory_usage = 0
logger.info("所有缓存已清空")
async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]:
"""获取流的快照(不修改缓存状态)"""
if stream_id in self.hot_cache:
return self.hot_cache[stream_id].create_snapshot()
elif stream_id in self.warm_storage:
return self.warm_storage[stream_id][0].create_snapshot()
elif stream_id in self.cold_storage:
return self.cold_storage[stream_id][0].create_snapshot()
return None
def get_cached_stream_ids(self) -> Set[str]:
"""获取所有缓存的流ID"""
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
# 全局缓存管理器实例
_cache_manager: Optional[TieredStreamCache] = None
def get_stream_cache_manager() -> TieredStreamCache:
"""获取流缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
_cache_manager = TieredStreamCache()
return _cache_manager
async def init_stream_cache_manager():
"""初始化流缓存管理器"""
manager = get_stream_cache_manager()
await manager.start()
async def shutdown_stream_cache_manager():
"""关闭流缓存管理器"""
manager = get_stream_cache_manager()
await manager.stop()

View File

@@ -464,7 +464,7 @@ class ChatManager:
async def get_or_create_stream( async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
) -> ChatStream: ) -> ChatStream:
"""获取或创建聊天流 """获取或创建聊天流 - 优化版本使用缓存管理器
Args: Args:
platform: 平台标识 platform: 平台标识
@@ -478,6 +478,31 @@ class ChatManager:
try: try:
stream_id = self._generate_stream_id(platform, user_info, group_info) stream_id = self._generate_stream_id(platform, user_info, group_info)
# 优先使用缓存管理器(优化版本)
try:
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
cache_manager = get_stream_cache_manager()
if cache_manager.is_running:
optimized_stream = await cache_manager.get_or_create_stream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
# 设置消息上下文
from .message import MessageRecv
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
optimized_stream.set_context(self.last_messages[stream_id])
# 转换为原始ChatStream以保持兼容性
return self._convert_to_original_stream(optimized_stream)
except Exception as e:
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
# 回退到原始方法
# 检查内存中是否存在 # 检查内存中是否存在
if stream_id in self.streams: if stream_id in self.streams:
stream = self.streams[stream_id] stream = self.streams[stream_id]
@@ -634,12 +659,35 @@ class ChatManager:
@staticmethod @staticmethod
async def _save_stream(stream: ChatStream): async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库 - 优化版本使用异步批量写入"""
if stream.saved: if stream.saved:
return return
stream_data_dict = stream.to_dict() stream_data_dict = stream.to_dict()
# 尝试使用数据库批量调度 # 优先使用新的批量写入
try:
from src.chat.message_manager.batch_database_writer import get_batch_writer
batch_writer = get_batch_writer()
if batch_writer.is_running:
success = await batch_writer.schedule_stream_update(
stream_id=stream_data_dict["stream_id"],
update_data=ChatManager._prepare_stream_data(stream_data_dict),
priority=1 # 流更新的优先级
)
if success:
stream.saved = True
logger.debug(f"聊天流 {stream.stream_id} 通过批量写入器调度成功")
return
else:
logger.warning(f"批量写入器队列已满,使用原始方法: {stream.stream_id}")
else:
logger.debug(f"批量写入器未运行,使用原始方法: {stream.stream_id}")
except Exception as e:
logger.debug(f"批量写入器保存聊天流失败,使用原始方法: {e}")
# 尝试使用数据库批量调度器回退方案1
try: try:
from src.common.database.db_batch_scheduler import batch_update, get_batch_session from src.common.database.db_batch_scheduler import batch_update, get_batch_session
@@ -657,7 +705,7 @@ class ChatManager:
except (ImportError, Exception) as e: except (ImportError, Exception) as e:
logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}") logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}")
# 回退到原始方法 # 回退到原始方法(最终方案)
async def _db_save_stream_async(s_data_dict: dict): async def _db_save_stream_async(s_data_dict: dict):
async with get_db_session() as session: async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info") user_info_d = s_data_dict.get("user_info")
@@ -782,6 +830,46 @@ class ChatManager:
chat_manager = None chat_manager = None
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
try:
# 创建原始ChatStream实例
original_stream = ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info()
)
# 复制状态
original_stream.create_time = optimized_stream.create_time
original_stream.last_active_time = optimized_stream.last_active_time
original_stream.sleep_pressure = optimized_stream.sleep_pressure
original_stream.base_interest_energy = optimized_stream.base_interest_energy
original_stream._focus_energy = optimized_stream._focus_energy
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
original_stream.saved = optimized_stream.saved
# 复制上下文信息(如果存在)
if hasattr(optimized_stream, '_stream_context') and optimized_stream._stream_context:
original_stream.stream_context = optimized_stream._stream_context
if hasattr(optimized_stream, '_context_manager') and optimized_stream._context_manager:
original_stream.context_manager = optimized_stream._context_manager
return original_stream
except Exception as e:
logger.error(f"转换OptimizedChatStream失败: {e}")
# 如果转换失败,创建一个新的原始流
return ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info()
)
def get_chat_manager(): def get_chat_manager():
global chat_manager global chat_manager
if chat_manager is None: if chat_manager is None:

View File

@@ -0,0 +1,494 @@
"""
优化版聊天流 - 实现写时复制机制
避免不必要的深拷贝开销,提升多流并发性能
"""
import asyncio
import copy
import hashlib
import time
from typing import TYPE_CHECKING, Any, Dict, Optional
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from .message import MessageRecv
install(extra_lines=3)
logger = get_logger("optimized_chat_stream")
class SharedContext:
"""共享上下文数据 - 只读数据结构"""
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = time.time()
self._frozen = True
def __setattr__(self, name, value):
if hasattr(self, '_frozen') and self._frozen and name not in ['_frozen']:
raise AttributeError(f"SharedContext is frozen, cannot modify {name}")
super().__setattr__(name, value)
class LocalChanges:
"""本地修改跟踪器"""
def __init__(self):
self._changes: Dict[str, Any] = {}
self._dirty = False
def set_change(self, key: str, value: Any):
"""设置修改项"""
self._changes[key] = value
self._dirty = True
def get_change(self, key: str, default: Any = None) -> Any:
"""获取修改项"""
return self._changes.get(key, default)
def has_changes(self) -> bool:
"""是否有修改"""
return self._dirty
def get_changes(self) -> Dict[str, Any]:
"""获取所有修改"""
return self._changes.copy()
def clear_changes(self):
"""清除修改记录"""
self._changes.clear()
self._dirty = False
class OptimizedChatStream:
"""优化版聊天流 - 使用写时复制机制"""
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
):
# 共享的只读数据
self._shared_context = SharedContext(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
# 本地修改数据
self._local_changes = LocalChanges()
# 写时复制标志
self._copy_on_write = False
# 基础参数
self.base_interest_energy = data.get("base_interest_energy", 0.5) if data else 0.5
self._focus_energy = data.get("focus_energy", 0.5) if data else 0.5
self.no_reply_consecutive = 0
# 创建StreamContext延迟创建
self._stream_context = None
self._context_manager = None
# 更新活跃时间
self.update_active_time()
# 保存标志
self.saved = False
@property
def stream_id(self) -> str:
return self._shared_context.stream_id
@property
def platform(self) -> str:
return self._shared_context.platform
@property
def user_info(self) -> UserInfo:
return self._shared_context.user_info
@user_info.setter
def user_info(self, value: UserInfo):
"""修改用户信息时触发写时复制"""
self._ensure_copy_on_write()
# 由于SharedContext是frozen的我们需要在本地修改中记录
self._local_changes.set_change('user_info', value)
@property
def group_info(self) -> Optional[GroupInfo]:
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
return self._local_changes.get_change('group_info')
return self._shared_context.group_info
@group_info.setter
def group_info(self, value: Optional[GroupInfo]):
"""修改群组信息时触发写时复制"""
self._ensure_copy_on_write()
self._local_changes.set_change('group_info', value)
@property
def create_time(self) -> float:
if self._local_changes.has_changes() and 'create_time' in self._local_changes._changes:
return self._local_changes.get_change('create_time')
return self._shared_context.create_time
@property
def last_active_time(self) -> float:
return self._local_changes.get_change('last_active_time', self.create_time)
@last_active_time.setter
def last_active_time(self, value: float):
self._local_changes.set_change('last_active_time', value)
self.saved = False
@property
def sleep_pressure(self) -> float:
return self._local_changes.get_change('sleep_pressure', 0.0)
@sleep_pressure.setter
def sleep_pressure(self, value: float):
self._local_changes.set_change('sleep_pressure', value)
self.saved = False
def _ensure_copy_on_write(self):
"""确保写时复制机制生效"""
if not self._copy_on_write:
self._copy_on_write = True
# 深拷贝共享上下文到本地
logger.debug(f"触发写时复制: {self.stream_id}")
def _get_effective_user_info(self) -> UserInfo:
"""获取有效的用户信息"""
if self._local_changes.has_changes() and 'user_info' in self._local_changes._changes:
return self._local_changes.get_change('user_info')
return self._shared_context.user_info
def _get_effective_group_info(self) -> Optional[GroupInfo]:
"""获取有效的群组信息"""
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
return self._local_changes.get_change('group_info')
return self._shared_context.group_info
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = time.time()
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
# 确保stream_context存在
if self._stream_context is None:
self._ensure_copy_on_write()
self._create_stream_context()
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json
from src.common.data_models.database_data_model import DatabaseMessages
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
group_info = getattr(message_info, "group_info", {})
reply_to = None
if hasattr(message, "message_segment") and message.message_segment:
reply_to = self._extract_reply_from_segment(message.message_segment)
db_message = DatabaseMessages(
message_id=getattr(message, "message_id", ""),
time=getattr(message, "time", time.time()),
chat_id=self._generate_chat_id(message_info),
reply_to=reply_to,
interest_value=getattr(message, "interest_value", 0.0),
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
if getattr(message, "key_words", None)
else None,
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
if getattr(message, "key_words_lite", None)
else None,
is_mentioned=getattr(message, "is_mentioned", None),
is_at=getattr(message, "is_at", False),
is_emoji=getattr(message, "is_emoji", False),
is_picid=getattr(message, "is_picid", False),
is_voice=getattr(message, "is_voice", False),
is_video=getattr(message, "is_video", False),
is_command=getattr(message, "is_command", False),
is_notify=getattr(message, "is_notify", False),
processed_plain_text=getattr(message, "processed_plain_text", ""),
display_message=getattr(message, "processed_plain_text", ""),
priority_mode=getattr(message, "priority_mode", None),
priority_info=json.dumps(getattr(message, "priority_info", None))
if getattr(message, "priority_info", None)
else None,
additional_config=getattr(message_info, "additional_config", None),
user_id=str(getattr(user_info, "user_id", "")),
user_nickname=getattr(user_info, "user_nickname", ""),
user_cardname=getattr(user_info, "user_cardname", None),
user_platform=getattr(user_info, "platform", ""),
chat_info_group_id=getattr(group_info, "group_id", None),
chat_info_group_name=getattr(group_info, "group_name", None),
chat_info_group_platform=getattr(group_info, "platform", None),
chat_info_user_id=str(getattr(user_info, "user_id", "")),
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
chat_info_user_platform=getattr(user_info, "platform", ""),
chat_info_stream_id=self.stream_id,
chat_info_platform=self.platform,
chat_info_create_time=self.create_time,
chat_info_last_active_time=self.last_active_time,
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
)
self._stream_context.set_current_message(db_message)
self._stream_context.priority_mode = getattr(message, "priority_mode", None)
self._stream_context.priority_info = getattr(message, "priority_info", None)
logger.debug(
f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, "
f"interest_value: {db_message.interest_value}"
)
def _create_stream_context(self):
"""创建StreamContext"""
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
self._stream_context = StreamContext(
stream_id=self.stream_id,
chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL
)
# 创建单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
self._context_manager = SingleStreamContextManager(
stream_id=self.stream_id, context=self._stream_context
)
@property
def stream_context(self):
"""获取StreamContext"""
if self._stream_context is None:
self._ensure_copy_on_write()
self._create_stream_context()
return self._stream_context
@property
def context_manager(self):
"""获取ContextManager"""
if self._context_manager is None:
self._ensure_copy_on_write()
self._create_stream_context()
return self._context_manager
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式 - 考虑本地修改"""
user_info = self._get_effective_user_info()
group_info = self._get_effective_group_info()
return {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": user_info.to_dict() if user_info else None,
"group_info": group_info.to_dict() if group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
"sleep_pressure": self.sleep_pressure,
"focus_energy": self.focus_energy,
"base_interest_energy": self.base_interest_energy,
"stream_context_chat_type": self.stream_context.chat_type.value,
"stream_context_chat_mode": self.stream_context.chat_mode.value,
"interruption_count": self.stream_context.interruption_count,
}
@classmethod
def from_dict(cls, data: Dict) -> "OptimizedChatStream":
"""从字典创建实例"""
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
instance = cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info, # type: ignore
group_info=group_info,
data=data,
)
# 恢复stream_context信息
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
# 恢复interruption_count信息
if "interruption_count" in data:
instance.stream_context.interruption_count = data["interruption_count"]
return instance
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
"""安全获取消息的actions字段"""
try:
actions = getattr(message, "actions", None)
if actions is None:
return None
if isinstance(actions, str):
try:
import json
actions = json.loads(actions)
except json.JSONDecodeError:
logger.warning(f"无法解析actions JSON字符串: {actions}")
return None
if isinstance(actions, list):
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
return filtered_actions if filtered_actions else None
else:
logger.warning(f"actions字段类型不支持: {type(actions)}")
return None
except Exception as e:
logger.warning(f"获取actions字段失败: {e}")
return None
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = self._extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
def _generate_chat_id(self, message_info) -> str:
"""生成chat_id基于群组或用户信息"""
try:
group_info = getattr(message_info, "group_info", None)
user_info = getattr(message_info, "user_info", None)
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
return f"{self.platform}_{group_info.group_id}"
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
return f"{self.platform}_{user_info.user_id}_private"
else:
return self.stream_id
except Exception as e:
logger.warning(f"生成chat_id失败: {e}")
return self.stream_id
@property
def focus_energy(self) -> float:
"""获取缓存的focus_energy值"""
return self._focus_energy
async def calculate_focus_energy(self) -> float:
"""异步计算focus_energy"""
try:
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
user_id = None
effective_user_info = self._get_effective_user_info()
if effective_user_info and hasattr(effective_user_info, "user_id"):
user_id = str(effective_user_info.user_id)
from src.chat.energy_system import energy_manager
energy = await energy_manager.calculate_focus_energy(
stream_id=self.stream_id, messages=all_messages, user_id=user_id
)
self._focus_energy = energy
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
return energy
except Exception as e:
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
return self._focus_energy
@focus_energy.setter
def focus_energy(self, value: float):
"""设置focus_energy值"""
self._focus_energy = max(0.0, min(1.0, value))
async def _get_user_relationship_score(self) -> float:
"""获取用户关系分"""
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
effective_user_info = self._get_effective_user_info()
if effective_user_info and hasattr(effective_user_info, "user_id"):
user_id = str(effective_user_info.user_id)
relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id)
logger.debug(f"OptimizedChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
return max(0.0, min(1.0, relationship_score))
except Exception as e:
logger.warning(f"OptimizedChatStream {self.stream_id}: 插件内部关系分计算失败: {e}")
return 0.3
def create_snapshot(self) -> "OptimizedChatStream":
"""创建当前状态的快照(用于缓存)"""
# 创建一个新的实例,共享相同的上下文
snapshot = OptimizedChatStream(
stream_id=self.stream_id,
platform=self.platform,
user_info=self._get_effective_user_info(),
group_info=self._get_effective_group_info()
)
# 复制本地修改(但不触发写时复制)
snapshot._local_changes._changes = self._local_changes.get_changes()
snapshot._local_changes._dirty = self._local_changes._dirty
snapshot._focus_energy = self._focus_energy
snapshot.base_interest_energy = self.base_interest_energy
snapshot.no_reply_consecutive = self.no_reply_consecutive
snapshot.saved = self.saved
return snapshot
# 为了向后兼容,创建一个工厂函数
def create_optimized_chat_stream(
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
) -> OptimizedChatStream:
"""创建优化版聊天流实例"""
return OptimizedChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
data=data
)

View File

@@ -337,6 +337,41 @@ class MemoryConfig(ValidatedConfigBase):
# 休眠机制 # 休眠机制
dormant_threshold_days: int = Field(default=90, description="休眠状态判定天数") dormant_threshold_days: int = Field(default=90, description="休眠状态判定天数")
# === 混合记忆系统配置 ===
# 采样模式配置
memory_sampling_mode: Literal["adaptive", "hippocampus", "precision"] = Field(
default="adaptive", description="记忆采样模式adaptive(自适应)hippocampus(海马体双峰采样)precision(精准记忆)"
)
# 海马体双峰采样配置
enable_hippocampus_sampling: bool = Field(default=True, description="启用海马体双峰采样策略")
hippocampus_sample_interval: int = Field(default=1800, description="海马体采样间隔默认30分钟")
hippocampus_sample_size: int = Field(default=30, description="海马体每次采样的消息数量")
hippocampus_batch_size: int = Field(default=5, description="海马体每批处理的记忆数量")
# 双峰分布配置 [近期均值, 近期标准差, 近期权重, 远期均值, 远期标准差, 远期权重]
hippocampus_distribution_config: list[float] = Field(
default=[12.0, 8.0, 0.7, 48.0, 24.0, 0.3],
description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]"
)
# 自适应采样配置
adaptive_sampling_enabled: bool = Field(default=True, description="启用自适应采样策略")
adaptive_sampling_threshold: float = Field(default=0.8, description="自适应采样负载阈值0-1")
adaptive_sampling_check_interval: int = Field(default=300, description="自适应采样检查间隔(秒)")
adaptive_sampling_max_concurrent_builds: int = Field(default=3, description="自适应采样最大并发记忆构建数")
# 精准记忆配置(现有系统的增强版本)
precision_memory_reply_threshold: float = Field(
default=0.6, description="精准记忆回复触发阈值(对话价值评分超过此值时触发记忆构建)"
)
precision_memory_max_builds_per_hour: int = Field(default=10, description="精准记忆每小时最大构建数量")
# 混合系统优化配置
memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡")
memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流")
memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列")
class MoodConfig(ValidatedConfigBase): class MoodConfig(ValidatedConfigBase):
"""情绪配置类""" """情绪配置类"""

View File

@@ -21,76 +21,69 @@ logger = get_logger(__name__)
class ColdStartTask(AsyncTask): class ColdStartTask(AsyncTask):
""" """
冷启动任务,专门用于处理那些在白名单里,但从未与机器人发生过交互的用户 冷启动任务,在机器人启动时执行一次
它的核心职责是“破冰”,主动创建聊天流并发起第一次问候 它的核心职责是“唤醒”那些因重启而“沉睡”的聊天流,确保它们能够接收主动思考
对于在白名单中但从未有过记录的全新用户,它也会发起第一次“破冰”问候。
""" """
def __init__(self): def __init__(self, bot_start_time: float):
super().__init__(task_name="ColdStartTask") super().__init__(task_name="ColdStartTask")
self.chat_manager = get_chat_manager() self.chat_manager = get_chat_manager()
self.executor = ProactiveThinkerExecutor() self.executor = ProactiveThinkerExecutor()
self.bot_start_time = bot_start_time
async def run(self): async def run(self):
"""任务主循环,周期性地检查是否有需要“破冰”的新用户""" """任务主逻辑,在启动后执行一次白名单扫描"""
logger.info("冷启动任务已启动,将周期性检查白名单中的新朋友。") logger.info("冷启动任务已启动,将在短暂延迟后开始唤醒沉睡的聊天流...")
# 初始等待一段时间,确保其他服务(如数据库)完全启动 await asyncio.sleep(30) # 延迟以确保所有服务和聊天流已从数据库加载完毕
await asyncio.sleep(100)
while True: try:
try: logger.info("【冷启动】开始扫描白名单,唤醒沉睡的聊天流...")
#开始就先暂停一小时,等bot聊一会再说()
await asyncio.sleep(3600)
logger.info("【冷启动】开始扫描白名单,寻找从未聊过的用户...")
# 从全局配置中获取私聊白名单 enabled_private_chats = global_config.proactive_thinking.enabled_private_chats
enabled_private_chats = global_config.proactive_thinking.enabled_private_chats if not enabled_private_chats:
if not enabled_private_chats: logger.debug("【冷启动】私聊白名单为空,任务结束。")
logger.debug("【冷启动】私聊白名单为空,任务暂停一小时。") return
await asyncio.sleep(3600) # 白名单为空时,没必要频繁检查
continue
# 遍历白名单中的每一个用户 for chat_id in enabled_private_chats:
for chat_id in enabled_private_chats: try:
try: platform, user_id_str = chat_id.split(":")
platform, user_id_str = chat_id.split(":") user_id = int(user_id_str)
user_id = int(user_id_str)
# 【核心逻辑】使用 chat_api 检查该用户是否已经存在聊天流ChatStream should_wake_up = False
# 如果返回了 ChatStream 对象,说明已经聊过天了,不是本次任务的目标 stream = chat_api.get_stream_by_user_id(user_id_str, platform)
if chat_api.get_stream_by_user_id(user_id_str, platform):
continue # 跳过已存在的用户
logger.info(f"【冷启动】发现白名单新用户 {chat_id},准备发起第一次问候。") if not stream:
should_wake_up = True
logger.info(f"【冷启动】发现全新用户 {chat_id},准备发起第一次问候。")
elif stream.last_active_time < self.bot_start_time:
should_wake_up = True
logger.info(f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。")
# 【增强体验】尝试从关系数据库中获取该用户的昵称 if should_wake_up:
# 这样打招呼时可以更亲切而不是只知道一个冷冰冰的ID
person_id = person_api.get_person_id(platform, user_id) person_id = person_api.get_person_id(platform, user_id)
nickname = await person_api.get_person_value(person_id, "nickname") nickname = await person_api.get_person_value(person_id, "nickname")
# 如果数据库里有昵称,就用数据库里的;如果没有,就用 "用户+ID" 作为备用
user_nickname = nickname or f"用户{user_id}" user_nickname = nickname or f"用户{user_id}"
# 创建 UserInfo 对象,这是创建聊天流的必要信息
user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname) user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname)
# 【关键步骤】主动创建聊天流。 # 使用 get_or_create_stream 来安全地获取或创建流
# 创建后,该用户就进入了机器人的“好友列表”,后续将由 ProactiveThinkingTask 接管
stream = await self.chat_manager.get_or_create_stream(platform, user_info) stream = await self.chat_manager.get_or_create_stream(platform, user_info)
formatted_stream_id = f"{stream.user_info.platform}:{stream.user_info.user_id}:private"
await self.executor.execute(stream_id=formatted_stream_id, start_mode="cold_start")
logger.info(f"【冷启动】已为用户 {chat_id} (昵称: {user_nickname}) 发送唤醒/问候消息。")
await self.executor.execute(stream_id=stream.stream_id, start_mode="cold_start") except ValueError:
logger.info(f"【冷启动】已为新用户 {chat_id} (昵称: {user_nickname}) 创建聊天流并发送问候。") logger.warning(f"【冷启动】白名单条目格式错误或用户ID无效已跳过: {chat_id}")
except Exception as e:
logger.error(f"【冷启动】处理用户 {chat_id} 时发生未知错误: {e}", exc_info=True)
except ValueError: except asyncio.CancelledError:
logger.warning(f"冷启动】白名单条目格式错误或用户ID无效已跳过: {chat_id}") logger.info("冷启动任务被正常取消。")
except Exception as e: except Exception as e:
logger.error(f"【冷启动】处理用户 {chat_id} 时发生未知错误: {e}", exc_info=True) logger.error(f"【冷启动】任务出现严重错误: {e}", exc_info=True)
finally:
except asyncio.CancelledError: logger.info("【冷启动】任务执行完毕。")
logger.info("冷启动任务被正常取消。")
break
except Exception as e:
logger.error(f"【冷启动】任务出现严重错误将在5分钟后重试: {e}", exc_info=True)
await asyncio.sleep(300)
class ProactiveThinkingTask(AsyncTask): class ProactiveThinkingTask(AsyncTask):
@@ -222,13 +215,15 @@ class ProactiveThinkerEventHandler(BaseEventHandler):
logger.info("检测到插件启动事件,正在初始化【主动思考】") logger.info("检测到插件启动事件,正在初始化【主动思考】")
# 检查总开关 # 检查总开关
if global_config.proactive_thinking.enable: if global_config.proactive_thinking.enable:
bot_start_time = time.time() # 记录“诞生时刻”
# 启动负责“日常唤醒”的核心任务 # 启动负责“日常唤醒”的核心任务
proactive_task = ProactiveThinkingTask() proactive_task = ProactiveThinkingTask()
await async_task_manager.add_task(proactive_task) await async_task_manager.add_task(proactive_task)
# 检查“冷启动”功能的独立开关 # 检查“冷启动”功能的独立开关
if global_config.proactive_thinking.enable_cold_start: if global_config.proactive_thinking.enable_cold_start:
cold_start_task = ColdStartTask() cold_start_task = ColdStartTask(bot_start_time)
await async_task_manager.add_task(cold_start_task) await async_task_manager.add_task(cold_start_task)
else: else:

View File

@@ -80,7 +80,7 @@ class ProactiveThinkerExecutor:
plan_prompt = self._build_plan_prompt(context, start_mode, topic, reason) plan_prompt = self._build_plan_prompt(context, start_mode, topic, reason)
is_success, response, _, _ = await llm_api.generate_with_model( is_success, response, _, _ = await llm_api.generate_with_model(
prompt=plan_prompt, model_config=model_config.model_task_config.utils prompt=plan_prompt, model_config=model_config.model_task_config.replyer
) )
if is_success and response: if is_success and response:
@@ -140,7 +140,7 @@ class ProactiveThinkerExecutor:
else "今天没有日程安排。" else "今天没有日程安排。"
) )
recent_messages = await message_api.get_recent_messages(stream_id, limit=10) recent_messages = await message_api.get_recent_messages(stream.stream_id)
recent_chat_history = ( recent_chat_history = (
await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "" await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else ""
) )
@@ -158,12 +158,12 @@ class ProactiveThinkerExecutor:
) )
# 2. 构建基础上下文 # 2. 构建基础上下文
mood_state = "暂时没有"
if global_config.mood.enable_mood: if global_config.mood.enable_mood:
try: try:
mood_state = mood_manager.get_mood_by_chat_id(stream.stream_id).mood_state mood_state = mood_manager.get_mood_by_chat_id(stream.stream_id).mood_state
except Exception as e: except Exception as e:
logger.error(f"获取情绪失败,原因:{e}") logger.error(f"获取情绪失败,原因:{e}")
mood_state = "暂时没有"
base_context = { base_context = {
"schedule_context": schedule_context, "schedule_context": schedule_context,
"recent_chat_history": recent_chat_history, "recent_chat_history": recent_chat_history,
@@ -281,30 +281,48 @@ class ProactiveThinkerExecutor:
# 构建通用尾部 # 构建通用尾部
prompt += """ prompt += """
# 决策指令 # 决策指令
请综合以上所有信息做出决策。你的决策需要以JSON格式输出包含以下字段 请综合以上所有信息,以稳定、真实、拟人的方式做出决策。你的决策需要以JSON格式输出包含以下字段
- `should_reply`: bool, 是否应该发起对话。 - `should_reply`: bool, 是否应该发起对话。
- `topic`: str, 如果 `should_reply` 为 true你打算聊什么话题(例如:问候一下今天的日程、关心一下昨天的某件事、分享一个你自己的趣事等) - `topic`: str, 如果 `should_reply` 为 true你打算聊什么话题
- `reason`: str, 做出此决策的简要理由。 - `reason`: str, 做出此决策的简要理由。
# 决策原则 # 决策原则
- **避免打扰**: 如果你最近(尤其是在最近的几次决策中)已经主动发起对话,请倾向于选择“不回复”,除非有非常重要和紧急的事情 - **谨慎对待未回复的对话**: 在发起新话题前,请检查【最近的聊天摘要】。如果最后一条消息是你自己发送的,请仔细评估等待的时间和上下文,判断再次主动发起对话是否礼貌和自然。如果等待时间很短(例如几分钟或半小时内),通常应该选择“不回复”
- **如果上下文中只有你的消息而没有别人的消息**:选择不回复,以防刷屏或者打扰到别人虽然第一 - **优先利用上下文**: 优先从【情境分析】中已有的信息如最近的聊天摘要、你的日程、你对Ta的关系印象寻找自然的话题切入点。
- **简单问候作为备选**: 如果上下文中没有合适的话题,可以生成一个简单、真诚的日常问候(例如“在忙吗?”,“下午好呀~”)。
- **避免抽象**: 避免创造过于复杂、抽象或需要对方思考很久才能明白的话题。目标是轻松、自然地开启对话。
- **避免过于频繁**: 如果你最近(尤其是在最近的几次决策中)已经主动发起过对话,请倾向于选择“不回复”,除非有非常重要和紧急的事情。
- **如果上下文中只有你的消息而没有别人的消息**:选择不回复,以防刷屏或者打扰到别人
--- ---
示例1 (应该回复): 示例1 (基于上下文):
{{ {{
"should_reply": true, "should_reply": true,
"topic": "提醒大家今天下午有'项目会议'的日程", "topic": "关心一下Ta昨天提到的那个项目进展如何了",
"reason": "现在是上午,下午有个重要会议,我觉得应该主动提醒一下大家,这会显得我很贴心" "reason": "用户昨天在聊天中提到了一个重要的项目,现在主动关心一下进展,会显得很体贴,也能自然地开启对话"
}} }}
示例2 (不应回复): 示例2 (简单问候):
{{
"should_reply": true,
"topic": "打个招呼问问Ta现在在忙些什么",
"reason": "最近没有聊天记录,日程也很常规,没有特别的切入点。一个简单的日常问候是最安全和自然的方式来重新连接。"
}}
示例3 (不应回复 - 过于频繁):
{{ {{
"should_reply": false, "should_reply": false,
"topic": null, "topic": null,
"reason": "虽然群里很活跃,但现在是深夜,而且最近的聊天话题我也不熟悉,没有合适的理由去打扰大家。" "reason": "虽然群里很活跃,但现在是深夜,而且最近的聊天话题我也不熟悉,没有合适的理由去打扰大家。"
}} }}
示例4 (不应回复 - 等待回应):
{{
"should_reply": false,
"topic": null,
"reason": "我注意到上一条消息是我几分钟前主动发送的,对方可能正在忙。为了表现出耐心和体贴,我现在最好保持安静,等待对方的回应。"
}}
--- ---
请输出你的决策: 请输出你的决策:
@@ -369,10 +387,18 @@ class ProactiveThinkerExecutor:
# 决策上下文 # 决策上下文
- **决策理由**: {reason} - **决策理由**: {reason}
- **你和Ta的关系**:
# 情境分析
1. **你的日程**:
{context["schedule_context"]}
2. **你和Ta的关系**:
- 简短印象: {relationship["short_impression"]} - 简短印象: {relationship["short_impression"]}
- 详细印象: {relationship["impression"]} - 详细印象: {relationship["impression"]}
- 好感度: {relationship["attitude"]}/100 - 好感度: {relationship["attitude"]}/100
3. **最近的聊天摘要**:
{context["recent_chat_history"]}
4. **你最近的相关动作**:
{context["action_history_context"]}
# 对话指引 # 对话指引
- 你的目标是“破冰”,让对话自然地开始。 - 你的目标是“破冰”,让对话自然地开始。
@@ -400,6 +426,7 @@ class ProactiveThinkerExecutor:
# 对话指引 # 对话指引
- 你决定和Ta聊聊关于“{topic}”的话题。 - 你决定和Ta聊聊关于“{topic}”的话题。
- **重要**: 在开始你的话题前,必须先用一句通用的、礼貌的开场白进行问候(例如:“在吗?”、“上午好!”、“晚上好呀~”),然后再自然地衔接你的话题,确保整个回复在一条消息内流畅、自然、像人类的说话方式。
- 请结合以上所有情境信息,自然地开启对话。 - 请结合以上所有情境信息,自然地开启对话。
- 你的语气应该符合你的人设({context["mood_state"]})以及你对Ta的好感度。 - 你的语气应该符合你的人设({context["mood_state"]})以及你对Ta的好感度。
""" """
@@ -437,6 +464,7 @@ class ProactiveThinkerExecutor:
# 对话指引 # 对话指引
- 你决定和大家聊聊关于“{topic}”的话题。 - 你决定和大家聊聊关于“{topic}”的话题。
- **重要**: 在开始你的话题前,必须先用一句通用的、礼貌的开场白进行问候(例如:“哈喽,大家好呀~”、“下午好!”),然后再自然地衔接你的话题,确保整个回复在一条消息内流畅、自然、像人类的说话方式。
- 你的语气应该更活泼、更具包容性,以吸引更多群成员参与讨论。你的语气应该符合你的人设)。 - 你的语气应该更活泼、更具包容性,以吸引更多群成员参与讨论。你的语气应该符合你的人设)。
- 请结合以上所有情境信息,自然地开启对话。 - 请结合以上所有情境信息,自然地开启对话。
- 可以分享你的看法、提出相关问题,或者开个合适的玩笑。 - 可以分享你的看法、提出相关问题,或者开个合适的玩笑。

View File

@@ -1,5 +1,5 @@
[inner] [inner]
version = "7.1.5" version = "7.1.6"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读---- #----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值 #如果你想要修改配置文件请递增version的值
@@ -208,6 +208,19 @@ max_context_emojis = 30 # 每次随机传递给LLM的表情包详细描述的最
enable_memory = true # 是否启用记忆系统 enable_memory = true # 是否启用记忆系统
memory_build_interval = 600 # 记忆构建间隔(秒)。间隔越低,学习越频繁,但可能产生更多冗余信息 memory_build_interval = 600 # 记忆构建间隔(秒)。间隔越低,学习越频繁,但可能产生更多冗余信息
# === 记忆采样系统配置 ===
memory_sampling_mode = "immediate" # 记忆采样模式hippocampus(海马体定时采样)immediate(即时采样)all(所有模式)
# 海马体双峰采样配置
enable_hippocampus_sampling = true # 启用海马体双峰采样策略
hippocampus_sample_interval = 1800 # 海马体采样间隔默认30分钟
hippocampus_sample_size = 30 # 海马体采样样本数量
hippocampus_batch_size = 10 # 海马体批量处理大小
hippocampus_distribution_config = [12.0, 8.0, 0.7, 48.0, 24.0, 0.3] # 海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]
# 即时采样配置
precision_memory_reply_threshold = 0.5 # 精准记忆回复阈值0-1高于此值的对话将立即构建记忆
min_memory_length = 10 # 最小记忆长度 min_memory_length = 10 # 最小记忆长度
max_memory_length = 500 # 最大记忆长度 max_memory_length = 500 # 最大记忆长度
memory_value_threshold = 0.5 # 记忆价值阈值,低于该值的记忆会被丢弃 memory_value_threshold = 0.5 # 记忆价值阈值,低于该值的记忆会被丢弃