Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
海马体双峰分布采样器
|
||||
基于旧版海马体的采样策略,适配新版记忆系统
|
||||
@@ -8,16 +7,15 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
@@ -47,7 +45,7 @@ class HippocampusSampleConfig:
|
||||
batch_size: int = 5 # 批处理大小
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls) -> 'HippocampusSampleConfig':
|
||||
def from_global_config(cls) -> "HippocampusSampleConfig":
|
||||
"""从全局配置创建海马体采样配置"""
|
||||
config = global_config.memory.hippocampus_distribution_config
|
||||
return cls(
|
||||
@@ -74,12 +72,12 @@ class HippocampusSampler:
|
||||
self.is_running = False
|
||||
|
||||
# 记忆构建模型
|
||||
self.memory_builder_model: Optional[LLMRequest] = None
|
||||
self.memory_builder_model: LLMRequest | None = None
|
||||
|
||||
# 统计信息
|
||||
self.sample_count = 0
|
||||
self.success_count = 0
|
||||
self.last_sample_results: List[Dict[str, Any]] = []
|
||||
self.last_sample_results: list[dict[str, Any]] = []
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化采样器"""
|
||||
@@ -101,7 +99,7 @@ class HippocampusSampler:
|
||||
logger.error(f"❌ 海马体采样器初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def generate_time_samples(self) -> List[datetime]:
|
||||
def generate_time_samples(self) -> list[datetime]:
|
||||
"""生成双峰分布的时间采样点"""
|
||||
# 计算每个分布的样本数
|
||||
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
|
||||
@@ -132,7 +130,7 @@ class HippocampusSampler:
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
|
||||
async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]:
|
||||
async def collect_message_samples(self, target_timestamp: float) -> list[dict[str, Any]] | None:
|
||||
"""收集指定时间戳附近的消息样本"""
|
||||
try:
|
||||
# 随机时间窗口:5-30分钟
|
||||
@@ -190,7 +188,7 @@ class HippocampusSampler:
|
||||
logger.error(f"收集消息样本失败: {e}")
|
||||
return None
|
||||
|
||||
async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]:
|
||||
async def build_memory_from_samples(self, messages: list[dict[str, Any]], target_timestamp: float) -> str | None:
|
||||
"""从消息样本构建记忆"""
|
||||
if not messages or not self.memory_system or not self.memory_builder_model:
|
||||
return None
|
||||
@@ -262,7 +260,7 @@ class HippocampusSampler:
|
||||
logger.error(f"海马体采样构建记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def perform_sampling_cycle(self) -> Dict[str, Any]:
|
||||
async def perform_sampling_cycle(self) -> dict[str, Any]:
|
||||
"""执行一次完整的采样周期(优化版:批量融合构建)"""
|
||||
if not self.should_sample():
|
||||
return {"status": "skipped", "reason": "interval_not_met"}
|
||||
@@ -363,7 +361,7 @@ class HippocampusSampler:
|
||||
"duration": time.time() - start_time,
|
||||
}
|
||||
|
||||
async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]:
|
||||
async def _collect_all_message_samples(self, time_samples: list[datetime]) -> list[list[dict[str, Any]]]:
|
||||
"""批量收集所有时间点的消息样本"""
|
||||
collected_messages = []
|
||||
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
||||
@@ -394,7 +392,7 @@ class HippocampusSampler:
|
||||
|
||||
return collected_messages
|
||||
|
||||
async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
|
||||
async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
|
||||
"""融合和去重消息样本"""
|
||||
if not collected_messages:
|
||||
return []
|
||||
@@ -450,7 +448,7 @@ class HippocampusSampler:
|
||||
# 返回原始消息组作为备选
|
||||
return collected_messages[:5] # 限制返回数量
|
||||
|
||||
def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]:
|
||||
def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
|
||||
"""合并时间间隔内的消息"""
|
||||
if not messages:
|
||||
return []
|
||||
@@ -481,7 +479,7 @@ class HippocampusSampler:
|
||||
|
||||
return result_groups
|
||||
|
||||
async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]:
|
||||
async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
|
||||
"""批量构建记忆"""
|
||||
if not fused_messages:
|
||||
return {"memory_count": 0, "memories": []}
|
||||
@@ -557,7 +555,7 @@ class HippocampusSampler:
|
||||
logger.error(f"批量构建记忆失败: {e}")
|
||||
return {"memory_count": 0, "error": str(e)}
|
||||
|
||||
async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str:
|
||||
async def _build_fused_conversation_text(self, fused_messages: list[list[dict[str, Any]]]) -> str:
|
||||
"""构建融合后的对话文本"""
|
||||
try:
|
||||
# 添加批次标识
|
||||
@@ -589,7 +587,7 @@ class HippocampusSampler:
|
||||
logger.error(f"构建融合文本失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
async def _fallback_individual_build(self, fused_messages: list[list[dict[str, Any]]]) -> dict[str, Any]:
|
||||
"""备选方案:单独构建每个消息组"""
|
||||
total_memories = []
|
||||
total_count = 0
|
||||
@@ -609,7 +607,7 @@ class HippocampusSampler:
|
||||
"fallback_mode": True
|
||||
}
|
||||
|
||||
async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]:
|
||||
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
|
||||
"""处理单个时间戳采样(保留作为备选方法)"""
|
||||
try:
|
||||
# 收集消息样本
|
||||
@@ -676,7 +674,7 @@ class HippocampusSampler:
|
||||
self.is_running = False
|
||||
logger.info("🛑 停止海马体后台采样任务")
|
||||
|
||||
def get_sampling_stats(self) -> Dict[str, Any]:
|
||||
def get_sampling_stats(self) -> dict[str, Any]:
|
||||
"""获取采样统计信息"""
|
||||
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
|
||||
|
||||
@@ -713,7 +711,7 @@ class HippocampusSampler:
|
||||
|
||||
|
||||
# 全局海马体采样器实例
|
||||
_hippocampus_sampler: Optional[HippocampusSampler] = None
|
||||
_hippocampus_sampler: HippocampusSampler | None = None
|
||||
|
||||
|
||||
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
|
||||
@@ -728,4 +726,4 @@ async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampl
|
||||
"""初始化全局海马体采样器"""
|
||||
sampler = get_hippocampus_sampler(memory_system)
|
||||
await sampler.initialize()
|
||||
return sampler
|
||||
return sampler
|
||||
|
||||
@@ -32,7 +32,7 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Type, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
E = TypeVar("E", bound=Enum)
|
||||
|
||||
@@ -503,7 +503,7 @@ class MemoryBuilder:
|
||||
logger.warning(f"无法解析未知的记忆类型 '{type_str}',回退到上下文类型")
|
||||
return MemoryType.CONTEXTUAL
|
||||
|
||||
def _parse_enum_value(self, enum_cls: Type[E], raw_value: Any, default: E, field_name: str) -> E:
|
||||
def _parse_enum_value(self, enum_cls: type[E], raw_value: Any, default: E, field_name: str) -> E:
|
||||
"""解析枚举值,兼容数字/字符串表示"""
|
||||
if isinstance(raw_value, enum_cls):
|
||||
return raw_value
|
||||
|
||||
@@ -215,8 +215,8 @@ class MemoryFusionEngine:
|
||||
if not keywords1 or not keywords2:
|
||||
return 0.0
|
||||
|
||||
set1 = set(k.lower() for k in keywords1)
|
||||
set2 = set(k.lower() for k in keywords2)
|
||||
set1 = set(k.lower() for k in keywords1) # noqa: C401
|
||||
set2 = set(k.lower() for k in keywords2) # noqa: C401
|
||||
|
||||
intersection = set1 & set2
|
||||
union = set1 | set2
|
||||
|
||||
@@ -69,14 +69,11 @@ class MemoryManager:
|
||||
# 初始化记忆系统
|
||||
self.memory_system = await initialize_memory_system(llm_model)
|
||||
|
||||
# 设置全局实例
|
||||
global_memory_manager = self.memory_system
|
||||
|
||||
self.is_initialized = True
|
||||
logger.info("✅ 记忆系统初始化完成")
|
||||
logger.info(" 记忆系统初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆系统初始化失败: {e}")
|
||||
logger.error(f"记忆系统初始化失败: {e}")
|
||||
# 如果系统初始化失败,创建一个空的管理器避免系统崩溃
|
||||
self.memory_system = None
|
||||
self.is_initialized = True # 标记为已初始化但系统不可用
|
||||
@@ -439,7 +436,7 @@ class MemoryManager:
|
||||
formatted_items = [self._format_object(item) for item in obj]
|
||||
filtered = [item for item in formatted_items if item]
|
||||
return self._clean_text("、".join(filtered)) if filtered else ""
|
||||
if isinstance(obj, (int, float)):
|
||||
if isinstance(obj, int | float):
|
||||
return str(obj)
|
||||
text = self._truncate(str(obj).strip())
|
||||
return self._clean_text(text)
|
||||
@@ -449,12 +446,12 @@ class MemoryManager:
|
||||
for key in keys:
|
||||
if obj.get(key):
|
||||
value = obj[key]
|
||||
if isinstance(value, (dict, list)):
|
||||
if isinstance(value, dict | list):
|
||||
return self._clean_text(self._format_object(value))
|
||||
return self._clean_text(value)
|
||||
if isinstance(obj, list) and obj:
|
||||
return self._clean_text(self._format_object(obj[0]))
|
||||
if isinstance(obj, (str, int, float)):
|
||||
if isinstance(obj, str | int | float):
|
||||
return self._clean_text(obj)
|
||||
return None
|
||||
|
||||
@@ -471,7 +468,7 @@ class MemoryManager:
|
||||
try:
|
||||
if self.memory_system:
|
||||
await self.memory_system.shutdown()
|
||||
logger.info("✅ 记忆系统已关闭")
|
||||
logger.info(" 记忆系统已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭记忆系统失败: {e}")
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
|
||||
|
||||
# 记忆采样模式枚举
|
||||
class MemorySamplingMode(Enum):
|
||||
"""记忆采样模式"""
|
||||
@@ -31,9 +33,10 @@ from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.chat.memory_system.vector_memory_storage_v2 import VectorMemoryStorage
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = get_logger("memory_system")
|
||||
|
||||
# 全局记忆作用域(共享记忆库)
|
||||
GLOBAL_MEMORY_SCOPE = "global"
|
||||
@@ -133,15 +136,15 @@ class MemorySystem:
|
||||
self.status = MemorySystemStatus.INITIALIZING
|
||||
|
||||
# 核心组件(简化版)
|
||||
self.memory_builder: MemoryBuilder = None
|
||||
self.fusion_engine: MemoryFusionEngine = None
|
||||
self.unified_storage = None # 统一存储系统
|
||||
self.query_planner: MemoryQueryPlanner = None
|
||||
self.memory_builder: MemoryBuilder | None = None
|
||||
self.fusion_engine: MemoryFusionEngine | None = None
|
||||
self.unified_storage: VectorMemoryStorage | None = None # 统一存储系统
|
||||
self.query_planner: MemoryQueryPlanner | None = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
|
||||
# LLM模型
|
||||
self.value_assessment_model: LLMRequest = None
|
||||
self.memory_extraction_model: LLMRequest = None
|
||||
self.value_assessment_model: LLMRequest | None = None
|
||||
self.memory_extraction_model: LLMRequest | None = None
|
||||
|
||||
# 统计信息
|
||||
self.total_memories = 0
|
||||
@@ -162,7 +165,6 @@ class MemorySystem:
|
||||
async def initialize(self):
|
||||
"""异步初始化记忆系统"""
|
||||
try:
|
||||
logger.info("正在初始化记忆系统...")
|
||||
|
||||
# 初始化LLM模型
|
||||
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||
@@ -249,7 +251,7 @@ class MemorySystem:
|
||||
|
||||
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
|
||||
|
||||
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
|
||||
planner_task_config = model_config.model_task_config.utils_small
|
||||
planner_model: LLMRequest | None = None
|
||||
try:
|
||||
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
|
||||
@@ -269,10 +271,8 @@ class MemorySystem:
|
||||
self.hippocampus_sampler = None
|
||||
|
||||
# 统一存储已经自动加载数据,无需额外加载
|
||||
logger.info("✅ 简化版记忆系统初始化完成")
|
||||
|
||||
self.status = MemorySystemStatus.READY
|
||||
logger.info("✅ 记忆系统初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
self.status = MemorySystemStatus.ERROR
|
||||
@@ -479,7 +479,7 @@ class MemorySystem:
|
||||
existing_id = self._memory_fingerprints.get(fingerprint_key)
|
||||
if existing_id and existing_id not in new_memory_ids:
|
||||
candidate_ids.add(existing_id)
|
||||
except Exception as exc:
|
||||
except Exception as exc: # noqa: PERF203
|
||||
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
|
||||
|
||||
# 基于主体索引的候选(使用统一存储)
|
||||
@@ -557,11 +557,11 @@ class MemorySystem:
|
||||
context = dict(context or {})
|
||||
|
||||
# 获取配置的采样模式
|
||||
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'precision')
|
||||
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
|
||||
current_mode = MemorySamplingMode(sampling_mode)
|
||||
|
||||
|
||||
context['__sampling_mode'] = current_mode.value
|
||||
context["__sampling_mode"] = current_mode.value
|
||||
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
||||
|
||||
# 根据采样模式处理记忆
|
||||
@@ -637,7 +637,7 @@ class MemorySystem:
|
||||
|
||||
# 检查信息价值阈值
|
||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||
threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5)
|
||||
threshold = getattr(global_config.memory, "precision_memory_reply_threshold", 0.5)
|
||||
|
||||
if value_score < threshold:
|
||||
logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建")
|
||||
@@ -843,7 +843,7 @@ class MemorySystem:
|
||||
for i, (mem, score, details) in enumerate(scored_memories[:3], 1):
|
||||
try:
|
||||
summary = mem.content[:60] if hasattr(mem, "content") and mem.content else ""
|
||||
except:
|
||||
except Exception:
|
||||
summary = ""
|
||||
logger.info(
|
||||
f" #{i} | final={details['final']:.3f} "
|
||||
@@ -1440,8 +1440,8 @@ class MemorySystem:
|
||||
context_keywords = context.get("keywords") or []
|
||||
keyword_overlap = 0.0
|
||||
if context_keywords:
|
||||
memory_keywords = set(k.lower() for k in memory.keywords)
|
||||
keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(
|
||||
memory_keywords = {k.lower() for k in memory.keywords}
|
||||
keyword_overlap = len(memory_keywords & {k.lower() for k in context_keywords}) / max(
|
||||
len(context_keywords), 1
|
||||
)
|
||||
|
||||
@@ -1489,7 +1489,7 @@ class MemorySystem:
|
||||
"""启动海马体采样"""
|
||||
if self.hippocampus_sampler:
|
||||
asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
|
||||
logger.info("🚀 海马体后台采样已启动")
|
||||
logger.info("海马体后台采样已启动")
|
||||
else:
|
||||
logger.warning("海马体采样器未初始化,无法启动采样")
|
||||
|
||||
@@ -1497,7 +1497,7 @@ class MemorySystem:
|
||||
"""停止海马体采样"""
|
||||
if self.hippocampus_sampler:
|
||||
self.hippocampus_sampler.stop_background_sampling()
|
||||
logger.info("🛑 海马体后台采样已停止")
|
||||
logger.info("海马体后台采样已停止")
|
||||
|
||||
def get_system_stats(self) -> dict[str, Any]:
|
||||
"""获取系统统计信息"""
|
||||
@@ -1536,10 +1536,10 @@ class MemorySystem:
|
||||
if self.unified_storage:
|
||||
self.unified_storage.cleanup()
|
||||
|
||||
logger.info("✅ 简化记忆系统已关闭")
|
||||
logger.info("简化记忆系统已关闭")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆系统关闭失败: {e}", exc_info=True)
|
||||
logger.error(f"记忆系统关闭失败: {e}", exc_info=True)
|
||||
|
||||
async def _rebuild_vector_storage_if_needed(self):
|
||||
"""重建向量存储(如果需要)"""
|
||||
@@ -1553,12 +1553,13 @@ class MemorySystem:
|
||||
|
||||
# 收集需要重建向量的记忆
|
||||
memories_to_rebuild = []
|
||||
for memory_id, memory in self.unified_storage.memory_cache.items():
|
||||
# 检查记忆是否有有效的 display 文本
|
||||
if memory.display and memory.display.strip():
|
||||
memories_to_rebuild.append(memory)
|
||||
elif memory.text_content and memory.text_content.strip():
|
||||
memories_to_rebuild.append(memory)
|
||||
if self.unified_storage:
|
||||
for memory in self.unified_storage.memory_cache.values():
|
||||
# 检查记忆是否有有效的 display 文本
|
||||
if memory.display and memory.display.strip():
|
||||
memories_to_rebuild.append(memory)
|
||||
elif memory.text_content and memory.text_content.strip():
|
||||
memories_to_rebuild.append(memory)
|
||||
|
||||
if not memories_to_rebuild:
|
||||
logger.warning("没有找到可重建向量的记忆")
|
||||
@@ -1583,14 +1584,16 @@ class MemorySystem:
|
||||
logger.error(f"批量重建向量失败: {e}")
|
||||
continue
|
||||
|
||||
# 保存重建的向量存储
|
||||
await self.unified_storage.save_storage()
|
||||
|
||||
final_count = self.unified_storage.storage_stats.get("total_vectors", 0)
|
||||
logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}")
|
||||
# 向量数据在 store_memories 中已保存,此处无需额外操作
|
||||
if self.unified_storage:
|
||||
storage_stats = self.unified_storage.get_storage_stats()
|
||||
final_count = storage_stats.get("total_vectors", 0)
|
||||
logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}")
|
||||
else:
|
||||
logger.warning("向量存储重建完成,但无法获取最终向量数量,因为存储系统未初始化")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量存储重建失败: {e}", exc_info=True)
|
||||
logger.error(f"向量存储重建失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 全局记忆系统实例
|
||||
@@ -1613,8 +1616,8 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None):
|
||||
await memory_system.initialize()
|
||||
|
||||
# 根据配置启动海马体采样
|
||||
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate')
|
||||
if sampling_mode in ['hippocampus', 'all']:
|
||||
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "immediate")
|
||||
if sampling_mode in ["hippocampus", "all"]:
|
||||
memory_system.start_hippocampus_sampling()
|
||||
|
||||
return memory_system
|
||||
|
||||
Reference in New Issue
Block a user