re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
committed by Windpicker-owo
parent 00ba07e0e1
commit a79253c714
263 changed files with 3781 additions and 3189 deletions

View File

@@ -1,37 +1,35 @@
# -*- coding: utf-8 -*-
"""
简化记忆系统模块
移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制
"""
# 核心数据结构
# 激活器
from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator
from .memory_chunk import (
ConfidenceLevel,
ContentStructure,
ImportanceLevel,
MemoryChunk,
MemoryMetadata,
ContentStructure,
MemoryType,
ImportanceLevel,
ConfidenceLevel,
create_memory_chunk,
)
# 兼容性别名
from .memory_chunk import MemoryChunk as Memory
# 遗忘引擎
from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
# 记忆核心系统
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
# 记忆管理器
from .memory_manager import MemoryManager, MemoryResult, memory_manager
# 激活器
from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator
# 记忆核心系统
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
# 兼容性别名
from .memory_chunk import MemoryChunk as Memory
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
__all__ = [
# 核心数据结构

View File

@@ -1,17 +1,17 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统适配器
将增强记忆系统集成到现有MoFox Bot架构中
"""
import time
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -47,10 +47,10 @@ class AdapterConfig:
class EnhancedMemoryAdapter:
"""增强记忆系统适配器"""
def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None):
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
self.llm_model = llm_model
self.config = config or AdapterConfig()
self.integration_layer: Optional[MemoryIntegrationLayer] = None
self.integration_layer: MemoryIntegrationLayer | None = None
self._initialized = False
# 统计信息
@@ -96,7 +96,7 @@ class EnhancedMemoryAdapter:
# 如果初始化失败,禁用增强记忆功能
self.config.enable_enhanced_memory = False
async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
@@ -105,7 +105,7 @@ class EnhancedMemoryAdapter:
self.adapter_stats["total_processed"] += 1
try:
payload_context: Dict[str, Any] = dict(context or {})
payload_context: dict[str, Any] = dict(context or {})
conversation_text = payload_context.get("conversation_text")
if not conversation_text:
@@ -146,8 +146,8 @@ class EnhancedMemoryAdapter:
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None
) -> List[MemoryChunk]:
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.config.enable_enhanced_memory:
return []
@@ -166,7 +166,7 @@ class EnhancedMemoryAdapter:
return []
async def get_memory_context_for_prompt(
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
) -> str:
"""获取用于提示词的记忆上下文"""
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
@@ -186,7 +186,7 @@ class EnhancedMemoryAdapter:
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
"""获取增强记忆系统摘要"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"available": False, "reason": "Not initialized or disabled"}
@@ -222,7 +222,7 @@ class EnhancedMemoryAdapter:
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
self.adapter_stats["average_processing_time"] = new_avg
def get_adapter_stats(self) -> Dict[str, Any]:
def get_adapter_stats(self) -> dict[str, Any]:
"""获取适配器统计信息"""
return self.adapter_stats.copy()
@@ -253,7 +253,7 @@ class EnhancedMemoryAdapter:
# 全局适配器实例
_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
@@ -292,8 +292,8 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
async def process_conversation_with_enhanced_memory(
context: Dict[str, Any], llm_model: Optional[LLMRequest] = None
) -> Dict[str, Any]:
context: dict[str, Any], llm_model: LLMRequest | None = None
) -> dict[str, Any]:
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
@@ -323,10 +323,10 @@ async def process_conversation_with_enhanced_memory(
async def retrieve_memories_with_enhanced_system(
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
limit: int = 10,
llm_model: Optional[LLMRequest] = None,
) -> List[MemoryChunk]:
llm_model: LLMRequest | None = None,
) -> list[MemoryChunk]:
"""使用增强记忆系统检索记忆"""
if not llm_model:
# 获取默认的LLM模型
@@ -345,9 +345,9 @@ async def retrieve_memories_with_enhanced_system(
async def get_memory_context_for_prompt(
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
max_memories: int = 5,
llm_model: Optional[LLMRequest] = None,
llm_model: LLMRequest | None = None,
) -> str:
"""获取用于提示词的记忆上下文"""
if not llm_model:

View File

@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统钩子
用于在消息处理过程中自动构建和检索记忆
"""
from typing import Dict, List, Any, Optional
from datetime import datetime
from typing import Any
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
logger = get_logger(__name__)
@@ -27,7 +27,7 @@ class EnhancedMemoryHooks:
user_id: str,
chat_id: str,
message_id: str,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
) -> bool:
"""
处理消息并构建记忆
@@ -106,8 +106,8 @@ class EnhancedMemoryHooks:
user_id: str,
chat_id: str,
limit: int = 5,
extra_context: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
extra_context: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""
为回复获取相关记忆

View File

@@ -1,19 +1,19 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统集成脚本
用于在现有系统中无缝集成增强记忆功能
"""
from typing import Dict, Any, Optional
from typing import Any
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
logger = get_logger(__name__)
async def process_user_message_memory(
message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
) -> bool:
"""
处理用户消息并构建记忆
@@ -44,8 +44,8 @@ async def process_user_message_memory(
async def get_relevant_memories_for_response(
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
为回复获取相关记忆
@@ -74,7 +74,7 @@ async def get_relevant_memories_for_response(
return {"has_memories": False, "memories": [], "memory_count": 0}
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
"""
格式化记忆信息用于Prompt
@@ -114,7 +114,7 @@ async def cleanup_memory_system():
logger.error(f"记忆系统清理失败: {e}")
def get_memory_system_status() -> Dict[str, Any]:
def get_memory_system_status() -> dict[str, Any]:
"""
获取记忆系统状态
@@ -133,7 +133,7 @@ def get_memory_system_status() -> Dict[str, Any]:
# 便捷函数
async def remember_message(
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
) -> bool:
"""
便捷的记忆构建函数
@@ -159,8 +159,8 @@ async def recall_memories(
user_id: str = "default_user",
chat_id: str = "default_chat",
limit: int = 5,
context: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
便捷的记忆检索函数

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
增强重排序器
实现文档设计的多维度评分模型
@@ -6,12 +5,12 @@
import math
import time
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -44,7 +43,7 @@ class ReRankingConfig:
freq_max_score: float = 5.0 # 最大频率得分
# 类型匹配权重映射
type_match_weights: Dict[str, Dict[str, float]] = None
type_match_weights: dict[str, dict[str, float]] = None
def __post_init__(self):
"""初始化类型匹配权重"""
@@ -157,7 +156,7 @@ class IntentClassifier:
],
}
def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType:
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
"""识别对话意图"""
if not query:
return IntentType.UNKNOWN
@@ -165,7 +164,7 @@ class IntentClassifier:
query_lower = query.lower()
# 统计各意图的匹配分数
intent_scores = {intent: 0 for intent in IntentType}
intent_scores = dict.fromkeys(IntentType, 0)
for intent, patterns in self.patterns.items():
for pattern in patterns:
@@ -187,7 +186,7 @@ class IntentClassifier:
class EnhancedReRanker:
"""增强重排序器 - 实现文档设计的多维度评分模型"""
def __init__(self, config: Optional[ReRankingConfig] = None):
def __init__(self, config: ReRankingConfig | None = None):
self.config = config or ReRankingConfig()
self.intent_classifier = IntentClassifier()
@@ -210,10 +209,10 @@ class EnhancedReRanker:
def rerank_memories(
self,
query: str,
candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
context: Dict[str, Any],
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
context: dict[str, Any],
limit: int = 10,
) -> List[Tuple[str, MemoryChunk, float]]:
) -> list[tuple[str, MemoryChunk, float]]:
"""
对候选记忆进行重排序
@@ -341,11 +340,11 @@ default_reranker = EnhancedReRanker()
def rerank_candidate_memories(
query: str,
candidate_memories: List[Tuple[str, MemoryChunk, float]],
context: Dict[str, Any],
candidate_memories: list[tuple[str, MemoryChunk, float]],
context: dict[str, Any],
limit: int = 10,
config: Optional[ReRankingConfig] = None,
) -> List[Tuple[str, MemoryChunk, float]]:
config: ReRankingConfig | None = None,
) -> list[tuple[str, MemoryChunk, float]]:
"""
便捷函数:对候选记忆进行重排序
"""

View File

@@ -1,18 +1,18 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统集成层
现在只管理新的增强记忆系统,旧系统已被完全移除
"""
import time
import asyncio
from typing import Dict, List, Optional, Any
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -40,12 +40,12 @@ class IntegrationConfig:
class MemoryIntegrationLayer:
"""记忆系统集成层 - 现在只管理增强记忆系统"""
def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None):
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
self.llm_model = llm_model
self.config = config or IntegrationConfig()
# 只初始化增强记忆系统
self.enhanced_memory: Optional[EnhancedMemorySystem] = None
self.enhanced_memory: EnhancedMemorySystem | None = None
# 集成统计
self.integration_stats = {
@@ -113,7 +113,7 @@ class MemoryIntegrationLayer:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
raise
async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]:
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
@@ -150,10 +150,10 @@ class MemoryIntegrationLayer:
async def retrieve_relevant_memories(
self,
query: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
) -> List[MemoryChunk]:
user_id: str | None = None,
context: dict[str, Any] | None = None,
limit: int | None = None,
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.enhanced_memory:
return []
@@ -172,7 +172,7 @@ class MemoryIntegrationLayer:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_system_status(self) -> Dict[str, Any]:
async def get_system_status(self) -> dict[str, Any]:
"""获取系统状态"""
if not self._initialized:
return {"status": "not_initialized"}
@@ -193,7 +193,7 @@ class MemoryIntegrationLayer:
logger.error(f"获取系统状态失败: {e}", exc_info=True)
return {"status": "error", "error": str(e)}
def get_integration_stats(self) -> Dict[str, Any]:
def get_integration_stats(self) -> dict[str, Any]:
"""获取集成统计信息"""
return self.integration_stats.copy()

View File

@@ -1,20 +1,20 @@
# -*- coding: utf-8 -*-
"""
记忆系统集成钩子
提供与现有MoFox Bot系统的无缝集成点
"""
import time
from typing import Dict, Optional, Any
from dataclasses import dataclass
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_adapter import (
get_memory_context_for_prompt,
process_conversation_with_enhanced_memory,
retrieve_memories_with_enhanced_system,
get_memory_context_for_prompt,
)
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -24,7 +24,7 @@ class HookResult:
success: bool
data: Any = None
error: Optional[str] = None
error: str | None = None
processing_time: float = 0.0
@@ -125,8 +125,8 @@ class MemoryIntegrationHooks:
# 尝试注册到事件系统
try:
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
# 注册消息后处理事件
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
@@ -238,11 +238,11 @@ class MemoryIntegrationHooks:
# 钩子处理器方法
async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult:
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
"""事件系统的消息处理处理器"""
return await self._on_message_processed_hook(event_data)
async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult:
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
"""消息后处理钩子"""
start_time = time.time()
@@ -289,7 +289,7 @@ class MemoryIntegrationHooks:
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult:
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
"""聊天流保存钩子"""
start_time = time.time()
@@ -345,7 +345,7 @@ class MemoryIntegrationHooks:
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult:
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
"""回复前钩子"""
start_time = time.time()
@@ -380,7 +380,7 @@ class MemoryIntegrationHooks:
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult:
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
"""知识库查询钩子"""
start_time = time.time()
@@ -411,7 +411,7 @@ class MemoryIntegrationHooks:
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult:
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
"""提示词构建钩子"""
start_time = time.time()
@@ -459,7 +459,7 @@ class MemoryIntegrationHooks:
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
self.hook_stats["average_hook_time"] = new_avg
def get_hook_stats(self) -> Dict[str, Any]:
def get_hook_stats(self) -> dict[str, Any]:
"""获取钩子统计信息"""
return self.hook_stats.copy()
@@ -501,7 +501,7 @@ class MemoryMaintenanceTask:
# 全局钩子实例
_memory_hooks: Optional[MemoryIntegrationHooks] = None
_memory_hooks: MemoryIntegrationHooks | None = None
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:

View File

@@ -1,20 +1,20 @@
# -*- coding: utf-8 -*-
"""
元数据索引系统
为记忆系统提供多维度的精准过滤和查询能力
"""
import threading
import time
import orjson
from typing import Dict, List, Optional, Tuple, Set, Any, Union
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
import threading
from collections import defaultdict
from pathlib import Path
from typing import Any
import orjson
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
logger = get_logger(__name__)
@@ -40,21 +40,21 @@ class IndexType(Enum):
class IndexQuery:
"""索引查询条件"""
user_ids: Optional[List[str]] = None
memory_types: Optional[List[MemoryType]] = None
subjects: Optional[List[str]] = None
keywords: Optional[List[str]] = None
tags: Optional[List[str]] = None
categories: Optional[List[str]] = None
time_range: Optional[Tuple[float, float]] = None
confidence_levels: Optional[List[ConfidenceLevel]] = None
importance_levels: Optional[List[ImportanceLevel]] = None
min_relationship_score: Optional[float] = None
max_relationship_score: Optional[float] = None
min_access_count: Optional[int] = None
semantic_hashes: Optional[List[str]] = None
limit: Optional[int] = None
sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score"
user_ids: list[str] | None = None
memory_types: list[MemoryType] | None = None
subjects: list[str] | None = None
keywords: list[str] | None = None
tags: list[str] | None = None
categories: list[str] | None = None
time_range: tuple[float, float] | None = None
confidence_levels: list[ConfidenceLevel] | None = None
importance_levels: list[ImportanceLevel] | None = None
min_relationship_score: float | None = None
max_relationship_score: float | None = None
min_access_count: int | None = None
semantic_hashes: list[str] | None = None
limit: int | None = None
sort_by: str | None = None # "created_at", "access_count", "relevance_score"
sort_order: str = "desc" # "asc", "desc"
@@ -62,10 +62,10 @@ class IndexQuery:
class IndexResult:
"""索引结果"""
memory_ids: List[str]
memory_ids: list[str]
total_count: int
query_time: float
filtered_by: List[str]
filtered_by: list[str]
class MetadataIndexManager:
@@ -94,7 +94,7 @@ class MetadataIndexManager:
self.access_frequency_index = [] # [(access_count, memory_id), ...]
# 内存缓存
self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {}
self.memory_metadata_cache: dict[str, dict[str, Any]] = {}
# 统计信息
self.index_stats = {
@@ -140,7 +140,7 @@ class MetadataIndexManager:
return key
@staticmethod
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
def _serialize_metadata_entry(metadata: dict[str, Any]) -> dict[str, Any]:
serialized = {}
for field_name, value in metadata.items():
if isinstance(value, Enum):
@@ -149,7 +149,7 @@ class MetadataIndexManager:
serialized[field_name] = value
return serialized
async def index_memories(self, memories: List[MemoryChunk]):
async def index_memories(self, memories: list[MemoryChunk]):
"""为记忆建立索引"""
if not memories:
return
@@ -375,7 +375,7 @@ class MetadataIndexManager:
logger.error(f"❌ 元数据查询失败: {e}", exc_info=True)
return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[])
def _get_candidate_memories(self, query: IndexQuery) -> Set[str]:
def _get_candidate_memories(self, query: IndexQuery) -> set[str]:
"""获取候选记忆ID集合"""
candidate_ids = set()
@@ -444,7 +444,7 @@ class MetadataIndexManager:
return candidate_ids
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
def _collect_index_matches(self, index_type: IndexType, token: str | Enum | None) -> set[str]:
"""根据给定token收集索引匹配支持部分匹配"""
mapping = self.indices.get(index_type)
if mapping is None:
@@ -461,7 +461,7 @@ class MetadataIndexManager:
if not key:
return set()
matches: Set[str] = set(mapping.get(key, set()))
matches: set[str] = set(mapping.get(key, set()))
if matches:
return set(matches)
@@ -477,7 +477,7 @@ class MetadataIndexManager:
return matches
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
def _apply_filters(self, candidate_ids: set[str], query: IndexQuery) -> list[str]:
"""应用过滤条件"""
filtered_ids = list(candidate_ids)
@@ -545,7 +545,7 @@ class MetadataIndexManager:
created_at = self.memory_metadata_cache[memory_id]["created_at"]
return start_time <= created_at <= end_time
def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]:
def _sort_memories(self, memory_ids: list[str], sort_by: str, sort_order: str) -> list[str]:
"""对记忆进行排序"""
if sort_by == "created_at":
# 使用时间索引(已经有序)
@@ -582,7 +582,7 @@ class MetadataIndexManager:
return memory_ids
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
def _get_applied_filters(self, query: IndexQuery) -> list[str]:
"""获取应用的过滤器列表"""
filters = []
if query.memory_types:
@@ -686,11 +686,11 @@ class MetadataIndexManager:
except Exception as e:
logger.error(f"❌ 移除记忆索引失败: {e}")
async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]:
async def get_memory_metadata(self, memory_id: str) -> dict[str, Any] | None:
"""获取记忆元数据"""
return self.memory_metadata_cache.get(memory_id)
async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]:
async def get_user_memory_ids(self, user_id: str, limit: int | None = None) -> list[str]:
"""获取用户的所有记忆ID"""
user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set()))
@@ -699,7 +699,7 @@ class MetadataIndexManager:
return user_memory_ids
async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]:
async def get_memory_statistics(self, user_id: str | None = None) -> dict[str, Any]:
"""获取记忆统计信息"""
stats = {
"total_memories": self.index_stats["total_memories"],
@@ -784,7 +784,7 @@ class MetadataIndexManager:
logger.info("正在保存元数据索引...")
# 保存各类索引
indices_data: Dict[str, Dict[str, List[str]]] = {}
indices_data: dict[str, dict[str, list[str]]] = {}
for index_type, index_data in self.indices.items():
serialized_index = {}
for key, values in index_data.items():
@@ -839,7 +839,7 @@ class MetadataIndexManager:
# 加载各类索引
indices_file = self.index_path / "indices.json"
if indices_file.exists():
with open(indices_file, "r", encoding="utf-8") as f:
with open(indices_file, encoding="utf-8") as f:
indices_data = orjson.loads(f.read())
for index_type_value, index_data in indices_data.items():
@@ -853,25 +853,25 @@ class MetadataIndexManager:
# 加载时间索引
time_index_file = self.index_path / "time_index.json"
if time_index_file.exists():
with open(time_index_file, "r", encoding="utf-8") as f:
with open(time_index_file, encoding="utf-8") as f:
self.time_index = orjson.loads(f.read())
# 加载关系分索引
relationship_index_file = self.index_path / "relationship_index.json"
if relationship_index_file.exists():
with open(relationship_index_file, "r", encoding="utf-8") as f:
with open(relationship_index_file, encoding="utf-8") as f:
self.relationship_index = orjson.loads(f.read())
# 加载访问频率索引
access_frequency_index_file = self.index_path / "access_frequency_index.json"
if access_frequency_index_file.exists():
with open(access_frequency_index_file, "r", encoding="utf-8") as f:
with open(access_frequency_index_file, encoding="utf-8") as f:
self.access_frequency_index = orjson.loads(f.read())
# 加载元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
if metadata_cache_file.exists():
with open(metadata_cache_file, "r", encoding="utf-8") as f:
with open(metadata_cache_file, encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
# 转换置信度和重要性为枚举类型
@@ -914,7 +914,7 @@ class MetadataIndexManager:
# 加载统计信息
stats_file = self.index_path / "index_stats.json"
if stats_file.exists():
with open(stats_file, "r", encoding="utf-8") as f:
with open(stats_file, encoding="utf-8") as f:
self.index_stats = orjson.loads(f.read())
# 更新记忆计数
@@ -1004,7 +1004,7 @@ class MetadataIndexManager:
if len(self.indices[IndexType.CATEGORY][category]) < min_frequency:
del self.indices[IndexType.CATEGORY][category]
def get_index_stats(self) -> Dict[str, Any]:
def get_index_stats(self) -> dict[str, Any]:
"""获取索引统计信息"""
stats = self.index_stats.copy()
if stats["total_queries"] > 0:

View File

@@ -1,19 +1,19 @@
# -*- coding: utf-8 -*-
"""
多阶段召回机制
实现粗粒度到细粒度的记忆检索优化
"""
import time
from typing import Dict, List, Optional, Set, Any
from dataclasses import dataclass, field
from enum import Enum
import orjson
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
import orjson
from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -73,11 +73,11 @@ class StageResult:
"""阶段结果"""
stage: RetrievalStage
memory_ids: List[str]
memory_ids: list[str]
processing_time: float
filtered_count: int
score_threshold: float
details: List[Dict[str, Any]] = field(default_factory=list)
details: list[dict[str, Any]] = field(default_factory=list)
@dataclass
@@ -86,17 +86,17 @@ class RetrievalResult:
query: str
user_id: str
final_memories: List[MemoryChunk]
stage_results: List[StageResult]
final_memories: list[MemoryChunk]
stage_results: list[StageResult]
total_processing_time: float
total_filtered: int
retrieval_stats: Dict[str, Any]
retrieval_stats: dict[str, Any]
class MultiStageRetrieval:
"""多阶段召回系统"""
def __init__(self, config: Optional[RetrievalConfig] = None):
def __init__(self, config: RetrievalConfig | None = None):
self.config = config or RetrievalConfig.from_global_config()
# 初始化增强重排序器
@@ -124,11 +124,11 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
context: dict[str, Any],
metadata_index,
vector_storage,
all_memories_cache: Dict[str, MemoryChunk],
limit: Optional[int] = None,
all_memories_cache: dict[str, MemoryChunk],
limit: int | None = None,
) -> RetrievalResult:
"""多阶段记忆检索"""
start_time = time.time()
@@ -136,7 +136,7 @@ class MultiStageRetrieval:
stage_results = []
current_memory_ids = set()
memory_debug_info: Dict[str, Dict[str, Any]] = {}
memory_debug_info: dict[str, dict[str, Any]] = {}
try:
logger.debug(f"开始多阶段检索query='{query}', user_id='{user_id}'")
@@ -311,11 +311,11 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
context: dict[str, Any],
metadata_index,
all_memories_cache: Dict[str, MemoryChunk],
all_memories_cache: dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段1元数据过滤"""
start_time = time.time()
@@ -345,7 +345,7 @@ class MultiStageRetrieval:
result = await metadata_index.query_memories(index_query)
result_ids = list(result.memory_ids)
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
details: List[Dict[str, Any]] = []
details: list[dict[str, Any]] = []
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
if not result_ids:
@@ -440,12 +440,12 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
context: dict[str, Any],
vector_storage,
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk],
candidate_ids: set[str],
all_memories_cache: dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段2向量搜索"""
start_time = time.time()
@@ -479,8 +479,8 @@ class MultiStageRetrieval:
# 过滤候选记忆
filtered_memories = []
details: List[Dict[str, Any]] = []
raw_details: List[Dict[str, Any]] = []
details: list[dict[str, Any]] = []
raw_details: list[dict[str, Any]] = []
threshold = self.config.vector_similarity_threshold
for memory_id, similarity in search_result:
@@ -561,7 +561,7 @@ class MultiStageRetrieval:
)
def _create_text_search_fallback(
self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float
self, candidate_ids: set[str], all_memories_cache: dict[str, MemoryChunk], query_text: str, start_time: float
) -> StageResult:
"""当向量搜索失败时,使用文本搜索作为回退策略"""
try:
@@ -618,18 +618,18 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk],
context: dict[str, Any],
candidate_ids: set[str],
all_memories_cache: dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段3语义重排序"""
start_time = time.time()
try:
reranked_memories = []
details: List[Dict[str, Any]] = []
details: list[dict[str, Any]] = []
threshold = self.config.semantic_similarity_threshold
for memory_id in candidate_ids:
@@ -704,19 +704,19 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
candidate_ids: List[str],
all_memories_cache: Dict[str, MemoryChunk],
context: dict[str, Any],
candidate_ids: list[str],
all_memories_cache: dict[str, MemoryChunk],
limit: int,
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段4上下文过滤"""
start_time = time.time()
try:
final_memories = []
details: List[Dict[str, Any]] = []
details: list[dict[str, Any]] = []
for memory_id in candidate_ids:
if memory_id not in all_memories_cache:
@@ -793,12 +793,12 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
all_memories_cache: Dict[str, MemoryChunk],
context: dict[str, Any],
all_memories_cache: dict[str, MemoryChunk],
limit: int,
*,
excluded_ids: Optional[Set[str]] = None,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
excluded_ids: set[str] | None = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""回退检索阶段 - 当主检索失败时使用更宽松的策略"""
start_time = time.time()
@@ -881,8 +881,8 @@ class MultiStageRetrieval:
)
async def _generate_query_embedding(
self, query: str, context: Dict[str, Any], vector_storage
) -> Optional[List[float]]:
self, query: str, context: dict[str, Any], vector_storage
) -> list[float] | None:
"""生成查询向量"""
try:
query_plan = context.get("query_plan")
@@ -916,7 +916,7 @@ class MultiStageRetrieval:
logger.error(f"生成查询向量时发生异常: {e}", exc_info=True)
return None
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
"""计算语义相似度 - 简化优化版本,提升召回率"""
try:
query_plan = context.get("query_plan")
@@ -947,9 +947,10 @@ class MultiStageRetrieval:
# 核心匹配策略2词汇匹配
word_score = 0.0
try:
import jieba
import re
import jieba
# 分词处理
query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text)
memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text)
@@ -1059,7 +1060,7 @@ class MultiStageRetrieval:
logger.warning(f"计算语义相似度失败: {e}")
return 0.0
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
"""计算上下文相关度"""
try:
score = 0.0
@@ -1132,7 +1133,7 @@ class MultiStageRetrieval:
return 0.0
async def _calculate_final_score(
self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float
self, query: str, memory: MemoryChunk, context: dict[str, Any], context_score: float
) -> float:
"""计算最终评分"""
try:
@@ -1184,7 +1185,7 @@ class MultiStageRetrieval:
logger.warning(f"计算最终评分失败: {e}")
return 0.0
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: list[str] | None) -> float:
if not required_subjects:
return 0.0
@@ -1229,7 +1230,7 @@ class MultiStageRetrieval:
except Exception:
return 0.5
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
def _extract_memory_types_from_context(self, context: dict[str, Any]) -> list[MemoryType]:
"""从上下文中提取记忆类型"""
try:
query_plan = context.get("query_plan")
@@ -1256,10 +1257,10 @@ class MultiStageRetrieval:
except Exception:
return []
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
def _extract_keywords_from_query(self, query: str, query_plan: Any | None = None) -> list[str]:
"""从查询中提取关键词"""
try:
extracted: List[str] = []
extracted: list[str] = []
if query_plan and getattr(query_plan, "required_keywords", None):
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
@@ -1283,7 +1284,7 @@ class MultiStageRetrieval:
except Exception:
return []
def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]):
def _update_retrieval_stats(self, total_time: float, stage_results: list[StageResult]):
"""更新检索统计"""
self.retrieval_stats["total_queries"] += 1
@@ -1306,7 +1307,7 @@ class MultiStageRetrieval:
]
stage_stat["avg_time"] = new_stage_avg
def get_retrieval_stats(self) -> Dict[str, Any]:
def get_retrieval_stats(self) -> dict[str, Any]:
"""获取检索统计信息"""
return self.retrieval_stats.copy()
@@ -1328,12 +1329,12 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
candidate_ids: List[str],
all_memories_cache: Dict[str, MemoryChunk],
context: dict[str, Any],
candidate_ids: list[str],
all_memories_cache: dict[str, MemoryChunk],
limit: int,
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段5增强重排序 - 使用多维度评分模型"""
start_time = time.time()

View File

@@ -1,24 +1,23 @@
# -*- coding: utf-8 -*-
"""
向量数据库存储接口
为记忆系统提供高效的向量存储和语义搜索能力
"""
import time
import orjson
import asyncio
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from pathlib import Path
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.common.config_helpers import resolve_embedding_dimension
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.config_helpers import resolve_embedding_dimension
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -48,7 +47,7 @@ class VectorStorageConfig:
class VectorStorageManager:
"""向量存储管理器"""
def __init__(self, config: Optional[VectorStorageConfig] = None):
def __init__(self, config: VectorStorageConfig | None = None):
self.config = config or VectorStorageConfig()
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
@@ -68,8 +67,8 @@ class VectorStorageManager:
self.index_to_memory_id = {} # vector index -> memory_id
# 内存缓存
self.memory_cache: Dict[str, MemoryChunk] = {}
self.vector_cache: Dict[str, List[float]] = {}
self.memory_cache: dict[str, MemoryChunk] = {}
self.vector_cache: dict[str, list[float]] = {}
# 统计信息
self.storage_stats = {
@@ -125,7 +124,7 @@ class VectorStorageManager:
)
logger.info("✅ 嵌入模型初始化完成")
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
"""生成查询向量,用于记忆召回"""
if not query_text:
logger.warning("查询文本为空,无法生成向量")
@@ -155,7 +154,7 @@ class VectorStorageManager:
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
return None
async def store_memories(self, memories: List[MemoryChunk]):
async def store_memories(self, memories: list[MemoryChunk]):
"""存储记忆向量"""
if not memories:
return
@@ -231,7 +230,7 @@ class VectorStorageManager:
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
return memory.memory_id
async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]):
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
"""批量生成和存储嵌入向量"""
if not memory_texts:
return
@@ -253,12 +252,12 @@ class VectorStorageManager:
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
async def _batch_generate_embeddings(self, memory_ids: List[str], texts: List[str]) -> Dict[str, List[float]]:
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
"""批量生成嵌入向量"""
if not texts:
return {}
results: Dict[str, List[float]] = {}
results: dict[str, list[float]] = {}
try:
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
@@ -281,7 +280,9 @@ class VectorStorageManager:
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
results[memory_id] = []
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)]
tasks = [
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
]
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
@@ -291,7 +292,7 @@ class VectorStorageManager:
return results
async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]):
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
"""添加单个记忆到向量存储"""
with self._lock:
try:
@@ -337,7 +338,7 @@ class VectorStorageManager:
except Exception as e:
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
def _normalize_vector(self, vector: List[float]) -> List[float]:
def _normalize_vector(self, vector: list[float]) -> list[float]:
"""L2归一化向量"""
if not vector:
return vector
@@ -357,12 +358,12 @@ class VectorStorageManager:
async def search_similar_memories(
self,
query_vector: Optional[List[float]] = None,
query_vector: list[float] | None = None,
*,
query_text: Optional[str] = None,
query_text: str | None = None,
limit: int = 10,
scope_id: Optional[str] = None,
) -> List[Tuple[str, float]]:
scope_id: str | None = None,
) -> list[tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
@@ -379,7 +380,7 @@ class VectorStorageManager:
logger.warning("查询向量生成失败")
return []
scope_filter: Optional[str] = None
scope_filter: str | None = None
if isinstance(scope_id, str):
normalized_scope = scope_id.strip().lower()
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
@@ -491,7 +492,7 @@ class VectorStorageManager:
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
return []
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
"""根据ID获取记忆"""
# 先检查缓存
if memory_id in self.memory_cache:
@@ -501,7 +502,7 @@ class VectorStorageManager:
self.storage_stats["total_searches"] += 1
return None
async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]):
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
"""更新记忆的嵌入向量"""
with self._lock:
try:
@@ -636,7 +637,7 @@ class VectorStorageManager:
# 加载记忆缓存
cache_file = self.storage_path / "memory_cache.json"
if cache_file.exists():
with open(cache_file, "r", encoding="utf-8") as f:
with open(cache_file, encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
self.memory_cache = {
@@ -646,13 +647,13 @@ class VectorStorageManager:
# 加载向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
if vector_cache_file.exists():
with open(vector_cache_file, "r", encoding="utf-8") as f:
with open(vector_cache_file, encoding="utf-8") as f:
self.vector_cache = orjson.loads(f.read())
# 加载映射关系
mapping_file = self.storage_path / "id_mapping.json"
if mapping_file.exists():
with open(mapping_file, "r", encoding="utf-8") as f:
with open(mapping_file, encoding="utf-8") as f:
mapping_data = orjson.loads(f.read())
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
@@ -689,7 +690,7 @@ class VectorStorageManager:
# 加载统计信息
stats_file = self.storage_path / "storage_stats.json"
if stats_file.exists():
with open(stats_file, "r", encoding="utf-8") as f:
with open(stats_file, encoding="utf-8") as f:
self.storage_stats = orjson.loads(f.read())
# 更新向量计数
@@ -806,7 +807,7 @@ class VectorStorageManager:
if invalid_memory_ids:
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
def get_storage_stats(self) -> Dict[str, Any]:
def get_storage_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
stats = self.storage_stats.copy()
if stats["total_searches"] > 0:
@@ -821,11 +822,11 @@ class SimpleVectorIndex:
def __init__(self, dimension: int):
self.dimension = dimension
self.vectors: List[List[float]] = []
self.vector_ids: List[int] = []
self.vectors: list[list[float]] = []
self.vector_ids: list[int] = []
self.next_id = 0
def add_vector(self, vector: List[float]) -> int:
def add_vector(self, vector: list[float]) -> int:
"""添加向量"""
if len(vector) != self.dimension:
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
@@ -837,7 +838,7 @@ class SimpleVectorIndex:
return vector_id
def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]:
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
"""搜索相似向量"""
if len(query_vector) != self.dimension:
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
@@ -853,7 +854,7 @@ class SimpleVectorIndex:
return results[:limit]
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
"""计算余弦相似度"""
try:
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))

View File

@@ -1,25 +1,24 @@
# -*- coding: utf-8 -*-
"""
记忆激活器
记忆系统的激活器组件
"""
import difflib
import orjson
from typing import List, Dict, Optional
from datetime import datetime
import orjson
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.memory_manager import MemoryResult
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> List:
def get_keywords_from_json(json_str) -> list:
"""
从JSON字符串中提取关键词列表
@@ -81,7 +80,7 @@ class MemoryActivator:
self.cached_keywords = set() # 用于缓存历史关键词
self.last_memory_query_time = 0 # 上次查询记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
"""
激活记忆
"""
@@ -155,7 +154,7 @@ class MemoryActivator:
return self.running_memory
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
"""查询统一记忆系统"""
try:
# 使用记忆系统
@@ -198,7 +197,7 @@ class MemoryActivator:
logger.error(f"查询统一记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
"""
获取即时记忆 - 兼容原有接口(使用统一存储)
"""

View File

@@ -1,25 +1,24 @@
# -*- coding: utf-8 -*-
"""
记忆激活器
记忆系统的激活器组件
"""
import difflib
import orjson
from typing import List, Dict, Optional
from datetime import datetime
import orjson
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.memory_manager import MemoryResult
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> List:
def get_keywords_from_json(json_str) -> list:
"""
从JSON字符串中提取关键词列表
@@ -81,7 +80,7 @@ class MemoryActivator:
self.cached_keywords = set() # 用于缓存历史关键词
self.last_memory_query_time = 0 # 上次查询记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
"""
激活记忆
"""
@@ -155,7 +154,7 @@ class MemoryActivator:
return self.running_memory
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
"""查询统一记忆系统"""
try:
# 使用记忆系统
@@ -198,7 +197,7 @@ class MemoryActivator:
logger.error(f"查询统一记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
"""
获取即时记忆 - 兼容原有接口(使用统一存储)
"""

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
记忆构建模块
从对话流中提取高质量、结构化记忆单元
@@ -33,19 +32,19 @@ import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union, Type
from typing import Any
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import (
MemoryChunk,
MemoryType,
ConfidenceLevel,
ImportanceLevel,
MemoryChunk,
MemoryType,
create_memory_chunk,
)
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -62,8 +61,8 @@ class ExtractionStrategy(Enum):
class ExtractionResult:
"""提取结果"""
memories: List[MemoryChunk]
confidence_scores: List[float]
memories: list[MemoryChunk]
confidence_scores: list[float]
extraction_time: float
strategy_used: ExtractionStrategy
@@ -85,8 +84,8 @@ class MemoryBuilder:
}
async def build_memories(
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float
) -> List[MemoryChunk]:
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float
) -> list[MemoryChunk]:
"""从对话中构建记忆"""
start_time = time.time()
@@ -116,8 +115,8 @@ class MemoryBuilder:
raise
async def _extract_with_llm(
self, text: str, context: Dict[str, Any], user_id: str, timestamp: float
) -> List[MemoryChunk]:
self, text: str, context: dict[str, Any], user_id: str, timestamp: float
) -> list[MemoryChunk]:
"""使用LLM提取记忆"""
try:
prompt = self._build_llm_extraction_prompt(text, context)
@@ -135,7 +134,7 @@ class MemoryBuilder:
logger.error(f"LLM提取失败: {e}")
raise MemoryExtractionError(str(e)) from e
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
def _build_llm_extraction_prompt(self, text: str, context: dict[str, Any]) -> str:
"""构建LLM提取提示"""
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
message_type = context.get("message_type", "normal")
@@ -315,7 +314,7 @@ class MemoryBuilder:
return prompt
def _extract_json_payload(self, response: str) -> Optional[str]:
def _extract_json_payload(self, response: str) -> str | None:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
if not response:
return None
@@ -338,8 +337,8 @@ class MemoryBuilder:
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _parse_llm_response(
self, response: str, user_id: str, timestamp: float, context: Dict[str, Any]
) -> List[MemoryChunk]:
self, response: str, user_id: str, timestamp: float, context: dict[str, Any]
) -> list[MemoryChunk]:
"""解析LLM响应"""
if not response:
raise MemoryExtractionError("LLM未返回任何响应")
@@ -385,7 +384,7 @@ class MemoryBuilder:
bot_display = self._clean_subject_text(bot_display)
memories: List[MemoryChunk] = []
memories: list[MemoryChunk] = []
for mem_data in memory_list:
try:
@@ -460,7 +459,7 @@ class MemoryBuilder:
return memories
def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
def _parse_enum_value(self, enum_cls: type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
"""解析枚举值,兼容数字/字符串表示"""
if isinstance(raw_value, enum_cls):
return raw_value
@@ -514,7 +513,7 @@ class MemoryBuilder:
)
return default
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
def _collect_bot_identifiers(self, context: dict[str, Any] | None) -> set[str]:
identifiers: set[str] = {"bot", "机器人", "ai助手"}
if not context:
return identifiers
@@ -540,7 +539,7 @@ class MemoryBuilder:
return identifiers
def _collect_system_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
def _collect_system_identifiers(self, context: dict[str, Any] | None) -> set[str]:
identifiers: set[str] = set()
if not context:
return identifiers
@@ -568,8 +567,8 @@ class MemoryBuilder:
return identifiers
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
participants: List[str] = []
def _resolve_conversation_participants(self, context: dict[str, Any] | None, user_id: str) -> list[str]:
participants: list[str] = []
if context:
candidate_keys = [
@@ -609,7 +608,7 @@ class MemoryBuilder:
if not participants:
participants = ["对话参与者"]
deduplicated: List[str] = []
deduplicated: list[str] = []
seen = set()
for name in participants:
key = name.lower()
@@ -620,7 +619,7 @@ class MemoryBuilder:
return deduplicated
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
def _resolve_user_display(self, context: dict[str, Any] | None, user_id: str) -> str:
candidate_keys = [
"user_display_name",
"user_name",
@@ -683,7 +682,7 @@ class MemoryBuilder:
return False
def _split_subject_string(self, value: str) -> List[str]:
def _split_subject_string(self, value: str) -> list[str]:
if not value:
return []
@@ -699,12 +698,12 @@ class MemoryBuilder:
subject: Any,
bot_identifiers: set[str],
system_identifiers: set[str],
default_subjects: List[str],
bot_display: Optional[str] = None,
) -> List[str]:
default_subjects: list[str],
bot_display: str | None = None,
) -> list[str]:
defaults = default_subjects or ["对话参与者"]
raw_candidates: List[str] = []
raw_candidates: list[str] = []
if isinstance(subject, list):
for item in subject:
if isinstance(item, str):
@@ -716,7 +715,7 @@ class MemoryBuilder:
elif subject is not None:
raw_candidates.extend(self._split_subject_string(str(subject)))
normalized: List[str] = []
normalized: list[str] = []
bot_primary = self._clean_subject_text(bot_display or "")
for candidate in raw_candidates:
@@ -741,7 +740,7 @@ class MemoryBuilder:
if not normalized:
normalized = list(defaults)
deduplicated: List[str] = []
deduplicated: list[str] = []
seen = set()
for name in normalized:
key = name.lower()
@@ -752,7 +751,7 @@ class MemoryBuilder:
return deduplicated
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
def _extract_value_from_object(self, obj: str | dict[str, Any] | list[Any], keys: list[str]) -> str | None:
if isinstance(obj, dict):
for key in keys:
value = obj.get(key)
@@ -773,9 +772,7 @@ class MemoryBuilder:
return obj.strip() or None
return None
def _compose_display_text(
self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]
) -> str:
def _compose_display_text(self, subjects: list[str], predicate: str, obj: str | dict[str, Any] | list[Any]) -> str:
subject_phrase = "".join(subjects) if subjects else "对话参与者"
predicate = (predicate or "").strip()
@@ -841,7 +838,7 @@ class MemoryBuilder:
return f"{subject_phrase}{predicate}".strip()
return subject_phrase
def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]:
def _validate_and_enhance_memories(self, memories: list[MemoryChunk], context: dict[str, Any]) -> list[MemoryChunk]:
"""验证和增强记忆"""
validated_memories = []
@@ -876,7 +873,7 @@ class MemoryBuilder:
return True
def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk:
def _enhance_memory(self, memory: MemoryChunk, context: dict[str, Any]) -> MemoryChunk:
"""增强记忆块"""
# 时间规范化处理
self._normalize_time_in_memory(memory)
@@ -985,7 +982,7 @@ class MemoryBuilder:
total_confidence / self.extraction_stats["successful_extractions"]
)
def get_extraction_stats(self) -> Dict[str, Any]:
def get_extraction_stats(self) -> dict[str, Any]:
"""获取提取统计信息"""
return self.extraction_stats.copy()

View File

@@ -1,18 +1,19 @@
# -*- coding: utf-8 -*-
"""
结构化记忆单元设计
实现高质量、结构化的记忆单元,符合文档设计规范
"""
import hashlib
import time
import uuid
import orjson
from typing import Dict, List, Optional, Any, Union, Iterable
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
import hashlib
from typing import Any
import numpy as np
import orjson
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -56,17 +57,17 @@ class ImportanceLevel(Enum):
class ContentStructure:
"""主谓宾结构,包含自然语言描述"""
subject: Union[str, List[str]]
subject: str | list[str]
predicate: str
object: Union[str, Dict]
object: str | dict
display: str = ""
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
def from_dict(cls, data: dict[str, Any]) -> "ContentStructure":
"""从字典创建实例"""
return cls(
subject=data.get("subject", ""),
@@ -75,7 +76,7 @@ class ContentStructure:
display=data.get("display", ""),
)
def to_subject_list(self) -> List[str]:
def to_subject_list(self) -> list[str]:
"""将主语转换为列表形式"""
if isinstance(self.subject, list):
return [s for s in self.subject if isinstance(s, str) and s.strip()]
@@ -99,7 +100,7 @@ class MemoryMetadata:
# 基础信息
memory_id: str # 唯一标识符
user_id: str # 用户ID
chat_id: Optional[str] = None # 聊天ID群聊或私聊
chat_id: str | None = None # 聊天ID群聊或私聊
# 时间信息
created_at: float = 0.0 # 创建时间戳
@@ -124,9 +125,9 @@ class MemoryMetadata:
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
# 来源信息
source_context: Optional[str] = None # 来源上下文片段
source_context: str | None = None # 来源上下文片段
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
source: Optional[str] = None
source: str | None = None
def __post_init__(self):
"""后初始化处理"""
@@ -209,7 +210,7 @@ class MemoryMetadata:
# 设置最小和最大阈值
return max(7.0, min(threshold, 365.0)) # 7天到1年之间
def should_forget(self, current_time: Optional[float] = None) -> bool:
def should_forget(self, current_time: float | None = None) -> bool:
"""判断是否应该遗忘"""
if current_time is None:
current_time = time.time()
@@ -222,7 +223,7 @@ class MemoryMetadata:
return days_since_activation > self.forgetting_threshold
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
"""判断是否处于休眠状态(长期未激活)"""
if current_time is None:
current_time = time.time()
@@ -230,7 +231,7 @@ class MemoryMetadata:
days_since_last_access = (current_time - self.last_accessed) / 86400
return days_since_last_access > inactive_days
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"memory_id": self.memory_id,
@@ -252,7 +253,7 @@ class MemoryMetadata:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata":
def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata":
"""从字典创建实例"""
return cls(
memory_id=data.get("memory_id", ""),
@@ -286,17 +287,17 @@ class MemoryChunk:
memory_type: MemoryType # 记忆类型
# 扩展信息
keywords: List[str] = field(default_factory=list) # 关键词列表
tags: List[str] = field(default_factory=list) # 标签列表
categories: List[str] = field(default_factory=list) # 分类列表
keywords: list[str] = field(default_factory=list) # 关键词列表
tags: list[str] = field(default_factory=list) # 标签列表
categories: list[str] = field(default_factory=list) # 分类列表
# 语义信息
embedding: Optional[List[float]] = None # 语义向量
semantic_hash: Optional[str] = None # 语义哈希值
embedding: list[float] | None = None # 语义向量
semantic_hash: str | None = None # 语义哈希值
# 关联信息
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表
temporal_context: dict[str, Any] | None = None # 时间上下文
def __post_init__(self):
"""后初始化处理"""
@@ -310,7 +311,7 @@ class MemoryChunk:
try:
# 使用向量和内容生成稳定的哈希
content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}"
content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}"
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
hash_input = f"{content_str}|{embedding_str}"
@@ -342,7 +343,7 @@ class MemoryChunk:
return self.content.display or str(self.content)
@property
def subjects(self) -> List[str]:
def subjects(self) -> list[str]:
"""获取主语列表"""
return self.content.to_subject_list()
@@ -354,11 +355,11 @@ class MemoryChunk:
"""更新相关度评分"""
self.metadata.update_relevance(new_score)
def should_forget(self, current_time: Optional[float] = None) -> bool:
def should_forget(self, current_time: float | None = None) -> bool:
"""判断是否应该遗忘"""
return self.metadata.should_forget(current_time)
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
"""判断是否处于休眠状态(长期未激活)"""
return self.metadata.is_dormant(current_time, inactive_days)
@@ -386,7 +387,7 @@ class MemoryChunk:
if memory_id and memory_id not in self.related_memories:
self.related_memories.append(memory_id)
def set_embedding(self, embedding: List[float]):
def set_embedding(self, embedding: list[float]):
"""设置语义向量"""
self.embedding = embedding
self._generate_semantic_hash()
@@ -415,7 +416,7 @@ class MemoryChunk:
logger.warning(f"计算记忆相似度失败: {e}")
return 0.0
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为完整的字典格式"""
return {
"metadata": self.metadata.to_dict(),
@@ -431,7 +432,7 @@ class MemoryChunk:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk":
def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk":
"""从字典创建实例"""
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
content = ContentStructure.from_dict(data.get("content", {}))
@@ -541,7 +542,7 @@ class MemoryChunk:
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str:
"""根据主谓宾生成自然语言描述"""
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
subject_part = "".join(subjects_clean) if subjects_clean else "对话参与者"
@@ -569,15 +570,15 @@ def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str,
def create_memory_chunk(
user_id: str,
subject: Union[str, List[str]],
subject: str | list[str],
predicate: str,
obj: Union[str, Dict],
obj: str | dict,
memory_type: MemoryType,
chat_id: Optional[str] = None,
source_context: Optional[str] = None,
chat_id: str | None = None,
source_context: str | None = None,
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
display: Optional[str] = None,
display: str | None = None,
**kwargs,
) -> MemoryChunk:
"""便捷的内存块创建函数"""
@@ -593,10 +594,10 @@ def create_memory_chunk(
source_context=source_context,
)
subjects: List[str]
subjects: list[str]
if isinstance(subject, list):
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
subject_payload: Union[str, List[str]] = subjects
subject_payload: str | list[str] = subjects
else:
cleaned = subject.strip() if isinstance(subject, str) else ""
subjects = [cleaned] if cleaned else []

View File

@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
"""
智能记忆遗忘引擎
基于重要程度、置信度和激活频率的智能遗忘机制
"""
import time
import asyncio
from typing import List, Dict, Optional, Tuple
from datetime import datetime
import time
from dataclasses import dataclass
from datetime import datetime
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, ImportanceLevel, ConfidenceLevel
logger = get_logger(__name__)
@@ -65,7 +63,7 @@ class ForgettingConfig:
class MemoryForgettingEngine:
"""智能记忆遗忘引擎"""
def __init__(self, config: Optional[ForgettingConfig] = None):
def __init__(self, config: ForgettingConfig | None = None):
self.config = config or ForgettingConfig()
self.stats = ForgettingStats()
self._last_forgetting_check = 0.0
@@ -116,7 +114,7 @@ class MemoryForgettingEngine:
# 确保在合理范围内
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断记忆是否应该被遗忘
@@ -155,7 +153,7 @@ class MemoryForgettingEngine:
return should_forget
def is_dormant_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断记忆是否处于休眠状态
@@ -168,7 +166,7 @@ class MemoryForgettingEngine:
"""
return memory.is_dormant(current_time, self.config.dormant_threshold_days)
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断是否应该强制遗忘休眠记忆
@@ -189,7 +187,7 @@ class MemoryForgettingEngine:
days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400
return days_since_last_access > self.config.force_forget_dormant_days
async def check_memories_for_forgetting(self, memories: List[MemoryChunk]) -> Tuple[List[str], List[str]]:
async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]:
"""
检查记忆列表,识别需要遗忘的记忆
@@ -241,7 +239,7 @@ class MemoryForgettingEngine:
return normal_forgetting_ids, force_forgetting_ids
async def perform_forgetting_check(self, memories: List[MemoryChunk]) -> Dict[str, any]:
async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]:
"""
执行完整的遗忘检查流程
@@ -314,7 +312,7 @@ class MemoryForgettingEngine:
except Exception as e:
logger.error(f"定期遗忘检查失败: {e}", exc_info=True)
def get_forgetting_stats(self) -> Dict[str, any]:
def get_forgetting_stats(self) -> dict[str, any]:
"""获取遗忘统计信息"""
return {
"total_checked": self.stats.total_checked,

View File

@@ -1,16 +1,14 @@
# -*- coding: utf-8 -*-
"""
记忆融合与去重机制
避免记忆碎片化,确保长期记忆库的高质量
"""
import time
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
logger = get_logger(__name__)
@@ -22,9 +20,9 @@ class FusionResult:
original_count: int
fused_count: int
removed_duplicates: int
merged_memories: List[MemoryChunk]
merged_memories: list[MemoryChunk]
fusion_time: float
details: List[str]
details: list[str]
@dataclass
@@ -32,9 +30,9 @@ class DuplicateGroup:
"""重复记忆组"""
group_id: str
memories: List[MemoryChunk]
similarity_matrix: List[List[float]]
representative_memory: Optional[MemoryChunk] = None
memories: list[MemoryChunk]
similarity_matrix: list[list[float]]
representative_memory: MemoryChunk | None = None
class MemoryFusionEngine:
@@ -59,8 +57,8 @@ class MemoryFusionEngine:
}
async def fuse_memories(
self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None
) -> List[MemoryChunk]:
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None
) -> list[MemoryChunk]:
"""融合记忆列表"""
start_time = time.time()
@@ -106,8 +104,8 @@ class MemoryFusionEngine:
return new_memories # 失败时返回原始记忆
async def _detect_duplicate_groups(
self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk]
) -> List[DuplicateGroup]:
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk]
) -> list[DuplicateGroup]:
"""检测重复记忆组"""
all_memories = new_memories + existing_memories
new_memory_ids = {memory.memory_id for memory in new_memories}
@@ -212,7 +210,7 @@ class MemoryFusionEngine:
jaccard_similarity = len(intersection) / len(union)
return jaccard_similarity
def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float:
def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float:
"""计算关键词相似度"""
if not keywords1 or not keywords2:
return 0.0
@@ -302,7 +300,7 @@ class MemoryFusionEngine:
return best_memory
async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]:
async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None:
"""融合记忆组"""
if not group.memories:
return None
@@ -328,7 +326,7 @@ class MemoryFusionEngine:
# 返回置信度最高的记忆
return max(group.memories, key=lambda m: m.metadata.confidence.value)
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk:
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk:
"""合并记忆属性"""
# 创建基础记忆的深拷贝
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
@@ -395,7 +393,7 @@ class MemoryFusionEngine:
source_ids = [m.memory_id[:8] for m in group.memories]
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]:
def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]:
"""合并时间上下文"""
contexts = [m.temporal_context for m in memories if m.temporal_context]
@@ -426,8 +424,8 @@ class MemoryFusionEngine:
return merged_context
async def incremental_fusion(
self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk]
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk]
) -> tuple[MemoryChunk, list[MemoryChunk]]:
"""增量融合(单个新记忆与现有记忆融合)"""
# 寻找相似记忆
similar_memories = []
@@ -493,7 +491,7 @@ class MemoryFusionEngine:
except Exception as e:
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
def get_fusion_stats(self) -> Dict[str, Any]:
def get_fusion_stats(self) -> dict[str, Any]:
"""获取融合统计信息"""
return self.fusion_stats.copy()

View File

@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
"""
记忆系统管理器
替代原有的 Hippocampus 和 instant_memory 系统
"""
import re
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_system import initialize_memory_system
from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -27,14 +25,14 @@ class MemoryResult:
timestamp: float
source: str = "memory"
relevance_score: float = 0.0
structure: Dict[str, Any] | None = None
structure: dict[str, Any] | None = None
class MemoryManager:
"""记忆系统管理器 - 替代原有的 HippocampusManager"""
def __init__(self):
self.memory_system: Optional[MemorySystem] = None
self.memory_system: MemorySystem | None = None
self.is_initialized = False
self.user_cache = {} # 用户记忆缓存
@@ -63,8 +61,8 @@ class MemoryManager:
logger.info("正在初始化记忆系统...")
# 获取LLM模型
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
@@ -121,7 +119,7 @@ class MemoryManager:
max_memory_length: int = 2,
time_weight: float = 1.0,
keyword_weight: float = 1.0,
) -> List[Tuple[str, str]]:
) -> list[tuple[str, str]]:
"""从文本获取相关记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
@@ -152,8 +150,8 @@ class MemoryManager:
return []
async def get_memory_from_topic(
self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> List[Tuple[str, str]]:
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> list[tuple[str, str]]:
"""从关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
@@ -208,8 +206,8 @@ class MemoryManager:
return []
async def process_conversation(
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None
) -> List[MemoryChunk]:
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None
) -> list[MemoryChunk]:
"""处理对话并构建记忆 - 新增功能"""
if not self.is_initialized or not self.memory_system:
return []
@@ -235,8 +233,8 @@ class MemoryManager:
return []
async def get_enhanced_memory_context(
self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5
) -> List[MemoryResult]:
self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5
) -> list[MemoryResult]:
"""获取增强记忆上下文 - 新增功能"""
if not self.is_initialized or not self.memory_system:
return []
@@ -267,7 +265,7 @@ class MemoryManager:
logger.error(f"get_enhanced_memory_context 失败: {e}")
return []
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]:
"""将记忆块转换为更易读的文本描述"""
structure = memory.content.to_dict()
if memory.display:
@@ -289,7 +287,7 @@ class MemoryManager:
return formatted, structure
def _format_subject(self, subject: Optional[str], memory: MemoryChunk) -> str:
def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str:
if not subject:
return "该用户"
@@ -299,7 +297,7 @@ class MemoryManager:
return "该聊天"
return self._clean_text(subject)
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> Optional[str]:
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None:
predicate = (predicate or "").strip()
obj_value = obj
@@ -446,10 +444,10 @@ class MemoryManager:
text = self._truncate(str(obj).strip())
return self._clean_text(text)
def _extract_from_object(self, obj: Any, keys: List[str]) -> Optional[str]:
def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None:
if isinstance(obj, dict):
for key in keys:
if key in obj and obj[key]:
if obj.get(key):
value = obj[key]
if isinstance(value, (dict, list)):
return self._clean_text(self._format_object(value))

View File

@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
"""
记忆元数据索引管理器
使用JSON文件存储记忆元数据支持快速模糊搜索和过滤
"""
import orjson
import threading
from pathlib import Path
from typing import Dict, List, Optional, Set, Any
from dataclasses import dataclass, asdict
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
import orjson
from src.common.logger import get_logger
@@ -25,10 +25,10 @@ class MemoryMetadataIndexEntry:
# 分类信息
memory_type: str # MemoryType.value
subjects: List[str] # 主语列表
objects: List[str] # 宾语列表
keywords: List[str] # 关键词列表
tags: List[str] # 标签列表
subjects: list[str] # 主语列表
objects: list[str] # 宾语列表
keywords: list[str] # 关键词列表
tags: list[str] # 标签列表
# 数值字段(用于范围过滤)
importance: int # ImportanceLevel.value (1-4)
@@ -37,8 +37,8 @@ class MemoryMetadataIndexEntry:
access_count: int # 访问次数
# 可选字段
chat_id: Optional[str] = None
content_preview: Optional[str] = None # 内容预览前100字符
chat_id: str | None = None
content_preview: str | None = None # 内容预览前100字符
class MemoryMetadataIndex:
@@ -46,13 +46,13 @@ class MemoryMetadataIndex:
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
self.index_file = Path(index_file)
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
# 倒排索引(用于快速查找)
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
self.lock = threading.RLock()
@@ -178,7 +178,7 @@ class MemoryMetadataIndex:
self._remove_from_inverted_indices(memory_id)
del self.index[memory_id]
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
"""批量添加或更新"""
with self.lock:
for entry in entries:
@@ -191,18 +191,18 @@ class MemoryMetadataIndex:
def search(
self,
memory_types: Optional[List[str]] = None,
subjects: Optional[List[str]] = None,
keywords: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
importance_min: Optional[int] = None,
importance_max: Optional[int] = None,
created_after: Optional[float] = None,
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
keywords: list[str] | None = None,
tags: list[str] | None = None,
importance_min: int | None = None,
importance_max: int | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
flexible_mode: bool = True, # 新增:灵活匹配模式
) -> List[str]:
) -> list[str]:
"""
搜索符合条件的记忆ID列表支持模糊匹配
@@ -237,14 +237,14 @@ class MemoryMetadataIndex:
def _search_flexible(
self,
memory_types: Optional[List[str]] = None,
subjects: Optional[List[str]] = None,
created_after: Optional[float] = None,
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
**kwargs, # 接受但不使用的参数
) -> List[str]:
) -> list[str]:
"""
灵活搜索模式2/4项匹配即可支持部分匹配
@@ -374,20 +374,20 @@ class MemoryMetadataIndex:
def _search_strict(
self,
memory_types: Optional[List[str]] = None,
subjects: Optional[List[str]] = None,
keywords: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
importance_min: Optional[int] = None,
importance_max: Optional[int] = None,
created_after: Optional[float] = None,
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None,
) -> List[str]:
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
keywords: list[str] | None = None,
tags: list[str] | None = None,
importance_min: int | None = None,
importance_max: int | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[str]:
"""严格搜索模式(原有逻辑)"""
# 初始候选集(所有记忆)
candidate_ids: Optional[Set[str]] = None
candidate_ids: set[str] | None = None
# 用户过滤(必选)
if user_id:
@@ -471,11 +471,11 @@ class MemoryMetadataIndex:
return result_ids
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
"""获取单个索引条目"""
return self.index.get(memory_id)
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""获取索引统计信息"""
with self.lock:
return {

View File

@@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-
"""记忆检索查询规划器"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any
import orjson
@@ -21,16 +20,16 @@ class MemoryQueryPlan:
"""查询规划结果"""
semantic_query: str
memory_types: List[MemoryType] = field(default_factory=list)
subject_includes: List[str] = field(default_factory=list)
object_includes: List[str] = field(default_factory=list)
required_keywords: List[str] = field(default_factory=list)
optional_keywords: List[str] = field(default_factory=list)
owner_filters: List[str] = field(default_factory=list)
memory_types: list[MemoryType] = field(default_factory=list)
subject_includes: list[str] = field(default_factory=list)
object_includes: list[str] = field(default_factory=list)
required_keywords: list[str] = field(default_factory=list)
optional_keywords: list[str] = field(default_factory=list)
owner_filters: list[str] = field(default_factory=list)
recency_preference: str = "any"
limit: int = 10
emphasis: Optional[str] = None
raw_plan: Dict[str, Any] = field(default_factory=dict)
emphasis: str | None = None
raw_plan: dict[str, Any] = field(default_factory=dict)
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
if not self.semantic_query:
@@ -46,11 +45,11 @@ class MemoryQueryPlan:
class MemoryQueryPlanner:
"""基于小模型的记忆检索查询规划器"""
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10):
self.model = planner_model
self.default_limit = default_limit
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan:
if not self.model:
logger.debug("未提供查询规划模型,使用默认规划")
return self._default_plan(query_text)
@@ -82,10 +81,10 @@ class MemoryQueryPlanner:
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
def _collect_list(key: str) -> List[str]:
def _collect_list(key: str) -> list[str]:
value = data.get(key)
if isinstance(value, str):
return [value]
@@ -94,7 +93,7 @@ class MemoryQueryPlanner:
return []
memory_type_values = _collect_list("memory_types")
memory_types: List[MemoryType] = []
memory_types: list[MemoryType] = []
for item in memory_type_values:
if not item:
continue
@@ -123,7 +122,7 @@ class MemoryQueryPlanner:
)
return plan
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str:
participants = context.get("participants") or context.get("speaker_names") or []
if isinstance(participants, str):
participants = [participants]
@@ -206,7 +205,7 @@ class MemoryQueryPlanner:
请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。
"""
def _extract_json_payload(self, response: str) -> Optional[str]:
def _extract_json_payload(self, response: str) -> str | None:
if not response:
return None

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
精准记忆系统核心模块
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
@@ -6,26 +5,27 @@
"""
import asyncio
import time
import orjson
import re
import hashlib
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
import re
import time
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
from enum import Enum
from typing import TYPE_CHECKING, Any
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger(__name__)
@@ -121,7 +121,7 @@ class MemorySystemConfig:
class MemorySystem:
"""精准记忆系统核心类"""
def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None):
def __init__(self, llm_model: LLMRequest | None = None, config: MemorySystemConfig | None = None):
self.config = config or MemorySystemConfig.from_global_config()
self.llm_model = llm_model
self.status = MemorySystemStatus.INITIALIZING
@@ -131,7 +131,7 @@ class MemorySystem:
self.fusion_engine: MemoryFusionEngine = None
self.unified_storage = None # 统一存储系统
self.query_planner: MemoryQueryPlanner = None
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
self.forgetting_engine: MemoryForgettingEngine | None = None
# LLM模型
self.value_assessment_model: LLMRequest = None
@@ -143,10 +143,10 @@ class MemorySystem:
self.last_retrieval_time = None
# 构建节流记录
self._last_memory_build_times: Dict[str, float] = {}
self._last_memory_build_times: dict[str, float] = {}
# 记忆指纹缓存,用于快速检测重复记忆
self._memory_fingerprints: Dict[str, str] = {}
self._memory_fingerprints: dict[str, str] = {}
logger.info("MemorySystem 初始化开始")
@@ -210,7 +210,7 @@ class MemorySystem:
raise
# 初始化遗忘引擎
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig
from src.chat.memory_system.memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine
# 从全局配置创建遗忘引擎配置
forgetting_config = ForgettingConfig(
@@ -241,7 +241,7 @@ class MemorySystem:
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
planner_model: Optional[LLMRequest] = None
planner_model: LLMRequest | None = None
try:
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
except Exception as planner_exc:
@@ -261,8 +261,8 @@ class MemorySystem:
raise
async def retrieve_memories_for_building(
self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5
) -> List[MemoryChunk]:
self, query_text: str, user_id: str | None = None, context: dict[str, Any] | None = None, limit: int = 5
) -> list[MemoryChunk]:
"""在构建记忆时检索相关记忆,使用统一存储系统
Args:
@@ -302,8 +302,8 @@ class MemorySystem:
return []
async def build_memory_from_conversation(
self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None
) -> List[MemoryChunk]:
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None
) -> list[MemoryChunk]:
"""从对话中构建记忆
Args:
@@ -318,8 +318,8 @@ class MemorySystem:
self.status = MemorySystemStatus.BUILDING
start_time = time.time()
build_scope_key: Optional[str] = None
build_marker_time: Optional[float] = None
build_scope_key: str | None = None
build_marker_time: float | None = None
try:
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
@@ -408,7 +408,7 @@ class MemorySystem:
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
raise
def _log_memory_preview(self, memories: List[MemoryChunk]) -> None:
def _log_memory_preview(self, memories: list[MemoryChunk]) -> None:
"""在控制台输出记忆预览,便于人工检查"""
if not memories:
logger.info("📝 本次未生成新的记忆")
@@ -425,12 +425,12 @@ class MemorySystem:
f"置信度={memory.metadata.confidence.name} | 内容={text}"
)
async def _collect_fusion_candidates(self, new_memories: List[MemoryChunk]) -> List[MemoryChunk]:
async def _collect_fusion_candidates(self, new_memories: list[MemoryChunk]) -> list[MemoryChunk]:
"""收集与新记忆相似的现有记忆,便于融合去重"""
if not new_memories:
return []
candidate_ids: Set[str] = set()
candidate_ids: set[str] = set()
new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)}
# 基于指纹的直接匹配
@@ -493,7 +493,7 @@ class MemorySystem:
continue
candidate_ids.add(memory_id)
existing_candidates: List[MemoryChunk] = []
existing_candidates: list[MemoryChunk] = []
cache = self.unified_storage.memory_cache if self.unified_storage else {}
for candidate_id in candidate_ids:
if candidate_id in new_memory_ids:
@@ -511,7 +511,7 @@ class MemorySystem:
return existing_candidates
async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]:
async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]:
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
start_time = time.time()
@@ -559,12 +559,12 @@ class MemorySystem:
async def retrieve_relevant_memories(
self,
query_text: Optional[str] = None,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
query_text: str | None = None,
user_id: str | None = None,
context: dict[str, Any] | None = None,
limit: int = 5,
**kwargs,
) -> List[MemoryChunk]:
) -> list[MemoryChunk]:
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
raw_query = query_text or kwargs.get("query")
if not raw_query:
@@ -750,7 +750,7 @@ class MemorySystem:
raise
@staticmethod
def _extract_json_payload(response: str) -> Optional[str]:
def _extract_json_payload(response: str) -> str | None:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
if not response:
return None
@@ -773,10 +773,10 @@ class MemorySystem:
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _normalize_context(
self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float]
) -> Dict[str, Any]:
self, raw_context: dict[str, Any] | None, user_id: str | None, timestamp: float | None
) -> dict[str, Any]:
"""标准化上下文,确保必备字段存在且格式正确"""
context: Dict[str, Any] = {}
context: dict[str, Any] = {}
if raw_context:
try:
context = dict(raw_context)
@@ -822,7 +822,7 @@ class MemorySystem:
return context
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: Dict[str, Any]) -> Dict[str, Any]:
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: dict[str, Any]) -> dict[str, Any]:
"""构建包含未读消息综合上下文的增强查询上下文
Args:
@@ -861,7 +861,7 @@ class MemorySystem:
return enhanced_context
async def _collect_unread_messages_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
async def _collect_unread_messages_context(self, stream_id: str) -> dict[str, Any] | None:
"""收集未读消息的综合上下文信息
Args:
@@ -953,7 +953,7 @@ class MemorySystem:
logger.warning(f"收集未读消息上下文失败: {e}", exc_info=True)
return None
def _build_unread_context_summary(self, messages_summary: List[Dict[str, Any]]) -> str:
def _build_unread_context_summary(self, messages_summary: list[dict[str, Any]]) -> str:
"""构建未读消息的文本摘要
Args:
@@ -974,7 +974,7 @@ class MemorySystem:
return " | ".join(summary_parts)
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
async def _resolve_conversation_context(self, fallback_text: str, context: dict[str, Any] | None) -> str:
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
if not context:
return fallback_text
@@ -1043,11 +1043,11 @@ class MemorySystem:
# 回退到传入文本
return fallback_text
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
def _get_build_scope_key(self, context: dict[str, Any], user_id: str | None) -> str | None:
"""确定用于节流控制的记忆构建作用域"""
return "global_scope"
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
def _determine_history_limit(self, context: dict[str, Any]) -> int:
"""确定历史消息获取数量限制在30-50之间"""
default_limit = 40
candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
@@ -1065,12 +1065,12 @@ class MemorySystem:
return history_limit
def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]:
def _format_history_messages(self, messages: list["DatabaseMessages"]) -> str | None:
"""将历史消息格式化为可供LLM处理的多轮对话文本"""
if not messages:
return None
lines: List[str] = []
lines: list[str] = []
for msg in messages:
try:
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None)
@@ -1105,7 +1105,7 @@ class MemorySystem:
return "\n".join(lines) if lines else None
async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float:
async def _assess_information_value(self, text: str, context: dict[str, Any]) -> float:
"""评估信息价值
Args:
@@ -1201,7 +1201,7 @@ class MemorySystem:
logger.error(f"信息价值评估失败: {e}", exc_info=True)
return 0.5 # 默认中等价值
async def _store_memories_unified(self, memory_chunks: List[MemoryChunk]) -> int:
async def _store_memories_unified(self, memory_chunks: list[MemoryChunk]) -> int:
"""使用统一存储系统存储记忆块"""
if not memory_chunks or not self.unified_storage:
return 0
@@ -1222,7 +1222,7 @@ class MemorySystem:
return 0
# 保留原有方法以兼容旧代码
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
async def _store_memories(self, memory_chunks: list[MemoryChunk]) -> int:
"""兼容性方法:重定向到统一存储"""
return await self._store_memories_unified(memory_chunks)
@@ -1271,7 +1271,7 @@ class MemorySystem:
key = self._fingerprint_key(memory.user_id, fingerprint)
self._memory_fingerprints[key] = memory.memory_id
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None:
for memory in memories:
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
@@ -1302,9 +1302,9 @@ class MemorySystem:
@staticmethod
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
return f"{str(user_id)}:{fingerprint}"
return f"{user_id!s}:{fingerprint}"
def get_system_stats(self) -> Dict[str, Any]:
def get_system_stats(self) -> dict[str, Any]:
"""获取系统统计信息"""
return {
"status": self.status.value,
@@ -1314,7 +1314,7 @@ class MemorySystem:
"config": asdict(self.config),
}
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
"""根据查询和上下文为记忆计算匹配分数"""
tokens_query = self._tokenize_text(query_text)
tokens_memory = self._tokenize_text(memory.text_content)
@@ -1338,7 +1338,7 @@ class MemorySystem:
final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost
return max(0.0, min(1.0, final_score))
def _tokenize_text(self, text: str) -> Set[str]:
def _tokenize_text(self, text: str) -> set[str]:
"""简单分词,兼容中英文"""
if not text:
return set()
@@ -1450,7 +1450,7 @@ def get_memory_system() -> MemorySystem:
return memory_system
async def initialize_memory_system(llm_model: Optional[LLMRequest] = None):
async def initialize_memory_system(llm_model: LLMRequest | None = None):
"""初始化全局记忆系统"""
global memory_system
if memory_system is None:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
基于Vector DB的统一记忆存储系统 V2
使用ChromaDB作为底层存储替代JSON存储方式
@@ -11,20 +10,21 @@
- 自动清理过期记忆
"""
import time
import orjson
import asyncio
import threading
from typing import Dict, List, Optional, Tuple, Any
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from src.common.logger import get_logger
from src.common.vector_db import vector_db_service
from src.chat.utils.utils import get_embedding
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
import orjson
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry
from src.chat.utils.utils import get_embedding
from src.common.logger import get_logger
from src.common.vector_db import vector_db_service
logger = get_logger(__name__)
@@ -32,7 +32,7 @@ logger = get_logger(__name__)
_ENUM_MAPPINGS_CACHE = {}
def _build_enum_mapping(enum_class: type) -> Dict[str, Any]:
def _build_enum_mapping(enum_class: type) -> dict[str, Any]:
"""构建枚举类的完整映射表
Args:
@@ -145,7 +145,7 @@ class VectorMemoryStorage:
"""基于Vector DB的记忆存储系统"""
def __init__(self, config: Optional[VectorStorageConfig] = None):
def __init__(self, config: VectorStorageConfig | None = None):
# 默认从全局配置读取如果没有传入config
if config is None:
try:
@@ -163,15 +163,15 @@ class VectorMemoryStorage:
self.vector_db_service = vector_db_service
# 内存缓存
self.memory_cache: Dict[str, MemoryChunk] = {}
self.cache_timestamps: Dict[str, float] = {}
self.memory_cache: dict[str, MemoryChunk] = {}
self.cache_timestamps: dict[str, float] = {}
self._cache = self.memory_cache # 别名,兼容旧代码
# 元数据索引管理器JSON文件索引
self.metadata_index = MemoryMetadataIndex()
# 遗忘引擎
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
self.forgetting_engine: MemoryForgettingEngine | None = None
if self.config.enable_forgetting:
self.forgetting_engine = MemoryForgettingEngine()
@@ -267,7 +267,7 @@ class VectorMemoryStorage:
except Exception as e:
logger.error(f"自动清理失败: {e}")
def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]:
def _memory_to_vector_format(self, memory: MemoryChunk) -> dict[str, Any]:
"""将MemoryChunk转换为向量存储格式"""
try:
# 获取memory_id
@@ -323,7 +323,7 @@ class VectorMemoryStorage:
logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True)
raise
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
def _vector_result_to_memory(self, document: str, metadata: dict[str, Any]) -> MemoryChunk | None:
"""将Vector DB结果转换为MemoryChunk"""
try:
# 从元数据中恢复完整记忆
@@ -440,7 +440,7 @@ class VectorMemoryStorage:
logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值")
return default
def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]:
def _get_from_cache(self, memory_id: str) -> MemoryChunk | None:
"""从缓存获取记忆"""
if not self.config.enable_caching:
return None
@@ -472,7 +472,7 @@ class VectorMemoryStorage:
self.memory_cache[memory_id] = memory
self.cache_timestamps[memory_id] = time.time()
async def store_memories(self, memories: List[MemoryChunk]) -> int:
async def store_memories(self, memories: list[MemoryChunk]) -> int:
"""批量存储记忆"""
if not memories:
return 0
@@ -603,11 +603,11 @@ class VectorMemoryStorage:
self,
query_text: str,
limit: int = 10,
similarity_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
similarity_threshold: float | None = None,
filters: dict[str, Any] | None = None,
# 新增元数据过滤参数用于JSON索引粗筛
metadata_filters: Optional[Dict[str, Any]] = None,
) -> List[Tuple[MemoryChunk, float]]:
metadata_filters: dict[str, Any] | None = None,
) -> list[tuple[MemoryChunk, float]]:
"""
搜索相似记忆(混合索引模式)
@@ -632,7 +632,7 @@ class VectorMemoryStorage:
try:
# === 阶段一JSON元数据粗筛可选 ===
candidate_ids: Optional[List[str]] = None
candidate_ids: list[str] | None = None
if metadata_filters:
logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}")
candidate_ids = self.metadata_index.search(
@@ -746,7 +746,7 @@ class VectorMemoryStorage:
logger.error(f"搜索相似记忆失败: {e}")
return []
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
"""根据ID获取记忆"""
# 首先尝试从缓存获取
memory = self._get_from_cache(memory_id)
@@ -772,7 +772,7 @@ class VectorMemoryStorage:
return None
async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]:
async def get_memories_by_filters(self, filters: dict[str, Any], limit: int = 100) -> list[MemoryChunk]:
"""根据过滤条件获取记忆"""
try:
results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit)
@@ -848,7 +848,7 @@ class VectorMemoryStorage:
logger.error(f"删除记忆 {memory_id} 失败: {e}")
return False
async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int:
async def delete_memories_by_filters(self, filters: dict[str, Any]) -> int:
"""根据过滤条件批量删除记忆"""
try:
# 先获取要删除的记忆ID
@@ -880,7 +880,7 @@ class VectorMemoryStorage:
logger.error(f"批量删除记忆失败: {e}")
return 0
async def perform_forgetting_check(self) -> Dict[str, Any]:
async def perform_forgetting_check(self) -> dict[str, Any]:
"""执行遗忘检查"""
if not self.forgetting_engine:
return {"error": "遗忘引擎未启用"}
@@ -925,7 +925,7 @@ class VectorMemoryStorage:
logger.error(f"执行遗忘检查失败: {e}")
return {"error": str(e)}
def get_storage_stats(self) -> Dict[str, Any]:
def get_storage_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
try:
current_total = vector_db_service.count(self.config.memory_collection)
@@ -960,7 +960,7 @@ class VectorMemoryStorage:
_global_vector_storage = None
def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage:
def get_vector_memory_storage(config: VectorStorageConfig | None = None) -> VectorMemoryStorage:
"""获取全局Vector记忆存储实例"""
global _global_vector_storage
@@ -974,15 +974,15 @@ def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> V
class VectorMemoryStorageAdapter:
"""适配器类提供与原UnifiedMemoryStorage兼容的接口"""
def __init__(self, config: Optional[VectorStorageConfig] = None):
def __init__(self, config: VectorStorageConfig | None = None):
self.storage = VectorMemoryStorage(config)
async def store_memories(self, memories: List[MemoryChunk]) -> int:
async def store_memories(self, memories: list[MemoryChunk]) -> int:
return await self.storage.store_memories(memories)
async def search_similar_memories(
self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None
) -> List[Tuple[str, float]]:
self, query_text: str, limit: int = 10, scope_id: str | None = None, filters: dict[str, Any] | None = None
) -> list[tuple[str, float]]:
results = await self.storage.search_similar_memories(query_text, limit, filters=filters)
# 转换为原格式:(memory_id, similarity)
return [
@@ -990,7 +990,7 @@ class VectorMemoryStorageAdapter:
for memory, similarity in results
]
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
return self.storage.get_storage_stats()