re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -1,37 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
简化记忆系统模块
|
||||
移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制
|
||||
"""
|
||||
|
||||
# 核心数据结构
|
||||
# 激活器
|
||||
from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator
|
||||
from .memory_chunk import (
|
||||
ConfidenceLevel,
|
||||
ContentStructure,
|
||||
ImportanceLevel,
|
||||
MemoryChunk,
|
||||
MemoryMetadata,
|
||||
ContentStructure,
|
||||
MemoryType,
|
||||
ImportanceLevel,
|
||||
ConfidenceLevel,
|
||||
create_memory_chunk,
|
||||
)
|
||||
|
||||
# 兼容性别名
|
||||
from .memory_chunk import MemoryChunk as Memory
|
||||
|
||||
# 遗忘引擎
|
||||
from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine
|
||||
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||
|
||||
# 记忆核心系统
|
||||
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
|
||||
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
|
||||
|
||||
# 记忆管理器
|
||||
from .memory_manager import MemoryManager, MemoryResult, memory_manager
|
||||
|
||||
# 激活器
|
||||
from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator
|
||||
# 记忆核心系统
|
||||
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
|
||||
|
||||
# 兼容性别名
|
||||
from .memory_chunk import MemoryChunk as Memory
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||
|
||||
__all__ = [
|
||||
# 核心数据结构
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统适配器
|
||||
将增强记忆系统集成到现有MoFox Bot架构中
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
|
||||
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -47,10 +47,10 @@ class AdapterConfig:
|
||||
class EnhancedMemoryAdapter:
|
||||
"""增强记忆系统适配器"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or AdapterConfig()
|
||||
self.integration_layer: Optional[MemoryIntegrationLayer] = None
|
||||
self.integration_layer: MemoryIntegrationLayer | None = None
|
||||
self._initialized = False
|
||||
|
||||
# 统计信息
|
||||
@@ -96,7 +96,7 @@ class EnhancedMemoryAdapter:
|
||||
# 如果初始化失败,禁用增强记忆功能
|
||||
self.config.enable_enhanced_memory = False
|
||||
|
||||
async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""处理对话记忆,以上下文为唯一输入"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"success": False, "error": "Enhanced memory not available"}
|
||||
@@ -105,7 +105,7 @@ class EnhancedMemoryAdapter:
|
||||
self.adapter_stats["total_processed"] += 1
|
||||
|
||||
try:
|
||||
payload_context: Dict[str, Any] = dict(context or {})
|
||||
payload_context: dict[str, Any] = dict(context or {})
|
||||
|
||||
conversation_text = payload_context.get("conversation_text")
|
||||
if not conversation_text:
|
||||
@@ -146,8 +146,8 @@ class EnhancedMemoryAdapter:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return []
|
||||
@@ -166,7 +166,7 @@ class EnhancedMemoryAdapter:
|
||||
return []
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
|
||||
@@ -186,7 +186,7 @@ class EnhancedMemoryAdapter:
|
||||
|
||||
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
|
||||
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
|
||||
"""获取增强记忆系统摘要"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"available": False, "reason": "Not initialized or disabled"}
|
||||
@@ -222,7 +222,7 @@ class EnhancedMemoryAdapter:
|
||||
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
|
||||
self.adapter_stats["average_processing_time"] = new_avg
|
||||
|
||||
def get_adapter_stats(self) -> Dict[str, Any]:
|
||||
def get_adapter_stats(self) -> dict[str, Any]:
|
||||
"""获取适配器统计信息"""
|
||||
return self.adapter_stats.copy()
|
||||
|
||||
@@ -253,7 +253,7 @@ class EnhancedMemoryAdapter:
|
||||
|
||||
|
||||
# 全局适配器实例
|
||||
_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None
|
||||
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
|
||||
|
||||
|
||||
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
|
||||
@@ -292,8 +292,8 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
||||
|
||||
|
||||
async def process_conversation_with_enhanced_memory(
|
||||
context: Dict[str, Any], llm_model: Optional[LLMRequest] = None
|
||||
) -> Dict[str, Any]:
|
||||
context: dict[str, Any], llm_model: LLMRequest | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
@@ -323,10 +323,10 @@ async def process_conversation_with_enhanced_memory(
|
||||
async def retrieve_memories_with_enhanced_system(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
) -> List[MemoryChunk]:
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""使用增强记忆系统检索记忆"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
@@ -345,9 +345,9 @@ async def retrieve_memories_with_enhanced_system(
|
||||
async def get_memory_context_for_prompt(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
max_memories: int = 5,
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
if not llm_model:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统钩子
|
||||
用于在消息处理过程中自动构建和检索记忆
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,7 +27,7 @@ class EnhancedMemoryHooks:
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理消息并构建记忆
|
||||
@@ -106,8 +106,8 @@ class EnhancedMemoryHooks:
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5,
|
||||
extra_context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
extra_context: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统集成脚本
|
||||
用于在现有系统中无缝集成增强记忆功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def process_user_message_memory(
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
处理用户消息并构建记忆
|
||||
@@ -44,8 +44,8 @@ async def process_user_message_memory(
|
||||
|
||||
|
||||
async def get_relevant_memories_for_response(
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
@@ -74,7 +74,7 @@ async def get_relevant_memories_for_response(
|
||||
return {"has_memories": False, "memories": [], "memory_count": 0}
|
||||
|
||||
|
||||
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
|
||||
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化记忆信息用于Prompt
|
||||
|
||||
@@ -114,7 +114,7 @@ async def cleanup_memory_system():
|
||||
logger.error(f"记忆系统清理失败: {e}")
|
||||
|
||||
|
||||
def get_memory_system_status() -> Dict[str, Any]:
|
||||
def get_memory_system_status() -> dict[str, Any]:
|
||||
"""
|
||||
获取记忆系统状态
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_memory_system_status() -> Dict[str, Any]:
|
||||
|
||||
# 便捷函数
|
||||
async def remember_message(
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
便捷的记忆构建函数
|
||||
@@ -159,8 +159,8 @@ async def recall_memories(
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
limit: int = 5,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
便捷的记忆检索函数
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强重排序器
|
||||
实现文档设计的多维度评分模型
|
||||
@@ -6,12 +5,12 @@
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -44,7 +43,7 @@ class ReRankingConfig:
|
||||
freq_max_score: float = 5.0 # 最大频率得分
|
||||
|
||||
# 类型匹配权重映射
|
||||
type_match_weights: Dict[str, Dict[str, float]] = None
|
||||
type_match_weights: dict[str, dict[str, float]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化类型匹配权重"""
|
||||
@@ -157,7 +156,7 @@ class IntentClassifier:
|
||||
],
|
||||
}
|
||||
|
||||
def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType:
|
||||
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
|
||||
"""识别对话意图"""
|
||||
if not query:
|
||||
return IntentType.UNKNOWN
|
||||
@@ -165,7 +164,7 @@ class IntentClassifier:
|
||||
query_lower = query.lower()
|
||||
|
||||
# 统计各意图的匹配分数
|
||||
intent_scores = {intent: 0 for intent in IntentType}
|
||||
intent_scores = dict.fromkeys(IntentType, 0)
|
||||
|
||||
for intent, patterns in self.patterns.items():
|
||||
for pattern in patterns:
|
||||
@@ -187,7 +186,7 @@ class IntentClassifier:
|
||||
class EnhancedReRanker:
|
||||
"""增强重排序器 - 实现文档设计的多维度评分模型"""
|
||||
|
||||
def __init__(self, config: Optional[ReRankingConfig] = None):
|
||||
def __init__(self, config: ReRankingConfig | None = None):
|
||||
self.config = config or ReRankingConfig()
|
||||
self.intent_classifier = IntentClassifier()
|
||||
|
||||
@@ -210,10 +209,10 @@ class EnhancedReRanker:
|
||||
def rerank_memories(
|
||||
self,
|
||||
query: str,
|
||||
candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: Dict[str, Any],
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
) -> List[Tuple[str, MemoryChunk, float]]:
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
对候选记忆进行重排序
|
||||
|
||||
@@ -341,11 +340,11 @@ default_reranker = EnhancedReRanker()
|
||||
|
||||
def rerank_candidate_memories(
|
||||
query: str,
|
||||
candidate_memories: List[Tuple[str, MemoryChunk, float]],
|
||||
context: Dict[str, Any],
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]],
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
config: Optional[ReRankingConfig] = None,
|
||||
) -> List[Tuple[str, MemoryChunk, float]]:
|
||||
config: ReRankingConfig | None = None,
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
便捷函数:对候选记忆进行重排序
|
||||
"""
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统集成层
|
||||
现在只管理新的增强记忆系统,旧系统已被完全移除
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -40,12 +40,12 @@ class IntegrationConfig:
|
||||
class MemoryIntegrationLayer:
|
||||
"""记忆系统集成层 - 现在只管理增强记忆系统"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or IntegrationConfig()
|
||||
|
||||
# 只初始化增强记忆系统
|
||||
self.enhanced_memory: Optional[EnhancedMemorySystem] = None
|
||||
self.enhanced_memory: EnhancedMemorySystem | None = None
|
||||
|
||||
# 集成统计
|
||||
self.integration_stats = {
|
||||
@@ -113,7 +113,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""处理对话记忆,仅使用上下文信息"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return {"success": False, "error": "Memory system not available"}
|
||||
@@ -150,10 +150,10 @@ class MemoryIntegrationLayer:
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[MemoryChunk]:
|
||||
user_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return []
|
||||
@@ -172,7 +172,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_system_status(self) -> Dict[str, Any]:
|
||||
async def get_system_status(self) -> dict[str, Any]:
|
||||
"""获取系统状态"""
|
||||
if not self._initialized:
|
||||
return {"status": "not_initialized"}
|
||||
@@ -193,7 +193,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"获取系统状态失败: {e}", exc_info=True)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_integration_stats(self) -> Dict[str, Any]:
|
||||
def get_integration_stats(self) -> dict[str, Any]:
|
||||
"""获取集成统计信息"""
|
||||
return self.integration_stats.copy()
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统集成钩子
|
||||
提供与现有MoFox Bot系统的无缝集成点
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_adapter import (
|
||||
get_memory_context_for_prompt,
|
||||
process_conversation_with_enhanced_memory,
|
||||
retrieve_memories_with_enhanced_system,
|
||||
get_memory_context_for_prompt,
|
||||
)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class HookResult:
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
|
||||
@@ -125,8 +125,8 @@ class MemoryIntegrationHooks:
|
||||
|
||||
# 尝试注册到事件系统
|
||||
try:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
# 注册消息后处理事件
|
||||
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
|
||||
@@ -238,11 +238,11 @@ class MemoryIntegrationHooks:
|
||||
|
||||
# 钩子处理器方法
|
||||
|
||||
async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
|
||||
"""事件系统的消息处理处理器"""
|
||||
return await self._on_message_processed_hook(event_data)
|
||||
|
||||
async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
|
||||
"""消息后处理钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -289,7 +289,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
|
||||
"""聊天流保存钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -345,7 +345,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
|
||||
"""回复前钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -380,7 +380,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
|
||||
"""知识库查询钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -411,7 +411,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
|
||||
"""提示词构建钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -459,7 +459,7 @@ class MemoryIntegrationHooks:
|
||||
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
|
||||
self.hook_stats["average_hook_time"] = new_avg
|
||||
|
||||
def get_hook_stats(self) -> Dict[str, Any]:
|
||||
def get_hook_stats(self) -> dict[str, Any]:
|
||||
"""获取钩子统计信息"""
|
||||
return self.hook_stats.copy()
|
||||
|
||||
@@ -501,7 +501,7 @@ class MemoryMaintenanceTask:
|
||||
|
||||
|
||||
# 全局钩子实例
|
||||
_memory_hooks: Optional[MemoryIntegrationHooks] = None
|
||||
_memory_hooks: MemoryIntegrationHooks | None = None
|
||||
|
||||
|
||||
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
元数据索引系统
|
||||
为记忆系统提供多维度的精准过滤和查询能力
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any, Union
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -40,21 +40,21 @@ class IndexType(Enum):
|
||||
class IndexQuery:
|
||||
"""索引查询条件"""
|
||||
|
||||
user_ids: Optional[List[str]] = None
|
||||
memory_types: Optional[List[MemoryType]] = None
|
||||
subjects: Optional[List[str]] = None
|
||||
keywords: Optional[List[str]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
categories: Optional[List[str]] = None
|
||||
time_range: Optional[Tuple[float, float]] = None
|
||||
confidence_levels: Optional[List[ConfidenceLevel]] = None
|
||||
importance_levels: Optional[List[ImportanceLevel]] = None
|
||||
min_relationship_score: Optional[float] = None
|
||||
max_relationship_score: Optional[float] = None
|
||||
min_access_count: Optional[int] = None
|
||||
semantic_hashes: Optional[List[str]] = None
|
||||
limit: Optional[int] = None
|
||||
sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score"
|
||||
user_ids: list[str] | None = None
|
||||
memory_types: list[MemoryType] | None = None
|
||||
subjects: list[str] | None = None
|
||||
keywords: list[str] | None = None
|
||||
tags: list[str] | None = None
|
||||
categories: list[str] | None = None
|
||||
time_range: tuple[float, float] | None = None
|
||||
confidence_levels: list[ConfidenceLevel] | None = None
|
||||
importance_levels: list[ImportanceLevel] | None = None
|
||||
min_relationship_score: float | None = None
|
||||
max_relationship_score: float | None = None
|
||||
min_access_count: int | None = None
|
||||
semantic_hashes: list[str] | None = None
|
||||
limit: int | None = None
|
||||
sort_by: str | None = None # "created_at", "access_count", "relevance_score"
|
||||
sort_order: str = "desc" # "asc", "desc"
|
||||
|
||||
|
||||
@@ -62,10 +62,10 @@ class IndexQuery:
|
||||
class IndexResult:
|
||||
"""索引结果"""
|
||||
|
||||
memory_ids: List[str]
|
||||
memory_ids: list[str]
|
||||
total_count: int
|
||||
query_time: float
|
||||
filtered_by: List[str]
|
||||
filtered_by: list[str]
|
||||
|
||||
|
||||
class MetadataIndexManager:
|
||||
@@ -94,7 +94,7 @@ class MetadataIndexManager:
|
||||
self.access_frequency_index = [] # [(access_count, memory_id), ...]
|
||||
|
||||
# 内存缓存
|
||||
self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.memory_metadata_cache: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.index_stats = {
|
||||
@@ -140,7 +140,7 @@ class MetadataIndexManager:
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _serialize_metadata_entry(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
serialized = {}
|
||||
for field_name, value in metadata.items():
|
||||
if isinstance(value, Enum):
|
||||
@@ -149,7 +149,7 @@ class MetadataIndexManager:
|
||||
serialized[field_name] = value
|
||||
return serialized
|
||||
|
||||
async def index_memories(self, memories: List[MemoryChunk]):
|
||||
async def index_memories(self, memories: list[MemoryChunk]):
|
||||
"""为记忆建立索引"""
|
||||
if not memories:
|
||||
return
|
||||
@@ -375,7 +375,7 @@ class MetadataIndexManager:
|
||||
logger.error(f"❌ 元数据查询失败: {e}", exc_info=True)
|
||||
return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[])
|
||||
|
||||
def _get_candidate_memories(self, query: IndexQuery) -> Set[str]:
|
||||
def _get_candidate_memories(self, query: IndexQuery) -> set[str]:
|
||||
"""获取候选记忆ID集合"""
|
||||
candidate_ids = set()
|
||||
|
||||
@@ -444,7 +444,7 @@ class MetadataIndexManager:
|
||||
|
||||
return candidate_ids
|
||||
|
||||
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
|
||||
def _collect_index_matches(self, index_type: IndexType, token: str | Enum | None) -> set[str]:
|
||||
"""根据给定token收集索引匹配,支持部分匹配"""
|
||||
mapping = self.indices.get(index_type)
|
||||
if mapping is None:
|
||||
@@ -461,7 +461,7 @@ class MetadataIndexManager:
|
||||
if not key:
|
||||
return set()
|
||||
|
||||
matches: Set[str] = set(mapping.get(key, set()))
|
||||
matches: set[str] = set(mapping.get(key, set()))
|
||||
|
||||
if matches:
|
||||
return set(matches)
|
||||
@@ -477,7 +477,7 @@ class MetadataIndexManager:
|
||||
|
||||
return matches
|
||||
|
||||
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
|
||||
def _apply_filters(self, candidate_ids: set[str], query: IndexQuery) -> list[str]:
|
||||
"""应用过滤条件"""
|
||||
filtered_ids = list(candidate_ids)
|
||||
|
||||
@@ -545,7 +545,7 @@ class MetadataIndexManager:
|
||||
created_at = self.memory_metadata_cache[memory_id]["created_at"]
|
||||
return start_time <= created_at <= end_time
|
||||
|
||||
def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]:
|
||||
def _sort_memories(self, memory_ids: list[str], sort_by: str, sort_order: str) -> list[str]:
|
||||
"""对记忆进行排序"""
|
||||
if sort_by == "created_at":
|
||||
# 使用时间索引(已经有序)
|
||||
@@ -582,7 +582,7 @@ class MetadataIndexManager:
|
||||
|
||||
return memory_ids
|
||||
|
||||
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
|
||||
def _get_applied_filters(self, query: IndexQuery) -> list[str]:
|
||||
"""获取应用的过滤器列表"""
|
||||
filters = []
|
||||
if query.memory_types:
|
||||
@@ -686,11 +686,11 @@ class MetadataIndexManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 移除记忆索引失败: {e}")
|
||||
|
||||
async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_memory_metadata(self, memory_id: str) -> dict[str, Any] | None:
|
||||
"""获取记忆元数据"""
|
||||
return self.memory_metadata_cache.get(memory_id)
|
||||
|
||||
async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]:
|
||||
async def get_user_memory_ids(self, user_id: str, limit: int | None = None) -> list[str]:
|
||||
"""获取用户的所有记忆ID"""
|
||||
user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set()))
|
||||
|
||||
@@ -699,7 +699,7 @@ class MetadataIndexManager:
|
||||
|
||||
return user_memory_ids
|
||||
|
||||
async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def get_memory_statistics(self, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""获取记忆统计信息"""
|
||||
stats = {
|
||||
"total_memories": self.index_stats["total_memories"],
|
||||
@@ -784,7 +784,7 @@ class MetadataIndexManager:
|
||||
logger.info("正在保存元数据索引...")
|
||||
|
||||
# 保存各类索引
|
||||
indices_data: Dict[str, Dict[str, List[str]]] = {}
|
||||
indices_data: dict[str, dict[str, list[str]]] = {}
|
||||
for index_type, index_data in self.indices.items():
|
||||
serialized_index = {}
|
||||
for key, values in index_data.items():
|
||||
@@ -839,7 +839,7 @@ class MetadataIndexManager:
|
||||
# 加载各类索引
|
||||
indices_file = self.index_path / "indices.json"
|
||||
if indices_file.exists():
|
||||
with open(indices_file, "r", encoding="utf-8") as f:
|
||||
with open(indices_file, encoding="utf-8") as f:
|
||||
indices_data = orjson.loads(f.read())
|
||||
|
||||
for index_type_value, index_data in indices_data.items():
|
||||
@@ -853,25 +853,25 @@ class MetadataIndexManager:
|
||||
# 加载时间索引
|
||||
time_index_file = self.index_path / "time_index.json"
|
||||
if time_index_file.exists():
|
||||
with open(time_index_file, "r", encoding="utf-8") as f:
|
||||
with open(time_index_file, encoding="utf-8") as f:
|
||||
self.time_index = orjson.loads(f.read())
|
||||
|
||||
# 加载关系分索引
|
||||
relationship_index_file = self.index_path / "relationship_index.json"
|
||||
if relationship_index_file.exists():
|
||||
with open(relationship_index_file, "r", encoding="utf-8") as f:
|
||||
with open(relationship_index_file, encoding="utf-8") as f:
|
||||
self.relationship_index = orjson.loads(f.read())
|
||||
|
||||
# 加载访问频率索引
|
||||
access_frequency_index_file = self.index_path / "access_frequency_index.json"
|
||||
if access_frequency_index_file.exists():
|
||||
with open(access_frequency_index_file, "r", encoding="utf-8") as f:
|
||||
with open(access_frequency_index_file, encoding="utf-8") as f:
|
||||
self.access_frequency_index = orjson.loads(f.read())
|
||||
|
||||
# 加载元数据缓存
|
||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||
if metadata_cache_file.exists():
|
||||
with open(metadata_cache_file, "r", encoding="utf-8") as f:
|
||||
with open(metadata_cache_file, encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
# 转换置信度和重要性为枚举类型
|
||||
@@ -914,7 +914,7 @@ class MetadataIndexManager:
|
||||
# 加载统计信息
|
||||
stats_file = self.index_path / "index_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
with open(stats_file, encoding="utf-8") as f:
|
||||
self.index_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新记忆计数
|
||||
@@ -1004,7 +1004,7 @@ class MetadataIndexManager:
|
||||
if len(self.indices[IndexType.CATEGORY][category]) < min_frequency:
|
||||
del self.indices[IndexType.CATEGORY][category]
|
||||
|
||||
def get_index_stats(self) -> Dict[str, Any]:
|
||||
def get_index_stats(self) -> dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
stats = self.index_stats.copy()
|
||||
if stats["total_queries"] > 0:
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
多阶段召回机制
|
||||
实现粗粒度到细粒度的记忆检索优化
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import orjson
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
import orjson
|
||||
from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,11 +73,11 @@ class StageResult:
|
||||
"""阶段结果"""
|
||||
|
||||
stage: RetrievalStage
|
||||
memory_ids: List[str]
|
||||
memory_ids: list[str]
|
||||
processing_time: float
|
||||
filtered_count: int
|
||||
score_threshold: float
|
||||
details: List[Dict[str, Any]] = field(default_factory=list)
|
||||
details: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,17 +86,17 @@ class RetrievalResult:
|
||||
|
||||
query: str
|
||||
user_id: str
|
||||
final_memories: List[MemoryChunk]
|
||||
stage_results: List[StageResult]
|
||||
final_memories: list[MemoryChunk]
|
||||
stage_results: list[StageResult]
|
||||
total_processing_time: float
|
||||
total_filtered: int
|
||||
retrieval_stats: Dict[str, Any]
|
||||
retrieval_stats: dict[str, Any]
|
||||
|
||||
|
||||
class MultiStageRetrieval:
|
||||
"""多阶段召回系统"""
|
||||
|
||||
def __init__(self, config: Optional[RetrievalConfig] = None):
|
||||
def __init__(self, config: RetrievalConfig | None = None):
|
||||
self.config = config or RetrievalConfig.from_global_config()
|
||||
|
||||
# 初始化增强重排序器
|
||||
@@ -124,11 +124,11 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
metadata_index,
|
||||
vector_storage,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
limit: Optional[int] = None,
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult:
|
||||
"""多阶段记忆检索"""
|
||||
start_time = time.time()
|
||||
@@ -136,7 +136,7 @@ class MultiStageRetrieval:
|
||||
|
||||
stage_results = []
|
||||
current_memory_ids = set()
|
||||
memory_debug_info: Dict[str, Dict[str, Any]] = {}
|
||||
memory_debug_info: dict[str, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
logger.debug(f"开始多阶段检索:query='{query}', user_id='{user_id}'")
|
||||
@@ -311,11 +311,11 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
metadata_index,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段1:元数据过滤"""
|
||||
start_time = time.time()
|
||||
@@ -345,7 +345,7 @@ class MultiStageRetrieval:
|
||||
result = await metadata_index.query_memories(index_query)
|
||||
result_ids = list(result.memory_ids)
|
||||
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
|
||||
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
|
||||
if not result_ids:
|
||||
@@ -440,12 +440,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
vector_storage,
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
candidate_ids: set[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段2:向量搜索"""
|
||||
start_time = time.time()
|
||||
@@ -479,8 +479,8 @@ class MultiStageRetrieval:
|
||||
|
||||
# 过滤候选记忆
|
||||
filtered_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
raw_details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
raw_details: list[dict[str, Any]] = []
|
||||
threshold = self.config.vector_similarity_threshold
|
||||
|
||||
for memory_id, similarity in search_result:
|
||||
@@ -561,7 +561,7 @@ class MultiStageRetrieval:
|
||||
)
|
||||
|
||||
def _create_text_search_fallback(
|
||||
self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float
|
||||
self, candidate_ids: set[str], all_memories_cache: dict[str, MemoryChunk], query_text: str, start_time: float
|
||||
) -> StageResult:
|
||||
"""当向量搜索失败时,使用文本搜索作为回退策略"""
|
||||
try:
|
||||
@@ -618,18 +618,18 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: set[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段3:语义重排序"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
reranked_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
threshold = self.config.semantic_similarity_threshold
|
||||
|
||||
for memory_id in candidate_ids:
|
||||
@@ -704,19 +704,19 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: List[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: list[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段4:上下文过滤"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
final_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
|
||||
for memory_id in candidate_ids:
|
||||
if memory_id not in all_memories_cache:
|
||||
@@ -793,12 +793,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
excluded_ids: Optional[Set[str]] = None,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
excluded_ids: set[str] | None = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""回退检索阶段 - 当主检索失败时使用更宽松的策略"""
|
||||
start_time = time.time()
|
||||
@@ -881,8 +881,8 @@ class MultiStageRetrieval:
|
||||
)
|
||||
|
||||
async def _generate_query_embedding(
|
||||
self, query: str, context: Dict[str, Any], vector_storage
|
||||
) -> Optional[List[float]]:
|
||||
self, query: str, context: dict[str, Any], vector_storage
|
||||
) -> list[float] | None:
|
||||
"""生成查询向量"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -916,7 +916,7 @@ class MultiStageRetrieval:
|
||||
logger.error(f"生成查询向量时发生异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""计算语义相似度 - 简化优化版本,提升召回率"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -947,9 +947,10 @@ class MultiStageRetrieval:
|
||||
# 核心匹配策略2:词汇匹配
|
||||
word_score = 0.0
|
||||
try:
|
||||
import jieba
|
||||
import re
|
||||
|
||||
import jieba
|
||||
|
||||
# 分词处理
|
||||
query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text)
|
||||
memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text)
|
||||
@@ -1059,7 +1060,7 @@ class MultiStageRetrieval:
|
||||
logger.warning(f"计算语义相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""计算上下文相关度"""
|
||||
try:
|
||||
score = 0.0
|
||||
@@ -1132,7 +1133,7 @@ class MultiStageRetrieval:
|
||||
return 0.0
|
||||
|
||||
async def _calculate_final_score(
|
||||
self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float
|
||||
self, query: str, memory: MemoryChunk, context: dict[str, Any], context_score: float
|
||||
) -> float:
|
||||
"""计算最终评分"""
|
||||
try:
|
||||
@@ -1184,7 +1185,7 @@ class MultiStageRetrieval:
|
||||
logger.warning(f"计算最终评分失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
|
||||
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: list[str] | None) -> float:
|
||||
if not required_subjects:
|
||||
return 0.0
|
||||
|
||||
@@ -1229,7 +1230,7 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return 0.5
|
||||
|
||||
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
|
||||
def _extract_memory_types_from_context(self, context: dict[str, Any]) -> list[MemoryType]:
|
||||
"""从上下文中提取记忆类型"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -1256,10 +1257,10 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
|
||||
def _extract_keywords_from_query(self, query: str, query_plan: Any | None = None) -> list[str]:
|
||||
"""从查询中提取关键词"""
|
||||
try:
|
||||
extracted: List[str] = []
|
||||
extracted: list[str] = []
|
||||
|
||||
if query_plan and getattr(query_plan, "required_keywords", None):
|
||||
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
|
||||
@@ -1283,7 +1284,7 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]):
|
||||
def _update_retrieval_stats(self, total_time: float, stage_results: list[StageResult]):
|
||||
"""更新检索统计"""
|
||||
self.retrieval_stats["total_queries"] += 1
|
||||
|
||||
@@ -1306,7 +1307,7 @@ class MultiStageRetrieval:
|
||||
]
|
||||
stage_stat["avg_time"] = new_stage_avg
|
||||
|
||||
def get_retrieval_stats(self) -> Dict[str, Any]:
|
||||
def get_retrieval_stats(self) -> dict[str, Any]:
|
||||
"""获取检索统计信息"""
|
||||
return self.retrieval_stats.copy()
|
||||
|
||||
@@ -1328,12 +1329,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: List[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: list[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段5:增强重排序 - 使用多维度评分模型"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量数据库存储接口
|
||||
为记忆系统提供高效的向量存储和语义搜索能力
|
||||
"""
|
||||
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -48,7 +47,7 @@ class VectorStorageConfig:
|
||||
class VectorStorageManager:
|
||||
"""向量存储管理器"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
|
||||
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
|
||||
@@ -68,8 +67,8 @@ class VectorStorageManager:
|
||||
self.index_to_memory_id = {} # vector index -> memory_id
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: Dict[str, List[float]] = {}
|
||||
self.memory_cache: dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: dict[str, list[float]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.storage_stats = {
|
||||
@@ -125,7 +124,7 @@ class VectorStorageManager:
|
||||
)
|
||||
logger.info("✅ 嵌入模型初始化完成")
|
||||
|
||||
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
|
||||
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
|
||||
"""生成查询向量,用于记忆召回"""
|
||||
if not query_text:
|
||||
logger.warning("查询文本为空,无法生成向量")
|
||||
@@ -155,7 +154,7 @@ class VectorStorageManager:
|
||||
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]):
|
||||
async def store_memories(self, memories: list[MemoryChunk]):
|
||||
"""存储记忆向量"""
|
||||
if not memories:
|
||||
return
|
||||
@@ -231,7 +230,7 @@ class VectorStorageManager:
|
||||
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
|
||||
return memory.memory_id
|
||||
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]):
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
|
||||
"""批量生成和存储嵌入向量"""
|
||||
if not memory_texts:
|
||||
return
|
||||
@@ -253,12 +252,12 @@ class VectorStorageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
|
||||
async def _batch_generate_embeddings(self, memory_ids: List[str], texts: List[str]) -> Dict[str, List[float]]:
|
||||
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
|
||||
"""批量生成嵌入向量"""
|
||||
if not texts:
|
||||
return {}
|
||||
|
||||
results: Dict[str, List[float]] = {}
|
||||
results: dict[str, list[float]] = {}
|
||||
|
||||
try:
|
||||
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
|
||||
@@ -281,7 +280,9 @@ class VectorStorageManager:
|
||||
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
|
||||
results[memory_id] = []
|
||||
|
||||
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)]
|
||||
tasks = [
|
||||
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
@@ -291,7 +292,7 @@ class VectorStorageManager:
|
||||
|
||||
return results
|
||||
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]):
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
|
||||
"""添加单个记忆到向量存储"""
|
||||
with self._lock:
|
||||
try:
|
||||
@@ -337,7 +338,7 @@ class VectorStorageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
|
||||
|
||||
def _normalize_vector(self, vector: List[float]) -> List[float]:
|
||||
def _normalize_vector(self, vector: list[float]) -> list[float]:
|
||||
"""L2归一化向量"""
|
||||
if not vector:
|
||||
return vector
|
||||
@@ -357,12 +358,12 @@ class VectorStorageManager:
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
query_vector: list[float] | None = None,
|
||||
*,
|
||||
query_text: Optional[str] = None,
|
||||
query_text: str | None = None,
|
||||
limit: int = 10,
|
||||
scope_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, float]]:
|
||||
scope_id: str | None = None,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -379,7 +380,7 @@ class VectorStorageManager:
|
||||
logger.warning("查询向量生成失败")
|
||||
return []
|
||||
|
||||
scope_filter: Optional[str] = None
|
||||
scope_filter: str | None = None
|
||||
if isinstance(scope_id, str):
|
||||
normalized_scope = scope_id.strip().lower()
|
||||
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
|
||||
@@ -491,7 +492,7 @@ class VectorStorageManager:
|
||||
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""根据ID获取记忆"""
|
||||
# 先检查缓存
|
||||
if memory_id in self.memory_cache:
|
||||
@@ -501,7 +502,7 @@ class VectorStorageManager:
|
||||
self.storage_stats["total_searches"] += 1
|
||||
return None
|
||||
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]):
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
|
||||
"""更新记忆的嵌入向量"""
|
||||
with self._lock:
|
||||
try:
|
||||
@@ -636,7 +637,7 @@ class VectorStorageManager:
|
||||
# 加载记忆缓存
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
self.memory_cache = {
|
||||
@@ -646,13 +647,13 @@ class VectorStorageManager:
|
||||
# 加载向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
if vector_cache_file.exists():
|
||||
with open(vector_cache_file, "r", encoding="utf-8") as f:
|
||||
with open(vector_cache_file, encoding="utf-8") as f:
|
||||
self.vector_cache = orjson.loads(f.read())
|
||||
|
||||
# 加载映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
if mapping_file.exists():
|
||||
with open(mapping_file, "r", encoding="utf-8") as f:
|
||||
with open(mapping_file, encoding="utf-8") as f:
|
||||
mapping_data = orjson.loads(f.read())
|
||||
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.memory_id_to_index = {
|
||||
@@ -689,7 +690,7 @@ class VectorStorageManager:
|
||||
# 加载统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
with open(stats_file, encoding="utf-8") as f:
|
||||
self.storage_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新向量计数
|
||||
@@ -806,7 +807,7 @@ class VectorStorageManager:
|
||||
if invalid_memory_ids:
|
||||
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
stats = self.storage_stats.copy()
|
||||
if stats["total_searches"] > 0:
|
||||
@@ -821,11 +822,11 @@ class SimpleVectorIndex:
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
self.dimension = dimension
|
||||
self.vectors: List[List[float]] = []
|
||||
self.vector_ids: List[int] = []
|
||||
self.vectors: list[list[float]] = []
|
||||
self.vector_ids: list[int] = []
|
||||
self.next_id = 0
|
||||
|
||||
def add_vector(self, vector: List[float]) -> int:
|
||||
def add_vector(self, vector: list[float]) -> int:
|
||||
"""添加向量"""
|
||||
if len(vector) != self.dimension:
|
||||
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
|
||||
@@ -837,7 +838,7 @@ class SimpleVectorIndex:
|
||||
|
||||
return vector_id
|
||||
|
||||
def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]:
|
||||
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
|
||||
"""搜索相似向量"""
|
||||
if len(query_vector) != self.dimension:
|
||||
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
|
||||
@@ -853,7 +854,7 @@ class SimpleVectorIndex:
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
|
||||
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆激活器
|
||||
记忆系统的激活器组件
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
|
||||
from src.chat.memory_system.memory_manager import MemoryResult
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
def get_keywords_from_json(json_str) -> list:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
@@ -81,7 +80,7 @@ class MemoryActivator:
|
||||
self.cached_keywords = set() # 用于缓存历史关键词
|
||||
self.last_memory_query_time = 0 # 上次查询记忆的时间
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
@@ -155,7 +154,7 @@ class MemoryActivator:
|
||||
|
||||
return self.running_memory
|
||||
|
||||
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
|
||||
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
|
||||
"""查询统一记忆系统"""
|
||||
try:
|
||||
# 使用记忆系统
|
||||
@@ -198,7 +197,7 @@ class MemoryActivator:
|
||||
logger.error(f"查询统一记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
|
||||
"""
|
||||
获取即时记忆 - 兼容原有接口(使用统一存储)
|
||||
"""
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆激活器
|
||||
记忆系统的激活器组件
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
|
||||
from src.chat.memory_system.memory_manager import MemoryResult
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
def get_keywords_from_json(json_str) -> list:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
@@ -81,7 +80,7 @@ class MemoryActivator:
|
||||
self.cached_keywords = set() # 用于缓存历史关键词
|
||||
self.last_memory_query_time = 0 # 上次查询记忆的时间
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
@@ -155,7 +154,7 @@ class MemoryActivator:
|
||||
|
||||
return self.running_memory
|
||||
|
||||
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
|
||||
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
|
||||
"""查询统一记忆系统"""
|
||||
try:
|
||||
# 使用记忆系统
|
||||
@@ -198,7 +197,7 @@ class MemoryActivator:
|
||||
logger.error(f"查询统一记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
|
||||
"""
|
||||
获取即时记忆 - 兼容原有接口(使用统一存储)
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆构建模块
|
||||
从对话流中提取高质量、结构化记忆单元
|
||||
@@ -33,19 +32,19 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union, Type
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk,
|
||||
MemoryType,
|
||||
ConfidenceLevel,
|
||||
ImportanceLevel,
|
||||
MemoryChunk,
|
||||
MemoryType,
|
||||
create_memory_chunk,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -62,8 +61,8 @@ class ExtractionStrategy(Enum):
|
||||
class ExtractionResult:
|
||||
"""提取结果"""
|
||||
|
||||
memories: List[MemoryChunk]
|
||||
confidence_scores: List[float]
|
||||
memories: list[MemoryChunk]
|
||||
confidence_scores: list[float]
|
||||
extraction_time: float
|
||||
strategy_used: ExtractionStrategy
|
||||
|
||||
@@ -85,8 +84,8 @@ class MemoryBuilder:
|
||||
}
|
||||
|
||||
async def build_memories(
|
||||
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -116,8 +115,8 @@ class MemoryBuilder:
|
||||
raise
|
||||
|
||||
async def _extract_with_llm(
|
||||
self, text: str, context: Dict[str, Any], user_id: str, timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
self, text: str, context: dict[str, Any], user_id: str, timestamp: float
|
||||
) -> list[MemoryChunk]:
|
||||
"""使用LLM提取记忆"""
|
||||
try:
|
||||
prompt = self._build_llm_extraction_prompt(text, context)
|
||||
@@ -135,7 +134,7 @@ class MemoryBuilder:
|
||||
logger.error(f"LLM提取失败: {e}")
|
||||
raise MemoryExtractionError(str(e)) from e
|
||||
|
||||
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
|
||||
def _build_llm_extraction_prompt(self, text: str, context: dict[str, Any]) -> str:
|
||||
"""构建LLM提取提示"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
message_type = context.get("message_type", "normal")
|
||||
@@ -315,7 +314,7 @@ class MemoryBuilder:
|
||||
|
||||
return prompt
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
def _extract_json_payload(self, response: str) -> str | None:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
@@ -338,8 +337,8 @@ class MemoryBuilder:
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _parse_llm_response(
|
||||
self, response: str, user_id: str, timestamp: float, context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
self, response: str, user_id: str, timestamp: float, context: dict[str, Any]
|
||||
) -> list[MemoryChunk]:
|
||||
"""解析LLM响应"""
|
||||
if not response:
|
||||
raise MemoryExtractionError("LLM未返回任何响应")
|
||||
@@ -385,7 +384,7 @@ class MemoryBuilder:
|
||||
|
||||
bot_display = self._clean_subject_text(bot_display)
|
||||
|
||||
memories: List[MemoryChunk] = []
|
||||
memories: list[MemoryChunk] = []
|
||||
|
||||
for mem_data in memory_list:
|
||||
try:
|
||||
@@ -460,7 +459,7 @@ class MemoryBuilder:
|
||||
|
||||
return memories
|
||||
|
||||
def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
|
||||
def _parse_enum_value(self, enum_cls: type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
|
||||
"""解析枚举值,兼容数字/字符串表示"""
|
||||
if isinstance(raw_value, enum_cls):
|
||||
return raw_value
|
||||
@@ -514,7 +513,7 @@ class MemoryBuilder:
|
||||
)
|
||||
return default
|
||||
|
||||
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
def _collect_bot_identifiers(self, context: dict[str, Any] | None) -> set[str]:
|
||||
identifiers: set[str] = {"bot", "机器人", "ai助手"}
|
||||
if not context:
|
||||
return identifiers
|
||||
@@ -540,7 +539,7 @@ class MemoryBuilder:
|
||||
|
||||
return identifiers
|
||||
|
||||
def _collect_system_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
def _collect_system_identifiers(self, context: dict[str, Any] | None) -> set[str]:
|
||||
identifiers: set[str] = set()
|
||||
if not context:
|
||||
return identifiers
|
||||
@@ -568,8 +567,8 @@ class MemoryBuilder:
|
||||
|
||||
return identifiers
|
||||
|
||||
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
|
||||
participants: List[str] = []
|
||||
def _resolve_conversation_participants(self, context: dict[str, Any] | None, user_id: str) -> list[str]:
|
||||
participants: list[str] = []
|
||||
|
||||
if context:
|
||||
candidate_keys = [
|
||||
@@ -609,7 +608,7 @@ class MemoryBuilder:
|
||||
if not participants:
|
||||
participants = ["对话参与者"]
|
||||
|
||||
deduplicated: List[str] = []
|
||||
deduplicated: list[str] = []
|
||||
seen = set()
|
||||
for name in participants:
|
||||
key = name.lower()
|
||||
@@ -620,7 +619,7 @@ class MemoryBuilder:
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
|
||||
def _resolve_user_display(self, context: dict[str, Any] | None, user_id: str) -> str:
|
||||
candidate_keys = [
|
||||
"user_display_name",
|
||||
"user_name",
|
||||
@@ -683,7 +682,7 @@ class MemoryBuilder:
|
||||
|
||||
return False
|
||||
|
||||
def _split_subject_string(self, value: str) -> List[str]:
|
||||
def _split_subject_string(self, value: str) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
|
||||
@@ -699,12 +698,12 @@ class MemoryBuilder:
|
||||
subject: Any,
|
||||
bot_identifiers: set[str],
|
||||
system_identifiers: set[str],
|
||||
default_subjects: List[str],
|
||||
bot_display: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
default_subjects: list[str],
|
||||
bot_display: str | None = None,
|
||||
) -> list[str]:
|
||||
defaults = default_subjects or ["对话参与者"]
|
||||
|
||||
raw_candidates: List[str] = []
|
||||
raw_candidates: list[str] = []
|
||||
if isinstance(subject, list):
|
||||
for item in subject:
|
||||
if isinstance(item, str):
|
||||
@@ -716,7 +715,7 @@ class MemoryBuilder:
|
||||
elif subject is not None:
|
||||
raw_candidates.extend(self._split_subject_string(str(subject)))
|
||||
|
||||
normalized: List[str] = []
|
||||
normalized: list[str] = []
|
||||
bot_primary = self._clean_subject_text(bot_display or "")
|
||||
|
||||
for candidate in raw_candidates:
|
||||
@@ -741,7 +740,7 @@ class MemoryBuilder:
|
||||
if not normalized:
|
||||
normalized = list(defaults)
|
||||
|
||||
deduplicated: List[str] = []
|
||||
deduplicated: list[str] = []
|
||||
seen = set()
|
||||
for name in normalized:
|
||||
key = name.lower()
|
||||
@@ -752,7 +751,7 @@ class MemoryBuilder:
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
|
||||
def _extract_value_from_object(self, obj: str | dict[str, Any] | list[Any], keys: list[str]) -> str | None:
|
||||
if isinstance(obj, dict):
|
||||
for key in keys:
|
||||
value = obj.get(key)
|
||||
@@ -773,9 +772,7 @@ class MemoryBuilder:
|
||||
return obj.strip() or None
|
||||
return None
|
||||
|
||||
def _compose_display_text(
|
||||
self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]
|
||||
) -> str:
|
||||
def _compose_display_text(self, subjects: list[str], predicate: str, obj: str | dict[str, Any] | list[Any]) -> str:
|
||||
subject_phrase = "、".join(subjects) if subjects else "对话参与者"
|
||||
predicate = (predicate or "").strip()
|
||||
|
||||
@@ -841,7 +838,7 @@ class MemoryBuilder:
|
||||
return f"{subject_phrase}{predicate}".strip()
|
||||
return subject_phrase
|
||||
|
||||
def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]:
|
||||
def _validate_and_enhance_memories(self, memories: list[MemoryChunk], context: dict[str, Any]) -> list[MemoryChunk]:
|
||||
"""验证和增强记忆"""
|
||||
validated_memories = []
|
||||
|
||||
@@ -876,7 +873,7 @@ class MemoryBuilder:
|
||||
|
||||
return True
|
||||
|
||||
def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk:
|
||||
def _enhance_memory(self, memory: MemoryChunk, context: dict[str, Any]) -> MemoryChunk:
|
||||
"""增强记忆块"""
|
||||
# 时间规范化处理
|
||||
self._normalize_time_in_memory(memory)
|
||||
@@ -985,7 +982,7 @@ class MemoryBuilder:
|
||||
total_confidence / self.extraction_stats["successful_extractions"]
|
||||
)
|
||||
|
||||
def get_extraction_stats(self) -> Dict[str, Any]:
|
||||
def get_extraction_stats(self) -> dict[str, Any]:
|
||||
"""获取提取统计信息"""
|
||||
return self.extraction_stats.copy()
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
结构化记忆单元设计
|
||||
实现高质量、结构化的记忆单元,符合文档设计规范
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Any, Union, Iterable
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -56,17 +57,17 @@ class ImportanceLevel(Enum):
|
||||
class ContentStructure:
|
||||
"""主谓宾结构,包含自然语言描述"""
|
||||
|
||||
subject: Union[str, List[str]]
|
||||
subject: str | list[str]
|
||||
predicate: str
|
||||
object: Union[str, Dict]
|
||||
object: str | dict
|
||||
display: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ContentStructure":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
subject=data.get("subject", ""),
|
||||
@@ -75,7 +76,7 @@ class ContentStructure:
|
||||
display=data.get("display", ""),
|
||||
)
|
||||
|
||||
def to_subject_list(self) -> List[str]:
|
||||
def to_subject_list(self) -> list[str]:
|
||||
"""将主语转换为列表形式"""
|
||||
if isinstance(self.subject, list):
|
||||
return [s for s in self.subject if isinstance(s, str) and s.strip()]
|
||||
@@ -99,7 +100,7 @@ class MemoryMetadata:
|
||||
# 基础信息
|
||||
memory_id: str # 唯一标识符
|
||||
user_id: str # 用户ID
|
||||
chat_id: Optional[str] = None # 聊天ID(群聊或私聊)
|
||||
chat_id: str | None = None # 聊天ID(群聊或私聊)
|
||||
|
||||
# 时间信息
|
||||
created_at: float = 0.0 # 创建时间戳
|
||||
@@ -124,9 +125,9 @@ class MemoryMetadata:
|
||||
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
|
||||
|
||||
# 来源信息
|
||||
source_context: Optional[str] = None # 来源上下文片段
|
||||
source_context: str | None = None # 来源上下文片段
|
||||
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
|
||||
source: Optional[str] = None
|
||||
source: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -209,7 +210,7 @@ class MemoryMetadata:
|
||||
# 设置最小和最大阈值
|
||||
return max(7.0, min(threshold, 365.0)) # 7天到1年之间
|
||||
|
||||
def should_forget(self, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget(self, current_time: float | None = None) -> bool:
|
||||
"""判断是否应该遗忘"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
@@ -222,7 +223,7 @@ class MemoryMetadata:
|
||||
|
||||
return days_since_activation > self.forgetting_threshold
|
||||
|
||||
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
|
||||
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
|
||||
"""判断是否处于休眠状态(长期未激活)"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
@@ -230,7 +231,7 @@ class MemoryMetadata:
|
||||
days_since_last_access = (current_time - self.last_accessed) / 86400
|
||||
return days_since_last_access > inactive_days
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"memory_id": self.memory_id,
|
||||
@@ -252,7 +253,7 @@ class MemoryMetadata:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
memory_id=data.get("memory_id", ""),
|
||||
@@ -286,17 +287,17 @@ class MemoryChunk:
|
||||
memory_type: MemoryType # 记忆类型
|
||||
|
||||
# 扩展信息
|
||||
keywords: List[str] = field(default_factory=list) # 关键词列表
|
||||
tags: List[str] = field(default_factory=list) # 标签列表
|
||||
categories: List[str] = field(default_factory=list) # 分类列表
|
||||
keywords: list[str] = field(default_factory=list) # 关键词列表
|
||||
tags: list[str] = field(default_factory=list) # 标签列表
|
||||
categories: list[str] = field(default_factory=list) # 分类列表
|
||||
|
||||
# 语义信息
|
||||
embedding: Optional[List[float]] = None # 语义向量
|
||||
semantic_hash: Optional[str] = None # 语义哈希值
|
||||
embedding: list[float] | None = None # 语义向量
|
||||
semantic_hash: str | None = None # 语义哈希值
|
||||
|
||||
# 关联信息
|
||||
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
|
||||
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
|
||||
related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表
|
||||
temporal_context: dict[str, Any] | None = None # 时间上下文
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -310,7 +311,7 @@ class MemoryChunk:
|
||||
|
||||
try:
|
||||
# 使用向量和内容生成稳定的哈希
|
||||
content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}"
|
||||
content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}"
|
||||
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
|
||||
|
||||
hash_input = f"{content_str}|{embedding_str}"
|
||||
@@ -342,7 +343,7 @@ class MemoryChunk:
|
||||
return self.content.display or str(self.content)
|
||||
|
||||
@property
|
||||
def subjects(self) -> List[str]:
|
||||
def subjects(self) -> list[str]:
|
||||
"""获取主语列表"""
|
||||
return self.content.to_subject_list()
|
||||
|
||||
@@ -354,11 +355,11 @@ class MemoryChunk:
|
||||
"""更新相关度评分"""
|
||||
self.metadata.update_relevance(new_score)
|
||||
|
||||
def should_forget(self, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget(self, current_time: float | None = None) -> bool:
|
||||
"""判断是否应该遗忘"""
|
||||
return self.metadata.should_forget(current_time)
|
||||
|
||||
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
|
||||
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
|
||||
"""判断是否处于休眠状态(长期未激活)"""
|
||||
return self.metadata.is_dormant(current_time, inactive_days)
|
||||
|
||||
@@ -386,7 +387,7 @@ class MemoryChunk:
|
||||
if memory_id and memory_id not in self.related_memories:
|
||||
self.related_memories.append(memory_id)
|
||||
|
||||
def set_embedding(self, embedding: List[float]):
|
||||
def set_embedding(self, embedding: list[float]):
|
||||
"""设置语义向量"""
|
||||
self.embedding = embedding
|
||||
self._generate_semantic_hash()
|
||||
@@ -415,7 +416,7 @@ class MemoryChunk:
|
||||
logger.warning(f"计算记忆相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为完整的字典格式"""
|
||||
return {
|
||||
"metadata": self.metadata.to_dict(),
|
||||
@@ -431,7 +432,7 @@ class MemoryChunk:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk":
|
||||
"""从字典创建实例"""
|
||||
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
|
||||
content = ContentStructure.from_dict(data.get("content", {}))
|
||||
@@ -541,7 +542,7 @@ class MemoryChunk:
|
||||
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
|
||||
|
||||
|
||||
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
|
||||
def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str:
|
||||
"""根据主谓宾生成自然语言描述"""
|
||||
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
|
||||
subject_part = "、".join(subjects_clean) if subjects_clean else "对话参与者"
|
||||
@@ -569,15 +570,15 @@ def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str,
|
||||
|
||||
def create_memory_chunk(
|
||||
user_id: str,
|
||||
subject: Union[str, List[str]],
|
||||
subject: str | list[str],
|
||||
predicate: str,
|
||||
obj: Union[str, Dict],
|
||||
obj: str | dict,
|
||||
memory_type: MemoryType,
|
||||
chat_id: Optional[str] = None,
|
||||
source_context: Optional[str] = None,
|
||||
chat_id: str | None = None,
|
||||
source_context: str | None = None,
|
||||
importance: ImportanceLevel = ImportanceLevel.NORMAL,
|
||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
|
||||
display: Optional[str] = None,
|
||||
display: str | None = None,
|
||||
**kwargs,
|
||||
) -> MemoryChunk:
|
||||
"""便捷的内存块创建函数"""
|
||||
@@ -593,10 +594,10 @@ def create_memory_chunk(
|
||||
source_context=source_context,
|
||||
)
|
||||
|
||||
subjects: List[str]
|
||||
subjects: list[str]
|
||||
if isinstance(subject, list):
|
||||
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
|
||||
subject_payload: Union[str, List[str]] = subjects
|
||||
subject_payload: str | list[str] = subjects
|
||||
else:
|
||||
cleaned = subject.strip() if isinstance(subject, str) else ""
|
||||
subjects = [cleaned] if cleaned else []
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
智能记忆遗忘引擎
|
||||
基于重要程度、置信度和激活频率的智能遗忘机制
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ImportanceLevel, ConfidenceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -65,7 +63,7 @@ class ForgettingConfig:
|
||||
class MemoryForgettingEngine:
|
||||
"""智能记忆遗忘引擎"""
|
||||
|
||||
def __init__(self, config: Optional[ForgettingConfig] = None):
|
||||
def __init__(self, config: ForgettingConfig | None = None):
|
||||
self.config = config or ForgettingConfig()
|
||||
self.stats = ForgettingStats()
|
||||
self._last_forgetting_check = 0.0
|
||||
@@ -116,7 +114,7 @@ class MemoryForgettingEngine:
|
||||
# 确保在合理范围内
|
||||
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
|
||||
|
||||
def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断记忆是否应该被遗忘
|
||||
|
||||
@@ -155,7 +153,7 @@ class MemoryForgettingEngine:
|
||||
|
||||
return should_forget
|
||||
|
||||
def is_dormant_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断记忆是否处于休眠状态
|
||||
|
||||
@@ -168,7 +166,7 @@ class MemoryForgettingEngine:
|
||||
"""
|
||||
return memory.is_dormant(current_time, self.config.dormant_threshold_days)
|
||||
|
||||
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断是否应该强制遗忘休眠记忆
|
||||
|
||||
@@ -189,7 +187,7 @@ class MemoryForgettingEngine:
|
||||
days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400
|
||||
return days_since_last_access > self.config.force_forget_dormant_days
|
||||
|
||||
async def check_memories_for_forgetting(self, memories: List[MemoryChunk]) -> Tuple[List[str], List[str]]:
|
||||
async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
检查记忆列表,识别需要遗忘的记忆
|
||||
|
||||
@@ -241,7 +239,7 @@ class MemoryForgettingEngine:
|
||||
|
||||
return normal_forgetting_ids, force_forgetting_ids
|
||||
|
||||
async def perform_forgetting_check(self, memories: List[MemoryChunk]) -> Dict[str, any]:
|
||||
async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]:
|
||||
"""
|
||||
执行完整的遗忘检查流程
|
||||
|
||||
@@ -314,7 +312,7 @@ class MemoryForgettingEngine:
|
||||
except Exception as e:
|
||||
logger.error(f"定期遗忘检查失败: {e}", exc_info=True)
|
||||
|
||||
def get_forgetting_stats(self) -> Dict[str, any]:
|
||||
def get_forgetting_stats(self) -> dict[str, any]:
|
||||
"""获取遗忘统计信息"""
|
||||
return {
|
||||
"total_checked": self.stats.total_checked,
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆融合与去重机制
|
||||
避免记忆碎片化,确保长期记忆库的高质量
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -22,9 +20,9 @@ class FusionResult:
|
||||
original_count: int
|
||||
fused_count: int
|
||||
removed_duplicates: int
|
||||
merged_memories: List[MemoryChunk]
|
||||
merged_memories: list[MemoryChunk]
|
||||
fusion_time: float
|
||||
details: List[str]
|
||||
details: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -32,9 +30,9 @@ class DuplicateGroup:
|
||||
"""重复记忆组"""
|
||||
|
||||
group_id: str
|
||||
memories: List[MemoryChunk]
|
||||
similarity_matrix: List[List[float]]
|
||||
representative_memory: Optional[MemoryChunk] = None
|
||||
memories: list[MemoryChunk]
|
||||
similarity_matrix: list[list[float]]
|
||||
representative_memory: MemoryChunk | None = None
|
||||
|
||||
|
||||
class MemoryFusionEngine:
|
||||
@@ -59,8 +57,8 @@ class MemoryFusionEngine:
|
||||
}
|
||||
|
||||
async def fuse_memories(
|
||||
self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""融合记忆列表"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -106,8 +104,8 @@ class MemoryFusionEngine:
|
||||
return new_memories # 失败时返回原始记忆
|
||||
|
||||
async def _detect_duplicate_groups(
|
||||
self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk]
|
||||
) -> List[DuplicateGroup]:
|
||||
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk]
|
||||
) -> list[DuplicateGroup]:
|
||||
"""检测重复记忆组"""
|
||||
all_memories = new_memories + existing_memories
|
||||
new_memory_ids = {memory.memory_id for memory in new_memories}
|
||||
@@ -212,7 +210,7 @@ class MemoryFusionEngine:
|
||||
jaccard_similarity = len(intersection) / len(union)
|
||||
return jaccard_similarity
|
||||
|
||||
def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float:
|
||||
def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float:
|
||||
"""计算关键词相似度"""
|
||||
if not keywords1 or not keywords2:
|
||||
return 0.0
|
||||
@@ -302,7 +300,7 @@ class MemoryFusionEngine:
|
||||
|
||||
return best_memory
|
||||
|
||||
async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]:
|
||||
async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None:
|
||||
"""融合记忆组"""
|
||||
if not group.memories:
|
||||
return None
|
||||
@@ -328,7 +326,7 @@ class MemoryFusionEngine:
|
||||
# 返回置信度最高的记忆
|
||||
return max(group.memories, key=lambda m: m.metadata.confidence.value)
|
||||
|
||||
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk:
|
||||
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk:
|
||||
"""合并记忆属性"""
|
||||
# 创建基础记忆的深拷贝
|
||||
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
|
||||
@@ -395,7 +393,7 @@ class MemoryFusionEngine:
|
||||
source_ids = [m.memory_id[:8] for m in group.memories]
|
||||
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
|
||||
|
||||
def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]:
|
||||
def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]:
|
||||
"""合并时间上下文"""
|
||||
contexts = [m.temporal_context for m in memories if m.temporal_context]
|
||||
|
||||
@@ -426,8 +424,8 @@ class MemoryFusionEngine:
|
||||
return merged_context
|
||||
|
||||
async def incremental_fusion(
|
||||
self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk]
|
||||
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
|
||||
self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk]
|
||||
) -> tuple[MemoryChunk, list[MemoryChunk]]:
|
||||
"""增量融合(单个新记忆与现有记忆融合)"""
|
||||
# 寻找相似记忆
|
||||
similar_memories = []
|
||||
@@ -493,7 +491,7 @@ class MemoryFusionEngine:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
|
||||
|
||||
def get_fusion_stats(self) -> Dict[str, Any]:
|
||||
def get_fusion_stats(self) -> dict[str, Any]:
|
||||
"""获取融合统计信息"""
|
||||
return self.fusion_stats.copy()
|
||||
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统管理器
|
||||
替代原有的 Hippocampus 和 instant_memory 系统
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_system import initialize_memory_system
|
||||
from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,14 +25,14 @@ class MemoryResult:
|
||||
timestamp: float
|
||||
source: str = "memory"
|
||||
relevance_score: float = 0.0
|
||||
structure: Dict[str, Any] | None = None
|
||||
structure: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""记忆系统管理器 - 替代原有的 HippocampusManager"""
|
||||
|
||||
def __init__(self):
|
||||
self.memory_system: Optional[MemorySystem] = None
|
||||
self.memory_system: MemorySystem | None = None
|
||||
self.is_initialized = False
|
||||
self.user_cache = {} # 用户记忆缓存
|
||||
|
||||
@@ -63,8 +61,8 @@ class MemoryManager:
|
||||
logger.info("正在初始化记忆系统...")
|
||||
|
||||
# 获取LLM模型
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
|
||||
|
||||
@@ -121,7 +119,7 @@ class MemoryManager:
|
||||
max_memory_length: int = 2,
|
||||
time_weight: float = 1.0,
|
||||
keyword_weight: float = 1.0,
|
||||
) -> List[Tuple[str, str]]:
|
||||
) -> list[tuple[str, str]]:
|
||||
"""从文本获取相关记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -152,8 +150,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> List[Tuple[str, str]]:
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> list[tuple[str, str]]:
|
||||
"""从关键词获取记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -208,8 +206,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def process_conversation(
|
||||
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""处理对话并构建记忆 - 新增功能"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -235,8 +233,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def get_enhanced_memory_context(
|
||||
self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5
|
||||
) -> List[MemoryResult]:
|
||||
self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5
|
||||
) -> list[MemoryResult]:
|
||||
"""获取增强记忆上下文 - 新增功能"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -267,7 +265,7 @@ class MemoryManager:
|
||||
logger.error(f"get_enhanced_memory_context 失败: {e}")
|
||||
return []
|
||||
|
||||
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
|
||||
def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]:
|
||||
"""将记忆块转换为更易读的文本描述"""
|
||||
structure = memory.content.to_dict()
|
||||
if memory.display:
|
||||
@@ -289,7 +287,7 @@ class MemoryManager:
|
||||
|
||||
return formatted, structure
|
||||
|
||||
def _format_subject(self, subject: Optional[str], memory: MemoryChunk) -> str:
|
||||
def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str:
|
||||
if not subject:
|
||||
return "该用户"
|
||||
|
||||
@@ -299,7 +297,7 @@ class MemoryManager:
|
||||
return "该聊天"
|
||||
return self._clean_text(subject)
|
||||
|
||||
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> Optional[str]:
|
||||
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None:
|
||||
predicate = (predicate or "").strip()
|
||||
obj_value = obj
|
||||
|
||||
@@ -446,10 +444,10 @@ class MemoryManager:
|
||||
text = self._truncate(str(obj).strip())
|
||||
return self._clean_text(text)
|
||||
|
||||
def _extract_from_object(self, obj: Any, keys: List[str]) -> Optional[str]:
|
||||
def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None:
|
||||
if isinstance(obj, dict):
|
||||
for key in keys:
|
||||
if key in obj and obj[key]:
|
||||
if obj.get(key):
|
||||
value = obj[key]
|
||||
if isinstance(value, (dict, list)):
|
||||
return self._clean_text(self._format_object(value))
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆元数据索引管理器
|
||||
使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,10 +25,10 @@ class MemoryMetadataIndexEntry:
|
||||
|
||||
# 分类信息
|
||||
memory_type: str # MemoryType.value
|
||||
subjects: List[str] # 主语列表
|
||||
objects: List[str] # 宾语列表
|
||||
keywords: List[str] # 关键词列表
|
||||
tags: List[str] # 标签列表
|
||||
subjects: list[str] # 主语列表
|
||||
objects: list[str] # 宾语列表
|
||||
keywords: list[str] # 关键词列表
|
||||
tags: list[str] # 标签列表
|
||||
|
||||
# 数值字段(用于范围过滤)
|
||||
importance: int # ImportanceLevel.value (1-4)
|
||||
@@ -37,8 +37,8 @@ class MemoryMetadataIndexEntry:
|
||||
access_count: int # 访问次数
|
||||
|
||||
# 可选字段
|
||||
chat_id: Optional[str] = None
|
||||
content_preview: Optional[str] = None # 内容预览(前100字符)
|
||||
chat_id: str | None = None
|
||||
content_preview: str | None = None # 内容预览(前100字符)
|
||||
|
||||
|
||||
class MemoryMetadataIndex:
|
||||
@@ -46,13 +46,13 @@ class MemoryMetadataIndex:
|
||||
|
||||
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
||||
self.index_file = Path(index_file)
|
||||
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
|
||||
# 倒排索引(用于快速查找)
|
||||
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
|
||||
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
|
||||
|
||||
self.lock = threading.RLock()
|
||||
|
||||
@@ -178,7 +178,7 @@ class MemoryMetadataIndex:
|
||||
self._remove_from_inverted_indices(memory_id)
|
||||
del self.index[memory_id]
|
||||
|
||||
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
|
||||
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
|
||||
"""批量添加或更新"""
|
||||
with self.lock:
|
||||
for entry in entries:
|
||||
@@ -191,18 +191,18 @@ class MemoryMetadataIndex:
|
||||
|
||||
def search(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
keywords: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
importance_min: Optional[int] = None,
|
||||
importance_max: Optional[int] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
importance_min: int | None = None,
|
||||
importance_max: int | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
flexible_mode: bool = True, # 新增:灵活匹配模式
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
搜索符合条件的记忆ID列表(支持模糊匹配)
|
||||
|
||||
@@ -237,14 +237,14 @@ class MemoryMetadataIndex:
|
||||
|
||||
def _search_flexible(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
**kwargs, # 接受但不使用的参数
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
灵活搜索模式:2/4项匹配即可,支持部分匹配
|
||||
|
||||
@@ -374,20 +374,20 @@ class MemoryMetadataIndex:
|
||||
|
||||
def _search_strict(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
keywords: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
importance_min: Optional[int] = None,
|
||||
importance_max: Optional[int] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
importance_min: int | None = None,
|
||||
importance_max: int | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[str]:
|
||||
"""严格搜索模式(原有逻辑)"""
|
||||
# 初始候选集(所有记忆)
|
||||
candidate_ids: Optional[Set[str]] = None
|
||||
candidate_ids: set[str] | None = None
|
||||
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
@@ -471,11 +471,11 @@ class MemoryMetadataIndex:
|
||||
|
||||
return result_ids
|
||||
|
||||
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
|
||||
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
|
||||
"""获取单个索引条目"""
|
||||
return self.index.get(memory_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
with self.lock:
|
||||
return {
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""记忆检索查询规划器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
@@ -21,16 +20,16 @@ class MemoryQueryPlan:
|
||||
"""查询规划结果"""
|
||||
|
||||
semantic_query: str
|
||||
memory_types: List[MemoryType] = field(default_factory=list)
|
||||
subject_includes: List[str] = field(default_factory=list)
|
||||
object_includes: List[str] = field(default_factory=list)
|
||||
required_keywords: List[str] = field(default_factory=list)
|
||||
optional_keywords: List[str] = field(default_factory=list)
|
||||
owner_filters: List[str] = field(default_factory=list)
|
||||
memory_types: list[MemoryType] = field(default_factory=list)
|
||||
subject_includes: list[str] = field(default_factory=list)
|
||||
object_includes: list[str] = field(default_factory=list)
|
||||
required_keywords: list[str] = field(default_factory=list)
|
||||
optional_keywords: list[str] = field(default_factory=list)
|
||||
owner_filters: list[str] = field(default_factory=list)
|
||||
recency_preference: str = "any"
|
||||
limit: int = 10
|
||||
emphasis: Optional[str] = None
|
||||
raw_plan: Dict[str, Any] = field(default_factory=dict)
|
||||
emphasis: str | None = None
|
||||
raw_plan: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
|
||||
if not self.semantic_query:
|
||||
@@ -46,11 +45,11 @@ class MemoryQueryPlan:
|
||||
class MemoryQueryPlanner:
|
||||
"""基于小模型的记忆检索查询规划器"""
|
||||
|
||||
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
|
||||
def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10):
|
||||
self.model = planner_model
|
||||
self.default_limit = default_limit
|
||||
|
||||
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
|
||||
async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan:
|
||||
if not self.model:
|
||||
logger.debug("未提供查询规划模型,使用默认规划")
|
||||
return self._default_plan(query_text)
|
||||
@@ -82,10 +81,10 @@ class MemoryQueryPlanner:
|
||||
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
|
||||
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
|
||||
|
||||
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||
def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
|
||||
|
||||
def _collect_list(key: str) -> List[str]:
|
||||
def _collect_list(key: str) -> list[str]:
|
||||
value = data.get(key)
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
@@ -94,7 +93,7 @@ class MemoryQueryPlanner:
|
||||
return []
|
||||
|
||||
memory_type_values = _collect_list("memory_types")
|
||||
memory_types: List[MemoryType] = []
|
||||
memory_types: list[MemoryType] = []
|
||||
for item in memory_type_values:
|
||||
if not item:
|
||||
continue
|
||||
@@ -123,7 +122,7 @@ class MemoryQueryPlanner:
|
||||
)
|
||||
return plan
|
||||
|
||||
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
|
||||
def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str:
|
||||
participants = context.get("participants") or context.get("speaker_names") or []
|
||||
if isinstance(participants, str):
|
||||
participants = [participants]
|
||||
@@ -206,7 +205,7 @@ class MemoryQueryPlanner:
|
||||
请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。
|
||||
"""
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
def _extract_json_payload(self, response: str) -> str | None:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
精准记忆系统核心模块
|
||||
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
|
||||
@@ -6,26 +5,27 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import orjson
|
||||
import re
|
||||
import hashlib
|
||||
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
||||
import re
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -121,7 +121,7 @@ class MemorySystemConfig:
|
||||
class MemorySystem:
|
||||
"""精准记忆系统核心类"""
|
||||
|
||||
def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest | None = None, config: MemorySystemConfig | None = None):
|
||||
self.config = config or MemorySystemConfig.from_global_config()
|
||||
self.llm_model = llm_model
|
||||
self.status = MemorySystemStatus.INITIALIZING
|
||||
@@ -131,7 +131,7 @@ class MemorySystem:
|
||||
self.fusion_engine: MemoryFusionEngine = None
|
||||
self.unified_storage = None # 统一存储系统
|
||||
self.query_planner: MemoryQueryPlanner = None
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
|
||||
# LLM模型
|
||||
self.value_assessment_model: LLMRequest = None
|
||||
@@ -143,10 +143,10 @@ class MemorySystem:
|
||||
self.last_retrieval_time = None
|
||||
|
||||
# 构建节流记录
|
||||
self._last_memory_build_times: Dict[str, float] = {}
|
||||
self._last_memory_build_times: dict[str, float] = {}
|
||||
|
||||
# 记忆指纹缓存,用于快速检测重复记忆
|
||||
self._memory_fingerprints: Dict[str, str] = {}
|
||||
self._memory_fingerprints: dict[str, str] = {}
|
||||
|
||||
logger.info("MemorySystem 初始化开始")
|
||||
|
||||
@@ -210,7 +210,7 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
# 初始化遗忘引擎
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig
|
||||
from src.chat.memory_system.memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine
|
||||
|
||||
# 从全局配置创建遗忘引擎配置
|
||||
forgetting_config = ForgettingConfig(
|
||||
@@ -241,7 +241,7 @@ class MemorySystem:
|
||||
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
|
||||
|
||||
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
|
||||
planner_model: Optional[LLMRequest] = None
|
||||
planner_model: LLMRequest | None = None
|
||||
try:
|
||||
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
|
||||
except Exception as planner_exc:
|
||||
@@ -261,8 +261,8 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
async def retrieve_memories_for_building(
|
||||
self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5
|
||||
) -> List[MemoryChunk]:
|
||||
self, query_text: str, user_id: str | None = None, context: dict[str, Any] | None = None, limit: int = 5
|
||||
) -> list[MemoryChunk]:
|
||||
"""在构建记忆时检索相关记忆,使用统一存储系统
|
||||
|
||||
Args:
|
||||
@@ -302,8 +302,8 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
Args:
|
||||
@@ -318,8 +318,8 @@ class MemorySystem:
|
||||
self.status = MemorySystemStatus.BUILDING
|
||||
start_time = time.time()
|
||||
|
||||
build_scope_key: Optional[str] = None
|
||||
build_marker_time: Optional[float] = None
|
||||
build_scope_key: str | None = None
|
||||
build_marker_time: float | None = None
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
|
||||
@@ -408,7 +408,7 @@ class MemorySystem:
|
||||
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _log_memory_preview(self, memories: List[MemoryChunk]) -> None:
|
||||
def _log_memory_preview(self, memories: list[MemoryChunk]) -> None:
|
||||
"""在控制台输出记忆预览,便于人工检查"""
|
||||
if not memories:
|
||||
logger.info("📝 本次未生成新的记忆")
|
||||
@@ -425,12 +425,12 @@ class MemorySystem:
|
||||
f"置信度={memory.metadata.confidence.name} | 内容={text}"
|
||||
)
|
||||
|
||||
async def _collect_fusion_candidates(self, new_memories: List[MemoryChunk]) -> List[MemoryChunk]:
|
||||
async def _collect_fusion_candidates(self, new_memories: list[MemoryChunk]) -> list[MemoryChunk]:
|
||||
"""收集与新记忆相似的现有记忆,便于融合去重"""
|
||||
if not new_memories:
|
||||
return []
|
||||
|
||||
candidate_ids: Set[str] = set()
|
||||
candidate_ids: set[str] = set()
|
||||
new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)}
|
||||
|
||||
# 基于指纹的直接匹配
|
||||
@@ -493,7 +493,7 @@ class MemorySystem:
|
||||
continue
|
||||
candidate_ids.add(memory_id)
|
||||
|
||||
existing_candidates: List[MemoryChunk] = []
|
||||
existing_candidates: list[MemoryChunk] = []
|
||||
cache = self.unified_storage.memory_cache if self.unified_storage else {}
|
||||
for candidate_id in candidate_ids:
|
||||
if candidate_id in new_memory_ids:
|
||||
@@ -511,7 +511,7 @@ class MemorySystem:
|
||||
|
||||
return existing_candidates
|
||||
|
||||
async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -559,12 +559,12 @@ class MemorySystem:
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query_text: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
query_text: str | None = None,
|
||||
user_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int = 5,
|
||||
**kwargs,
|
||||
) -> List[MemoryChunk]:
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
|
||||
raw_query = query_text or kwargs.get("query")
|
||||
if not raw_query:
|
||||
@@ -750,7 +750,7 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_payload(response: str) -> Optional[str]:
|
||||
def _extract_json_payload(response: str) -> str | None:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
@@ -773,10 +773,10 @@ class MemorySystem:
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _normalize_context(
|
||||
self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float]
|
||||
) -> Dict[str, Any]:
|
||||
self, raw_context: dict[str, Any] | None, user_id: str | None, timestamp: float | None
|
||||
) -> dict[str, Any]:
|
||||
"""标准化上下文,确保必备字段存在且格式正确"""
|
||||
context: Dict[str, Any] = {}
|
||||
context: dict[str, Any] = {}
|
||||
if raw_context:
|
||||
try:
|
||||
context = dict(raw_context)
|
||||
@@ -822,7 +822,7 @@ class MemorySystem:
|
||||
|
||||
return context
|
||||
|
||||
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""构建包含未读消息综合上下文的增强查询上下文
|
||||
|
||||
Args:
|
||||
@@ -861,7 +861,7 @@ class MemorySystem:
|
||||
|
||||
return enhanced_context
|
||||
|
||||
async def _collect_unread_messages_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def _collect_unread_messages_context(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""收集未读消息的综合上下文信息
|
||||
|
||||
Args:
|
||||
@@ -953,7 +953,7 @@ class MemorySystem:
|
||||
logger.warning(f"收集未读消息上下文失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _build_unread_context_summary(self, messages_summary: List[Dict[str, Any]]) -> str:
|
||||
def _build_unread_context_summary(self, messages_summary: list[dict[str, Any]]) -> str:
|
||||
"""构建未读消息的文本摘要
|
||||
|
||||
Args:
|
||||
@@ -974,7 +974,7 @@ class MemorySystem:
|
||||
|
||||
return " | ".join(summary_parts)
|
||||
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: dict[str, Any] | None) -> str:
|
||||
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
|
||||
if not context:
|
||||
return fallback_text
|
||||
@@ -1043,11 +1043,11 @@ class MemorySystem:
|
||||
# 回退到传入文本
|
||||
return fallback_text
|
||||
|
||||
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
||||
def _get_build_scope_key(self, context: dict[str, Any], user_id: str | None) -> str | None:
|
||||
"""确定用于节流控制的记忆构建作用域"""
|
||||
return "global_scope"
|
||||
|
||||
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
|
||||
def _determine_history_limit(self, context: dict[str, Any]) -> int:
|
||||
"""确定历史消息获取数量,限制在30-50之间"""
|
||||
default_limit = 40
|
||||
candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
|
||||
@@ -1065,12 +1065,12 @@ class MemorySystem:
|
||||
|
||||
return history_limit
|
||||
|
||||
def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]:
|
||||
def _format_history_messages(self, messages: list["DatabaseMessages"]) -> str | None:
|
||||
"""将历史消息格式化为可供LLM处理的多轮对话文本"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
lines: List[str] = []
|
||||
lines: list[str] = []
|
||||
for msg in messages:
|
||||
try:
|
||||
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None)
|
||||
@@ -1105,7 +1105,7 @@ class MemorySystem:
|
||||
|
||||
return "\n".join(lines) if lines else None
|
||||
|
||||
async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float:
|
||||
async def _assess_information_value(self, text: str, context: dict[str, Any]) -> float:
|
||||
"""评估信息价值
|
||||
|
||||
Args:
|
||||
@@ -1201,7 +1201,7 @@ class MemorySystem:
|
||||
logger.error(f"信息价值评估失败: {e}", exc_info=True)
|
||||
return 0.5 # 默认中等价值
|
||||
|
||||
async def _store_memories_unified(self, memory_chunks: List[MemoryChunk]) -> int:
|
||||
async def _store_memories_unified(self, memory_chunks: list[MemoryChunk]) -> int:
|
||||
"""使用统一存储系统存储记忆块"""
|
||||
if not memory_chunks or not self.unified_storage:
|
||||
return 0
|
||||
@@ -1222,7 +1222,7 @@ class MemorySystem:
|
||||
return 0
|
||||
|
||||
# 保留原有方法以兼容旧代码
|
||||
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
|
||||
async def _store_memories(self, memory_chunks: list[MemoryChunk]) -> int:
|
||||
"""兼容性方法:重定向到统一存储"""
|
||||
return await self._store_memories_unified(memory_chunks)
|
||||
|
||||
@@ -1271,7 +1271,7 @@ class MemorySystem:
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
self._memory_fingerprints[key] = memory.memory_id
|
||||
|
||||
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
|
||||
def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None:
|
||||
for memory in memories:
|
||||
fingerprint = self._build_memory_fingerprint(memory)
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
@@ -1302,9 +1302,9 @@ class MemorySystem:
|
||||
|
||||
@staticmethod
|
||||
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
|
||||
return f"{str(user_id)}:{fingerprint}"
|
||||
return f"{user_id!s}:{fingerprint}"
|
||||
|
||||
def get_system_stats(self) -> Dict[str, Any]:
|
||||
def get_system_stats(self) -> dict[str, Any]:
|
||||
"""获取系统统计信息"""
|
||||
return {
|
||||
"status": self.status.value,
|
||||
@@ -1314,7 +1314,7 @@ class MemorySystem:
|
||||
"config": asdict(self.config),
|
||||
}
|
||||
|
||||
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""根据查询和上下文为记忆计算匹配分数"""
|
||||
tokens_query = self._tokenize_text(query_text)
|
||||
tokens_memory = self._tokenize_text(memory.text_content)
|
||||
@@ -1338,7 +1338,7 @@ class MemorySystem:
|
||||
final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost
|
||||
return max(0.0, min(1.0, final_score))
|
||||
|
||||
def _tokenize_text(self, text: str) -> Set[str]:
|
||||
def _tokenize_text(self, text: str) -> set[str]:
|
||||
"""简单分词,兼容中英文"""
|
||||
if not text:
|
||||
return set()
|
||||
@@ -1450,7 +1450,7 @@ def get_memory_system() -> MemorySystem:
|
||||
return memory_system
|
||||
|
||||
|
||||
async def initialize_memory_system(llm_model: Optional[LLMRequest] = None):
|
||||
async def initialize_memory_system(llm_model: LLMRequest | None = None):
|
||||
"""初始化全局记忆系统"""
|
||||
global memory_system
|
||||
if memory_system is None:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于Vector DB的统一记忆存储系统 V2
|
||||
使用ChromaDB作为底层存储,替代JSON存储方式
|
||||
@@ -11,20 +10,21 @@
|
||||
- 自动清理过期记忆
|
||||
"""
|
||||
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -32,7 +32,7 @@ logger = get_logger(__name__)
|
||||
_ENUM_MAPPINGS_CACHE = {}
|
||||
|
||||
|
||||
def _build_enum_mapping(enum_class: type) -> Dict[str, Any]:
|
||||
def _build_enum_mapping(enum_class: type) -> dict[str, Any]:
|
||||
"""构建枚举类的完整映射表
|
||||
|
||||
Args:
|
||||
@@ -145,7 +145,7 @@ class VectorMemoryStorage:
|
||||
|
||||
"""基于Vector DB的记忆存储系统"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
# 默认从全局配置读取,如果没有传入config
|
||||
if config is None:
|
||||
try:
|
||||
@@ -163,15 +163,15 @@ class VectorMemoryStorage:
|
||||
self.vector_db_service = vector_db_service
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.cache_timestamps: Dict[str, float] = {}
|
||||
self.memory_cache: dict[str, MemoryChunk] = {}
|
||||
self.cache_timestamps: dict[str, float] = {}
|
||||
self._cache = self.memory_cache # 别名,兼容旧代码
|
||||
|
||||
# 元数据索引管理器(JSON文件索引)
|
||||
self.metadata_index = MemoryMetadataIndex()
|
||||
|
||||
# 遗忘引擎
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
if self.config.enable_forgetting:
|
||||
self.forgetting_engine = MemoryForgettingEngine()
|
||||
|
||||
@@ -267,7 +267,7 @@ class VectorMemoryStorage:
|
||||
except Exception as e:
|
||||
logger.error(f"自动清理失败: {e}")
|
||||
|
||||
def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]:
|
||||
def _memory_to_vector_format(self, memory: MemoryChunk) -> dict[str, Any]:
|
||||
"""将MemoryChunk转换为向量存储格式"""
|
||||
try:
|
||||
# 获取memory_id
|
||||
@@ -323,7 +323,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
|
||||
def _vector_result_to_memory(self, document: str, metadata: dict[str, Any]) -> MemoryChunk | None:
|
||||
"""将Vector DB结果转换为MemoryChunk"""
|
||||
try:
|
||||
# 从元数据中恢复完整记忆
|
||||
@@ -440,7 +440,7 @@ class VectorMemoryStorage:
|
||||
logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值")
|
||||
return default
|
||||
|
||||
def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
def _get_from_cache(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""从缓存获取记忆"""
|
||||
if not self.config.enable_caching:
|
||||
return None
|
||||
@@ -472,7 +472,7 @@ class VectorMemoryStorage:
|
||||
self.memory_cache[memory_id] = memory
|
||||
self.cache_timestamps[memory_id] = time.time()
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
async def store_memories(self, memories: list[MemoryChunk]) -> int:
|
||||
"""批量存储记忆"""
|
||||
if not memories:
|
||||
return 0
|
||||
@@ -603,11 +603,11 @@ class VectorMemoryStorage:
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
# 新增:元数据过滤参数(用于JSON索引粗筛)
|
||||
metadata_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Tuple[MemoryChunk, float]]:
|
||||
metadata_filters: dict[str, Any] | None = None,
|
||||
) -> list[tuple[MemoryChunk, float]]:
|
||||
"""
|
||||
搜索相似记忆(混合索引模式)
|
||||
|
||||
@@ -632,7 +632,7 @@ class VectorMemoryStorage:
|
||||
|
||||
try:
|
||||
# === 阶段一:JSON元数据粗筛(可选) ===
|
||||
candidate_ids: Optional[List[str]] = None
|
||||
candidate_ids: list[str] | None = None
|
||||
if metadata_filters:
|
||||
logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}")
|
||||
candidate_ids = self.metadata_index.search(
|
||||
@@ -746,7 +746,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"搜索相似记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""根据ID获取记忆"""
|
||||
# 首先尝试从缓存获取
|
||||
memory = self._get_from_cache(memory_id)
|
||||
@@ -772,7 +772,7 @@ class VectorMemoryStorage:
|
||||
|
||||
return None
|
||||
|
||||
async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]:
|
||||
async def get_memories_by_filters(self, filters: dict[str, Any], limit: int = 100) -> list[MemoryChunk]:
|
||||
"""根据过滤条件获取记忆"""
|
||||
try:
|
||||
results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit)
|
||||
@@ -848,7 +848,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"删除记忆 {memory_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int:
|
||||
async def delete_memories_by_filters(self, filters: dict[str, Any]) -> int:
|
||||
"""根据过滤条件批量删除记忆"""
|
||||
try:
|
||||
# 先获取要删除的记忆ID
|
||||
@@ -880,7 +880,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"批量删除记忆失败: {e}")
|
||||
return 0
|
||||
|
||||
async def perform_forgetting_check(self) -> Dict[str, Any]:
|
||||
async def perform_forgetting_check(self) -> dict[str, Any]:
|
||||
"""执行遗忘检查"""
|
||||
if not self.forgetting_engine:
|
||||
return {"error": "遗忘引擎未启用"}
|
||||
@@ -925,7 +925,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"执行遗忘检查失败: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
try:
|
||||
current_total = vector_db_service.count(self.config.memory_collection)
|
||||
@@ -960,7 +960,7 @@ class VectorMemoryStorage:
|
||||
_global_vector_storage = None
|
||||
|
||||
|
||||
def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage:
|
||||
def get_vector_memory_storage(config: VectorStorageConfig | None = None) -> VectorMemoryStorage:
|
||||
"""获取全局Vector记忆存储实例"""
|
||||
global _global_vector_storage
|
||||
|
||||
@@ -974,15 +974,15 @@ def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> V
|
||||
class VectorMemoryStorageAdapter:
|
||||
"""适配器类,提供与原UnifiedMemoryStorage兼容的接口"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
self.storage = VectorMemoryStorage(config)
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
async def store_memories(self, memories: list[MemoryChunk]) -> int:
|
||||
return await self.storage.store_memories(memories)
|
||||
|
||||
async def search_similar_memories(
|
||||
self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
self, query_text: str, limit: int = 10, scope_id: str | None = None, filters: dict[str, Any] | None = None
|
||||
) -> list[tuple[str, float]]:
|
||||
results = await self.storage.search_similar_memories(query_text, limit, filters=filters)
|
||||
# 转换为原格式:(memory_id, similarity)
|
||||
return [
|
||||
@@ -990,7 +990,7 @@ class VectorMemoryStorageAdapter:
|
||||
for memory, similarity in results
|
||||
]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
return self.storage.get_storage_stats()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user