feat(memory): 重构记忆系统并移除插件热重载
重构记忆系统核心模块,引入全局记忆作用域、记忆指纹去重机制和查询规划器,优化多阶段检索性能。移除插件热重载系统及其相关依赖。 主要变更: - 引入全局记忆作用域,简化记忆管理 - 实现记忆指纹去重,避免重复记忆存储 - 新增查询规划器,支持语义查询规划和记忆类型过滤 - 优化多阶段检索,增加语义重排和权重配置 - 改进向量存储,支持嵌入维度自动解析和查询向量生成 - 增强元数据索引,支持主体索引和更新操作 - 记忆构建器支持多主体和自然语言展示 - 移除watchdog依赖和插件热重载模块 - 更新配置模板,简化记忆配置项 BREAKING CHANGE: 移除插件热重载系统,相关API和命令不再可用。记忆系统接口有较大调整,使用该系统的模块需要适配新接口。
This commit is contained in:
@@ -70,7 +70,6 @@ dependencies = [
|
||||
"tqdm>=4.67.1",
|
||||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"watchdog>=6.0.0",
|
||||
"websockets>=15.0.1",
|
||||
"aiomysql>=0.2.0",
|
||||
"aiosqlite>=0.21.0",
|
||||
|
||||
@@ -50,7 +50,6 @@ reportportal-client
|
||||
scikit-learn
|
||||
seaborn
|
||||
structlog
|
||||
watchdog
|
||||
httpx
|
||||
requests
|
||||
beautifulsoup4
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
|
||||
logger = get_logger("bot_interest_manager")
|
||||
@@ -28,7 +29,9 @@ class BotInterestManager:
|
||||
# Embedding客户端配置
|
||||
self.embedding_request = None
|
||||
self.embedding_config = None
|
||||
self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度
|
||||
configured_dim = resolve_embedding_dimension()
|
||||
self.embedding_dimension = int(configured_dim) if configured_dim else 0
|
||||
self._detected_embedding_dimension: Optional[int] = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
@@ -82,8 +85,11 @@ class BotInterestManager:
|
||||
|
||||
logger.info("📋 找到embedding模型配置")
|
||||
self.embedding_config = model_config.model_task_config.embedding
|
||||
self.embedding_dimension = 1024 # BGE-M3的维度
|
||||
logger.info(f"📐 使用模型维度: {self.embedding_dimension}")
|
||||
|
||||
if self.embedding_dimension:
|
||||
logger.info(f"📐 配置的embedding维度: {self.embedding_dimension}")
|
||||
else:
|
||||
logger.info("📐 未在配置中检测到embedding维度,将根据首次返回的向量自动识别")
|
||||
|
||||
# 创建LLMRequest实例用于embedding
|
||||
self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding")
|
||||
@@ -350,7 +356,27 @@ class BotInterestManager:
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
self.embedding_cache[text] = embedding
|
||||
logger.debug(f"✅ Embedding获取成功,维度: {len(embedding)}, 模型: {model_name}")
|
||||
|
||||
current_dim = len(embedding)
|
||||
if self._detected_embedding_dimension is None:
|
||||
self._detected_embedding_dimension = current_dim
|
||||
if self.embedding_dimension and self.embedding_dimension != current_dim:
|
||||
logger.warning(
|
||||
"⚠️ 实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新",
|
||||
current_dim,
|
||||
self.embedding_dimension,
|
||||
)
|
||||
else:
|
||||
self.embedding_dimension = current_dim
|
||||
logger.info(f"📏 检测到embedding维度: {current_dim}")
|
||||
elif current_dim != self.embedding_dimension:
|
||||
logger.warning(
|
||||
"⚠️ 收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。",
|
||||
self.embedding_dimension,
|
||||
current_dim,
|
||||
)
|
||||
|
||||
logger.debug(f"✅ Embedding获取成功,维度: {current_dim}, 模型: {model_name}")
|
||||
return embedding
|
||||
else:
|
||||
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
||||
|
||||
@@ -26,6 +26,7 @@ from rich.progress import (
|
||||
TextColumn,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -504,7 +505,10 @@ class EmbeddingStore:
|
||||
# L2归一化
|
||||
faiss.normalize_L2(embeddings)
|
||||
# 构建索引
|
||||
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
||||
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
||||
if not embedding_dim:
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.faiss_index.add(embeddings)
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
|
||||
@@ -11,12 +11,27 @@ from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
MEMORY_TYPE_LABELS = {
|
||||
MemoryType.PERSONAL_FACT: "个人事实",
|
||||
MemoryType.EVENT: "事件",
|
||||
MemoryType.PREFERENCE: "偏好",
|
||||
MemoryType.OPINION: "观点",
|
||||
MemoryType.RELATIONSHIP: "关系",
|
||||
MemoryType.EMOTION: "情感",
|
||||
MemoryType.KNOWLEDGE: "知识",
|
||||
MemoryType.SKILL: "技能",
|
||||
MemoryType.GOAL: "目标",
|
||||
MemoryType.EXPERIENCE: "经验",
|
||||
MemoryType.CONTEXTUAL: "上下文",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig:
|
||||
"""适配器配置"""
|
||||
@@ -85,12 +100,9 @@ class EnhancedMemoryAdapter:
|
||||
|
||||
async def process_conversation_memory(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""处理对话记忆"""
|
||||
"""处理对话记忆,以上下文为唯一输入"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"success": False, "error": "Enhanced memory not available"}
|
||||
|
||||
@@ -98,10 +110,30 @@ class EnhancedMemoryAdapter:
|
||||
self.adapter_stats["total_processed"] += 1
|
||||
|
||||
try:
|
||||
payload_context: Dict[str, Any] = dict(context or {})
|
||||
|
||||
conversation_text = payload_context.get("conversation_text")
|
||||
if not conversation_text:
|
||||
conversation_candidate = (
|
||||
payload_context.get("message_content")
|
||||
or payload_context.get("latest_message")
|
||||
or payload_context.get("raw_text")
|
||||
)
|
||||
if conversation_candidate is not None:
|
||||
conversation_text = str(conversation_candidate)
|
||||
payload_context["conversation_text"] = conversation_text
|
||||
else:
|
||||
conversation_text = ""
|
||||
else:
|
||||
conversation_text = str(conversation_text)
|
||||
|
||||
if "timestamp" not in payload_context:
|
||||
payload_context["timestamp"] = time.time()
|
||||
|
||||
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
|
||||
|
||||
# 使用集成层处理对话
|
||||
result = await self.integration_layer.process_conversation(
|
||||
conversation_text, context, user_id, timestamp
|
||||
)
|
||||
result = await self.integration_layer.process_conversation(payload_context)
|
||||
|
||||
# 更新统计
|
||||
processing_time = time.time() - start_time
|
||||
@@ -132,7 +164,7 @@ class EnhancedMemoryAdapter:
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.integration_layer.retrieve_relevant_memories(
|
||||
query, user_id, context, limit
|
||||
query, None, context, limit
|
||||
)
|
||||
|
||||
self.adapter_stats["memories_retrieved"] += len(memories)
|
||||
@@ -157,12 +189,15 @@ class EnhancedMemoryAdapter:
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
# 格式化记忆为提示词友好的格式
|
||||
memory_context_parts = []
|
||||
for memory in memories:
|
||||
memory_context_parts.append(f"- {memory.text_content}")
|
||||
# 格式化记忆为提示词友好的Markdown结构
|
||||
lines: List[str] = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
||||
|
||||
return "\n".join(memory_context_parts)
|
||||
for memory in memories:
|
||||
type_label = MEMORY_TYPE_LABELS.get(memory.memory_type, memory.memory_type.value)
|
||||
display_text = memory.display or memory.text_content
|
||||
lines.append(f"- **[{type_label}]** {display_text}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
|
||||
"""获取增强记忆系统摘要"""
|
||||
@@ -270,13 +305,10 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
||||
|
||||
|
||||
async def process_conversation_with_enhanced_memory(
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None,
|
||||
llm_model: Optional[LLMRequest] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""使用增强记忆系统处理对话"""
|
||||
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
@@ -284,7 +316,18 @@ async def process_conversation_with_enhanced_memory(
|
||||
|
||||
try:
|
||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||
return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp)
|
||||
payload_context = dict(context or {})
|
||||
|
||||
if "conversation_text" not in payload_context:
|
||||
conversation_candidate = (
|
||||
payload_context.get("message_content")
|
||||
or payload_context.get("latest_message")
|
||||
or payload_context.get("raw_text")
|
||||
)
|
||||
if conversation_candidate is not None:
|
||||
payload_context["conversation_text"] = str(conversation_candidate)
|
||||
|
||||
return await adapter.process_conversation_memory(payload_context)
|
||||
except Exception as e:
|
||||
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强型精准记忆系统核心模块
|
||||
基于文档设计的高效记忆构建、存储与召回优化系统
|
||||
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
|
||||
2. 内置 LLM 查询规划器与嵌入维度自动解析机制,直接从模型配置推断向量存储参数。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import orjson
|
||||
import re
|
||||
import hashlib
|
||||
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
@@ -22,12 +24,16 @@ from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
||||
from src.chat.memory_system.metadata_index import MetadataIndexManager
|
||||
from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局记忆作用域(共享记忆库)
|
||||
GLOBAL_MEMORY_SCOPE = "global"
|
||||
|
||||
|
||||
class MemorySystemStatus(Enum):
|
||||
"""记忆系统状态"""
|
||||
@@ -47,14 +53,20 @@ class MemorySystemConfig:
|
||||
memory_value_threshold: float = 0.7
|
||||
min_build_interval_seconds: float = 300.0
|
||||
|
||||
# 向量存储配置
|
||||
vector_dimension: int = 768
|
||||
# 向量存储配置(嵌入维度自动来自模型配置)
|
||||
vector_dimension: int = 1024
|
||||
similarity_threshold: float = 0.8
|
||||
|
||||
# 召回配置
|
||||
coarse_recall_limit: int = 50
|
||||
fine_recall_limit: int = 10
|
||||
semantic_rerank_limit: int = 20
|
||||
final_recall_limit: int = 5
|
||||
semantic_similarity_threshold: float = 0.6
|
||||
vector_weight: float = 0.4
|
||||
semantic_weight: float = 0.3
|
||||
context_weight: float = 0.2
|
||||
recency_weight: float = 0.1
|
||||
|
||||
# 融合配置
|
||||
fusion_similarity_threshold: float = 0.85
|
||||
@@ -64,6 +76,23 @@ class MemorySystemConfig:
|
||||
def from_global_config(cls):
|
||||
"""从全局配置创建配置实例"""
|
||||
|
||||
embedding_dimension = None
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
embedding_dimension = getattr(embedding_task, "embedding_dimension", None)
|
||||
except Exception:
|
||||
embedding_dimension = None
|
||||
|
||||
if not embedding_dimension:
|
||||
try:
|
||||
embedding_dimension = getattr(global_config.lpmm_knowledge, "embedding_dimension", None)
|
||||
except Exception:
|
||||
embedding_dimension = None
|
||||
|
||||
if not embedding_dimension:
|
||||
embedding_dimension = 1024
|
||||
|
||||
return cls(
|
||||
# 记忆构建配置
|
||||
min_memory_length=global_config.memory.min_memory_length,
|
||||
@@ -72,13 +101,19 @@ class MemorySystemConfig:
|
||||
min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0),
|
||||
|
||||
# 向量存储配置
|
||||
vector_dimension=global_config.memory.vector_dimension,
|
||||
vector_dimension=int(embedding_dimension),
|
||||
similarity_threshold=global_config.memory.vector_similarity_threshold,
|
||||
|
||||
# 召回配置
|
||||
coarse_recall_limit=global_config.memory.metadata_filter_limit,
|
||||
fine_recall_limit=global_config.memory.final_result_limit,
|
||||
fine_recall_limit=global_config.memory.vector_search_limit,
|
||||
semantic_rerank_limit=global_config.memory.semantic_rerank_limit,
|
||||
final_recall_limit=global_config.memory.final_result_limit,
|
||||
semantic_similarity_threshold=getattr(global_config.memory, "semantic_similarity_threshold", 0.6),
|
||||
vector_weight=global_config.memory.vector_weight,
|
||||
semantic_weight=global_config.memory.semantic_weight,
|
||||
context_weight=global_config.memory.context_weight,
|
||||
recency_weight=global_config.memory.recency_weight,
|
||||
|
||||
# 融合配置
|
||||
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
|
||||
@@ -104,6 +139,7 @@ class EnhancedMemorySystem:
|
||||
self.vector_storage: VectorStorageManager = None
|
||||
self.metadata_index: MetadataIndexManager = None
|
||||
self.retrieval_system: MultiStageRetrieval = None
|
||||
self.query_planner: MemoryQueryPlanner = None
|
||||
|
||||
# LLM模型
|
||||
self.value_assessment_model: LLMRequest = None
|
||||
@@ -117,6 +153,9 @@ class EnhancedMemorySystem:
|
||||
# 构建节流记录
|
||||
self._last_memory_build_times: Dict[str, float] = {}
|
||||
|
||||
# 记忆指纹缓存,用于快速检测重复记忆
|
||||
self._memory_fingerprints: Dict[str, str] = {}
|
||||
|
||||
logger.info("EnhancedMemorySystem 初始化开始")
|
||||
|
||||
async def initialize(self):
|
||||
@@ -125,19 +164,29 @@ class EnhancedMemorySystem:
|
||||
logger.info("正在初始化增强型记忆系统...")
|
||||
|
||||
# 初始化LLM模型
|
||||
task_config = (
|
||||
self.llm_model.model_for_task
|
||||
if self.llm_model is not None
|
||||
else model_config.model_task_config.utils_small
|
||||
)
|
||||
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||
|
||||
value_task_config = getattr(model_config.model_task_config, "utils_small", None)
|
||||
extraction_task_config = getattr(model_config.model_task_config, "utils", None)
|
||||
|
||||
if value_task_config is None:
|
||||
logger.warning("未找到 utils_small 模型配置,回退到 utils 或外部提供的模型配置。")
|
||||
value_task_config = extraction_task_config or fallback_task
|
||||
|
||||
if extraction_task_config is None:
|
||||
logger.warning("未找到 utils 模型配置,回退到 utils_small 或外部提供的模型配置。")
|
||||
extraction_task_config = value_task_config or fallback_task
|
||||
|
||||
if value_task_config is None or extraction_task_config is None:
|
||||
raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。")
|
||||
|
||||
self.value_assessment_model = LLMRequest(
|
||||
model_set=task_config,
|
||||
model_set=value_task_config,
|
||||
request_type="memory.value_assessment"
|
||||
)
|
||||
|
||||
self.memory_extraction_model = LLMRequest(
|
||||
model_set=task_config,
|
||||
model_set=extraction_task_config,
|
||||
request_type="memory.extraction"
|
||||
)
|
||||
|
||||
@@ -155,13 +204,36 @@ class EnhancedMemorySystem:
|
||||
retrieval_config = RetrievalConfig(
|
||||
metadata_filter_limit=self.config.coarse_recall_limit,
|
||||
vector_search_limit=self.config.fine_recall_limit,
|
||||
final_result_limit=self.config.final_recall_limit
|
||||
semantic_rerank_limit=self.config.semantic_rerank_limit,
|
||||
final_result_limit=self.config.final_recall_limit,
|
||||
vector_similarity_threshold=self.config.similarity_threshold,
|
||||
semantic_similarity_threshold=self.config.semantic_similarity_threshold,
|
||||
vector_weight=self.config.vector_weight,
|
||||
semantic_weight=self.config.semantic_weight,
|
||||
context_weight=self.config.context_weight,
|
||||
recency_weight=self.config.recency_weight,
|
||||
)
|
||||
self.retrieval_system = MultiStageRetrieval(retrieval_config)
|
||||
|
||||
planner_task_config = getattr(model_config.model_task_config, "planner", None)
|
||||
planner_model: Optional[LLMRequest] = None
|
||||
try:
|
||||
planner_model = LLMRequest(
|
||||
model_set=planner_task_config,
|
||||
request_type="memory.query_planner"
|
||||
)
|
||||
except Exception as planner_exc:
|
||||
logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True)
|
||||
|
||||
self.query_planner = MemoryQueryPlanner(
|
||||
planner_model,
|
||||
default_limit=self.config.final_recall_limit
|
||||
)
|
||||
|
||||
# 加载持久化数据
|
||||
await self.vector_storage.load_storage()
|
||||
await self.metadata_index.load_index()
|
||||
self._populate_memory_fingerprints()
|
||||
|
||||
self.status = MemorySystemStatus.READY
|
||||
logger.info("✅ 增强型记忆系统初始化完成")
|
||||
@@ -174,7 +246,7 @@ class EnhancedMemorySystem:
|
||||
async def retrieve_memories_for_building(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 5
|
||||
) -> List[MemoryChunk]:
|
||||
@@ -182,7 +254,6 @@ class EnhancedMemorySystem:
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
user_id: 用户ID
|
||||
context: 上下文信息
|
||||
limit: 返回结果数量限制
|
||||
|
||||
@@ -201,7 +272,6 @@ class EnhancedMemorySystem:
|
||||
# 执行检索
|
||||
memories = await self.vector_storage.search_similar_memories(
|
||||
query_text=query_text,
|
||||
user_id=user_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
@@ -218,23 +288,18 @@ class EnhancedMemorySystem:
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
Args:
|
||||
conversation_text: 对话文本
|
||||
context: 上下文信息(包括用户信息、群组信息等)
|
||||
user_id: 用户ID
|
||||
context: 上下文信息
|
||||
timestamp: 时间戳,默认为当前时间
|
||||
|
||||
Returns:
|
||||
构建的记忆块列表
|
||||
"""
|
||||
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
|
||||
raise RuntimeError("记忆系统未就绪")
|
||||
|
||||
original_status = self.status
|
||||
self.status = MemorySystemStatus.BUILDING
|
||||
start_time = time.time()
|
||||
@@ -243,9 +308,9 @@ class EnhancedMemorySystem:
|
||||
build_marker_time: Optional[float] = None
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, user_id, timestamp)
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
|
||||
|
||||
build_scope_key = self._get_build_scope_key(normalized_context, user_id)
|
||||
build_scope_key = self._get_build_scope_key(normalized_context, GLOBAL_MEMORY_SCOPE)
|
||||
min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0))
|
||||
current_time = time.time()
|
||||
|
||||
@@ -266,7 +331,7 @@ class EnhancedMemorySystem:
|
||||
|
||||
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
|
||||
|
||||
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
|
||||
logger.debug("开始构建记忆,文本长度: %d", len(conversation_text))
|
||||
|
||||
# 1. 信息价值评估
|
||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||
@@ -280,7 +345,7 @@ class EnhancedMemorySystem:
|
||||
memory_chunks = await self.memory_builder.build_memories(
|
||||
conversation_text,
|
||||
normalized_context,
|
||||
user_id,
|
||||
GLOBAL_MEMORY_SCOPE,
|
||||
timestamp or time.time()
|
||||
)
|
||||
|
||||
@@ -293,19 +358,24 @@ class EnhancedMemorySystem:
|
||||
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks)
|
||||
|
||||
# 4. 存储记忆
|
||||
await self._store_memories(fused_chunks)
|
||||
stored_count = await self._store_memories(fused_chunks)
|
||||
|
||||
# 4.1 控制台预览
|
||||
self._log_memory_preview(fused_chunks)
|
||||
|
||||
# 5. 更新统计
|
||||
self.total_memories += len(fused_chunks)
|
||||
self.total_memories += stored_count
|
||||
self.last_build_time = time.time()
|
||||
if build_scope_key:
|
||||
self._last_memory_build_times[build_scope_key] = self.last_build_time
|
||||
|
||||
build_time = time.time() - start_time
|
||||
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒")
|
||||
logger.info(
|
||||
"✅ 生成 %d 条记忆,成功入库 %d 条,耗时 %.2f秒",
|
||||
len(fused_chunks),
|
||||
stored_count,
|
||||
build_time,
|
||||
)
|
||||
|
||||
self.status = original_status
|
||||
return fused_chunks
|
||||
@@ -347,21 +417,34 @@ class EnhancedMemorySystem:
|
||||
|
||||
async def process_conversation_memory(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""对外暴露的对话记忆处理接口,兼容旧调用方式"""
|
||||
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, user_id, timestamp)
|
||||
context = dict(context or {})
|
||||
|
||||
conversation_candidate = (
|
||||
context.get("conversation_text")
|
||||
or context.get("message_content")
|
||||
or context.get("latest_message")
|
||||
or context.get("raw_text")
|
||||
or ""
|
||||
)
|
||||
|
||||
conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
|
||||
|
||||
timestamp = context.get("timestamp")
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
|
||||
normalized_context.setdefault("conversation_text", conversation_text)
|
||||
|
||||
memories = await self.build_memory_from_conversation(
|
||||
conversation_text=conversation_text,
|
||||
context=normalized_context,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
@@ -395,52 +478,77 @@ class EnhancedMemorySystem:
|
||||
**kwargs
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆,兼容 query/query_text 参数形式"""
|
||||
if self.status != MemorySystemStatus.READY:
|
||||
raise RuntimeError("记忆系统未就绪")
|
||||
|
||||
query_text = query_text or kwargs.get("query")
|
||||
if not query_text:
|
||||
raw_query = query_text or kwargs.get("query")
|
||||
if not raw_query:
|
||||
raise ValueError("query_text 或 query 参数不能为空")
|
||||
|
||||
context = context or {}
|
||||
user_id = user_id or kwargs.get("user_id")
|
||||
resolved_user_id = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
if self.retrieval_system is None or self.metadata_index is None:
|
||||
raise RuntimeError("检索组件未初始化")
|
||||
|
||||
all_memories_cache = self.vector_storage.memory_cache
|
||||
if not all_memories_cache:
|
||||
logger.debug("记忆缓存为空,返回空结果")
|
||||
self.last_retrieval_time = time.time()
|
||||
self.status = MemorySystemStatus.READY
|
||||
return []
|
||||
|
||||
self.status = MemorySystemStatus.RETRIEVING
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, user_id, None)
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
|
||||
|
||||
candidate_memories = list(self.vector_storage.memory_cache.values())
|
||||
if user_id:
|
||||
candidate_memories = [m for m in candidate_memories if m.user_id == user_id]
|
||||
effective_limit = limit or self.config.final_recall_limit
|
||||
query_plan = None
|
||||
planner_ran = False
|
||||
resolved_query_text = raw_query
|
||||
if self.query_planner:
|
||||
try:
|
||||
planner_ran = True
|
||||
query_plan = await self.query_planner.plan_query(raw_query, normalized_context)
|
||||
normalized_context["query_plan"] = query_plan
|
||||
effective_limit = min(effective_limit, query_plan.limit or effective_limit)
|
||||
if getattr(query_plan, "semantic_query", None):
|
||||
resolved_query_text = query_plan.semantic_query
|
||||
logger.debug(
|
||||
"查询规划: semantic='%s', types=%s, subjects=%s, limit=%d",
|
||||
query_plan.semantic_query,
|
||||
[mt.value for mt in query_plan.memory_types],
|
||||
query_plan.subject_includes,
|
||||
query_plan.limit,
|
||||
)
|
||||
except Exception as plan_exc:
|
||||
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
|
||||
|
||||
if not candidate_memories:
|
||||
self.status = MemorySystemStatus.READY
|
||||
self.last_retrieval_time = time.time()
|
||||
logger.debug(f"未找到用户 {user_id} 的候选记忆")
|
||||
return []
|
||||
effective_limit = effective_limit or self.config.final_recall_limit
|
||||
effective_limit = max(1, min(effective_limit, self.config.final_recall_limit))
|
||||
normalized_context["resolved_query_text"] = resolved_query_text
|
||||
|
||||
scored_memories = []
|
||||
for memory in candidate_memories:
|
||||
score = self._compute_memory_score(query_text, memory, normalized_context)
|
||||
if score > 0:
|
||||
scored_memories.append((memory, score))
|
||||
|
||||
if not scored_memories:
|
||||
# 如果所有分数为0,返回最近的记忆作为降级策略
|
||||
if normalized_context.get("__memory_building__"):
|
||||
logger.debug("当前处于记忆构建流程,跳过查询规划并进行降级检索")
|
||||
self.status = MemorySystemStatus.BUILDING
|
||||
final_memories = []
|
||||
candidate_memories = list(all_memories_cache.values())
|
||||
candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True)
|
||||
scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]]
|
||||
final_memories = candidate_memories[:effective_limit]
|
||||
else:
|
||||
scored_memories.sort(key=lambda item: item[1], reverse=True)
|
||||
retrieval_result = await self.retrieval_system.retrieve_memories(
|
||||
query=resolved_query_text,
|
||||
user_id=resolved_user_id,
|
||||
context=normalized_context,
|
||||
metadata_index=self.metadata_index,
|
||||
vector_storage=self.vector_storage,
|
||||
all_memories_cache=all_memories_cache,
|
||||
limit=effective_limit,
|
||||
)
|
||||
|
||||
top_memories = [memory for memory, _ in scored_memories[:limit]]
|
||||
final_memories = retrieval_result.final_memories
|
||||
|
||||
# 更新访问信息和缓存
|
||||
for memory, score in scored_memories[:limit]:
|
||||
for memory in final_memories:
|
||||
memory.update_access()
|
||||
memory.update_relevance(score)
|
||||
|
||||
cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id)
|
||||
if cache_entry is not None:
|
||||
cache_entry["last_accessed"] = memory.metadata.last_accessed
|
||||
@@ -448,14 +556,34 @@ class EnhancedMemorySystem:
|
||||
cache_entry["relevance_score"] = memory.metadata.relevance_score
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}秒"
|
||||
plan_summary = ""
|
||||
if planner_ran and query_plan:
|
||||
plan_types = ",".join(mt.value for mt in query_plan.memory_types) or "-"
|
||||
plan_subjects = ",".join(query_plan.subject_includes) or "-"
|
||||
plan_summary = (
|
||||
f" | planner.semantic='{query_plan.semantic_query}'"
|
||||
f" | planner.limit={query_plan.limit}"
|
||||
f" | planner.types={plan_types}"
|
||||
f" | planner.subjects={plan_subjects}"
|
||||
)
|
||||
|
||||
log_message = (
|
||||
"✅ 记忆检索完成"
|
||||
f" | user={resolved_user_id}"
|
||||
f" | count={len(final_memories)}"
|
||||
f" | duration={retrieval_time:.3f}s"
|
||||
f" | applied_limit={effective_limit}"
|
||||
f" | raw_query='{raw_query}'"
|
||||
f" | semantic_query='{resolved_query_text}'"
|
||||
f"{plan_summary}"
|
||||
)
|
||||
|
||||
logger.info(log_message)
|
||||
|
||||
self.last_retrieval_time = time.time()
|
||||
self.status = MemorySystemStatus.READY
|
||||
|
||||
return top_memories
|
||||
return final_memories
|
||||
|
||||
except Exception as e:
|
||||
self.status = MemorySystemStatus.ERROR
|
||||
@@ -499,8 +627,8 @@ class EnhancedMemorySystem:
|
||||
except Exception:
|
||||
context = dict(raw_context or {})
|
||||
|
||||
# 基础字段
|
||||
context["user_id"] = context.get("user_id") or user_id or "unknown"
|
||||
# 基础字段(统一使用全局作用域)
|
||||
context["user_id"] = GLOBAL_MEMORY_SCOPE
|
||||
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
|
||||
context["message_type"] = context.get("message_type") or "normal"
|
||||
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
|
||||
@@ -523,8 +651,8 @@ class EnhancedMemorySystem:
|
||||
if stream_id:
|
||||
context["stream_id"] = stream_id
|
||||
|
||||
# chat_id 兜底
|
||||
context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}"
|
||||
# 全局记忆无需聊天隔离
|
||||
context["chat_id"] = context.get("chat_id") or "global_chat"
|
||||
|
||||
# 历史窗口配置
|
||||
window_candidate = (
|
||||
@@ -616,18 +744,7 @@ class EnhancedMemorySystem:
|
||||
|
||||
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
||||
"""确定用于节流控制的记忆构建作用域"""
|
||||
stream_id = context.get("stream_id")
|
||||
if stream_id:
|
||||
return f"stream::{stream_id}"
|
||||
|
||||
chat_id = context.get("chat_id")
|
||||
if chat_id:
|
||||
return f"chat::{chat_id}"
|
||||
|
||||
if user_id:
|
||||
return f"user::{user_id}"
|
||||
|
||||
return None
|
||||
return "global_scope"
|
||||
|
||||
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
|
||||
"""确定历史消息获取数量,限制在30-50之间"""
|
||||
@@ -789,24 +906,134 @@ class EnhancedMemorySystem:
|
||||
logger.error(f"信息价值评估失败: {e}", exc_info=True)
|
||||
return 0.5 # 默认中等价值
|
||||
|
||||
async def _store_memories(self, memory_chunks: List[MemoryChunk]):
|
||||
"""存储记忆块到各个存储系统"""
|
||||
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
|
||||
"""存储记忆块到各个存储系统,返回成功入库数量"""
|
||||
if not memory_chunks:
|
||||
return
|
||||
return 0
|
||||
|
||||
unique_memories: List[MemoryChunk] = []
|
||||
skipped_duplicates = 0
|
||||
|
||||
for memory in memory_chunks:
|
||||
fingerprint = self._build_memory_fingerprint(memory)
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
|
||||
existing_id = self._memory_fingerprints.get(key)
|
||||
if existing_id:
|
||||
existing = self.vector_storage.memory_cache.get(existing_id)
|
||||
if existing:
|
||||
self._merge_existing_memory(existing, memory)
|
||||
await self.metadata_index.update_memory_entry(existing)
|
||||
skipped_duplicates += 1
|
||||
logger.debug(
|
||||
"检测到重复记忆,已合并到现有记录 | memory_id=%s",
|
||||
existing.memory_id,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 指纹存在但缓存缺失,视为新记忆并覆盖旧映射
|
||||
logger.debug("检测到过期指纹映射,重写现有条目")
|
||||
|
||||
unique_memories.append(memory)
|
||||
|
||||
if not unique_memories:
|
||||
if skipped_duplicates:
|
||||
logger.info("本次记忆全部与现有内容重复,跳过入库")
|
||||
return 0
|
||||
|
||||
# 并行存储到向量数据库和元数据索引
|
||||
storage_tasks = []
|
||||
storage_tasks = [
|
||||
self.vector_storage.store_memories(unique_memories),
|
||||
self.metadata_index.index_memories(unique_memories),
|
||||
]
|
||||
|
||||
# 向量存储
|
||||
storage_tasks.append(self.vector_storage.store_memories(memory_chunks))
|
||||
|
||||
# 元数据索引
|
||||
storage_tasks.append(self.metadata_index.index_memories(memory_chunks))
|
||||
|
||||
# 等待所有存储任务完成
|
||||
await asyncio.gather(*storage_tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统")
|
||||
self._register_memory_fingerprints(unique_memories)
|
||||
|
||||
logger.debug(
|
||||
"成功存储 %d 条记忆(跳过重复 %d 条)",
|
||||
len(unique_memories),
|
||||
skipped_duplicates,
|
||||
)
|
||||
|
||||
return len(unique_memories)
|
||||
|
||||
def _merge_existing_memory(self, existing: MemoryChunk, incoming: MemoryChunk) -> None:
|
||||
"""将新记忆的信息合并到已存在的记忆中"""
|
||||
updated = False
|
||||
|
||||
for keyword in incoming.keywords:
|
||||
if keyword not in existing.keywords:
|
||||
existing.add_keyword(keyword)
|
||||
updated = True
|
||||
|
||||
for tag in incoming.tags:
|
||||
if tag not in existing.tags:
|
||||
existing.add_tag(tag)
|
||||
updated = True
|
||||
|
||||
for category in incoming.categories:
|
||||
if category not in existing.categories:
|
||||
existing.add_category(category)
|
||||
updated = True
|
||||
|
||||
if incoming.metadata.source_context:
|
||||
existing.metadata.source_context = incoming.metadata.source_context
|
||||
|
||||
if incoming.metadata.importance.value > existing.metadata.importance.value:
|
||||
existing.metadata.importance = incoming.metadata.importance
|
||||
updated = True
|
||||
|
||||
if incoming.metadata.confidence.value > existing.metadata.confidence.value:
|
||||
existing.metadata.confidence = incoming.metadata.confidence
|
||||
updated = True
|
||||
|
||||
if incoming.metadata.relevance_score > existing.metadata.relevance_score:
|
||||
existing.metadata.relevance_score = incoming.metadata.relevance_score
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
existing.metadata.last_modified = time.time()
|
||||
|
||||
def _populate_memory_fingerprints(self) -> None:
|
||||
"""基于当前缓存构建记忆指纹映射"""
|
||||
self._memory_fingerprints.clear()
|
||||
for memory in self.vector_storage.memory_cache.values():
|
||||
fingerprint = self._build_memory_fingerprint(memory)
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
self._memory_fingerprints[key] = memory.memory_id
|
||||
|
||||
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
|
||||
for memory in memories:
|
||||
fingerprint = self._build_memory_fingerprint(memory)
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
self._memory_fingerprints[key] = memory.memory_id
|
||||
|
||||
def _build_memory_fingerprint(self, memory: MemoryChunk) -> str:
|
||||
subjects = memory.subjects or []
|
||||
subject_part = "|".join(sorted(s.strip() for s in subjects if s))
|
||||
predicate_part = (memory.content.predicate or "").strip()
|
||||
|
||||
obj = memory.content.object
|
||||
if isinstance(obj, (dict, list)):
|
||||
obj_part = orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
else:
|
||||
obj_part = str(obj).strip()
|
||||
|
||||
base = "|".join([
|
||||
str(memory.user_id or "unknown"),
|
||||
memory.memory_type.value,
|
||||
subject_part,
|
||||
predicate_part,
|
||||
obj_part,
|
||||
])
|
||||
|
||||
return hashlib.sha256(base.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
|
||||
return f"{str(user_id)}:{fingerprint}"
|
||||
|
||||
def get_system_stats(self) -> Dict[str, Any]:
|
||||
"""获取系统统计信息"""
|
||||
|
||||
@@ -241,12 +241,12 @@ class EnhancedMemoryManager:
|
||||
return []
|
||||
|
||||
try:
|
||||
result = await self.enhanced_system.process_conversation_memory(
|
||||
conversation_text=conversation_text,
|
||||
context=context,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp
|
||||
)
|
||||
payload_context = dict(context or {})
|
||||
payload_context.setdefault("conversation_text", conversation_text)
|
||||
if timestamp is not None:
|
||||
payload_context.setdefault("timestamp", timestamp)
|
||||
|
||||
result = await self.enhanced_system.process_conversation_memory(payload_context)
|
||||
|
||||
# 从结果中提取记忆块
|
||||
memory_chunks = []
|
||||
@@ -274,7 +274,7 @@ class EnhancedMemoryManager:
|
||||
try:
|
||||
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
|
||||
query=query_text,
|
||||
user_id=user_id,
|
||||
user_id=None,
|
||||
context=context or {},
|
||||
limit=limit
|
||||
)
|
||||
@@ -303,6 +303,9 @@ class EnhancedMemoryManager:
|
||||
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
|
||||
"""将记忆块转换为更易读的文本描述"""
|
||||
structure = memory.content.to_dict()
|
||||
if memory.display:
|
||||
return self._clean_text(memory.display), structure
|
||||
|
||||
subject = structure.get("subject")
|
||||
predicate = structure.get("predicate") or ""
|
||||
obj = structure.get("object")
|
||||
|
||||
@@ -114,12 +114,9 @@ class MemoryIntegrationLayer:
|
||||
|
||||
async def process_conversation(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""处理对话记忆"""
|
||||
"""处理对话记忆,仅使用上下文信息"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return {"success": False, "error": "Memory system not available"}
|
||||
|
||||
@@ -128,13 +125,12 @@ class MemoryIntegrationLayer:
|
||||
self.integration_stats["enhanced_queries"] += 1
|
||||
|
||||
try:
|
||||
payload_context = dict(context or {})
|
||||
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
|
||||
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
|
||||
|
||||
# 直接使用增强记忆系统处理
|
||||
result = await self.enhanced_memory.process_conversation_memory(
|
||||
conversation_text=conversation_text,
|
||||
context=context,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp
|
||||
)
|
||||
result = await self.enhanced_memory.process_conversation_memory(payload_context)
|
||||
|
||||
# 更新统计
|
||||
processing_time = time.time() - start_time
|
||||
@@ -156,7 +152,7 @@ class MemoryIntegrationLayer:
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[MemoryChunk]:
|
||||
@@ -168,7 +164,7 @@ class MemoryIntegrationLayer:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.enhanced_memory.retrieve_relevant_memories(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
user_id=None,
|
||||
context=context or {},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
@@ -2,21 +2,48 @@
|
||||
"""
|
||||
记忆构建模块
|
||||
从对话流中提取高质量、结构化记忆单元
|
||||
输出格式要求:
|
||||
{{
|
||||
"memories": [
|
||||
{{
|
||||
"type": "记忆类型",
|
||||
"display": "用于直接展示和检索的自然语言描述",
|
||||
"subject": ["主体1", "主体2"],
|
||||
"predicate": "谓语(动作/状态)",
|
||||
"object": "宾语(对象/属性或结构体)",
|
||||
"keywords": ["关键词1", "关键词2"],
|
||||
"importance": "重要性等级(1-4)",
|
||||
"confidence": "置信度(1-4)",
|
||||
"reasoning": "提取理由"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意:
|
||||
1. `subject` 可包含多个主体,请用数组表示;若主体不明确,请根据上下文给出最合理的称呼
|
||||
2. `display` 必须是一句完整流畅的中文描述,可直接用于用户展示和向量搜索
|
||||
3. 只提取确实值得记忆的信息,不要提取琐碎内容
|
||||
4. 确保信息准确、具体、有价值
|
||||
5. 重要性: 1=低, 2=一般, 3=高, 4=关键;置信度: 1=低, 2=中等, 3=高, 4=已验证
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union, Type
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
|
||||
create_memory_chunk
|
||||
MemoryChunk,
|
||||
MemoryType,
|
||||
ConfidenceLevel,
|
||||
ImportanceLevel,
|
||||
create_memory_chunk,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -24,6 +51,7 @@ logger = get_logger(__name__)
|
||||
|
||||
class ExtractionStrategy(Enum):
|
||||
"""提取策略"""
|
||||
|
||||
LLM_BASED = "llm_based" # 基于LLM的智能提取
|
||||
RULE_BASED = "rule_based" # 基于规则的提取
|
||||
HYBRID = "hybrid" # 混合策略
|
||||
@@ -171,18 +199,18 @@ class MemoryBuilder:
|
||||
"""使用规则提取记忆"""
|
||||
memories = []
|
||||
|
||||
subject_display = self._resolve_user_display(context, user_id)
|
||||
subjects = self._resolve_conversation_participants(context, user_id)
|
||||
|
||||
# 规则1: 检测个人信息
|
||||
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subject_display)
|
||||
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(personal_info)
|
||||
|
||||
# 规则2: 检测偏好信息
|
||||
preferences = self._extract_preferences(text, user_id, timestamp, context, subject_display)
|
||||
preferences = self._extract_preferences(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(preferences)
|
||||
|
||||
# 规则3: 检测事件信息
|
||||
events = self._extract_events(text, user_id, timestamp, context, subject_display)
|
||||
events = self._extract_events(text, user_id, timestamp, context, subjects)
|
||||
memories.extend(events)
|
||||
|
||||
return memories
|
||||
@@ -258,10 +286,7 @@ class MemoryBuilder:
|
||||
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
|
||||
|
||||
当前时间: {current_date}
|
||||
聊天ID: {chat_id}
|
||||
消息类型: {message_type}
|
||||
目标用户ID: {target_user_id_display}
|
||||
目标用户称呼: {target_user_name}
|
||||
|
||||
## 🤖 机器人身份(仅供参考,禁止写入记忆)
|
||||
- 机器人名称: {bot_name_display}
|
||||
@@ -272,7 +297,6 @@ class MemoryBuilder:
|
||||
|
||||
请务必遵守以下命名规范:
|
||||
- 当说话者是机器人时,请使用“{bot_name_display}”或其他明确称呼作为主语;
|
||||
- 如果看到系统自动生成的长ID(类似 {target_user_id}),请改用“{target_user_name}”、机器人的称呼或“该用户”描述,不要把ID写入记忆;
|
||||
- 记录关键事实时,请准确标记主体是机器人还是用户,避免混淆。
|
||||
|
||||
对话内容:
|
||||
@@ -450,7 +474,7 @@ class MemoryBuilder:
|
||||
|
||||
bot_identifiers = self._collect_bot_identifiers(context)
|
||||
system_identifiers = self._collect_system_identifiers(context)
|
||||
default_subject = self._resolve_user_display(context, user_id)
|
||||
default_subjects = self._resolve_conversation_participants(context, user_id)
|
||||
|
||||
bot_display = None
|
||||
if context:
|
||||
@@ -481,19 +505,33 @@ class MemoryBuilder:
|
||||
for mem_data in memory_list:
|
||||
try:
|
||||
subject_value = mem_data.get("subject")
|
||||
normalized_subject = self._normalize_subject(
|
||||
normalized_subject = self._normalize_subjects(
|
||||
subject_value,
|
||||
bot_identifiers,
|
||||
system_identifiers,
|
||||
default_subject,
|
||||
default_subjects,
|
||||
bot_display
|
||||
)
|
||||
|
||||
if normalized_subject is None:
|
||||
if not normalized_subject:
|
||||
logger.debug("跳过疑似机器人自身信息的记忆: %s", mem_data)
|
||||
continue
|
||||
|
||||
# 创建记忆块
|
||||
importance_level = self._parse_enum_value(
|
||||
ImportanceLevel,
|
||||
mem_data.get("importance"),
|
||||
ImportanceLevel.NORMAL,
|
||||
"importance"
|
||||
)
|
||||
|
||||
confidence_level = self._parse_enum_value(
|
||||
ConfidenceLevel,
|
||||
mem_data.get("confidence"),
|
||||
ConfidenceLevel.MEDIUM,
|
||||
"confidence"
|
||||
)
|
||||
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=normalized_subject,
|
||||
@@ -502,8 +540,9 @@ class MemoryBuilder:
|
||||
memory_type=MemoryType(mem_data.get("type", "contextual")),
|
||||
chat_id=context.get("chat_id"),
|
||||
source_context=mem_data.get("reasoning", ""),
|
||||
importance=ImportanceLevel(mem_data.get("importance", 2)),
|
||||
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
|
||||
importance=importance_level,
|
||||
confidence=confidence_level,
|
||||
display=mem_data.get("display")
|
||||
)
|
||||
|
||||
# 添加关键词
|
||||
@@ -511,13 +550,6 @@ class MemoryBuilder:
|
||||
for keyword in keywords:
|
||||
memory.add_keyword(keyword)
|
||||
|
||||
subject_text = memory.content.subject.strip() if isinstance(memory.content.subject, str) else str(memory.content.subject)
|
||||
if not subject_text:
|
||||
memory.content.subject = default_subject
|
||||
elif subject_text.lower() in system_identifiers or self._looks_like_system_identifier(subject_text):
|
||||
logger.debug("将系统标识主语替换为默认用户名称: %s", subject_text)
|
||||
memory.content.subject = default_subject
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
except Exception as e:
|
||||
@@ -526,6 +558,64 @@ class MemoryBuilder:
|
||||
|
||||
return memories
|
||||
|
||||
def _parse_enum_value(
|
||||
self,
|
||||
enum_cls: Type[Enum],
|
||||
raw_value: Any,
|
||||
default: Enum,
|
||||
field_name: str
|
||||
) -> Enum:
|
||||
"""解析枚举值,兼容数字/字符串表示"""
|
||||
if isinstance(raw_value, enum_cls):
|
||||
return raw_value
|
||||
|
||||
if raw_value is None:
|
||||
return default
|
||||
|
||||
# 直接尝试整数转换
|
||||
if isinstance(raw_value, (int, float)):
|
||||
int_value = int(raw_value)
|
||||
try:
|
||||
return enum_cls(int_value)
|
||||
except ValueError:
|
||||
logger.debug("%s=%s 无法解析为 %s", field_name, raw_value, enum_cls.__name__)
|
||||
return default
|
||||
|
||||
if isinstance(raw_value, str):
|
||||
value_str = raw_value.strip()
|
||||
if not value_str:
|
||||
return default
|
||||
|
||||
if value_str.isdigit():
|
||||
try:
|
||||
return enum_cls(int(value_str))
|
||||
except ValueError:
|
||||
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
|
||||
else:
|
||||
normalized = value_str.replace("-", "_").replace(" ", "_").upper()
|
||||
for member in enum_cls:
|
||||
if member.name == normalized:
|
||||
return member
|
||||
for member in enum_cls:
|
||||
if str(member.value).lower() == value_str.lower():
|
||||
return member
|
||||
|
||||
try:
|
||||
return enum_cls(value_str)
|
||||
except ValueError:
|
||||
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
|
||||
|
||||
try:
|
||||
return enum_cls(raw_value)
|
||||
except Exception:
|
||||
logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
|
||||
field_name,
|
||||
raw_value,
|
||||
type(raw_value).__name__,
|
||||
enum_cls.__name__,
|
||||
default.name)
|
||||
return default
|
||||
|
||||
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
identifiers: set[str] = {"bot", "机器人", "ai助手"}
|
||||
if not context:
|
||||
@@ -580,6 +670,58 @@ class MemoryBuilder:
|
||||
|
||||
return identifiers
|
||||
|
||||
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
|
||||
participants: List[str] = []
|
||||
|
||||
if context:
|
||||
candidate_keys = [
|
||||
"participants",
|
||||
"participant_names",
|
||||
"speaker_names",
|
||||
"members",
|
||||
"member_names",
|
||||
"mention_users",
|
||||
"audiences"
|
||||
]
|
||||
|
||||
for key in candidate_keys:
|
||||
value = context.get(key)
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
cleaned = self._clean_subject_text(item)
|
||||
if cleaned:
|
||||
participants.append(cleaned)
|
||||
elif isinstance(value, str):
|
||||
for part in self._split_subject_string(value):
|
||||
if part:
|
||||
participants.append(part)
|
||||
|
||||
fallback = self._resolve_user_display(context, user_id)
|
||||
if fallback:
|
||||
participants.append(fallback)
|
||||
|
||||
if context:
|
||||
bot_name = context.get("bot_name") or context.get("bot_identity")
|
||||
if isinstance(bot_name, str):
|
||||
cleaned = self._clean_subject_text(bot_name)
|
||||
if cleaned:
|
||||
participants.append(cleaned)
|
||||
|
||||
if not participants:
|
||||
participants = ["对话参与者"]
|
||||
|
||||
deduplicated: List[str] = []
|
||||
seen = set()
|
||||
for name in participants:
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduplicated.append(name)
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
|
||||
candidate_keys = [
|
||||
"user_display_name",
|
||||
@@ -626,51 +768,160 @@ class MemoryBuilder:
|
||||
|
||||
return False
|
||||
|
||||
def _normalize_subject(
|
||||
def _split_subject_string(self, value: str) -> List[str]:
|
||||
if not value:
|
||||
return []
|
||||
|
||||
replaced = re.sub(r"\band\b", "、", value, flags=re.IGNORECASE)
|
||||
replaced = replaced.replace("和", "、").replace("与", "、").replace("及", "、")
|
||||
replaced = replaced.replace("&", "、").replace("/", "、").replace("+", "、")
|
||||
|
||||
tokens = [self._clean_subject_text(token) for token in re.split(r"[、,,;;]+", replaced)]
|
||||
return [token for token in tokens if token]
|
||||
|
||||
def _normalize_subjects(
|
||||
self,
|
||||
subject: Any,
|
||||
bot_identifiers: set[str],
|
||||
system_identifiers: set[str],
|
||||
default_subject: str,
|
||||
default_subjects: List[str],
|
||||
bot_display: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
if subject is None:
|
||||
return default_subject
|
||||
) -> List[str]:
|
||||
defaults = default_subjects or ["对话参与者"]
|
||||
|
||||
subject_str = subject if isinstance(subject, str) else str(subject)
|
||||
cleaned = self._clean_subject_text(subject_str)
|
||||
if not cleaned:
|
||||
return default_subject
|
||||
raw_candidates: List[str] = []
|
||||
if isinstance(subject, list):
|
||||
for item in subject:
|
||||
if isinstance(item, str):
|
||||
raw_candidates.extend(self._split_subject_string(item))
|
||||
elif item is not None:
|
||||
raw_candidates.extend(self._split_subject_string(str(item)))
|
||||
elif isinstance(subject, str):
|
||||
raw_candidates.extend(self._split_subject_string(subject))
|
||||
elif subject is not None:
|
||||
raw_candidates.extend(self._split_subject_string(str(subject)))
|
||||
|
||||
lowered = cleaned.lower()
|
||||
normalized: List[str] = []
|
||||
bot_primary = self._clean_subject_text(bot_display or "")
|
||||
|
||||
if lowered in bot_identifiers:
|
||||
return bot_primary or cleaned
|
||||
for candidate in raw_candidates:
|
||||
if not candidate:
|
||||
continue
|
||||
|
||||
if lowered in {"用户", "user", "the user", "对方", "对手"}:
|
||||
return default_subject
|
||||
lowered = candidate.lower()
|
||||
if lowered in bot_identifiers:
|
||||
normalized.append(bot_primary or candidate)
|
||||
continue
|
||||
|
||||
prefix_match = re.match(r"^(用户|User|user|USER|成员|member|Member|target|Target|TARGET)[\s::\-\u2014_]*?(.*)$", cleaned)
|
||||
if prefix_match:
|
||||
remainder = self._clean_subject_text(prefix_match.group(2))
|
||||
if not remainder:
|
||||
return default_subject
|
||||
remainder_lower = remainder.lower()
|
||||
if remainder_lower in bot_identifiers:
|
||||
return bot_primary or remainder
|
||||
if (
|
||||
remainder_lower in system_identifiers
|
||||
or self._looks_like_system_identifier(remainder)
|
||||
):
|
||||
return default_subject
|
||||
cleaned = remainder
|
||||
lowered = cleaned.lower()
|
||||
if lowered in {"用户", "user", "the user", "对方", "对手"}:
|
||||
normalized.extend(defaults)
|
||||
continue
|
||||
|
||||
if lowered in system_identifiers or self._looks_like_system_identifier(cleaned):
|
||||
return default_subject
|
||||
if lowered in system_identifiers or self._looks_like_system_identifier(candidate):
|
||||
continue
|
||||
|
||||
return cleaned
|
||||
normalized.append(candidate)
|
||||
|
||||
if not normalized:
|
||||
normalized = list(defaults)
|
||||
|
||||
deduplicated: List[str] = []
|
||||
seen = set()
|
||||
for name in normalized:
|
||||
key = name.lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduplicated.append(name)
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
|
||||
if isinstance(obj, dict):
|
||||
for key in keys:
|
||||
value = obj.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, list):
|
||||
compact = "、".join(str(item) for item in value[:3])
|
||||
if compact:
|
||||
return compact
|
||||
else:
|
||||
value_str = str(value).strip()
|
||||
if value_str:
|
||||
return value_str
|
||||
elif isinstance(obj, list):
|
||||
compact = "、".join(str(item) for item in obj[:3])
|
||||
return compact or None
|
||||
elif isinstance(obj, str):
|
||||
return obj.strip() or None
|
||||
return None
|
||||
|
||||
def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str:
|
||||
subject_phrase = "、".join(subjects) if subjects else "对话参与者"
|
||||
predicate = (predicate or "").strip()
|
||||
|
||||
if predicate == "is_named":
|
||||
name = self._extract_value_from_object(obj, ["name", "nickname"]) or ""
|
||||
name = self._clean_subject_text(name)
|
||||
if name:
|
||||
quoted = name if (name.startswith("「") and name.endswith("」")) else f"「{name}」"
|
||||
return f"{subject_phrase}的昵称是{quoted}"
|
||||
elif predicate == "is_age":
|
||||
age = self._extract_value_from_object(obj, ["age"]) or ""
|
||||
age = self._clean_subject_text(age)
|
||||
if age:
|
||||
return f"{subject_phrase}今年{age}岁"
|
||||
elif predicate == "is_profession":
|
||||
profession = self._extract_value_from_object(obj, ["profession", "job"]) or ""
|
||||
profession = self._clean_subject_text(profession)
|
||||
if profession:
|
||||
return f"{subject_phrase}的职业是{profession}"
|
||||
elif predicate == "lives_in":
|
||||
location = self._extract_value_from_object(obj, ["location", "city", "place"]) or ""
|
||||
location = self._clean_subject_text(location)
|
||||
if location:
|
||||
return f"{subject_phrase}居住在{location}"
|
||||
elif predicate == "has_phone":
|
||||
phone = self._extract_value_from_object(obj, ["phone", "number"]) or ""
|
||||
phone = self._clean_subject_text(phone)
|
||||
if phone:
|
||||
return f"{subject_phrase}的电话号码是{phone}"
|
||||
elif predicate == "has_email":
|
||||
email = self._extract_value_from_object(obj, ["email"]) or ""
|
||||
email = self._clean_subject_text(email)
|
||||
if email:
|
||||
return f"{subject_phrase}的邮箱是{email}"
|
||||
elif predicate in {"likes", "likes_food", "favorite_is"}:
|
||||
liked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
|
||||
liked = self._clean_subject_text(liked)
|
||||
if liked:
|
||||
verb = "喜欢" if predicate != "likes_food" else "爱吃"
|
||||
if predicate == "favorite_is":
|
||||
verb = "最喜欢"
|
||||
return f"{subject_phrase}{verb}{liked}"
|
||||
elif predicate in {"dislikes", "hates"}:
|
||||
disliked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
|
||||
disliked = self._clean_subject_text(disliked)
|
||||
if disliked:
|
||||
verb = "不喜欢" if predicate == "dislikes" else "讨厌"
|
||||
return f"{subject_phrase}{verb}{disliked}"
|
||||
elif predicate == "mentioned_event":
|
||||
description = self._extract_value_from_object(obj, ["event_text", "description"]) or ""
|
||||
description = self._clean_subject_text(description)
|
||||
if description:
|
||||
return f"{subject_phrase}提到了:{description}"
|
||||
|
||||
obj_text = self._extract_value_from_object(obj, ["value", "detail", "content"]) or ""
|
||||
obj_text = self._clean_subject_text(obj_text)
|
||||
|
||||
if predicate and obj_text:
|
||||
return f"{subject_phrase}{predicate}{obj_text}".strip()
|
||||
if obj_text:
|
||||
return f"{subject_phrase}{obj_text}".strip()
|
||||
if predicate:
|
||||
return f"{subject_phrase}{predicate}".strip()
|
||||
return subject_phrase
|
||||
|
||||
def _extract_personal_info(
|
||||
self,
|
||||
@@ -678,7 +929,7 @@ class MemoryBuilder:
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subject_display: str
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取个人信息"""
|
||||
memories = []
|
||||
@@ -702,13 +953,14 @@ class MemoryBuilder:
|
||||
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subject_display,
|
||||
subject=subjects,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
memory_type=MemoryType.PERSONAL_FACT,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.HIGH,
|
||||
confidence=ConfidenceLevel.HIGH
|
||||
confidence=ConfidenceLevel.HIGH,
|
||||
display=self._compose_display_text(subjects, predicate, obj)
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
@@ -721,7 +973,7 @@ class MemoryBuilder:
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subject_display: str
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取偏好信息"""
|
||||
memories = []
|
||||
@@ -740,13 +992,14 @@ class MemoryBuilder:
|
||||
if match:
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subject_display,
|
||||
subject=subjects,
|
||||
predicate=predicate,
|
||||
obj=match.group(1),
|
||||
memory_type=MemoryType.PREFERENCE,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.NORMAL,
|
||||
confidence=ConfidenceLevel.MEDIUM
|
||||
confidence=ConfidenceLevel.MEDIUM,
|
||||
display=self._compose_display_text(subjects, predicate, match.group(1))
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
@@ -759,7 +1012,7 @@ class MemoryBuilder:
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any],
|
||||
subject_display: str
|
||||
subjects: List[str]
|
||||
) -> List[MemoryChunk]:
|
||||
"""提取事件信息"""
|
||||
memories = []
|
||||
@@ -770,13 +1023,14 @@ class MemoryBuilder:
|
||||
if any(keyword in text for keyword in event_keywords):
|
||||
memory = create_memory_chunk(
|
||||
user_id=user_id,
|
||||
subject=subject_display,
|
||||
subject=subjects,
|
||||
predicate="mentioned_event",
|
||||
obj={"event_text": text, "timestamp": timestamp},
|
||||
memory_type=MemoryType.EVENT,
|
||||
chat_id=context.get("chat_id"),
|
||||
importance=ImportanceLevel.NORMAL,
|
||||
confidence=ConfidenceLevel.MEDIUM
|
||||
confidence=ConfidenceLevel.MEDIUM,
|
||||
display=self._compose_display_text(subjects, "mentioned_event", text)
|
||||
)
|
||||
|
||||
memories.append(memory)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import time
|
||||
import uuid
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from typing import Dict, List, Optional, Any, Union, Iterable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -52,17 +52,20 @@ class ImportanceLevel(Enum):
|
||||
|
||||
@dataclass
|
||||
class ContentStructure:
|
||||
"""主谓宾三元组结构"""
|
||||
subject: str # 主语(通常为用户)
|
||||
predicate: str # 谓语(动作、状态、关系)
|
||||
object: Union[str, Dict] # 宾语(对象、属性、值)
|
||||
"""主谓宾结构,包含自然语言描述"""
|
||||
|
||||
subject: Union[str, List[str]]
|
||||
predicate: str
|
||||
object: Union[str, Dict]
|
||||
display: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object": self.object
|
||||
"object": self.object,
|
||||
"display": self.display
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -71,16 +74,25 @@ class ContentStructure:
|
||||
return cls(
|
||||
subject=data.get("subject", ""),
|
||||
predicate=data.get("predicate", ""),
|
||||
object=data.get("object", "")
|
||||
object=data.get("object", ""),
|
||||
display=data.get("display", "")
|
||||
)
|
||||
|
||||
def to_subject_list(self) -> List[str]:
|
||||
"""将主语转换为列表形式"""
|
||||
if isinstance(self.subject, list):
|
||||
return [s for s in self.subject if isinstance(s, str) and s.strip()]
|
||||
if isinstance(self.subject, str) and self.subject.strip():
|
||||
return [self.subject.strip()]
|
||||
return []
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
if isinstance(self.object, dict):
|
||||
object_str = str(self.object)
|
||||
else:
|
||||
object_str = str(self.object)
|
||||
return f"{self.subject} {self.predicate} {object_str}"
|
||||
if self.display:
|
||||
return self.display
|
||||
subjects = "、".join(self.to_subject_list()) or str(self.subject)
|
||||
object_str = self.object if isinstance(self.object, str) else str(self.object)
|
||||
return f"{subjects} {self.predicate} {object_str}".strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -236,9 +248,19 @@ class MemoryChunk:
|
||||
|
||||
@property
|
||||
def text_content(self) -> str:
|
||||
"""获取文本内容"""
|
||||
"""获取文本内容(优先使用display)"""
|
||||
return str(self.content)
|
||||
|
||||
@property
|
||||
def display(self) -> str:
|
||||
"""获取展示文本"""
|
||||
return self.content.display or str(self.content)
|
||||
|
||||
@property
|
||||
def subjects(self) -> List[str]:
|
||||
"""获取主语列表"""
|
||||
return self.content.to_subject_list()
|
||||
|
||||
def update_access(self):
|
||||
"""更新访问信息"""
|
||||
self.metadata.update_access()
|
||||
@@ -415,16 +437,42 @@ class MemoryChunk:
|
||||
confidence_icon = "●" * self.metadata.confidence.value
|
||||
importance_icon = "★" * self.metadata.importance.value
|
||||
|
||||
return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}"
|
||||
return f"{emoji} [{self.memory_type.value}] {self.display} {confidence_icon} {importance_icon}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""调试表示"""
|
||||
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
|
||||
|
||||
|
||||
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
|
||||
"""根据主谓宾生成自然语言描述"""
|
||||
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
|
||||
subject_part = "、".join(subjects_clean) if subjects_clean else "对话参与者"
|
||||
|
||||
if isinstance(obj, dict):
|
||||
object_candidates = []
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, (str, int, float)):
|
||||
object_candidates.append(f"{key}:{value}")
|
||||
elif isinstance(value, list):
|
||||
compact = "、".join(str(item) for item in value[:3])
|
||||
object_candidates.append(f"{key}:{compact}")
|
||||
object_part = ",".join(object_candidates) if object_candidates else str(obj)
|
||||
else:
|
||||
object_part = str(obj).strip()
|
||||
|
||||
predicate_clean = predicate.strip()
|
||||
if not predicate_clean:
|
||||
return f"{subject_part} {object_part}".strip()
|
||||
|
||||
if object_part:
|
||||
return f"{subject_part}{predicate_clean}{object_part}".strip()
|
||||
return f"{subject_part}{predicate_clean}".strip()
|
||||
|
||||
|
||||
def create_memory_chunk(
|
||||
user_id: str,
|
||||
subject: str,
|
||||
subject: Union[str, List[str]],
|
||||
predicate: str,
|
||||
obj: Union[str, Dict],
|
||||
memory_type: MemoryType,
|
||||
@@ -432,6 +480,7 @@ def create_memory_chunk(
|
||||
source_context: Optional[str] = None,
|
||||
importance: ImportanceLevel = ImportanceLevel.NORMAL,
|
||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
|
||||
display: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> MemoryChunk:
|
||||
"""便捷的内存块创建函数"""
|
||||
@@ -447,10 +496,22 @@ def create_memory_chunk(
|
||||
source_context=source_context
|
||||
)
|
||||
|
||||
subjects: List[str]
|
||||
if isinstance(subject, list):
|
||||
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
|
||||
subject_payload: Union[str, List[str]] = subjects
|
||||
else:
|
||||
cleaned = subject.strip() if isinstance(subject, str) else ""
|
||||
subjects = [cleaned] if cleaned else []
|
||||
subject_payload = cleaned
|
||||
|
||||
display_text = display or _build_display_text(subjects, predicate, obj)
|
||||
|
||||
content = ContentStructure(
|
||||
subject=subject,
|
||||
subject=subject_payload,
|
||||
predicate=predicate,
|
||||
object=obj
|
||||
object=obj,
|
||||
display=display_text
|
||||
)
|
||||
|
||||
chunk = MemoryChunk(
|
||||
|
||||
@@ -266,8 +266,12 @@ class MemoryFusionEngine:
|
||||
consistency_score = 0.0
|
||||
|
||||
# 主语一致性
|
||||
if mem1.content.subject == mem2.content.subject:
|
||||
consistency_score += 0.4
|
||||
subjects1 = set(mem1.subjects)
|
||||
subjects2 = set(mem2.subjects)
|
||||
if subjects1 or subjects2:
|
||||
overlap = len(subjects1 & subjects2)
|
||||
union_count = max(len(subjects1 | subjects2), 1)
|
||||
consistency_score += (overlap / union_count) * 0.4
|
||||
|
||||
# 谓语相似性
|
||||
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)
|
||||
|
||||
@@ -282,9 +282,11 @@ class MemoryIntegrationHooks:
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
result = await process_conversation_with_enhanced_memory(
|
||||
conversation_text, context, user_id
|
||||
)
|
||||
memory_context = dict(context)
|
||||
memory_context["conversation_text"] = conversation_text
|
||||
memory_context["user_id"] = user_id
|
||||
|
||||
result = await process_conversation_with_enhanced_memory(memory_context)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
@@ -336,9 +338,11 @@ class MemoryIntegrationHooks:
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
result = await process_conversation_with_enhanced_memory(
|
||||
conversation_text, context, user_id
|
||||
)
|
||||
memory_context = dict(context)
|
||||
memory_context["conversation_text"] = conversation_text
|
||||
memory_context["user_id"] = user_id
|
||||
|
||||
result = await process_conversation_with_enhanced_memory(memory_context)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
|
||||
194
src/chat/memory_system/memory_query_planner.py
Normal file
194
src/chat/memory_system/memory_query_planner.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""记忆检索查询规划器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryType
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryQueryPlan:
|
||||
"""查询规划结果"""
|
||||
|
||||
semantic_query: str
|
||||
memory_types: List[MemoryType] = field(default_factory=list)
|
||||
subject_includes: List[str] = field(default_factory=list)
|
||||
object_includes: List[str] = field(default_factory=list)
|
||||
required_keywords: List[str] = field(default_factory=list)
|
||||
optional_keywords: List[str] = field(default_factory=list)
|
||||
owner_filters: List[str] = field(default_factory=list)
|
||||
recency_preference: str = "any"
|
||||
limit: int = 10
|
||||
emphasis: Optional[str] = None
|
||||
raw_plan: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
|
||||
if not self.semantic_query:
|
||||
self.semantic_query = fallback_query
|
||||
if self.limit <= 0:
|
||||
self.limit = default_limit
|
||||
self.recency_preference = (self.recency_preference or "any").lower()
|
||||
if self.recency_preference not in {"any", "recent", "historical"}:
|
||||
self.recency_preference = "any"
|
||||
self.emphasis = (self.emphasis or "balanced").lower()
|
||||
|
||||
|
||||
class MemoryQueryPlanner:
|
||||
"""基于小模型的记忆检索查询规划器"""
|
||||
|
||||
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
|
||||
self.model = planner_model
|
||||
self.default_limit = default_limit
|
||||
|
||||
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
|
||||
if not self.model:
|
||||
logger.debug("未提供查询规划模型,使用默认规划")
|
||||
return self._default_plan(query_text)
|
||||
|
||||
prompt = self._build_prompt(query_text, context)
|
||||
|
||||
try:
|
||||
response, _ = await self.model.generate_response_async(prompt, temperature=0.2)
|
||||
payload = self._extract_json_payload(response)
|
||||
if not payload:
|
||||
logger.debug("查询规划模型未返回结构化结果,使用默认规划")
|
||||
return self._default_plan(query_text)
|
||||
|
||||
try:
|
||||
data = orjson.loads(payload)
|
||||
except orjson.JSONDecodeError as exc:
|
||||
preview = payload[:200]
|
||||
logger.warning("解析查询规划JSON失败: %s,片段: %s", exc, preview)
|
||||
return self._default_plan(query_text)
|
||||
|
||||
plan = self._parse_plan_dict(data, query_text)
|
||||
plan.ensure_defaults(query_text, self.default_limit)
|
||||
return plan
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("查询规划模型调用失败: %s", exc, exc_info=True)
|
||||
return self._default_plan(query_text)
|
||||
|
||||
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
|
||||
return MemoryQueryPlan(
|
||||
semantic_query=query_text,
|
||||
limit=self.default_limit
|
||||
)
|
||||
|
||||
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
|
||||
|
||||
def _collect_list(key: str) -> List[str]:
|
||||
value = data.get(key)
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
if isinstance(value, list):
|
||||
return [self._safe_str(item) for item in value if self._safe_str(item)]
|
||||
return []
|
||||
|
||||
memory_type_values = _collect_list("memory_types")
|
||||
memory_types: List[MemoryType] = []
|
||||
for item in memory_type_values:
|
||||
if not item:
|
||||
continue
|
||||
try:
|
||||
memory_types.append(MemoryType(item))
|
||||
except ValueError:
|
||||
# 尝试匹配value值
|
||||
normalized = item.lower()
|
||||
for mt in MemoryType:
|
||||
if mt.value == normalized:
|
||||
memory_types.append(mt)
|
||||
break
|
||||
|
||||
plan = MemoryQueryPlan(
|
||||
semantic_query=semantic_query,
|
||||
memory_types=memory_types,
|
||||
subject_includes=_collect_list("subject_includes"),
|
||||
object_includes=_collect_list("object_includes"),
|
||||
required_keywords=_collect_list("required_keywords"),
|
||||
optional_keywords=_collect_list("optional_keywords"),
|
||||
owner_filters=_collect_list("owner_filters"),
|
||||
recency_preference=self._safe_str(data.get("recency")) or "any",
|
||||
limit=self._safe_int(data.get("limit"), self.default_limit),
|
||||
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
|
||||
raw_plan=data
|
||||
)
|
||||
return plan
|
||||
|
||||
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
|
||||
participants = context.get("participants") or context.get("speaker_names") or []
|
||||
if isinstance(participants, str):
|
||||
participants = [participants]
|
||||
participants = [p for p in participants if isinstance(p, str) and p.strip()]
|
||||
participant_preview = "、".join(participants[:5]) or "未知"
|
||||
|
||||
persona = context.get("bot_personality") or context.get("bot_identity") or "未知"
|
||||
|
||||
return f"""
|
||||
你是一名记忆检索分析师,将根据对话查询生成结构化的检索计划。
|
||||
请结合提供的上下文,输出一个JSON对象,字段含义如下:
|
||||
- semantic_query: 提供给向量检索的自然语言查询,要求清晰具体;
|
||||
- memory_types: 建议检索的记忆类型数组,取值范围参见 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual);
|
||||
- subject_includes: 需要出现在记忆主语中的人或角色列表;
|
||||
- object_includes: 记忆中需要提到的重要对象或主题关键词列表;
|
||||
- required_keywords: 检索时必须包含的关键词;
|
||||
- optional_keywords: 可以提升相关性的附加关键词;
|
||||
- owner_filters: 如果需要限制检索所属主体,请列出用户ID或其它标识;
|
||||
- recency: 建议的时间偏好,可选 recent/any/historical;
|
||||
- emphasis: 检索策略倾向,可选 precision/recall/balanced;
|
||||
- limit: 推荐的最大返回数量(1-15之间);
|
||||
- notes: 额外说明,可选。
|
||||
|
||||
当前查询: "{query_text}"
|
||||
已知的对话参与者: {participant_preview}
|
||||
机器人设定: {persona}
|
||||
|
||||
请输出符合要求的JSON,禁止添加额外说明或Markdown代码块。
|
||||
"""
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
stripped = response.strip()
|
||||
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
|
||||
if code_block_match:
|
||||
candidate = code_block_match.group(1).strip()
|
||||
if candidate:
|
||||
return candidate
|
||||
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start:end + 1]
|
||||
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
@staticmethod
|
||||
def _safe_str(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value.strip()
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
@staticmethod
|
||||
def _safe_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
number = int(value)
|
||||
if number <= 0:
|
||||
return default
|
||||
return number
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
@@ -25,6 +25,7 @@ class IndexType(Enum):
|
||||
"""索引类型"""
|
||||
MEMORY_TYPE = "memory_type" # 记忆类型索引
|
||||
USER_ID = "user_id" # 用户ID索引
|
||||
SUBJECT = "subject" # 主体索引
|
||||
KEYWORD = "keyword" # 关键词索引
|
||||
TAG = "tag" # 标签索引
|
||||
CATEGORY = "category" # 分类索引
|
||||
@@ -41,6 +42,7 @@ class IndexQuery:
|
||||
"""索引查询条件"""
|
||||
user_ids: Optional[List[str]] = None
|
||||
memory_types: Optional[List[MemoryType]] = None
|
||||
subjects: Optional[List[str]] = None
|
||||
keywords: Optional[List[str]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
categories: Optional[List[str]] = None
|
||||
@@ -76,6 +78,7 @@ class MetadataIndexManager:
|
||||
self.indices = {
|
||||
IndexType.MEMORY_TYPE: defaultdict(set),
|
||||
IndexType.USER_ID: defaultdict(set),
|
||||
IndexType.SUBJECT: defaultdict(set),
|
||||
IndexType.KEYWORD: defaultdict(set),
|
||||
IndexType.TAG: defaultdict(set),
|
||||
IndexType.CATEGORY: defaultdict(set),
|
||||
@@ -110,6 +113,41 @@ class MetadataIndexManager:
|
||||
self.auto_save_interval = 500 # 每500次操作自动保存
|
||||
self._operation_count = 0
|
||||
|
||||
@staticmethod
|
||||
def _serialize_index_key(index_type: IndexType, key: Any) -> str:
|
||||
"""将索引键序列化为字符串以便存储"""
|
||||
if isinstance(key, Enum):
|
||||
value = key.value
|
||||
else:
|
||||
value = key
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_index_key(index_type: IndexType, key: str) -> Any:
|
||||
"""根据索引类型反序列化索引键"""
|
||||
try:
|
||||
if index_type == IndexType.MEMORY_TYPE:
|
||||
return MemoryType(key)
|
||||
if index_type == IndexType.CONFIDENCE:
|
||||
return ConfidenceLevel(int(key))
|
||||
if index_type == IndexType.IMPORTANCE:
|
||||
return ImportanceLevel(int(key))
|
||||
# 其他索引键默认使用原始字符串(可能已经是lower后的字符串)
|
||||
return key
|
||||
except Exception:
|
||||
logger.warning("无法反序列化索引键 %s 在索引 %s 中,使用原始字符串", key, index_type.value)
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized = {}
|
||||
for field_name, value in metadata.items():
|
||||
if isinstance(value, Enum):
|
||||
serialized[field_name] = value.value
|
||||
else:
|
||||
serialized[field_name] = value
|
||||
return serialized
|
||||
|
||||
async def index_memories(self, memories: List[MemoryChunk]):
|
||||
"""为记忆建立索引"""
|
||||
if not memories:
|
||||
@@ -142,6 +180,68 @@ class MetadataIndexManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
async def update_memory_entry(self, memory: MemoryChunk):
|
||||
"""更新已存在记忆的索引信息"""
|
||||
if not memory:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
entry = self.memory_metadata_cache.get(memory.memory_id)
|
||||
if entry is None:
|
||||
# 若不存在则作为新记忆索引
|
||||
self._index_single_memory(memory)
|
||||
return
|
||||
|
||||
old_confidence = entry.get("confidence")
|
||||
old_importance = entry.get("importance")
|
||||
old_semantic_hash = entry.get("semantic_hash")
|
||||
|
||||
entry.update(
|
||||
{
|
||||
"user_id": memory.user_id,
|
||||
"memory_type": memory.memory_type,
|
||||
"created_at": memory.metadata.created_at,
|
||||
"last_accessed": memory.metadata.last_accessed,
|
||||
"access_count": memory.metadata.access_count,
|
||||
"confidence": memory.metadata.confidence,
|
||||
"importance": memory.metadata.importance,
|
||||
"relationship_score": memory.metadata.relationship_score,
|
||||
"relevance_score": memory.metadata.relevance_score,
|
||||
"semantic_hash": memory.semantic_hash,
|
||||
"subjects": memory.subjects,
|
||||
}
|
||||
)
|
||||
|
||||
# 更新置信度/重要性索引
|
||||
if isinstance(old_confidence, ConfidenceLevel):
|
||||
self.indices[IndexType.CONFIDENCE][old_confidence].discard(memory.memory_id)
|
||||
if isinstance(old_importance, ImportanceLevel):
|
||||
self.indices[IndexType.IMPORTANCE][old_importance].discard(memory.memory_id)
|
||||
if isinstance(old_semantic_hash, str):
|
||||
self.indices[IndexType.SEMANTIC_HASH][old_semantic_hash].discard(memory.memory_id)
|
||||
|
||||
self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory.memory_id)
|
||||
self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory.memory_id)
|
||||
if memory.semantic_hash:
|
||||
self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory.memory_id)
|
||||
|
||||
# 同步关键词/标签/分类索引
|
||||
for keyword in memory.keywords:
|
||||
if keyword:
|
||||
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory.memory_id)
|
||||
|
||||
for tag in memory.tags:
|
||||
if tag:
|
||||
self.indices[IndexType.TAG][tag.lower()].add(memory.memory_id)
|
||||
|
||||
for category in memory.categories:
|
||||
if category:
|
||||
self.indices[IndexType.CATEGORY][category.lower()].add(memory.memory_id)
|
||||
|
||||
for subject in memory.subjects:
|
||||
if subject:
|
||||
self.indices[IndexType.SUBJECT][subject.strip().lower()].add(memory.memory_id)
|
||||
|
||||
def _index_single_memory(self, memory: MemoryChunk):
|
||||
"""为单个记忆建立索引"""
|
||||
memory_id = memory.memory_id
|
||||
@@ -157,7 +257,8 @@ class MetadataIndexManager:
|
||||
"importance": memory.metadata.importance,
|
||||
"relationship_score": memory.metadata.relationship_score,
|
||||
"relevance_score": memory.metadata.relevance_score,
|
||||
"semantic_hash": memory.semantic_hash
|
||||
"semantic_hash": memory.semantic_hash,
|
||||
"subjects": memory.subjects
|
||||
}
|
||||
|
||||
# 记忆类型索引
|
||||
@@ -166,6 +267,12 @@ class MetadataIndexManager:
|
||||
# 用户ID索引
|
||||
self.indices[IndexType.USER_ID][memory.user_id].add(memory_id)
|
||||
|
||||
# 主体索引
|
||||
for subject in memory.subjects:
|
||||
normalized = subject.strip().lower()
|
||||
if normalized:
|
||||
self.indices[IndexType.SUBJECT][normalized].add(memory_id)
|
||||
|
||||
# 关键词索引
|
||||
for keyword in memory.keywords:
|
||||
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id)
|
||||
@@ -282,13 +389,6 @@ class MetadataIndexManager:
|
||||
# 应用最严格的过滤条件
|
||||
applied_filters = []
|
||||
|
||||
if query.user_ids:
|
||||
user_ids_set = set()
|
||||
for user_id in query.user_ids:
|
||||
user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set()))
|
||||
candidate_ids.update(user_ids_set)
|
||||
applied_filters.append("user_ids")
|
||||
|
||||
if query.memory_types:
|
||||
memory_types_set = set()
|
||||
for memory_type in query.memory_types:
|
||||
@@ -302,7 +402,7 @@ class MetadataIndexManager:
|
||||
if query.keywords:
|
||||
keywords_set = set()
|
||||
for keyword in query.keywords:
|
||||
keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set()))
|
||||
keywords_set.update(self._collect_index_matches(IndexType.KEYWORD, keyword))
|
||||
if applied_filters:
|
||||
candidate_ids &= keywords_set
|
||||
else:
|
||||
@@ -329,12 +429,55 @@ class MetadataIndexManager:
|
||||
candidate_ids.update(categories_set)
|
||||
applied_filters.append("categories")
|
||||
|
||||
if query.subjects:
|
||||
subjects_set = set()
|
||||
for subject in query.subjects:
|
||||
subjects_set.update(self._collect_index_matches(IndexType.SUBJECT, subject))
|
||||
if applied_filters:
|
||||
candidate_ids &= subjects_set
|
||||
else:
|
||||
candidate_ids.update(subjects_set)
|
||||
applied_filters.append("subjects")
|
||||
|
||||
# 如果没有应用任何过滤条件,返回所有记忆
|
||||
if not applied_filters:
|
||||
return all_memory_ids
|
||||
|
||||
return candidate_ids
|
||||
|
||||
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
|
||||
"""根据给定token收集索引匹配,支持部分匹配"""
|
||||
mapping = self.indices.get(index_type)
|
||||
if mapping is None:
|
||||
return set()
|
||||
|
||||
key = ""
|
||||
if isinstance(token, Enum):
|
||||
key = str(token.value).strip().lower()
|
||||
elif isinstance(token, str):
|
||||
key = token.strip().lower()
|
||||
elif token is not None:
|
||||
key = str(token).strip().lower()
|
||||
|
||||
if not key:
|
||||
return set()
|
||||
|
||||
matches: Set[str] = set(mapping.get(key, set()))
|
||||
|
||||
if matches:
|
||||
return set(matches)
|
||||
|
||||
for existing_key, ids in mapping.items():
|
||||
if not existing_key or not isinstance(existing_key, str):
|
||||
continue
|
||||
normalized = existing_key.strip().lower()
|
||||
if not normalized:
|
||||
continue
|
||||
if key in normalized or normalized in key:
|
||||
matches.update(ids)
|
||||
|
||||
return matches
|
||||
|
||||
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
|
||||
"""应用过滤条件"""
|
||||
filtered_ids = list(candidate_ids)
|
||||
@@ -440,10 +583,10 @@ class MetadataIndexManager:
|
||||
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
|
||||
"""获取应用的过滤器列表"""
|
||||
filters = []
|
||||
if query.user_ids:
|
||||
filters.append("user_ids")
|
||||
if query.memory_types:
|
||||
filters.append("memory_types")
|
||||
if query.subjects:
|
||||
filters.append("subjects")
|
||||
if query.keywords:
|
||||
filters.append("keywords")
|
||||
if query.tags:
|
||||
@@ -502,6 +645,18 @@ class MetadataIndexManager:
|
||||
# 从各类索引中移除
|
||||
self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id)
|
||||
self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id)
|
||||
subjects = metadata.get("subjects") or []
|
||||
for subject in subjects:
|
||||
if not isinstance(subject, str):
|
||||
continue
|
||||
normalized = subject.strip().lower()
|
||||
if not normalized:
|
||||
continue
|
||||
subject_bucket = self.indices[IndexType.SUBJECT].get(normalized)
|
||||
if subject_bucket is not None:
|
||||
subject_bucket.discard(memory_id)
|
||||
if not subject_bucket:
|
||||
self.indices[IndexType.SUBJECT].pop(normalized, None)
|
||||
|
||||
# 从时间索引中移除
|
||||
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id]
|
||||
@@ -625,11 +780,13 @@ class MetadataIndexManager:
|
||||
logger.info("正在保存元数据索引...")
|
||||
|
||||
# 保存各类索引
|
||||
indices_data = {}
|
||||
indices_data: Dict[str, Dict[str, List[str]]] = {}
|
||||
for index_type, index_data in self.indices.items():
|
||||
indices_data[index_type.value] = {
|
||||
key: list(values) for key, values in index_data.items()
|
||||
}
|
||||
serialized_index = {}
|
||||
for key, values in index_data.items():
|
||||
serialized_key = self._serialize_index_key(index_type, key)
|
||||
serialized_index[serialized_key] = list(values)
|
||||
indices_data[index_type.value] = serialized_index
|
||||
|
||||
indices_file = self.index_path / "indices.json"
|
||||
with open(indices_file, 'w', encoding='utf-8') as f:
|
||||
@@ -652,8 +809,12 @@ class MetadataIndexManager:
|
||||
|
||||
# 保存元数据缓存
|
||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||
metadata_serialized = {
|
||||
memory_id: self._serialize_metadata_entry(metadata)
|
||||
for memory_id, metadata in self.memory_metadata_cache.items()
|
||||
}
|
||||
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
|
||||
# 保存统计信息
|
||||
stats_file = self.index_path / "index_stats.json"
|
||||
@@ -679,9 +840,11 @@ class MetadataIndexManager:
|
||||
|
||||
for index_type_value, index_data in indices_data.items():
|
||||
index_type = IndexType(index_type_value)
|
||||
self.indices[index_type] = {
|
||||
key: set(values) for key, values in index_data.items()
|
||||
}
|
||||
restored_index = defaultdict(set)
|
||||
for key_str, values in index_data.items():
|
||||
restored_key = self._deserialize_index_key(index_type, key_str)
|
||||
restored_index[restored_key] = set(values)
|
||||
self.indices[index_type] = restored_index
|
||||
|
||||
# 加载时间索引
|
||||
time_index_file = self.index_path / "time_index.json"
|
||||
@@ -709,10 +872,38 @@ class MetadataIndexManager:
|
||||
|
||||
# 转换置信度和重要性为枚举类型
|
||||
for memory_id, metadata in cache_data.items():
|
||||
if isinstance(metadata["confidence"], str):
|
||||
metadata["confidence"] = ConfidenceLevel(metadata["confidence"])
|
||||
if isinstance(metadata["importance"], str):
|
||||
metadata["importance"] = ImportanceLevel(metadata["importance"])
|
||||
memory_type_value = metadata.get("memory_type")
|
||||
if isinstance(memory_type_value, str):
|
||||
try:
|
||||
metadata["memory_type"] = MemoryType(memory_type_value)
|
||||
except ValueError:
|
||||
logger.warning("无法解析memory_type %s", memory_type_value)
|
||||
|
||||
confidence_value = metadata.get("confidence")
|
||||
if isinstance(confidence_value, (str, int)):
|
||||
try:
|
||||
metadata["confidence"] = ConfidenceLevel(int(confidence_value))
|
||||
except ValueError:
|
||||
logger.warning("无法解析confidence %s", confidence_value)
|
||||
|
||||
importance_value = metadata.get("importance")
|
||||
if isinstance(importance_value, (str, int)):
|
||||
try:
|
||||
metadata["importance"] = ImportanceLevel(int(importance_value))
|
||||
except ValueError:
|
||||
logger.warning("无法解析importance %s", importance_value)
|
||||
|
||||
subjects_value = metadata.get("subjects")
|
||||
if isinstance(subjects_value, str):
|
||||
metadata["subjects"] = [subjects_value]
|
||||
elif isinstance(subjects_value, list):
|
||||
cleaned_subjects = []
|
||||
for item in subjects_value:
|
||||
if isinstance(item, str) and item.strip():
|
||||
cleaned_subjects.append(item.strip())
|
||||
metadata["subjects"] = cleaned_subjects
|
||||
else:
|
||||
metadata["subjects"] = []
|
||||
|
||||
self.memory_metadata_cache = cache_data
|
||||
|
||||
|
||||
@@ -203,11 +203,17 @@ class MultiStageRetrieval:
|
||||
try:
|
||||
from .metadata_index import IndexQuery
|
||||
|
||||
# 构建索引查询
|
||||
query_plan = context.get("query_plan")
|
||||
|
||||
memory_types = self._extract_memory_types_from_context(context)
|
||||
keywords = self._extract_keywords_from_query(query, query_plan)
|
||||
subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
|
||||
|
||||
index_query = IndexQuery(
|
||||
user_ids=[user_id],
|
||||
memory_types=self._extract_memory_types_from_context(context),
|
||||
keywords=self._extract_keywords_from_query(query),
|
||||
user_ids=None,
|
||||
memory_types=memory_types,
|
||||
subjects=subjects,
|
||||
keywords=keywords,
|
||||
limit=self.config.metadata_filter_limit,
|
||||
sort_by="last_accessed",
|
||||
sort_order="desc"
|
||||
@@ -215,13 +221,66 @@ class MultiStageRetrieval:
|
||||
|
||||
# 执行查询
|
||||
result = await metadata_index.query_memories(index_query)
|
||||
filtered_count = result.total_count - len(result.memory_ids)
|
||||
result_ids = list(result.memory_ids)
|
||||
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
|
||||
|
||||
logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆")
|
||||
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
|
||||
if not result_ids:
|
||||
sorted_ids = sorted(
|
||||
(memory.memory_id for memory in all_memories_cache.values()),
|
||||
key=lambda mid: all_memories_cache[mid].metadata.last_accessed if mid in all_memories_cache else 0,
|
||||
reverse=True,
|
||||
)
|
||||
if memory_types:
|
||||
type_filtered = [
|
||||
mid for mid in sorted_ids
|
||||
if all_memories_cache[mid].memory_type in memory_types
|
||||
]
|
||||
sorted_ids = type_filtered or sorted_ids
|
||||
if subjects:
|
||||
subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()]
|
||||
if subject_candidates:
|
||||
subject_filtered = [
|
||||
mid for mid in sorted_ids
|
||||
if any(
|
||||
subj.strip().lower() in subject_candidates
|
||||
for subj in all_memories_cache[mid].subjects
|
||||
)
|
||||
]
|
||||
sorted_ids = subject_filtered or sorted_ids
|
||||
|
||||
if keywords:
|
||||
keyword_pool = {kw.lower() for kw in keywords if isinstance(kw, str) and kw.strip()}
|
||||
if keyword_pool:
|
||||
keyword_filtered = []
|
||||
for mid in sorted_ids:
|
||||
memory_text = (
|
||||
(all_memories_cache[mid].display or "")
|
||||
+ "\n"
|
||||
+ (all_memories_cache[mid].text_content or "")
|
||||
).lower()
|
||||
if any(kw in memory_text for kw in keyword_pool):
|
||||
keyword_filtered.append(mid)
|
||||
sorted_ids = keyword_filtered or sorted_ids
|
||||
|
||||
result_ids = sorted_ids[: self.config.metadata_filter_limit]
|
||||
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
|
||||
logger.debug(
|
||||
"元数据过滤未命中索引,使用近似回退: types=%s, subjects=%s, keywords=%s",
|
||||
bool(memory_types),
|
||||
bool(subjects),
|
||||
bool(keywords),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"元数据过滤:候选=%d, 返回=%d",
|
||||
len(all_memories_cache),
|
||||
len(result_ids),
|
||||
)
|
||||
|
||||
return StageResult(
|
||||
stage=RetrievalStage.METADATA_FILTERING,
|
||||
memory_ids=result.memory_ids,
|
||||
memory_ids=result_ids,
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=filtered_count,
|
||||
score_threshold=0.0
|
||||
@@ -251,7 +310,7 @@ class MultiStageRetrieval:
|
||||
|
||||
try:
|
||||
# 生成查询向量
|
||||
query_embedding = await self._generate_query_embedding(query, context)
|
||||
query_embedding = await self._generate_query_embedding(query, context, vector_storage)
|
||||
|
||||
if not query_embedding:
|
||||
return StageResult(
|
||||
@@ -263,22 +322,24 @@ class MultiStageRetrieval:
|
||||
)
|
||||
|
||||
# 执行向量搜索
|
||||
search_result = await vector_storage.search_similar(
|
||||
query_embedding,
|
||||
search_result = await vector_storage.search_similar_memories(
|
||||
query_vector=query_embedding,
|
||||
limit=self.config.vector_search_limit
|
||||
)
|
||||
|
||||
candidate_pool = candidate_ids or set(all_memories_cache.keys())
|
||||
|
||||
# 过滤候选记忆
|
||||
filtered_memories = []
|
||||
for memory_id, similarity in search_result:
|
||||
if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold:
|
||||
if memory_id in candidate_pool and similarity >= self.config.vector_similarity_threshold:
|
||||
filtered_memories.append((memory_id, similarity))
|
||||
|
||||
# 按相似度排序
|
||||
filtered_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
|
||||
|
||||
filtered_count = len(candidate_ids) - len(result_ids)
|
||||
filtered_count = max(0, len(candidate_pool) - len(result_ids))
|
||||
|
||||
logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
|
||||
|
||||
@@ -407,12 +468,20 @@ class MultiStageRetrieval:
|
||||
score_threshold=0.0
|
||||
)
|
||||
|
||||
async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]:
|
||||
async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]:
|
||||
"""生成查询向量"""
|
||||
try:
|
||||
# 这里应该调用embedding模型
|
||||
# 由于我们可能没有直接的embedding模型,返回None或使用简单的方法
|
||||
# 在实际实现中,这里应该调用与记忆存储相同的embedding模型
|
||||
query_plan = context.get("query_plan")
|
||||
query_text = query
|
||||
if query_plan and getattr(query_plan, "semantic_query", None):
|
||||
query_text = query_plan.semantic_query
|
||||
|
||||
if not query_text:
|
||||
return None
|
||||
|
||||
if hasattr(vector_storage, "generate_query_embedding"):
|
||||
return await vector_storage.generate_query_embedding(query_text)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"生成查询向量失败: {e}")
|
||||
@@ -421,9 +490,15 @@ class MultiStageRetrieval:
|
||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
"""计算语义相似度"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
query_text = query
|
||||
if query_plan and getattr(query_plan, "semantic_query", None):
|
||||
query_text = query_plan.semantic_query
|
||||
|
||||
# 简单的文本相似度计算
|
||||
query_words = set(query.lower().split())
|
||||
memory_words = set(memory.text_content.lower().split())
|
||||
query_words = set(query_text.lower().split())
|
||||
memory_text = (memory.display or memory.text_content or "").lower()
|
||||
memory_words = set(memory_text.split())
|
||||
|
||||
if not query_words or not memory_words:
|
||||
return 0.0
|
||||
@@ -443,10 +518,15 @@ class MultiStageRetrieval:
|
||||
try:
|
||||
score = 0.0
|
||||
|
||||
query_plan = context.get("query_plan")
|
||||
|
||||
# 检查记忆类型是否匹配上下文
|
||||
if context.get("expected_memory_types"):
|
||||
if memory.memory_type in context["expected_memory_types"]:
|
||||
score += 0.3
|
||||
elif query_plan and getattr(query_plan, "memory_types", None):
|
||||
if memory.memory_type in query_plan.memory_types:
|
||||
score += 0.3
|
||||
|
||||
# 检查关键词匹配
|
||||
if context.get("keywords"):
|
||||
@@ -456,6 +536,35 @@ class MultiStageRetrieval:
|
||||
if overlap:
|
||||
score += len(overlap) / max(len(context_keywords), 1) * 0.4
|
||||
|
||||
if query_plan:
|
||||
# 主体匹配
|
||||
subject_score = self._calculate_subject_overlap(memory, getattr(query_plan, "subject_includes", []))
|
||||
score += subject_score * 0.3
|
||||
|
||||
# 对象/描述匹配
|
||||
object_keywords = getattr(query_plan, "object_includes", []) or []
|
||||
if object_keywords:
|
||||
display_text = (memory.display or memory.text_content or "").lower()
|
||||
hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
|
||||
if hits:
|
||||
score += min(0.3, hits * 0.1)
|
||||
|
||||
optional_keywords = getattr(query_plan, "optional_keywords", []) or []
|
||||
if optional_keywords:
|
||||
display_text = (memory.display or memory.text_content or "").lower()
|
||||
hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
|
||||
if hits:
|
||||
score += min(0.2, hits * 0.05)
|
||||
|
||||
# 时间偏好
|
||||
recency_pref = getattr(query_plan, "recency_preference", "")
|
||||
if recency_pref:
|
||||
memory_age = time.time() - memory.metadata.created_at
|
||||
if recency_pref == "recent" and memory_age < 7 * 24 * 3600:
|
||||
score += 0.2
|
||||
elif recency_pref == "historical" and memory_age > 30 * 24 * 3600:
|
||||
score += 0.1
|
||||
|
||||
# 检查时效性
|
||||
if context.get("recent_only", False):
|
||||
memory_age = time.time() - memory.metadata.created_at
|
||||
@@ -471,6 +580,8 @@ class MultiStageRetrieval:
|
||||
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
|
||||
"""计算最终评分"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
|
||||
# 语义相似度
|
||||
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
|
||||
|
||||
@@ -482,13 +593,29 @@ class MultiStageRetrieval:
|
||||
|
||||
# 时效性评分
|
||||
recency_score = self._calculate_recency_score(memory.metadata.created_at)
|
||||
if query_plan:
|
||||
recency_pref = getattr(query_plan, "recency_preference", "")
|
||||
if recency_pref == "recent":
|
||||
recency_score = max(recency_score, 0.8)
|
||||
elif recency_pref == "historical":
|
||||
recency_score = min(recency_score, 0.5)
|
||||
|
||||
# 权重组合
|
||||
vector_weight = self.config.vector_weight
|
||||
semantic_weight = self.config.semantic_weight
|
||||
context_weight = self.config.context_weight
|
||||
recency_weight = self.config.recency_weight
|
||||
|
||||
if query_plan and getattr(query_plan, "emphasis", None) == "precision":
|
||||
semantic_weight += 0.05
|
||||
elif query_plan and getattr(query_plan, "emphasis", None) == "recall":
|
||||
context_weight += 0.05
|
||||
|
||||
final_score = (
|
||||
semantic_score * self.config.semantic_weight +
|
||||
vector_score * self.config.vector_weight +
|
||||
context_score * self.config.context_weight +
|
||||
recency_score * self.config.recency_weight
|
||||
semantic_score * semantic_weight +
|
||||
vector_score * vector_weight +
|
||||
context_score * context_weight +
|
||||
recency_score * recency_weight
|
||||
)
|
||||
|
||||
# 加入记忆重要性权重
|
||||
@@ -501,6 +628,31 @@ class MultiStageRetrieval:
|
||||
logger.warning(f"计算最终评分失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
|
||||
if not required_subjects:
|
||||
return 0.0
|
||||
|
||||
memory_subjects = {subject.lower() for subject in memory.subjects if isinstance(subject, str)}
|
||||
if not memory_subjects:
|
||||
return 0.0
|
||||
|
||||
hit = 0
|
||||
total = 0
|
||||
for subject in required_subjects:
|
||||
if not isinstance(subject, str):
|
||||
continue
|
||||
total += 1
|
||||
normalized = subject.strip().lower()
|
||||
if not normalized:
|
||||
continue
|
||||
if any(normalized in mem_subject for mem_subject in memory_subjects):
|
||||
hit += 1
|
||||
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
return hit / total
|
||||
|
||||
def _calculate_recency_score(self, timestamp: float) -> float:
|
||||
"""计算时效性评分"""
|
||||
try:
|
||||
@@ -524,6 +676,10 @@ class MultiStageRetrieval:
|
||||
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
|
||||
"""从上下文中提取记忆类型"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
if query_plan and getattr(query_plan, "memory_types", None):
|
||||
return query_plan.memory_types
|
||||
|
||||
if "expected_memory_types" in context:
|
||||
return context["expected_memory_types"]
|
||||
|
||||
@@ -544,15 +700,30 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _extract_keywords_from_query(self, query: str) -> List[str]:
|
||||
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
|
||||
"""从查询中提取关键词"""
|
||||
try:
|
||||
extracted: List[str] = []
|
||||
|
||||
if query_plan and getattr(query_plan, "required_keywords", None):
|
||||
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
|
||||
|
||||
# 简单的关键词提取
|
||||
words = query.lower().split()
|
||||
# 过滤停用词
|
||||
stopwords = {"的", "是", "在", "有", "我", "你", "他", "她", "它", "这", "那", "了", "吗", "呢"}
|
||||
keywords = [word for word in words if len(word) > 1 and word not in stopwords]
|
||||
return keywords[:10] # 最多返回10个关键词
|
||||
extracted.extend(word for word in words if len(word) > 1 and word not in stopwords)
|
||||
|
||||
# 去重并保留顺序
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for word in extracted:
|
||||
if word in seen or not word:
|
||||
continue
|
||||
seen.add(word)
|
||||
deduplicated.append(word)
|
||||
|
||||
return deduplicated[:10]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -36,12 +37,12 @@ except ImportError:
|
||||
@dataclass
|
||||
class VectorStorageConfig:
|
||||
"""向量存储配置"""
|
||||
dimension: int = 768
|
||||
dimension: int = 1024
|
||||
similarity_threshold: float = 0.8
|
||||
index_type: str = "flat" # flat, ivf, hnsw
|
||||
max_index_size: int = 100000
|
||||
storage_path: str = "data/memory_vectors"
|
||||
auto_save_interval: int = 100 # 每N次操作自动保存
|
||||
auto_save_interval: int = 10 # 每N次操作自动保存
|
||||
enable_compression: bool = True
|
||||
|
||||
|
||||
@@ -50,6 +51,15 @@ class VectorStorageManager:
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
|
||||
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
|
||||
if resolved_dimension and resolved_dimension != self.config.dimension:
|
||||
logger.info(
|
||||
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
|
||||
resolved_dimension,
|
||||
self.config.dimension,
|
||||
)
|
||||
self.config.dimension = resolved_dimension
|
||||
self.storage_path = Path(self.config.storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -117,6 +127,32 @@ class VectorStorageManager:
|
||||
)
|
||||
logger.info("✅ 嵌入模型初始化完成")
|
||||
|
||||
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
|
||||
"""生成查询向量,用于记忆召回"""
|
||||
if not query_text:
|
||||
return None
|
||||
|
||||
try:
|
||||
await self.initialize_embedding_model()
|
||||
|
||||
embedding, _ = await self.embedding_model.get_embedding(query_text)
|
||||
if not embedding:
|
||||
return None
|
||||
|
||||
if len(embedding) != self.config.dimension:
|
||||
logger.warning(
|
||||
"查询向量维度不匹配: 期望 %d, 实际 %d",
|
||||
self.config.dimension,
|
||||
len(embedding)
|
||||
)
|
||||
return None
|
||||
|
||||
return self._normalize_vector(embedding)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]):
|
||||
"""存储记忆向量"""
|
||||
if not memories:
|
||||
@@ -213,7 +249,7 @@ class VectorStorageManager:
|
||||
results[memory_id] = embedding
|
||||
else:
|
||||
logger.warning(
|
||||
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)",
|
||||
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
|
||||
self.config.dimension,
|
||||
len(embedding) if embedding else 0,
|
||||
memory_id,
|
||||
@@ -299,14 +335,32 @@ class VectorStorageManager:
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
query_vector: Optional[List[float]] = None,
|
||||
*,
|
||||
query_text: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
user_id: Optional[str] = None
|
||||
scope_id: Optional[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if query_vector is None:
|
||||
if not query_text:
|
||||
return []
|
||||
|
||||
query_vector = await self.generate_query_embedding(query_text)
|
||||
if not query_vector:
|
||||
return []
|
||||
|
||||
scope_filter: Optional[str] = None
|
||||
if isinstance(scope_id, str):
|
||||
normalized_scope = scope_id.strip().lower()
|
||||
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
|
||||
scope_filter = scope_id
|
||||
elif scope_id:
|
||||
scope_filter = str(scope_id)
|
||||
|
||||
# 规范化查询向量
|
||||
query_vector = self._normalize_vector(query_vector)
|
||||
|
||||
@@ -341,10 +395,9 @@ class VectorStorageManager:
|
||||
|
||||
memory_id = self.index_to_memory_id.get(index)
|
||||
if memory_id:
|
||||
# 应用用户过滤
|
||||
if user_id:
|
||||
if scope_filter:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory and memory.user_id != user_id:
|
||||
if memory and str(memory.user_id) != scope_filter:
|
||||
continue
|
||||
|
||||
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
|
||||
@@ -481,8 +534,14 @@ class VectorStorageManager:
|
||||
# 保存映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
mapping_data = {
|
||||
"memory_id_to_index": self.memory_id_to_index,
|
||||
"index_to_memory_id": self.index_to_memory_id
|
||||
"memory_id_to_index": {
|
||||
str(memory_id): int(index)
|
||||
for memory_id, index in self.memory_id_to_index.items()
|
||||
},
|
||||
"index_to_memory_id": {
|
||||
str(index): memory_id
|
||||
for index, memory_id in self.index_to_memory_id.items()
|
||||
}
|
||||
}
|
||||
with open(mapping_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
@@ -529,8 +588,17 @@ class VectorStorageManager:
|
||||
if mapping_file.exists():
|
||||
with open(mapping_file, 'r', encoding='utf-8') as f:
|
||||
mapping_data = orjson.loads(f.read())
|
||||
self.memory_id_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.index_to_memory_id = mapping_data.get("index_to_memory_id", {})
|
||||
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.memory_id_to_index = {
|
||||
str(memory_id): int(index)
|
||||
for memory_id, index in raw_memory_to_index.items()
|
||||
}
|
||||
|
||||
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
|
||||
self.index_to_memory_id = {
|
||||
int(index): memory_id
|
||||
for index, memory_id in raw_index_to_memory.items()
|
||||
}
|
||||
|
||||
# 加载FAISS索引(如果可用)
|
||||
if FAISS_AVAILABLE:
|
||||
|
||||
@@ -469,14 +469,14 @@ class ChatBot:
|
||||
async def preprocess():
|
||||
# 存储消息到数据库
|
||||
from .storage import MessageStorage
|
||||
|
||||
|
||||
try:
|
||||
await MessageStorage.store_message(message, message.chat_stream)
|
||||
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 使用消息管理器处理消息(保持原有功能)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
|
||||
@@ -373,12 +373,12 @@ class Prompt:
|
||||
|
||||
# 性能优化 - 为不同任务设置不同的超时时间
|
||||
task_timeouts = {
|
||||
"memory_block": 5.0, # 记忆系统可能较慢,单独设置超时
|
||||
"tool_info": 3.0, # 工具信息中等速度
|
||||
"relation_info": 2.0, # 关系信息通常较快
|
||||
"knowledge_info": 3.0, # 知识库查询中等速度
|
||||
"cross_context": 2.0, # 上下文处理通常较快
|
||||
"expression_habits": 1.5, # 表达习惯处理很快
|
||||
"memory_block": 15.0, # 记忆系统
|
||||
"tool_info": 15.0, # 工具信息
|
||||
"relation_info": 10.0, # 关系信息
|
||||
"knowledge_info": 10.0, # 知识库查询
|
||||
"cross_context": 10.0, # 上下文处理
|
||||
"expression_habits": 10.0, # 表达习惯
|
||||
}
|
||||
|
||||
# 分别处理每个任务,避免慢任务影响快任务
|
||||
@@ -558,12 +558,8 @@ class Prompt:
|
||||
)
|
||||
]
|
||||
|
||||
# 等待所有记忆查询完成(最多3秒)
|
||||
try:
|
||||
running_memories, instant_memory = await asyncio.wait_for(
|
||||
asyncio.gather(*memory_tasks, return_exceptions=True),
|
||||
timeout=3.0
|
||||
)
|
||||
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
|
||||
|
||||
# 处理可能的异常结果
|
||||
if isinstance(running_memories, Exception):
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.database.sqlalchemy_models import CacheEntries
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save
|
||||
from src.common.vector_db import vector_db_service
|
||||
@@ -40,7 +41,11 @@ class CacheManager:
|
||||
|
||||
# L1 缓存 (内存)
|
||||
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
||||
if not embedding_dim:
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
|
||||
self.embedding_dimension = embedding_dim
|
||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||
|
||||
@@ -72,7 +77,7 @@ class CacheManager:
|
||||
embedding_array = embedding_array.flatten()
|
||||
|
||||
# 检查维度是否符合预期
|
||||
expected_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
expected_dim = getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension
|
||||
if embedding_array.shape[0] != expected_dim:
|
||||
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
|
||||
return None
|
||||
|
||||
42
src/common/config_helpers.py
Normal file
42
src/common/config_helpers.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
|
||||
def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]:
|
||||
"""获取当前配置的嵌入向量维度。
|
||||
|
||||
优先顺序:
|
||||
1. 模型配置中 `model_task_config.embedding.embedding_dimension`
|
||||
2. 机器人配置中 `lpmm_knowledge.embedding_dimension`
|
||||
3. 调用方提供的 fallback
|
||||
"""
|
||||
|
||||
candidates: list[Optional[int]] = []
|
||||
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
|
||||
try:
|
||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
|
||||
candidates.append(fallback)
|
||||
|
||||
resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||
|
||||
if resolved and sync_global:
|
||||
try:
|
||||
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
|
||||
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return resolved
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Dict, Any, Literal, Union
|
||||
from typing import List, Dict, Any, Literal, Union, Optional
|
||||
from pydantic import Field
|
||||
from threading import Lock
|
||||
|
||||
@@ -105,6 +105,11 @@ class TaskConfig(ValidatedConfigBase):
|
||||
max_tokens: int = Field(default=800, description="任务最大输出token数")
|
||||
temperature: float = Field(default=0.7, description="模型温度")
|
||||
concurrency_count: int = Field(default=1, description="并发请求数量")
|
||||
embedding_dimension: Optional[int] = Field(
|
||||
default=None,
|
||||
description="嵌入模型输出向量维度,仅在嵌入任务中使用",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_model_list(cls, v):
|
||||
|
||||
@@ -443,21 +443,6 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
|
||||
enable_memory: bool = Field(default=True, description="启用记忆")
|
||||
memory_build_interval: int = Field(default=600, description="记忆构建间隔")
|
||||
memory_build_distribution: list[float] = Field(
|
||||
default_factory=lambda: [6.0, 3.0, 0.6, 32.0, 12.0, 0.4], description="记忆构建分布"
|
||||
)
|
||||
memory_build_sample_num: int = Field(default=8, description="记忆构建样本数量")
|
||||
memory_build_sample_length: int = Field(default=40, description="记忆构建样本长度")
|
||||
memory_compress_rate: float = Field(default=0.1, description="记忆压缩率")
|
||||
forget_memory_interval: int = Field(default=1000, description="遗忘记忆间隔")
|
||||
memory_forget_time: int = Field(default=24, description="记忆遗忘时间")
|
||||
memory_forget_percentage: float = Field(default=0.01, description="记忆遗忘百分比")
|
||||
consolidate_memory_interval: int = Field(default=1000, description="记忆巩固间隔")
|
||||
consolidation_similarity_threshold: float = Field(default=0.7, description="巩固相似性阈值")
|
||||
consolidate_memory_percentage: float = Field(default=0.01, description="巩固记忆百分比")
|
||||
memory_ban_words: list[str] = Field(
|
||||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"], description="记忆禁用词"
|
||||
)
|
||||
enable_instant_memory: bool = Field(default=True, description="启用即时记忆")
|
||||
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
|
||||
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
|
||||
@@ -472,8 +457,8 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值")
|
||||
|
||||
# 向量存储配置
|
||||
vector_dimension: int = Field(default=768, description="向量维度")
|
||||
vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值")
|
||||
semantic_similarity_threshold: float = Field(default=0.6, description="语义相似度阈值")
|
||||
|
||||
# 多阶段检索配置
|
||||
metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量")
|
||||
|
||||
15
src/main.py
15
src/main.py
@@ -27,9 +27,8 @@ from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
# 导入新的插件管理器和热重载管理器
|
||||
# 导入新的插件管理器
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
@@ -116,13 +115,7 @@ class MainSystem:
|
||||
except Exception as e:
|
||||
logger.error(f"停止消息重组器时出错: {e}")
|
||||
|
||||
try:
|
||||
# 停止插件热重载系统
|
||||
hot_reload_manager.stop()
|
||||
logger.info("🛑 插件热重载系统已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止热重载系统时出错: {e}")
|
||||
|
||||
|
||||
try:
|
||||
# 停止增强记忆系统
|
||||
if global_config.memory.enable_memory:
|
||||
@@ -228,9 +221,7 @@ MoFox_Bot(第三方修改版)
|
||||
# 处理所有缓存的事件订阅(插件加载完成后)
|
||||
event_manager.process_all_pending_subscriptions()
|
||||
|
||||
# 启动插件热重载系统
|
||||
hot_reload_manager.start()
|
||||
|
||||
|
||||
# 初始化表情管理器
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
@@ -8,12 +8,10 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"event_manager",
|
||||
"global_announcement_manager",
|
||||
"hot_reload_manager",
|
||||
]
|
||||
|
||||
@@ -1,512 +0,0 @@
|
||||
"""
|
||||
插件热重载模块
|
||||
|
||||
使用 Watchdog 监听插件目录变化,自动重载插件
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Dict, Set, List, Optional, Tuple
|
||||
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .plugin_manager import plugin_manager
|
||||
|
||||
logger = get_logger("plugin_hot_reload")
|
||||
|
||||
|
||||
class PluginFileHandler(FileSystemEventHandler):
|
||||
"""插件文件变化处理器"""
|
||||
|
||||
def __init__(self, hot_reload_manager):
|
||||
super().__init__()
|
||||
self.hot_reload_manager = hot_reload_manager
|
||||
self.pending_reloads: Set[str] = set() # 待重载的插件名称
|
||||
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
|
||||
self.debounce_delay = 2.0 # 增加防抖延迟到2秒,确保文件写入完成
|
||||
self.file_change_cache: Dict[str, float] = {} # 文件变化缓存
|
||||
|
||||
def on_modified(self, event):
|
||||
"""文件修改事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "modified")
|
||||
|
||||
def on_created(self, event):
|
||||
"""文件创建事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "created")
|
||||
|
||||
def on_deleted(self, event):
|
||||
"""文件删除事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "deleted")
|
||||
|
||||
def _handle_file_change(self, file_path: str, change_type: str):
|
||||
"""处理文件变化"""
|
||||
try:
|
||||
# 获取插件名称
|
||||
plugin_info = self._get_plugin_info_from_path(file_path)
|
||||
if not plugin_info:
|
||||
return
|
||||
|
||||
plugin_name, source_type = plugin_info
|
||||
current_time = time.time()
|
||||
|
||||
# 文件变化缓存,避免重复处理同一文件的快速连续变化
|
||||
file_cache_key = f"{file_path}_{change_type}"
|
||||
last_file_time = self.file_change_cache.get(file_cache_key, 0)
|
||||
if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略
|
||||
return
|
||||
self.file_change_cache[file_cache_key] = current_time
|
||||
|
||||
# 插件级别的防抖处理
|
||||
last_plugin_time = self.last_reload_time.get(plugin_name, 0)
|
||||
if current_time - last_plugin_time < self.debounce_delay:
|
||||
# 如果在防抖期内,更新待重载标记但不立即处理
|
||||
self.pending_reloads.add(plugin_name)
|
||||
return
|
||||
|
||||
file_name = Path(file_path).name
|
||||
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type}) [{source_type}] -> {plugin_name}")
|
||||
|
||||
# 如果是删除事件,处理关键文件删除
|
||||
if change_type == "deleted":
|
||||
# 解析实际的插件名称
|
||||
actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name)
|
||||
|
||||
if file_name == "plugin.py":
|
||||
if actual_plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(
|
||||
f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
self.hot_reload_manager._unload_plugin(actual_plugin_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
return
|
||||
elif file_name in ("manifest.toml", "_manifest.json"):
|
||||
if actual_plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(
|
||||
f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
self.hot_reload_manager._unload_plugin(actual_plugin_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
return
|
||||
|
||||
# 对于修改和创建事件,都进行重载
|
||||
# 添加到待重载列表
|
||||
self.pending_reloads.add(plugin_name)
|
||||
self.last_reload_time[plugin_name] = current_time
|
||||
|
||||
# 延迟重载,确保文件写入完成
|
||||
reload_thread = Thread(
|
||||
target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True
|
||||
)
|
||||
reload_thread.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理文件变化时发生错误: {e}", exc_info=True)
|
||||
|
||||
def _delayed_reload(self, plugin_name: str, source_type: str, trigger_time: float):
|
||||
"""延迟重载插件"""
|
||||
try:
|
||||
# 等待文件写入完成
|
||||
time.sleep(self.debounce_delay)
|
||||
|
||||
# 检查是否还需要重载(可能在等待期间有更新的变化)
|
||||
if plugin_name not in self.pending_reloads:
|
||||
return
|
||||
|
||||
# 检查是否有更新的重载请求
|
||||
if self.last_reload_time.get(plugin_name, 0) > trigger_time:
|
||||
return
|
||||
|
||||
self.pending_reloads.discard(plugin_name)
|
||||
logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]")
|
||||
|
||||
# 执行深度重载
|
||||
success = self.hot_reload_manager._deep_reload_plugin(plugin_name)
|
||||
if success:
|
||||
logger.info(f"✅ 插件重载成功: {plugin_name} [{source_type}]")
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} [{source_type}]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
|
||||
def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]:
|
||||
"""从文件路径获取插件信息
|
||||
|
||||
Returns:
|
||||
tuple[插件名称, 源类型] 或 None
|
||||
"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查是否在任何一个监听的插件目录中
|
||||
for watch_dir in self.hot_reload_manager.watch_directories:
|
||||
plugin_root = Path(watch_dir)
|
||||
if path.is_relative_to(plugin_root):
|
||||
# 确定源类型
|
||||
if "src" in str(plugin_root):
|
||||
source_type = "built-in"
|
||||
else:
|
||||
source_type = "external"
|
||||
|
||||
# 获取插件目录名(插件名)
|
||||
relative_path = path.relative_to(plugin_root)
|
||||
if len(relative_path.parts) == 0:
|
||||
continue
|
||||
|
||||
plugin_name = relative_path.parts[0]
|
||||
|
||||
# 确认这是一个有效的插件目录
|
||||
plugin_dir = plugin_root / plugin_name
|
||||
if plugin_dir.is_dir():
|
||||
# 检查是否有插件主文件或配置文件
|
||||
has_plugin_py = (plugin_dir / "plugin.py").exists()
|
||||
has_manifest = (plugin_dir / "manifest.toml").exists() or (
|
||||
plugin_dir / "_manifest.json"
|
||||
).exists()
|
||||
|
||||
if has_plugin_py or has_manifest:
|
||||
return plugin_name, source_type
|
||||
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class PluginHotReloadManager:
|
||||
"""插件热重载管理器"""
|
||||
|
||||
def __init__(self, watch_directories: Optional[List[str]] = None):
|
||||
if watch_directories is None:
|
||||
# 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录
|
||||
self.watch_directories = [
|
||||
os.path.join(os.getcwd(), "plugins"), # 外部插件目录
|
||||
os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录
|
||||
]
|
||||
else:
|
||||
self.watch_directories = watch_directories
|
||||
|
||||
self.observers = []
|
||||
self.file_handlers = []
|
||||
self.is_running = False
|
||||
|
||||
# 确保监听目录存在
|
||||
for watch_dir in self.watch_directories:
|
||||
if not os.path.exists(watch_dir):
|
||||
os.makedirs(watch_dir, exist_ok=True)
|
||||
logger.info(f"📁 创建插件监听目录: {watch_dir}")
|
||||
|
||||
def start(self):
|
||||
"""启动热重载监听"""
|
||||
if self.is_running:
|
||||
logger.warning("插件热重载已经在运行中")
|
||||
return
|
||||
|
||||
try:
|
||||
# 为每个监听目录创建独立的观察者
|
||||
for watch_dir in self.watch_directories:
|
||||
observer = Observer()
|
||||
file_handler = PluginFileHandler(self)
|
||||
|
||||
observer.schedule(file_handler, watch_dir, recursive=True)
|
||||
|
||||
observer.start()
|
||||
self.observers.append(observer)
|
||||
self.file_handlers.append(file_handler)
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# 打印监听的目录信息
|
||||
dir_info = []
|
||||
for watch_dir in self.watch_directories:
|
||||
if "src" in watch_dir:
|
||||
dir_info.append(f"{watch_dir} (内置插件)")
|
||||
else:
|
||||
dir_info.append(f"{watch_dir} (外部插件)")
|
||||
|
||||
logger.info("🚀 插件热重载已启动,监听目录:")
|
||||
for info in dir_info:
|
||||
logger.info(f" 📂 {info}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动插件热重载失败: {e}")
|
||||
self.stop() # 清理已创建的观察者
|
||||
self.is_running = False
|
||||
|
||||
def stop(self):
|
||||
"""停止热重载监听"""
|
||||
if not self.is_running and not self.observers:
|
||||
return
|
||||
|
||||
# 停止所有观察者
|
||||
for observer in self.observers:
|
||||
try:
|
||||
observer.stop()
|
||||
observer.join()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 停止观察者时发生错误: {e}")
|
||||
|
||||
self.observers.clear()
|
||||
self.file_handlers.clear()
|
||||
self.is_running = False
|
||||
logger.info("🛑 插件热重载已停止")
|
||||
|
||||
def _reload_plugin(self, plugin_name: str):
|
||||
"""重载指定插件(简单重载)"""
|
||||
try:
|
||||
# 解析实际的插件名称
|
||||
actual_plugin_name = self._resolve_plugin_name(plugin_name)
|
||||
logger.info(f"🔄 开始简单重载插件: {plugin_name} -> {actual_plugin_name}")
|
||||
|
||||
if plugin_manager.reload_plugin(actual_plugin_name):
|
||||
logger.info(f"✅ 插件简单重载成功: {actual_plugin_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 插件简单重载失败: {actual_plugin_name}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _resolve_plugin_name(folder_name: str) -> str:
|
||||
"""
|
||||
将文件夹名称解析为实际的插件名称
|
||||
通过检查插件管理器中的路径映射来找到对应的插件名
|
||||
"""
|
||||
# 首先检查是否直接匹配
|
||||
if folder_name in plugin_manager.plugin_classes:
|
||||
logger.debug(f"🔍 直接匹配插件名: {folder_name}")
|
||||
return folder_name
|
||||
|
||||
# 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称
|
||||
matched_plugins = []
|
||||
for plugin_name, plugin_path in plugin_manager.plugin_paths.items():
|
||||
# 检查路径是否包含该文件夹名
|
||||
if folder_name in plugin_path:
|
||||
matched_plugins.append((plugin_name, plugin_path))
|
||||
|
||||
# 在匹配的插件中,优先选择在插件类中存在的
|
||||
for plugin_name, plugin_path in matched_plugins:
|
||||
if plugin_name in plugin_manager.plugin_classes:
|
||||
logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})")
|
||||
return plugin_name
|
||||
|
||||
# 如果还是没找到在插件类中存在的,返回第一个匹配项
|
||||
if matched_plugins:
|
||||
plugin_name, plugin_path = matched_plugins[0]
|
||||
logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在")
|
||||
return plugin_name
|
||||
|
||||
# 如果还是没找到,返回原文件夹名
|
||||
logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称")
|
||||
return folder_name
|
||||
|
||||
def _deep_reload_plugin(self, plugin_name: str):
|
||||
"""深度重载指定插件(清理模块缓存)"""
|
||||
try:
|
||||
# 解析实际的插件名称
|
||||
actual_plugin_name = self._resolve_plugin_name(plugin_name)
|
||||
logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}")
|
||||
|
||||
# 强制清理相关模块缓存
|
||||
self._force_clear_plugin_modules(plugin_name)
|
||||
|
||||
# 使用插件管理器的强制重载功能
|
||||
success = plugin_manager.force_reload_plugin(actual_plugin_name)
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 插件深度重载失败,尝试简单重载: {actual_plugin_name}")
|
||||
# 如果深度重载失败,尝试简单重载
|
||||
return self._reload_plugin(actual_plugin_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 深度重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
# 出错时尝试简单重载
|
||||
return self._reload_plugin(plugin_name)
|
||||
|
||||
@staticmethod
|
||||
def _force_clear_plugin_modules(plugin_name: str):
|
||||
"""强制清理插件相关的模块缓存"""
|
||||
|
||||
# 找到所有相关的模块名
|
||||
modules_to_remove = []
|
||||
plugin_module_prefix = f"src.plugins.built_in.{plugin_name}"
|
||||
|
||||
for module_name in list(sys.modules.keys()):
|
||||
if plugin_module_prefix in module_name:
|
||||
modules_to_remove.append(module_name)
|
||||
|
||||
# 删除模块缓存
|
||||
for module_name in modules_to_remove:
|
||||
if module_name in sys.modules:
|
||||
logger.debug(f"🗑️ 清理模块缓存: {module_name}")
|
||||
del sys.modules[module_name]
|
||||
|
||||
@staticmethod
|
||||
def _force_reimport_plugin(plugin_name: str):
|
||||
"""强制重新导入插件(委托给插件管理器)"""
|
||||
try:
|
||||
# 使用插件管理器的重载功能
|
||||
success = plugin_manager.reload_plugin(plugin_name)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _unload_plugin(plugin_name: str):
|
||||
"""卸载指定插件"""
|
||||
try:
|
||||
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.unload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def reload_all_plugins(self):
|
||||
"""重载所有插件"""
|
||||
try:
|
||||
logger.info("🔄 开始深度重载所有插件...")
|
||||
|
||||
# 获取当前已加载的插件列表
|
||||
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for plugin_name in loaded_plugins:
|
||||
logger.info(f"🔄 重载插件: {plugin_name}")
|
||||
if self._deep_reload_plugin(plugin_name):
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个")
|
||||
|
||||
# 清理全局缓存
|
||||
importlib.invalidate_caches()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载所有插件时发生错误: {e}", exc_info=True)
|
||||
|
||||
def force_reload_plugin(self, plugin_name: str):
|
||||
"""手动强制重载指定插件(委托给插件管理器)"""
|
||||
try:
|
||||
logger.info(f"🔄 手动强制重载插件: {plugin_name}")
|
||||
|
||||
# 清理待重载列表中的该插件(避免重复重载)
|
||||
for handler in self.file_handlers:
|
||||
handler.pending_reloads.discard(plugin_name)
|
||||
|
||||
# 使用插件管理器的强制重载功能
|
||||
success = plugin_manager.force_reload_plugin(plugin_name)
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 手动强制重载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 手动强制重载失败: {plugin_name}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def add_watch_directory(self, directory: str):
|
||||
"""添加新的监听目录"""
|
||||
if directory in self.watch_directories:
|
||||
logger.info(f"目录 {directory} 已在监听列表中")
|
||||
return
|
||||
|
||||
# 确保目录存在
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.info(f"📁 创建插件监听目录: {directory}")
|
||||
|
||||
self.watch_directories.append(directory)
|
||||
|
||||
# 如果热重载正在运行,为新目录创建观察者
|
||||
if self.is_running:
|
||||
try:
|
||||
observer = Observer()
|
||||
file_handler = PluginFileHandler(self)
|
||||
|
||||
observer.schedule(file_handler, directory, recursive=True)
|
||||
|
||||
observer.start()
|
||||
self.observers.append(observer)
|
||||
self.file_handlers.append(file_handler)
|
||||
|
||||
logger.info(f"📂 已添加新的监听目录: {directory}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加监听目录 {directory} 失败: {e}")
|
||||
self.watch_directories.remove(directory)
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""获取热重载状态"""
|
||||
pending_reloads = set()
|
||||
if self.file_handlers:
|
||||
for handler in self.file_handlers:
|
||||
pending_reloads.update(handler.pending_reloads)
|
||||
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"watch_directories": self.watch_directories,
|
||||
"active_observers": len(self.observers),
|
||||
"loaded_plugins": len(plugin_manager.loaded_plugins),
|
||||
"failed_plugins": len(plugin_manager.failed_plugins),
|
||||
"pending_reloads": list(pending_reloads),
|
||||
"debounce_delay": self.file_handlers[0].debounce_delay if self.file_handlers else 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def clear_all_caches():
|
||||
"""清理所有Python模块缓存"""
|
||||
try:
|
||||
logger.info("🧹 开始清理所有Python模块缓存...")
|
||||
|
||||
# 重新扫描所有插件目录,这会重新加载模块
|
||||
plugin_manager.rescan_plugin_directory()
|
||||
logger.info("✅ 模块缓存清理完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 全局热重载管理器实例
|
||||
hot_reload_manager = PluginHotReloadManager()
|
||||
@@ -91,7 +91,6 @@ INSTALL_NAME_TO_IMPORT_NAME = {
|
||||
"pyusb": "usb", # USB访问
|
||||
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
|
||||
"psutil": "psutil", # 系统信息和进程管理
|
||||
"watchdog": "watchdog", # 文件系统事件监控
|
||||
"python-gnupg": "gnupg", # GnuPG的Python接口
|
||||
# ============== 加密与安全 (Cryptography & Security) ==============
|
||||
"pycrypto": "Crypto", # 加密库 (较旧)
|
||||
|
||||
@@ -15,7 +15,6 @@ from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
|
||||
class ManagementCommand(PlusCommand):
|
||||
@@ -78,10 +77,6 @@ class ManagementCommand(PlusCommand):
|
||||
await self._force_reload_plugin(args[1])
|
||||
elif action in ["add_dir", "添加目录"] and len(args) > 1:
|
||||
await self._add_dir(args[1])
|
||||
elif action in ["hotreload_status", "热重载状态"]:
|
||||
await self._show_hotreload_status()
|
||||
elif action in ["clear_cache", "清理缓存"]:
|
||||
await self._clear_all_caches()
|
||||
else:
|
||||
await self.send_text("❌ 插件管理命令不合法\n使用 /pm plugin help 查看帮助")
|
||||
return False, "命令不合法", True
|
||||
@@ -179,14 +174,9 @@ class ManagementCommand(PlusCommand):
|
||||
• `/pm plugin force_reload <插件名>` - 强制重载指定插件(深度清理)
|
||||
• `/pm plugin add_dir <目录路径>` - 添加插件目录
|
||||
|
||||
<EFBFBD> 热重载管理:
|
||||
• `/pm plugin hotreload_status` - 查看热重载状态
|
||||
• `/pm plugin clear_cache` - 清理所有模块缓存
|
||||
|
||||
<EFBFBD>📝 示例:
|
||||
• `/pm plugin load echo_example`
|
||||
• `/pm plugin force_reload permission_manager_plugin`
|
||||
• `/pm plugin clear_cache`"""
|
||||
• `/pm plugin force_reload permission_manager_plugin`"""
|
||||
elif target == "component":
|
||||
help_msg = """🧩 组件管理命令帮助
|
||||
|
||||
@@ -262,7 +252,7 @@ class ManagementCommand(PlusCommand):
|
||||
await self.send_text(f"🔄 开始强制重载插件: `{plugin_name}`...")
|
||||
|
||||
try:
|
||||
success = hot_reload_manager.force_reload_plugin(plugin_name)
|
||||
success = plugin_manage_api.force_reload_plugin(plugin_name)
|
||||
if success:
|
||||
await self.send_text(f"✅ 插件强制重载成功: `{plugin_name}`")
|
||||
else:
|
||||
@@ -270,44 +260,7 @@ class ManagementCommand(PlusCommand):
|
||||
except Exception as e:
|
||||
await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}")
|
||||
|
||||
async def _show_hotreload_status(self):
|
||||
"""显示热重载状态"""
|
||||
try:
|
||||
status = hot_reload_manager.get_status()
|
||||
|
||||
status_text = f"""🔄 **热重载系统状态**
|
||||
|
||||
🟢 **运行状态:** {"运行中" if status["is_running"] else "已停止"}
|
||||
📂 **监听目录:** {len(status["watch_directories"])} 个
|
||||
👁️ **活跃观察者:** {status["active_observers"]} 个
|
||||
📦 **已加载插件:** {status["loaded_plugins"]} 个
|
||||
❌ **失败插件:** {status["failed_plugins"]} 个
|
||||
⏱️ **防抖延迟:** {status.get("debounce_delay", 0)} 秒
|
||||
|
||||
📋 **监听的目录:**"""
|
||||
|
||||
for i, watch_dir in enumerate(status["watch_directories"], 1):
|
||||
dir_type = "(内置插件)" if "src" in watch_dir else "(外部插件)"
|
||||
status_text += f"\n{i}. `{watch_dir}` {dir_type}"
|
||||
|
||||
if status.get("pending_reloads"):
|
||||
status_text += f"\n\n⏳ **待重载插件:** {', '.join([f'`{p}`' for p in status['pending_reloads']])}"
|
||||
|
||||
await self.send_text(status_text)
|
||||
|
||||
except Exception as e:
|
||||
await self.send_text(f"❌ 获取热重载状态时发生错误: {str(e)}")
|
||||
|
||||
async def _clear_all_caches(self):
|
||||
"""清理所有模块缓存"""
|
||||
await self.send_text("🧹 开始清理所有Python模块缓存...")
|
||||
|
||||
try:
|
||||
hot_reload_manager.clear_all_caches()
|
||||
await self.send_text("✅ 模块缓存清理完成!建议重载相关插件以确保生效。")
|
||||
except Exception as e:
|
||||
await self.send_text(f"❌ 清理缓存时发生错误: {str(e)}")
|
||||
|
||||
|
||||
async def _add_dir(self, dir_path: str):
|
||||
"""添加插件目录"""
|
||||
await self.send_text(f"📁 正在添加插件目录: `{dir_path}`")
|
||||
|
||||
@@ -255,25 +255,34 @@ max_context_emojis = 30 # 每次随机传递给LLM的表情包详细描述的最
|
||||
[memory]
|
||||
enable_memory = true # 是否启用记忆系统
|
||||
memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低,MoFox-Bot学习越多,但是冗余信息也会增多
|
||||
memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
||||
memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
|
||||
memory_build_sample_length = 30 # 采样长度,数值越高一段记忆内容越丰富
|
||||
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
|
||||
|
||||
forget_memory_interval = 3000 # 记忆遗忘间隔 单位秒 间隔越低,MoFox-Bot遗忘越频繁,记忆更精简,但更难学习
|
||||
memory_forget_time = 48 #多长时间后的记忆会被遗忘 单位小时
|
||||
memory_forget_percentage = 0.008 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认
|
||||
|
||||
consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低,MoFox-Bot整合越频繁,记忆更精简
|
||||
consolidation_similarity_threshold = 0.7 # 相似度阈值
|
||||
consolidation_check_percentage = 0.05 # 检查节点比例
|
||||
|
||||
enable_instant_memory = false # 是否启用即时记忆,测试功能,可能存在未知问题
|
||||
enable_instant_memory = true # 是否启用即时记忆
|
||||
enable_llm_instant_memory = true # 是否启用基于LLM的瞬时记忆
|
||||
enable_vector_instant_memory = true # 是否启用基于向量的瞬时记忆
|
||||
enable_enhanced_memory = true # 是否启用增强记忆系统
|
||||
enhanced_memory_auto_save = true # 是否自动保存增强记忆
|
||||
|
||||
#不希望记忆的词,已经记忆的不会受到影响,需要手动清理
|
||||
memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
|
||||
min_memory_length = 10 # 最小记忆长度
|
||||
max_memory_length = 500 # 最大记忆长度
|
||||
memory_value_threshold = 0.7 # 记忆价值阈值,低于该值的记忆会被丢弃
|
||||
vector_similarity_threshold = 0.8 # 向量相似度阈值
|
||||
semantic_similarity_threshold = 0.6 # 语义重排阶段的最低匹配阈值
|
||||
|
||||
metadata_filter_limit = 100 # 元数据过滤阶段返回数量上限
|
||||
vector_search_limit = 50 # 向量搜索阶段返回数量上限
|
||||
semantic_rerank_limit = 20 # 语义重排阶段返回数量上限
|
||||
final_result_limit = 10 # 综合筛选后的最终返回数量
|
||||
|
||||
vector_weight = 0.4 # 综合评分中向量相似度的权重
|
||||
semantic_weight = 0.3 # 综合评分中语义匹配的权重
|
||||
context_weight = 0.2 # 综合评分中上下文关联的权重
|
||||
recency_weight = 0.1 # 综合评分中时效性的权重
|
||||
|
||||
fusion_similarity_threshold = 0.85 # 记忆融合时的相似度阈值
|
||||
deduplication_window_hours = 24 # 记忆去重窗口(小时)
|
||||
|
||||
enable_memory_cache = true # 是否启用本地记忆缓存
|
||||
cache_ttl_seconds = 300 # 缓存有效期(秒)
|
||||
max_cache_size = 1000 # 缓存中允许的最大记忆条数
|
||||
|
||||
[voice]
|
||||
enable_asr = true # 是否启用语音识别,启用后MoFox-Bot可以识别语音消息,启用该功能需要配置语音识别模型[model.voice]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "1.3.5"
|
||||
version = "1.3.6"
|
||||
|
||||
# 配置文件版本号迭代规则同bot_config.toml
|
||||
|
||||
@@ -203,6 +203,7 @@ max_tokens = 1000
|
||||
#嵌入模型
|
||||
[model_task_config.embedding]
|
||||
model_list = ["bge-m3"]
|
||||
embedding_dimension = 1024
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user