This commit is contained in:
tt-P607
2025-10-01 06:04:13 +08:00
29 changed files with 1638 additions and 922 deletions

View File

@@ -70,7 +70,6 @@ dependencies = [
"tqdm>=4.67.1",
"urllib3>=2.5.0",
"uvicorn>=0.35.0",
"watchdog>=6.0.0",
"websockets>=15.0.1",
"aiomysql>=0.2.0",
"aiosqlite>=0.21.0",

View File

@@ -50,7 +50,6 @@ reportportal-client
scikit-learn
seaborn
structlog
watchdog
httpx
requests
beautifulsoup4

View File

@@ -12,6 +12,7 @@ from sqlalchemy import select
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
logger = get_logger("bot_interest_manager")
@@ -28,7 +29,9 @@ class BotInterestManager:
# Embedding客户端配置
self.embedding_request = None
self.embedding_config = None
self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度
configured_dim = resolve_embedding_dimension()
self.embedding_dimension = int(configured_dim) if configured_dim else 0
self._detected_embedding_dimension: Optional[int] = None
@property
def is_initialized(self) -> bool:
@@ -82,8 +85,11 @@ class BotInterestManager:
logger.info("📋 找到embedding模型配置")
self.embedding_config = model_config.model_task_config.embedding
self.embedding_dimension = 1024 # BGE-M3的维度
logger.info(f"📐 使用模型维度: {self.embedding_dimension}")
if self.embedding_dimension:
logger.info(f"📐 配置的embedding维度: {self.embedding_dimension}")
else:
logger.info("📐 未在配置中检测到embedding维度将根据首次返回的向量自动识别")
# 创建LLMRequest实例用于embedding
self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding")
@@ -350,7 +356,27 @@ class BotInterestManager:
if embedding and len(embedding) > 0:
self.embedding_cache[text] = embedding
logger.debug(f"✅ Embedding获取成功维度: {len(embedding)}, 模型: {model_name}")
current_dim = len(embedding)
if self._detected_embedding_dimension is None:
self._detected_embedding_dimension = current_dim
if self.embedding_dimension and self.embedding_dimension != current_dim:
logger.warning(
"⚠️ 实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新",
current_dim,
self.embedding_dimension,
)
else:
self.embedding_dimension = current_dim
logger.info(f"📏 检测到embedding维度: {current_dim}")
elif current_dim != self.embedding_dimension:
logger.warning(
"⚠️ 收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。",
self.embedding_dimension,
current_dim,
)
logger.debug(f"✅ Embedding获取成功维度: {current_dim}, 模型: {model_name}")
return embedding
else:
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")

View File

@@ -26,6 +26,7 @@ from rich.progress import (
TextColumn,
)
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
install(extra_lines=3)
@@ -504,7 +505,10 @@ class EmbeddingStore:
# L2归一化
faiss.normalize_L2(embeddings)
# 构建索引
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
if not embedding_dim:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:

View File

@@ -11,12 +11,27 @@ from dataclasses import dataclass
from src.common.logger import get_logger
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
MEMORY_TYPE_LABELS = {
MemoryType.PERSONAL_FACT: "个人事实",
MemoryType.EVENT: "事件",
MemoryType.PREFERENCE: "偏好",
MemoryType.OPINION: "观点",
MemoryType.RELATIONSHIP: "关系",
MemoryType.EMOTION: "情感",
MemoryType.KNOWLEDGE: "知识",
MemoryType.SKILL: "技能",
MemoryType.GOAL: "目标",
MemoryType.EXPERIENCE: "经验",
MemoryType.CONTEXTUAL: "上下文",
}
@dataclass
class AdapterConfig:
"""适配器配置"""
@@ -85,12 +100,9 @@ class EnhancedMemoryAdapter:
async def process_conversation_memory(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""处理对话记忆"""
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
@@ -98,10 +110,30 @@ class EnhancedMemoryAdapter:
self.adapter_stats["total_processed"] += 1
try:
payload_context: Dict[str, Any] = dict(context or {})
conversation_text = payload_context.get("conversation_text")
if not conversation_text:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
conversation_text = str(conversation_candidate)
payload_context["conversation_text"] = conversation_text
else:
conversation_text = ""
else:
conversation_text = str(conversation_text)
if "timestamp" not in payload_context:
payload_context["timestamp"] = time.time()
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
# 使用集成层处理对话
result = await self.integration_layer.process_conversation(
conversation_text, context, user_id, timestamp
)
result = await self.integration_layer.process_conversation(payload_context)
# 更新统计
processing_time = time.time() - start_time
@@ -132,7 +164,7 @@ class EnhancedMemoryAdapter:
try:
limit = limit or self.config.max_retrieval_results
memories = await self.integration_layer.retrieve_relevant_memories(
query, user_id, context, limit
query, None, context, limit
)
self.adapter_stats["memories_retrieved"] += len(memories)
@@ -157,12 +189,15 @@ class EnhancedMemoryAdapter:
if not memories:
return ""
# 格式化记忆为提示词友好的格式
memory_context_parts = []
for memory in memories:
memory_context_parts.append(f"- {memory.text_content}")
# 格式化记忆为提示词友好的Markdown结构
lines: List[str] = ["### 🧠 相关记忆 (Relevant Memories)", ""]
return "\n".join(memory_context_parts)
for memory in memories:
type_label = MEMORY_TYPE_LABELS.get(memory.memory_type, memory.memory_type.value)
display_text = memory.display or memory.text_content
lines.append(f"- **[{type_label}]** {display_text}")
return "\n".join(lines)
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
"""获取增强记忆系统摘要"""
@@ -270,13 +305,10 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
async def process_conversation_with_enhanced_memory(
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None,
llm_model: Optional[LLMRequest] = None
) -> Dict[str, Any]:
"""使用增强记忆系统处理对话"""
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
@@ -284,7 +316,18 @@ async def process_conversation_with_enhanced_memory(
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp)
payload_context = dict(context or {})
if "conversation_text" not in payload_context:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
payload_context["conversation_text"] = str(conversation_candidate)
return await adapter.process_conversation_memory(payload_context)
except Exception as e:
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}

View File

@@ -1,13 +1,15 @@
# -*- coding: utf-8 -*-
"""
增强型精准记忆系统核心模块
基于文档设计的高效记忆构建、存储与召回优化系统
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
2. 内置 LLM 查询规划器与嵌入维度自动解析机制,直接从模型配置推断向量存储参数。
"""
import asyncio
import time
import orjson
import re
import hashlib
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
@@ -22,12 +24,16 @@ from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
from src.chat.memory_system.metadata_index import MetadataIndexManager
from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger(__name__)
# 全局记忆作用域(共享记忆库)
GLOBAL_MEMORY_SCOPE = "global"
class MemorySystemStatus(Enum):
"""记忆系统状态"""
@@ -47,14 +53,20 @@ class MemorySystemConfig:
memory_value_threshold: float = 0.7
min_build_interval_seconds: float = 300.0
# 向量存储配置
vector_dimension: int = 768
# 向量存储配置(嵌入维度自动来自模型配置)
vector_dimension: int = 1024
similarity_threshold: float = 0.8
# 召回配置
coarse_recall_limit: int = 50
fine_recall_limit: int = 10
semantic_rerank_limit: int = 20
final_recall_limit: int = 5
semantic_similarity_threshold: float = 0.6
vector_weight: float = 0.4
semantic_weight: float = 0.3
context_weight: float = 0.2
recency_weight: float = 0.1
# 融合配置
fusion_similarity_threshold: float = 0.85
@@ -64,6 +76,23 @@ class MemorySystemConfig:
def from_global_config(cls):
"""从全局配置创建配置实例"""
embedding_dimension = None
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
embedding_dimension = getattr(embedding_task, "embedding_dimension", None)
except Exception:
embedding_dimension = None
if not embedding_dimension:
try:
embedding_dimension = getattr(global_config.lpmm_knowledge, "embedding_dimension", None)
except Exception:
embedding_dimension = None
if not embedding_dimension:
embedding_dimension = 1024
return cls(
# 记忆构建配置
min_memory_length=global_config.memory.min_memory_length,
@@ -72,13 +101,19 @@ class MemorySystemConfig:
min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0),
# 向量存储配置
vector_dimension=global_config.memory.vector_dimension,
vector_dimension=int(embedding_dimension),
similarity_threshold=global_config.memory.vector_similarity_threshold,
# 召回配置
coarse_recall_limit=global_config.memory.metadata_filter_limit,
fine_recall_limit=global_config.memory.final_result_limit,
fine_recall_limit=global_config.memory.vector_search_limit,
semantic_rerank_limit=global_config.memory.semantic_rerank_limit,
final_recall_limit=global_config.memory.final_result_limit,
semantic_similarity_threshold=getattr(global_config.memory, "semantic_similarity_threshold", 0.6),
vector_weight=global_config.memory.vector_weight,
semantic_weight=global_config.memory.semantic_weight,
context_weight=global_config.memory.context_weight,
recency_weight=global_config.memory.recency_weight,
# 融合配置
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
@@ -104,6 +139,7 @@ class EnhancedMemorySystem:
self.vector_storage: VectorStorageManager = None
self.metadata_index: MetadataIndexManager = None
self.retrieval_system: MultiStageRetrieval = None
self.query_planner: MemoryQueryPlanner = None
# LLM模型
self.value_assessment_model: LLMRequest = None
@@ -117,6 +153,9 @@ class EnhancedMemorySystem:
# 构建节流记录
self._last_memory_build_times: Dict[str, float] = {}
# 记忆指纹缓存,用于快速检测重复记忆
self._memory_fingerprints: Dict[str, str] = {}
logger.info("EnhancedMemorySystem 初始化开始")
async def initialize(self):
@@ -125,19 +164,29 @@ class EnhancedMemorySystem:
logger.info("正在初始化增强型记忆系统...")
# 初始化LLM模型
task_config = (
self.llm_model.model_for_task
if self.llm_model is not None
else model_config.model_task_config.utils_small
)
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
value_task_config = getattr(model_config.model_task_config, "utils_small", None)
extraction_task_config = getattr(model_config.model_task_config, "utils", None)
if value_task_config is None:
logger.warning("未找到 utils_small 模型配置,回退到 utils 或外部提供的模型配置。")
value_task_config = extraction_task_config or fallback_task
if extraction_task_config is None:
logger.warning("未找到 utils 模型配置,回退到 utils_small 或外部提供的模型配置。")
extraction_task_config = value_task_config or fallback_task
if value_task_config is None or extraction_task_config is None:
raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。")
self.value_assessment_model = LLMRequest(
model_set=task_config,
model_set=value_task_config,
request_type="memory.value_assessment"
)
self.memory_extraction_model = LLMRequest(
model_set=task_config,
model_set=extraction_task_config,
request_type="memory.extraction"
)
@@ -155,13 +204,36 @@ class EnhancedMemorySystem:
retrieval_config = RetrievalConfig(
metadata_filter_limit=self.config.coarse_recall_limit,
vector_search_limit=self.config.fine_recall_limit,
final_result_limit=self.config.final_recall_limit
semantic_rerank_limit=self.config.semantic_rerank_limit,
final_result_limit=self.config.final_recall_limit,
vector_similarity_threshold=self.config.similarity_threshold,
semantic_similarity_threshold=self.config.semantic_similarity_threshold,
vector_weight=self.config.vector_weight,
semantic_weight=self.config.semantic_weight,
context_weight=self.config.context_weight,
recency_weight=self.config.recency_weight,
)
self.retrieval_system = MultiStageRetrieval(retrieval_config)
planner_task_config = getattr(model_config.model_task_config, "planner", None)
planner_model: Optional[LLMRequest] = None
try:
planner_model = LLMRequest(
model_set=planner_task_config,
request_type="memory.query_planner"
)
except Exception as planner_exc:
logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True)
self.query_planner = MemoryQueryPlanner(
planner_model,
default_limit=self.config.final_recall_limit
)
# 加载持久化数据
await self.vector_storage.load_storage()
await self.metadata_index.load_index()
self._populate_memory_fingerprints()
self.status = MemorySystemStatus.READY
logger.info("✅ 增强型记忆系统初始化完成")
@@ -174,7 +246,7 @@ class EnhancedMemorySystem:
async def retrieve_memories_for_building(
self,
query_text: str,
user_id: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
) -> List[MemoryChunk]:
@@ -182,7 +254,6 @@ class EnhancedMemorySystem:
Args:
query_text: 查询文本
user_id: 用户ID
context: 上下文信息
limit: 返回结果数量限制
@@ -201,7 +272,6 @@ class EnhancedMemorySystem:
# 执行检索
memories = await self.vector_storage.search_similar_memories(
query_text=query_text,
user_id=user_id,
limit=limit
)
@@ -218,23 +288,18 @@ class EnhancedMemorySystem:
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
) -> List[MemoryChunk]:
"""从对话中构建记忆
Args:
conversation_text: 对话文本
context: 上下文信息(包括用户信息、群组信息等)
user_id: 用户ID
context: 上下文信息
timestamp: 时间戳,默认为当前时间
Returns:
构建的记忆块列表
"""
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
raise RuntimeError("记忆系统未就绪")
original_status = self.status
self.status = MemorySystemStatus.BUILDING
start_time = time.time()
@@ -243,9 +308,9 @@ class EnhancedMemorySystem:
build_marker_time: Optional[float] = None
try:
normalized_context = self._normalize_context(context, user_id, timestamp)
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
build_scope_key = self._get_build_scope_key(normalized_context, user_id)
build_scope_key = self._get_build_scope_key(normalized_context, GLOBAL_MEMORY_SCOPE)
min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0))
current_time = time.time()
@@ -266,7 +331,7 @@ class EnhancedMemorySystem:
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
logger.debug("开始构建记忆,文本长度: %d", len(conversation_text))
# 1. 信息价值评估
value_score = await self._assess_information_value(conversation_text, normalized_context)
@@ -280,7 +345,7 @@ class EnhancedMemorySystem:
memory_chunks = await self.memory_builder.build_memories(
conversation_text,
normalized_context,
user_id,
GLOBAL_MEMORY_SCOPE,
timestamp or time.time()
)
@@ -293,19 +358,24 @@ class EnhancedMemorySystem:
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks)
# 4. 存储记忆
await self._store_memories(fused_chunks)
stored_count = await self._store_memories(fused_chunks)
# 4.1 控制台预览
self._log_memory_preview(fused_chunks)
# 5. 更新统计
self.total_memories += len(fused_chunks)
self.total_memories += stored_count
self.last_build_time = time.time()
if build_scope_key:
self._last_memory_build_times[build_scope_key] = self.last_build_time
build_time = time.time() - start_time
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}")
logger.info(
"✅ 生成 %d 条记忆,成功入库 %d 条,耗时 %.2f",
len(fused_chunks),
stored_count,
build_time,
)
self.status = original_status
return fused_chunks
@@ -347,21 +417,34 @@ class EnhancedMemorySystem:
async def process_conversation_memory(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Dict[str, Any]
) -> Dict[str, Any]:
"""对外暴露的对话记忆处理接口,兼容旧调用方式"""
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, timestamp)
context = dict(context or {})
conversation_candidate = (
context.get("conversation_text")
or context.get("message_content")
or context.get("latest_message")
or context.get("raw_text")
or ""
)
conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
timestamp = context.get("timestamp")
if timestamp is None:
timestamp = time.time()
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
normalized_context.setdefault("conversation_text", conversation_text)
memories = await self.build_memory_from_conversation(
conversation_text=conversation_text,
context=normalized_context,
user_id=user_id,
timestamp=timestamp
)
@@ -395,52 +478,77 @@ class EnhancedMemorySystem:
**kwargs
) -> List[MemoryChunk]:
"""检索相关记忆,兼容 query/query_text 参数形式"""
if self.status != MemorySystemStatus.READY:
raise RuntimeError("记忆系统未就绪")
query_text = query_text or kwargs.get("query")
if not query_text:
raw_query = query_text or kwargs.get("query")
if not raw_query:
raise ValueError("query_text 或 query 参数不能为空")
context = context or {}
user_id = user_id or kwargs.get("user_id")
resolved_user_id = GLOBAL_MEMORY_SCOPE
if self.retrieval_system is None or self.metadata_index is None:
raise RuntimeError("检索组件未初始化")
all_memories_cache = self.vector_storage.memory_cache
if not all_memories_cache:
logger.debug("记忆缓存为空,返回空结果")
self.last_retrieval_time = time.time()
self.status = MemorySystemStatus.READY
return []
self.status = MemorySystemStatus.RETRIEVING
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, None)
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
candidate_memories = list(self.vector_storage.memory_cache.values())
if user_id:
candidate_memories = [m for m in candidate_memories if m.user_id == user_id]
effective_limit = limit or self.config.final_recall_limit
query_plan = None
planner_ran = False
resolved_query_text = raw_query
if self.query_planner:
try:
planner_ran = True
query_plan = await self.query_planner.plan_query(raw_query, normalized_context)
normalized_context["query_plan"] = query_plan
effective_limit = min(effective_limit, query_plan.limit or effective_limit)
if getattr(query_plan, "semantic_query", None):
resolved_query_text = query_plan.semantic_query
logger.debug(
"查询规划: semantic='%s', types=%s, subjects=%s, limit=%d",
query_plan.semantic_query,
[mt.value for mt in query_plan.memory_types],
query_plan.subject_includes,
query_plan.limit,
)
except Exception as plan_exc:
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
if not candidate_memories:
self.status = MemorySystemStatus.READY
self.last_retrieval_time = time.time()
logger.debug(f"未找到用户 {user_id} 的候选记忆")
return []
effective_limit = effective_limit or self.config.final_recall_limit
effective_limit = max(1, min(effective_limit, self.config.final_recall_limit))
normalized_context["resolved_query_text"] = resolved_query_text
scored_memories = []
for memory in candidate_memories:
score = self._compute_memory_score(query_text, memory, normalized_context)
if score > 0:
scored_memories.append((memory, score))
if not scored_memories:
# 如果所有分数为0返回最近的记忆作为降级策略
if normalized_context.get("__memory_building__"):
logger.debug("当前处于记忆构建流程,跳过查询规划并进行降级检索")
self.status = MemorySystemStatus.BUILDING
final_memories = []
candidate_memories = list(all_memories_cache.values())
candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True)
scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]]
final_memories = candidate_memories[:effective_limit]
else:
scored_memories.sort(key=lambda item: item[1], reverse=True)
retrieval_result = await self.retrieval_system.retrieve_memories(
query=resolved_query_text,
user_id=resolved_user_id,
context=normalized_context,
metadata_index=self.metadata_index,
vector_storage=self.vector_storage,
all_memories_cache=all_memories_cache,
limit=effective_limit,
)
top_memories = [memory for memory, _ in scored_memories[:limit]]
final_memories = retrieval_result.final_memories
# 更新访问信息和缓存
for memory, score in scored_memories[:limit]:
for memory in final_memories:
memory.update_access()
memory.update_relevance(score)
cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id)
if cache_entry is not None:
cache_entry["last_accessed"] = memory.metadata.last_accessed
@@ -448,14 +556,34 @@ class EnhancedMemorySystem:
cache_entry["relevance_score"] = memory.metadata.relevance_score
retrieval_time = time.time() - start_time
logger.info(
f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}"
plan_summary = ""
if planner_ran and query_plan:
plan_types = ",".join(mt.value for mt in query_plan.memory_types) or "-"
plan_subjects = ",".join(query_plan.subject_includes) or "-"
plan_summary = (
f" | planner.semantic='{query_plan.semantic_query}'"
f" | planner.limit={query_plan.limit}"
f" | planner.types={plan_types}"
f" | planner.subjects={plan_subjects}"
)
log_message = (
"✅ 记忆检索完成"
f" | user={resolved_user_id}"
f" | count={len(final_memories)}"
f" | duration={retrieval_time:.3f}s"
f" | applied_limit={effective_limit}"
f" | raw_query='{raw_query}'"
f" | semantic_query='{resolved_query_text}'"
f"{plan_summary}"
)
logger.info(log_message)
self.last_retrieval_time = time.time()
self.status = MemorySystemStatus.READY
return top_memories
return final_memories
except Exception as e:
self.status = MemorySystemStatus.ERROR
@@ -499,8 +627,8 @@ class EnhancedMemorySystem:
except Exception:
context = dict(raw_context or {})
# 基础字段
context["user_id"] = context.get("user_id") or user_id or "unknown"
# 基础字段(统一使用全局作用域)
context["user_id"] = GLOBAL_MEMORY_SCOPE
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
context["message_type"] = context.get("message_type") or "normal"
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
@@ -523,8 +651,8 @@ class EnhancedMemorySystem:
if stream_id:
context["stream_id"] = stream_id
# chat_id 兜底
context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}"
# 全局记忆无需聊天隔离
context["chat_id"] = context.get("chat_id") or "global_chat"
# 历史窗口配置
window_candidate = (
@@ -616,18 +744,7 @@ class EnhancedMemorySystem:
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
"""确定用于节流控制的记忆构建作用域"""
stream_id = context.get("stream_id")
if stream_id:
return f"stream::{stream_id}"
chat_id = context.get("chat_id")
if chat_id:
return f"chat::{chat_id}"
if user_id:
return f"user::{user_id}"
return None
return "global_scope"
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
"""确定历史消息获取数量限制在30-50之间"""
@@ -789,24 +906,134 @@ class EnhancedMemorySystem:
logger.error(f"信息价值评估失败: {e}", exc_info=True)
return 0.5 # 默认中等价值
async def _store_memories(self, memory_chunks: List[MemoryChunk]):
"""存储记忆块到各个存储系统"""
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
"""存储记忆块到各个存储系统,返回成功入库数量"""
if not memory_chunks:
return
return 0
unique_memories: List[MemoryChunk] = []
skipped_duplicates = 0
for memory in memory_chunks:
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
existing_id = self._memory_fingerprints.get(key)
if existing_id:
existing = self.vector_storage.memory_cache.get(existing_id)
if existing:
self._merge_existing_memory(existing, memory)
await self.metadata_index.update_memory_entry(existing)
skipped_duplicates += 1
logger.debug(
"检测到重复记忆,已合并到现有记录 | memory_id=%s",
existing.memory_id,
)
continue
else:
# 指纹存在但缓存缺失,视为新记忆并覆盖旧映射
logger.debug("检测到过期指纹映射,重写现有条目")
unique_memories.append(memory)
if not unique_memories:
if skipped_duplicates:
logger.info("本次记忆全部与现有内容重复,跳过入库")
return 0
# 并行存储到向量数据库和元数据索引
storage_tasks = []
storage_tasks = [
self.vector_storage.store_memories(unique_memories),
self.metadata_index.index_memories(unique_memories),
]
# 向量存储
storage_tasks.append(self.vector_storage.store_memories(memory_chunks))
# 元数据索引
storage_tasks.append(self.metadata_index.index_memories(memory_chunks))
# 等待所有存储任务完成
await asyncio.gather(*storage_tasks, return_exceptions=True)
logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统")
self._register_memory_fingerprints(unique_memories)
logger.debug(
"成功存储 %d 条记忆(跳过重复 %d 条)",
len(unique_memories),
skipped_duplicates,
)
return len(unique_memories)
def _merge_existing_memory(self, existing: MemoryChunk, incoming: MemoryChunk) -> None:
"""将新记忆的信息合并到已存在的记忆中"""
updated = False
for keyword in incoming.keywords:
if keyword not in existing.keywords:
existing.add_keyword(keyword)
updated = True
for tag in incoming.tags:
if tag not in existing.tags:
existing.add_tag(tag)
updated = True
for category in incoming.categories:
if category not in existing.categories:
existing.add_category(category)
updated = True
if incoming.metadata.source_context:
existing.metadata.source_context = incoming.metadata.source_context
if incoming.metadata.importance.value > existing.metadata.importance.value:
existing.metadata.importance = incoming.metadata.importance
updated = True
if incoming.metadata.confidence.value > existing.metadata.confidence.value:
existing.metadata.confidence = incoming.metadata.confidence
updated = True
if incoming.metadata.relevance_score > existing.metadata.relevance_score:
existing.metadata.relevance_score = incoming.metadata.relevance_score
updated = True
if updated:
existing.metadata.last_modified = time.time()
def _populate_memory_fingerprints(self) -> None:
"""基于当前缓存构建记忆指纹映射"""
self._memory_fingerprints.clear()
for memory in self.vector_storage.memory_cache.values():
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
self._memory_fingerprints[key] = memory.memory_id
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
for memory in memories:
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
self._memory_fingerprints[key] = memory.memory_id
def _build_memory_fingerprint(self, memory: MemoryChunk) -> str:
subjects = memory.subjects or []
subject_part = "|".join(sorted(s.strip() for s in subjects if s))
predicate_part = (memory.content.predicate or "").strip()
obj = memory.content.object
if isinstance(obj, (dict, list)):
obj_part = orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8")
else:
obj_part = str(obj).strip()
base = "|".join([
str(memory.user_id or "unknown"),
memory.memory_type.value,
subject_part,
predicate_part,
obj_part,
])
return hashlib.sha256(base.encode("utf-8")).hexdigest()
@staticmethod
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
return f"{str(user_id)}:{fingerprint}"
def get_system_stats(self) -> Dict[str, Any]:
"""获取系统统计信息"""

View File

@@ -241,12 +241,12 @@ class EnhancedMemoryManager:
return []
try:
result = await self.enhanced_system.process_conversation_memory(
conversation_text=conversation_text,
context=context,
user_id=user_id,
timestamp=timestamp
)
payload_context = dict(context or {})
payload_context.setdefault("conversation_text", conversation_text)
if timestamp is not None:
payload_context.setdefault("timestamp", timestamp)
result = await self.enhanced_system.process_conversation_memory(payload_context)
# 从结果中提取记忆块
memory_chunks = []
@@ -274,7 +274,7 @@ class EnhancedMemoryManager:
try:
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
query=query_text,
user_id=user_id,
user_id=None,
context=context or {},
limit=limit
)
@@ -303,6 +303,9 @@ class EnhancedMemoryManager:
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
"""将记忆块转换为更易读的文本描述"""
structure = memory.content.to_dict()
if memory.display:
return self._clean_text(memory.display), structure
subject = structure.get("subject")
predicate = structure.get("predicate") or ""
obj = structure.get("object")

View File

@@ -114,12 +114,9 @@ class MemoryIntegrationLayer:
async def process_conversation(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Dict[str, Any]
) -> Dict[str, Any]:
"""处理对话记忆"""
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
@@ -128,13 +125,12 @@ class MemoryIntegrationLayer:
self.integration_stats["enhanced_queries"] += 1
try:
payload_context = dict(context or {})
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
# 直接使用增强记忆系统处理
result = await self.enhanced_memory.process_conversation_memory(
conversation_text=conversation_text,
context=context,
user_id=user_id,
timestamp=timestamp
)
result = await self.enhanced_memory.process_conversation_memory(payload_context)
# 更新统计
processing_time = time.time() - start_time
@@ -156,7 +152,7 @@ class MemoryIntegrationLayer:
async def retrieve_relevant_memories(
self,
query: str,
user_id: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None
) -> List[MemoryChunk]:
@@ -168,7 +164,7 @@ class MemoryIntegrationLayer:
limit = limit or self.config.max_retrieval_results
memories = await self.enhanced_memory.retrieve_relevant_memories(
query=query,
user_id=user_id,
user_id=None,
context=context or {},
limit=limit
)

View File

@@ -2,21 +2,48 @@
"""
记忆构建模块
从对话流中提取高质量、结构化记忆单元
输出格式要求:
{{
"memories": [
{{
"type": "记忆类型",
"display": "用于直接展示和检索的自然语言描述",
"subject": ["主体1", "主体2"],
"predicate": "谓语(动作/状态)",
"object": "宾语(对象/属性或结构体)",
"keywords": ["关键词1", "关键词2"],
"importance": "重要性等级(1-4)",
"confidence": "置信度(1-4)",
"reasoning": "提取理由"
}}
]
}}
注意:
1. `subject` 可包含多个主体,请用数组表示;若主体不明确,请根据上下文给出最合理的称呼
2. `display` 必须是一句完整流畅的中文描述,可直接用于用户展示和向量搜索
3. 只提取确实值得记忆的信息,不要提取琐碎内容
4. 确保信息准确、具体、有价值
5. 重要性: 1=低, 2=一般, 3=高, 4=关键;置信度: 1=低, 2=中等, 3=高, 4=已验证
"""
import re
import time
import orjson
from typing import Dict, List, Optional, Any
from datetime import datetime
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union, Type
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
create_memory_chunk
MemoryChunk,
MemoryType,
ConfidenceLevel,
ImportanceLevel,
create_memory_chunk,
)
logger = get_logger(__name__)
@@ -24,6 +51,7 @@ logger = get_logger(__name__)
class ExtractionStrategy(Enum):
"""提取策略"""
LLM_BASED = "llm_based" # 基于LLM的智能提取
RULE_BASED = "rule_based" # 基于规则的提取
HYBRID = "hybrid" # 混合策略
@@ -171,18 +199,18 @@ class MemoryBuilder:
"""使用规则提取记忆"""
memories = []
subject_display = self._resolve_user_display(context, user_id)
subjects = self._resolve_conversation_participants(context, user_id)
# 规则1: 检测个人信息
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subject_display)
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects)
memories.extend(personal_info)
# 规则2: 检测偏好信息
preferences = self._extract_preferences(text, user_id, timestamp, context, subject_display)
preferences = self._extract_preferences(text, user_id, timestamp, context, subjects)
memories.extend(preferences)
# 规则3: 检测事件信息
events = self._extract_events(text, user_id, timestamp, context, subject_display)
events = self._extract_events(text, user_id, timestamp, context, subjects)
memories.extend(events)
return memories
@@ -258,10 +286,7 @@ class MemoryBuilder:
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
当前时间: {current_date}
聊天ID: {chat_id}
消息类型: {message_type}
目标用户ID: {target_user_id_display}
目标用户称呼: {target_user_name}
## 🤖 机器人身份(仅供参考,禁止写入记忆)
- 机器人名称: {bot_name_display}
@@ -272,7 +297,6 @@ class MemoryBuilder:
请务必遵守以下命名规范:
- 当说话者是机器人时,请使用“{bot_name_display}”或其他明确称呼作为主语;
- 如果看到系统自动生成的长ID类似 {target_user_id}),请改用“{target_user_name}”、机器人的称呼或“该用户”描述不要把ID写入记忆
- 记录关键事实时,请准确标记主体是机器人还是用户,避免混淆。
对话内容:
@@ -450,7 +474,7 @@ class MemoryBuilder:
bot_identifiers = self._collect_bot_identifiers(context)
system_identifiers = self._collect_system_identifiers(context)
default_subject = self._resolve_user_display(context, user_id)
default_subjects = self._resolve_conversation_participants(context, user_id)
bot_display = None
if context:
@@ -481,19 +505,33 @@ class MemoryBuilder:
for mem_data in memory_list:
try:
subject_value = mem_data.get("subject")
normalized_subject = self._normalize_subject(
normalized_subject = self._normalize_subjects(
subject_value,
bot_identifiers,
system_identifiers,
default_subject,
default_subjects,
bot_display
)
if normalized_subject is None:
if not normalized_subject:
logger.debug("跳过疑似机器人自身信息的记忆: %s", mem_data)
continue
# 创建记忆块
importance_level = self._parse_enum_value(
ImportanceLevel,
mem_data.get("importance"),
ImportanceLevel.NORMAL,
"importance"
)
confidence_level = self._parse_enum_value(
ConfidenceLevel,
mem_data.get("confidence"),
ConfidenceLevel.MEDIUM,
"confidence"
)
memory = create_memory_chunk(
user_id=user_id,
subject=normalized_subject,
@@ -502,8 +540,9 @@ class MemoryBuilder:
memory_type=MemoryType(mem_data.get("type", "contextual")),
chat_id=context.get("chat_id"),
source_context=mem_data.get("reasoning", ""),
importance=ImportanceLevel(mem_data.get("importance", 2)),
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
importance=importance_level,
confidence=confidence_level,
display=mem_data.get("display")
)
# 添加关键词
@@ -511,13 +550,6 @@ class MemoryBuilder:
for keyword in keywords:
memory.add_keyword(keyword)
subject_text = memory.content.subject.strip() if isinstance(memory.content.subject, str) else str(memory.content.subject)
if not subject_text:
memory.content.subject = default_subject
elif subject_text.lower() in system_identifiers or self._looks_like_system_identifier(subject_text):
logger.debug("将系统标识主语替换为默认用户名称: %s", subject_text)
memory.content.subject = default_subject
memories.append(memory)
except Exception as e:
@@ -526,6 +558,64 @@ class MemoryBuilder:
return memories
def _parse_enum_value(
self,
enum_cls: Type[Enum],
raw_value: Any,
default: Enum,
field_name: str
) -> Enum:
"""解析枚举值,兼容数字/字符串表示"""
if isinstance(raw_value, enum_cls):
return raw_value
if raw_value is None:
return default
# 直接尝试整数转换
if isinstance(raw_value, (int, float)):
int_value = int(raw_value)
try:
return enum_cls(int_value)
except ValueError:
logger.debug("%s=%s 无法解析为 %s", field_name, raw_value, enum_cls.__name__)
return default
if isinstance(raw_value, str):
value_str = raw_value.strip()
if not value_str:
return default
if value_str.isdigit():
try:
return enum_cls(int(value_str))
except ValueError:
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
else:
normalized = value_str.replace("-", "_").replace(" ", "_").upper()
for member in enum_cls:
if member.name == normalized:
return member
for member in enum_cls:
if str(member.value).lower() == value_str.lower():
return member
try:
return enum_cls(value_str)
except ValueError:
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
try:
return enum_cls(raw_value)
except Exception:
logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
field_name,
raw_value,
type(raw_value).__name__,
enum_cls.__name__,
default.name)
return default
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
identifiers: set[str] = {"bot", "机器人", "ai助手"}
if not context:
@@ -580,6 +670,58 @@ class MemoryBuilder:
return identifiers
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
participants: List[str] = []
if context:
candidate_keys = [
"participants",
"participant_names",
"speaker_names",
"members",
"member_names",
"mention_users",
"audiences"
]
for key in candidate_keys:
value = context.get(key)
if isinstance(value, (list, tuple, set)):
for item in value:
if isinstance(item, str):
cleaned = self._clean_subject_text(item)
if cleaned:
participants.append(cleaned)
elif isinstance(value, str):
for part in self._split_subject_string(value):
if part:
participants.append(part)
fallback = self._resolve_user_display(context, user_id)
if fallback:
participants.append(fallback)
if context:
bot_name = context.get("bot_name") or context.get("bot_identity")
if isinstance(bot_name, str):
cleaned = self._clean_subject_text(bot_name)
if cleaned:
participants.append(cleaned)
if not participants:
participants = ["对话参与者"]
deduplicated: List[str] = []
seen = set()
for name in participants:
key = name.lower()
if key in seen:
continue
seen.add(key)
deduplicated.append(name)
return deduplicated
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
candidate_keys = [
"user_display_name",
@@ -626,51 +768,160 @@ class MemoryBuilder:
return False
def _normalize_subject(
def _split_subject_string(self, value: str) -> List[str]:
if not value:
return []
replaced = re.sub(r"\band\b", "", value, flags=re.IGNORECASE)
replaced = replaced.replace("", "").replace("", "").replace("", "")
replaced = replaced.replace("&", "").replace("/", "").replace("+", "")
tokens = [self._clean_subject_text(token) for token in re.split(r"[、,;]+", replaced)]
return [token for token in tokens if token]
def _normalize_subjects(
self,
subject: Any,
bot_identifiers: set[str],
system_identifiers: set[str],
default_subject: str,
default_subjects: List[str],
bot_display: Optional[str] = None
) -> Optional[str]:
if subject is None:
return default_subject
) -> List[str]:
defaults = default_subjects or ["对话参与者"]
subject_str = subject if isinstance(subject, str) else str(subject)
cleaned = self._clean_subject_text(subject_str)
if not cleaned:
return default_subject
raw_candidates: List[str] = []
if isinstance(subject, list):
for item in subject:
if isinstance(item, str):
raw_candidates.extend(self._split_subject_string(item))
elif item is not None:
raw_candidates.extend(self._split_subject_string(str(item)))
elif isinstance(subject, str):
raw_candidates.extend(self._split_subject_string(subject))
elif subject is not None:
raw_candidates.extend(self._split_subject_string(str(subject)))
lowered = cleaned.lower()
normalized: List[str] = []
bot_primary = self._clean_subject_text(bot_display or "")
if lowered in bot_identifiers:
return bot_primary or cleaned
for candidate in raw_candidates:
if not candidate:
continue
if lowered in {"用户", "user", "the user", "对方", "对手"}:
return default_subject
lowered = candidate.lower()
if lowered in bot_identifiers:
normalized.append(bot_primary or candidate)
continue
prefix_match = re.match(r"^(用户|User|user|USER|成员|member|Member|target|Target|TARGET)[\s:\-\u2014_]*?(.*)$", cleaned)
if prefix_match:
remainder = self._clean_subject_text(prefix_match.group(2))
if not remainder:
return default_subject
remainder_lower = remainder.lower()
if remainder_lower in bot_identifiers:
return bot_primary or remainder
if (
remainder_lower in system_identifiers
or self._looks_like_system_identifier(remainder)
):
return default_subject
cleaned = remainder
lowered = cleaned.lower()
if lowered in {"用户", "user", "the user", "对方", "对手"}:
normalized.extend(defaults)
continue
if lowered in system_identifiers or self._looks_like_system_identifier(cleaned):
return default_subject
if lowered in system_identifiers or self._looks_like_system_identifier(candidate):
continue
return cleaned
normalized.append(candidate)
if not normalized:
normalized = list(defaults)
deduplicated: List[str] = []
seen = set()
for name in normalized:
key = name.lower()
if key in seen:
continue
seen.add(key)
deduplicated.append(name)
return deduplicated
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
if isinstance(obj, dict):
for key in keys:
value = obj.get(key)
if value is None:
continue
if isinstance(value, list):
compact = "".join(str(item) for item in value[:3])
if compact:
return compact
else:
value_str = str(value).strip()
if value_str:
return value_str
elif isinstance(obj, list):
compact = "".join(str(item) for item in obj[:3])
return compact or None
elif isinstance(obj, str):
return obj.strip() or None
return None
def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str:
subject_phrase = "".join(subjects) if subjects else "对话参与者"
predicate = (predicate or "").strip()
if predicate == "is_named":
name = self._extract_value_from_object(obj, ["name", "nickname"]) or ""
name = self._clean_subject_text(name)
if name:
quoted = name if (name.startswith("") and name.endswith("")) else f"{name}"
return f"{subject_phrase}的昵称是{quoted}"
elif predicate == "is_age":
age = self._extract_value_from_object(obj, ["age"]) or ""
age = self._clean_subject_text(age)
if age:
return f"{subject_phrase}今年{age}"
elif predicate == "is_profession":
profession = self._extract_value_from_object(obj, ["profession", "job"]) or ""
profession = self._clean_subject_text(profession)
if profession:
return f"{subject_phrase}的职业是{profession}"
elif predicate == "lives_in":
location = self._extract_value_from_object(obj, ["location", "city", "place"]) or ""
location = self._clean_subject_text(location)
if location:
return f"{subject_phrase}居住在{location}"
elif predicate == "has_phone":
phone = self._extract_value_from_object(obj, ["phone", "number"]) or ""
phone = self._clean_subject_text(phone)
if phone:
return f"{subject_phrase}的电话号码是{phone}"
elif predicate == "has_email":
email = self._extract_value_from_object(obj, ["email"]) or ""
email = self._clean_subject_text(email)
if email:
return f"{subject_phrase}的邮箱是{email}"
elif predicate in {"likes", "likes_food", "favorite_is"}:
liked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
liked = self._clean_subject_text(liked)
if liked:
verb = "喜欢" if predicate != "likes_food" else "爱吃"
if predicate == "favorite_is":
verb = "最喜欢"
return f"{subject_phrase}{verb}{liked}"
elif predicate in {"dislikes", "hates"}:
disliked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
disliked = self._clean_subject_text(disliked)
if disliked:
verb = "不喜欢" if predicate == "dislikes" else "讨厌"
return f"{subject_phrase}{verb}{disliked}"
elif predicate == "mentioned_event":
description = self._extract_value_from_object(obj, ["event_text", "description"]) or ""
description = self._clean_subject_text(description)
if description:
return f"{subject_phrase}提到了:{description}"
obj_text = self._extract_value_from_object(obj, ["value", "detail", "content"]) or ""
obj_text = self._clean_subject_text(obj_text)
if predicate and obj_text:
return f"{subject_phrase}{predicate}{obj_text}".strip()
if obj_text:
return f"{subject_phrase}{obj_text}".strip()
if predicate:
return f"{subject_phrase}{predicate}".strip()
return subject_phrase
def _extract_personal_info(
self,
@@ -678,7 +929,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取个人信息"""
memories = []
@@ -702,13 +953,14 @@ class MemoryBuilder:
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate=predicate,
obj=obj,
memory_type=MemoryType.PERSONAL_FACT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.HIGH,
confidence=ConfidenceLevel.HIGH
confidence=ConfidenceLevel.HIGH,
display=self._compose_display_text(subjects, predicate, obj)
)
memories.append(memory)
@@ -721,7 +973,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取偏好信息"""
memories = []
@@ -740,13 +992,14 @@ class MemoryBuilder:
if match:
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate=predicate,
obj=match.group(1),
memory_type=MemoryType.PREFERENCE,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM
confidence=ConfidenceLevel.MEDIUM,
display=self._compose_display_text(subjects, predicate, match.group(1))
)
memories.append(memory)
@@ -759,7 +1012,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取事件信息"""
memories = []
@@ -770,13 +1023,14 @@ class MemoryBuilder:
if any(keyword in text for keyword in event_keywords):
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate="mentioned_event",
obj={"event_text": text, "timestamp": timestamp},
memory_type=MemoryType.EVENT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM
confidence=ConfidenceLevel.MEDIUM,
display=self._compose_display_text(subjects, "mentioned_event", text)
)
memories.append(memory)

View File

@@ -7,7 +7,7 @@
import time
import uuid
import orjson
from typing import Dict, List, Optional, Any, Union
from typing import Dict, List, Optional, Any, Union, Iterable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
@@ -52,17 +52,20 @@ class ImportanceLevel(Enum):
@dataclass
class ContentStructure:
"""主谓宾三元组结构"""
subject: str # 主语(通常为用户)
predicate: str # 谓语(动作、状态、关系)
object: Union[str, Dict] # 宾语(对象、属性、值)
"""主谓宾结构,包含自然语言描述"""
subject: Union[str, List[str]]
predicate: str
object: Union[str, Dict]
display: str = ""
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"subject": self.subject,
"predicate": self.predicate,
"object": self.object
"object": self.object,
"display": self.display
}
@classmethod
@@ -71,16 +74,25 @@ class ContentStructure:
return cls(
subject=data.get("subject", ""),
predicate=data.get("predicate", ""),
object=data.get("object", "")
object=data.get("object", ""),
display=data.get("display", "")
)
def to_subject_list(self) -> List[str]:
"""将主语转换为列表形式"""
if isinstance(self.subject, list):
return [s for s in self.subject if isinstance(s, str) and s.strip()]
if isinstance(self.subject, str) and self.subject.strip():
return [self.subject.strip()]
return []
def __str__(self) -> str:
"""字符串表示"""
if isinstance(self.object, dict):
object_str = str(self.object)
else:
object_str = str(self.object)
return f"{self.subject} {self.predicate} {object_str}"
if self.display:
return self.display
subjects = "".join(self.to_subject_list()) or str(self.subject)
object_str = self.object if isinstance(self.object, str) else str(self.object)
return f"{subjects} {self.predicate} {object_str}".strip()
@dataclass
@@ -236,9 +248,19 @@ class MemoryChunk:
@property
def text_content(self) -> str:
"""获取文本内容"""
"""获取文本内容优先使用display"""
return str(self.content)
@property
def display(self) -> str:
"""获取展示文本"""
return self.content.display or str(self.content)
@property
def subjects(self) -> List[str]:
"""获取主语列表"""
return self.content.to_subject_list()
def update_access(self):
"""更新访问信息"""
self.metadata.update_access()
@@ -415,16 +437,42 @@ class MemoryChunk:
confidence_icon = "" * self.metadata.confidence.value
importance_icon = "" * self.metadata.importance.value
return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}"
return f"{emoji} [{self.memory_type.value}] {self.display} {confidence_icon} {importance_icon}"
def __repr__(self) -> str:
"""调试表示"""
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
"""根据主谓宾生成自然语言描述"""
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
subject_part = "".join(subjects_clean) if subjects_clean else "对话参与者"
if isinstance(obj, dict):
object_candidates = []
for key, value in obj.items():
if isinstance(value, (str, int, float)):
object_candidates.append(f"{key}:{value}")
elif isinstance(value, list):
compact = "".join(str(item) for item in value[:3])
object_candidates.append(f"{key}:{compact}")
object_part = "".join(object_candidates) if object_candidates else str(obj)
else:
object_part = str(obj).strip()
predicate_clean = predicate.strip()
if not predicate_clean:
return f"{subject_part} {object_part}".strip()
if object_part:
return f"{subject_part}{predicate_clean}{object_part}".strip()
return f"{subject_part}{predicate_clean}".strip()
def create_memory_chunk(
user_id: str,
subject: str,
subject: Union[str, List[str]],
predicate: str,
obj: Union[str, Dict],
memory_type: MemoryType,
@@ -432,6 +480,7 @@ def create_memory_chunk(
source_context: Optional[str] = None,
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
display: Optional[str] = None,
**kwargs
) -> MemoryChunk:
"""便捷的内存块创建函数"""
@@ -447,10 +496,22 @@ def create_memory_chunk(
source_context=source_context
)
subjects: List[str]
if isinstance(subject, list):
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
subject_payload: Union[str, List[str]] = subjects
else:
cleaned = subject.strip() if isinstance(subject, str) else ""
subjects = [cleaned] if cleaned else []
subject_payload = cleaned
display_text = display or _build_display_text(subjects, predicate, obj)
content = ContentStructure(
subject=subject,
subject=subject_payload,
predicate=predicate,
object=obj
object=obj,
display=display_text
)
chunk = MemoryChunk(

View File

@@ -266,8 +266,12 @@ class MemoryFusionEngine:
consistency_score = 0.0
# 主语一致性
if mem1.content.subject == mem2.content.subject:
consistency_score += 0.4
subjects1 = set(mem1.subjects)
subjects2 = set(mem2.subjects)
if subjects1 or subjects2:
overlap = len(subjects1 & subjects2)
union_count = max(len(subjects1 | subjects2), 1)
consistency_score += (overlap / union_count) * 0.4
# 谓语相似性
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)

View File

@@ -282,9 +282,11 @@ class MemoryIntegrationHooks:
}
# 使用增强记忆系统处理对话
result = await process_conversation_with_enhanced_memory(
conversation_text, context, user_id
)
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
@@ -336,9 +338,11 @@ class MemoryIntegrationHooks:
}
# 使用增强记忆系统处理对话
result = await process_conversation_with_enhanced_memory(
conversation_text, context, user_id
)
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)

View File

@@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-
"""记忆检索查询规划器"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import orjson
from src.chat.memory_system.memory_chunk import MemoryType
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@dataclass
class MemoryQueryPlan:
"""查询规划结果"""
semantic_query: str
memory_types: List[MemoryType] = field(default_factory=list)
subject_includes: List[str] = field(default_factory=list)
object_includes: List[str] = field(default_factory=list)
required_keywords: List[str] = field(default_factory=list)
optional_keywords: List[str] = field(default_factory=list)
owner_filters: List[str] = field(default_factory=list)
recency_preference: str = "any"
limit: int = 10
emphasis: Optional[str] = None
raw_plan: Dict[str, Any] = field(default_factory=dict)
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
if not self.semantic_query:
self.semantic_query = fallback_query
if self.limit <= 0:
self.limit = default_limit
self.recency_preference = (self.recency_preference or "any").lower()
if self.recency_preference not in {"any", "recent", "historical"}:
self.recency_preference = "any"
self.emphasis = (self.emphasis or "balanced").lower()
class MemoryQueryPlanner:
"""基于小模型的记忆检索查询规划器"""
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
self.model = planner_model
self.default_limit = default_limit
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
if not self.model:
logger.debug("未提供查询规划模型,使用默认规划")
return self._default_plan(query_text)
prompt = self._build_prompt(query_text, context)
try:
response, _ = await self.model.generate_response_async(prompt, temperature=0.2)
payload = self._extract_json_payload(response)
if not payload:
logger.debug("查询规划模型未返回结构化结果,使用默认规划")
return self._default_plan(query_text)
try:
data = orjson.loads(payload)
except orjson.JSONDecodeError as exc:
preview = payload[:200]
logger.warning("解析查询规划JSON失败: %s,片段: %s", exc, preview)
return self._default_plan(query_text)
plan = self._parse_plan_dict(data, query_text)
plan.ensure_defaults(query_text, self.default_limit)
return plan
except Exception as exc:
logger.error("查询规划模型调用失败: %s", exc, exc_info=True)
return self._default_plan(query_text)
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
return MemoryQueryPlan(
semantic_query=query_text,
limit=self.default_limit
)
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
def _collect_list(key: str) -> List[str]:
value = data.get(key)
if isinstance(value, str):
return [value]
if isinstance(value, list):
return [self._safe_str(item) for item in value if self._safe_str(item)]
return []
memory_type_values = _collect_list("memory_types")
memory_types: List[MemoryType] = []
for item in memory_type_values:
if not item:
continue
try:
memory_types.append(MemoryType(item))
except ValueError:
# 尝试匹配value值
normalized = item.lower()
for mt in MemoryType:
if mt.value == normalized:
memory_types.append(mt)
break
plan = MemoryQueryPlan(
semantic_query=semantic_query,
memory_types=memory_types,
subject_includes=_collect_list("subject_includes"),
object_includes=_collect_list("object_includes"),
required_keywords=_collect_list("required_keywords"),
optional_keywords=_collect_list("optional_keywords"),
owner_filters=_collect_list("owner_filters"),
recency_preference=self._safe_str(data.get("recency")) or "any",
limit=self._safe_int(data.get("limit"), self.default_limit),
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
raw_plan=data
)
return plan
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
participants = context.get("participants") or context.get("speaker_names") or []
if isinstance(participants, str):
participants = [participants]
participants = [p for p in participants if isinstance(p, str) and p.strip()]
participant_preview = "".join(participants[:5]) or "未知"
persona = context.get("bot_personality") or context.get("bot_identity") or "未知"
return f"""
你是一名记忆检索分析师,将根据对话查询生成结构化的检索计划。
请结合提供的上下文输出一个JSON对象字段含义如下
- semantic_query: 提供给向量检索的自然语言查询,要求清晰具体;
- memory_types: 建议检索的记忆类型数组,取值范围参见 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual)
- subject_includes: 需要出现在记忆主语中的人或角色列表;
- object_includes: 记忆中需要提到的重要对象或主题关键词列表;
- required_keywords: 检索时必须包含的关键词;
- optional_keywords: 可以提升相关性的附加关键词;
- owner_filters: 如果需要限制检索所属主体请列出用户ID或其它标识
- recency: 建议的时间偏好,可选 recent/any/historical
- emphasis: 检索策略倾向,可选 precision/recall/balanced
- limit: 推荐的最大返回数量(1-15之间)
- notes: 额外说明,可选。
当前查询: "{query_text}"
已知的对话参与者: {participant_preview}
机器人设定: {persona}
请输出符合要求的JSON禁止添加额外说明或Markdown代码块。
"""
def _extract_json_payload(self, response: str) -> Optional[str]:
if not response:
return None
stripped = response.strip()
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
if code_block_match:
candidate = code_block_match.group(1).strip()
if candidate:
return candidate
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1]
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
@staticmethod
def _safe_str(value: Any) -> str:
if isinstance(value, str):
return value.strip()
if value is None:
return ""
return str(value).strip()
@staticmethod
def _safe_int(value: Any, default: int) -> int:
try:
number = int(value)
if number <= 0:
return default
return number
except (TypeError, ValueError):
return default

View File

@@ -25,6 +25,7 @@ class IndexType(Enum):
"""索引类型"""
MEMORY_TYPE = "memory_type" # 记忆类型索引
USER_ID = "user_id" # 用户ID索引
SUBJECT = "subject" # 主体索引
KEYWORD = "keyword" # 关键词索引
TAG = "tag" # 标签索引
CATEGORY = "category" # 分类索引
@@ -41,6 +42,7 @@ class IndexQuery:
"""索引查询条件"""
user_ids: Optional[List[str]] = None
memory_types: Optional[List[MemoryType]] = None
subjects: Optional[List[str]] = None
keywords: Optional[List[str]] = None
tags: Optional[List[str]] = None
categories: Optional[List[str]] = None
@@ -76,6 +78,7 @@ class MetadataIndexManager:
self.indices = {
IndexType.MEMORY_TYPE: defaultdict(set),
IndexType.USER_ID: defaultdict(set),
IndexType.SUBJECT: defaultdict(set),
IndexType.KEYWORD: defaultdict(set),
IndexType.TAG: defaultdict(set),
IndexType.CATEGORY: defaultdict(set),
@@ -110,6 +113,41 @@ class MetadataIndexManager:
self.auto_save_interval = 500 # 每500次操作自动保存
self._operation_count = 0
@staticmethod
def _serialize_index_key(index_type: IndexType, key: Any) -> str:
"""将索引键序列化为字符串以便存储"""
if isinstance(key, Enum):
value = key.value
else:
value = key
return str(value)
@staticmethod
def _deserialize_index_key(index_type: IndexType, key: str) -> Any:
"""根据索引类型反序列化索引键"""
try:
if index_type == IndexType.MEMORY_TYPE:
return MemoryType(key)
if index_type == IndexType.CONFIDENCE:
return ConfidenceLevel(int(key))
if index_type == IndexType.IMPORTANCE:
return ImportanceLevel(int(key))
# 其他索引键默认使用原始字符串可能已经是lower后的字符串
return key
except Exception:
logger.warning("无法反序列化索引键 %s 在索引 %s 中,使用原始字符串", key, index_type.value)
return key
@staticmethod
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
serialized = {}
for field_name, value in metadata.items():
if isinstance(value, Enum):
serialized[field_name] = value.value
else:
serialized[field_name] = value
return serialized
async def index_memories(self, memories: List[MemoryChunk]):
"""为记忆建立索引"""
if not memories:
@@ -142,6 +180,68 @@ class MetadataIndexManager:
except Exception as e:
logger.error(f"❌ 元数据索引失败: {e}", exc_info=True)
async def update_memory_entry(self, memory: MemoryChunk):
"""更新已存在记忆的索引信息"""
if not memory:
return
with self._lock:
entry = self.memory_metadata_cache.get(memory.memory_id)
if entry is None:
# 若不存在则作为新记忆索引
self._index_single_memory(memory)
return
old_confidence = entry.get("confidence")
old_importance = entry.get("importance")
old_semantic_hash = entry.get("semantic_hash")
entry.update(
{
"user_id": memory.user_id,
"memory_type": memory.memory_type,
"created_at": memory.metadata.created_at,
"last_accessed": memory.metadata.last_accessed,
"access_count": memory.metadata.access_count,
"confidence": memory.metadata.confidence,
"importance": memory.metadata.importance,
"relationship_score": memory.metadata.relationship_score,
"relevance_score": memory.metadata.relevance_score,
"semantic_hash": memory.semantic_hash,
"subjects": memory.subjects,
}
)
# 更新置信度/重要性索引
if isinstance(old_confidence, ConfidenceLevel):
self.indices[IndexType.CONFIDENCE][old_confidence].discard(memory.memory_id)
if isinstance(old_importance, ImportanceLevel):
self.indices[IndexType.IMPORTANCE][old_importance].discard(memory.memory_id)
if isinstance(old_semantic_hash, str):
self.indices[IndexType.SEMANTIC_HASH][old_semantic_hash].discard(memory.memory_id)
self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory.memory_id)
self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory.memory_id)
if memory.semantic_hash:
self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory.memory_id)
# 同步关键词/标签/分类索引
for keyword in memory.keywords:
if keyword:
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory.memory_id)
for tag in memory.tags:
if tag:
self.indices[IndexType.TAG][tag.lower()].add(memory.memory_id)
for category in memory.categories:
if category:
self.indices[IndexType.CATEGORY][category.lower()].add(memory.memory_id)
for subject in memory.subjects:
if subject:
self.indices[IndexType.SUBJECT][subject.strip().lower()].add(memory.memory_id)
def _index_single_memory(self, memory: MemoryChunk):
"""为单个记忆建立索引"""
memory_id = memory.memory_id
@@ -157,7 +257,8 @@ class MetadataIndexManager:
"importance": memory.metadata.importance,
"relationship_score": memory.metadata.relationship_score,
"relevance_score": memory.metadata.relevance_score,
"semantic_hash": memory.semantic_hash
"semantic_hash": memory.semantic_hash,
"subjects": memory.subjects
}
# 记忆类型索引
@@ -166,6 +267,12 @@ class MetadataIndexManager:
# 用户ID索引
self.indices[IndexType.USER_ID][memory.user_id].add(memory_id)
# 主体索引
for subject in memory.subjects:
normalized = subject.strip().lower()
if normalized:
self.indices[IndexType.SUBJECT][normalized].add(memory_id)
# 关键词索引
for keyword in memory.keywords:
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id)
@@ -282,13 +389,6 @@ class MetadataIndexManager:
# 应用最严格的过滤条件
applied_filters = []
if query.user_ids:
user_ids_set = set()
for user_id in query.user_ids:
user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set()))
candidate_ids.update(user_ids_set)
applied_filters.append("user_ids")
if query.memory_types:
memory_types_set = set()
for memory_type in query.memory_types:
@@ -302,7 +402,7 @@ class MetadataIndexManager:
if query.keywords:
keywords_set = set()
for keyword in query.keywords:
keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set()))
keywords_set.update(self._collect_index_matches(IndexType.KEYWORD, keyword))
if applied_filters:
candidate_ids &= keywords_set
else:
@@ -329,12 +429,55 @@ class MetadataIndexManager:
candidate_ids.update(categories_set)
applied_filters.append("categories")
if query.subjects:
subjects_set = set()
for subject in query.subjects:
subjects_set.update(self._collect_index_matches(IndexType.SUBJECT, subject))
if applied_filters:
candidate_ids &= subjects_set
else:
candidate_ids.update(subjects_set)
applied_filters.append("subjects")
# 如果没有应用任何过滤条件,返回所有记忆
if not applied_filters:
return all_memory_ids
return candidate_ids
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
"""根据给定token收集索引匹配支持部分匹配"""
mapping = self.indices.get(index_type)
if mapping is None:
return set()
key = ""
if isinstance(token, Enum):
key = str(token.value).strip().lower()
elif isinstance(token, str):
key = token.strip().lower()
elif token is not None:
key = str(token).strip().lower()
if not key:
return set()
matches: Set[str] = set(mapping.get(key, set()))
if matches:
return set(matches)
for existing_key, ids in mapping.items():
if not existing_key or not isinstance(existing_key, str):
continue
normalized = existing_key.strip().lower()
if not normalized:
continue
if key in normalized or normalized in key:
matches.update(ids)
return matches
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
"""应用过滤条件"""
filtered_ids = list(candidate_ids)
@@ -440,10 +583,10 @@ class MetadataIndexManager:
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
"""获取应用的过滤器列表"""
filters = []
if query.user_ids:
filters.append("user_ids")
if query.memory_types:
filters.append("memory_types")
if query.subjects:
filters.append("subjects")
if query.keywords:
filters.append("keywords")
if query.tags:
@@ -502,6 +645,18 @@ class MetadataIndexManager:
# 从各类索引中移除
self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id)
self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id)
subjects = metadata.get("subjects") or []
for subject in subjects:
if not isinstance(subject, str):
continue
normalized = subject.strip().lower()
if not normalized:
continue
subject_bucket = self.indices[IndexType.SUBJECT].get(normalized)
if subject_bucket is not None:
subject_bucket.discard(memory_id)
if not subject_bucket:
self.indices[IndexType.SUBJECT].pop(normalized, None)
# 从时间索引中移除
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id]
@@ -625,11 +780,13 @@ class MetadataIndexManager:
logger.info("正在保存元数据索引...")
# 保存各类索引
indices_data = {}
indices_data: Dict[str, Dict[str, List[str]]] = {}
for index_type, index_data in self.indices.items():
indices_data[index_type.value] = {
key: list(values) for key, values in index_data.items()
}
serialized_index = {}
for key, values in index_data.items():
serialized_key = self._serialize_index_key(index_type, key)
serialized_index[serialized_key] = list(values)
indices_data[index_type.value] = serialized_index
indices_file = self.index_path / "indices.json"
with open(indices_file, 'w', encoding='utf-8') as f:
@@ -652,8 +809,12 @@ class MetadataIndexManager:
# 保存元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
metadata_serialized = {
memory_id: self._serialize_metadata_entry(metadata)
for memory_id, metadata in self.memory_metadata_cache.items()
}
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存统计信息
stats_file = self.index_path / "index_stats.json"
@@ -679,9 +840,11 @@ class MetadataIndexManager:
for index_type_value, index_data in indices_data.items():
index_type = IndexType(index_type_value)
self.indices[index_type] = {
key: set(values) for key, values in index_data.items()
}
restored_index = defaultdict(set)
for key_str, values in index_data.items():
restored_key = self._deserialize_index_key(index_type, key_str)
restored_index[restored_key] = set(values)
self.indices[index_type] = restored_index
# 加载时间索引
time_index_file = self.index_path / "time_index.json"
@@ -709,10 +872,38 @@ class MetadataIndexManager:
# 转换置信度和重要性为枚举类型
for memory_id, metadata in cache_data.items():
if isinstance(metadata["confidence"], str):
metadata["confidence"] = ConfidenceLevel(metadata["confidence"])
if isinstance(metadata["importance"], str):
metadata["importance"] = ImportanceLevel(metadata["importance"])
memory_type_value = metadata.get("memory_type")
if isinstance(memory_type_value, str):
try:
metadata["memory_type"] = MemoryType(memory_type_value)
except ValueError:
logger.warning("无法解析memory_type %s", memory_type_value)
confidence_value = metadata.get("confidence")
if isinstance(confidence_value, (str, int)):
try:
metadata["confidence"] = ConfidenceLevel(int(confidence_value))
except ValueError:
logger.warning("无法解析confidence %s", confidence_value)
importance_value = metadata.get("importance")
if isinstance(importance_value, (str, int)):
try:
metadata["importance"] = ImportanceLevel(int(importance_value))
except ValueError:
logger.warning("无法解析importance %s", importance_value)
subjects_value = metadata.get("subjects")
if isinstance(subjects_value, str):
metadata["subjects"] = [subjects_value]
elif isinstance(subjects_value, list):
cleaned_subjects = []
for item in subjects_value:
if isinstance(item, str) and item.strip():
cleaned_subjects.append(item.strip())
metadata["subjects"] = cleaned_subjects
else:
metadata["subjects"] = []
self.memory_metadata_cache = cache_data

View File

@@ -203,11 +203,17 @@ class MultiStageRetrieval:
try:
from .metadata_index import IndexQuery
# 构建索引查询
query_plan = context.get("query_plan")
memory_types = self._extract_memory_types_from_context(context)
keywords = self._extract_keywords_from_query(query, query_plan)
subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
index_query = IndexQuery(
user_ids=[user_id],
memory_types=self._extract_memory_types_from_context(context),
keywords=self._extract_keywords_from_query(query),
user_ids=None,
memory_types=memory_types,
subjects=subjects,
keywords=keywords,
limit=self.config.metadata_filter_limit,
sort_by="last_accessed",
sort_order="desc"
@@ -215,13 +221,66 @@ class MultiStageRetrieval:
# 执行查询
result = await metadata_index.query_memories(index_query)
filtered_count = result.total_count - len(result.memory_ids)
result_ids = list(result.memory_ids)
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆")
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
if not result_ids:
sorted_ids = sorted(
(memory.memory_id for memory in all_memories_cache.values()),
key=lambda mid: all_memories_cache[mid].metadata.last_accessed if mid in all_memories_cache else 0,
reverse=True,
)
if memory_types:
type_filtered = [
mid for mid in sorted_ids
if all_memories_cache[mid].memory_type in memory_types
]
sorted_ids = type_filtered or sorted_ids
if subjects:
subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()]
if subject_candidates:
subject_filtered = [
mid for mid in sorted_ids
if any(
subj.strip().lower() in subject_candidates
for subj in all_memories_cache[mid].subjects
)
]
sorted_ids = subject_filtered or sorted_ids
if keywords:
keyword_pool = {kw.lower() for kw in keywords if isinstance(kw, str) and kw.strip()}
if keyword_pool:
keyword_filtered = []
for mid in sorted_ids:
memory_text = (
(all_memories_cache[mid].display or "")
+ "\n"
+ (all_memories_cache[mid].text_content or "")
).lower()
if any(kw in memory_text for kw in keyword_pool):
keyword_filtered.append(mid)
sorted_ids = keyword_filtered or sorted_ids
result_ids = sorted_ids[: self.config.metadata_filter_limit]
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
logger.debug(
"元数据过滤未命中索引,使用近似回退: types=%s, subjects=%s, keywords=%s",
bool(memory_types),
bool(subjects),
bool(keywords),
)
logger.debug(
"元数据过滤:候选=%d, 返回=%d",
len(all_memories_cache),
len(result_ids),
)
return StageResult(
stage=RetrievalStage.METADATA_FILTERING,
memory_ids=result.memory_ids,
memory_ids=result_ids,
processing_time=time.time() - start_time,
filtered_count=filtered_count,
score_threshold=0.0
@@ -251,7 +310,7 @@ class MultiStageRetrieval:
try:
# 生成查询向量
query_embedding = await self._generate_query_embedding(query, context)
query_embedding = await self._generate_query_embedding(query, context, vector_storage)
if not query_embedding:
return StageResult(
@@ -263,22 +322,24 @@ class MultiStageRetrieval:
)
# 执行向量搜索
search_result = await vector_storage.search_similar(
query_embedding,
search_result = await vector_storage.search_similar_memories(
query_vector=query_embedding,
limit=self.config.vector_search_limit
)
candidate_pool = candidate_ids or set(all_memories_cache.keys())
# 过滤候选记忆
filtered_memories = []
for memory_id, similarity in search_result:
if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold:
if memory_id in candidate_pool and similarity >= self.config.vector_similarity_threshold:
filtered_memories.append((memory_id, similarity))
# 按相似度排序
filtered_memories.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
filtered_count = len(candidate_ids) - len(result_ids)
filtered_count = max(0, len(candidate_pool) - len(result_ids))
logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
@@ -407,12 +468,20 @@ class MultiStageRetrieval:
score_threshold=0.0
)
async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]:
async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]:
"""生成查询向量"""
try:
# 这里应该调用embedding模型
# 由于我们可能没有直接的embedding模型返回None或使用简单的方法
# 在实际实现中这里应该调用与记忆存储相同的embedding模型
query_plan = context.get("query_plan")
query_text = query
if query_plan and getattr(query_plan, "semantic_query", None):
query_text = query_plan.semantic_query
if not query_text:
return None
if hasattr(vector_storage, "generate_query_embedding"):
return await vector_storage.generate_query_embedding(query_text)
return None
except Exception as e:
logger.warning(f"生成查询向量失败: {e}")
@@ -421,9 +490,15 @@ class MultiStageRetrieval:
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
"""计算语义相似度"""
try:
query_plan = context.get("query_plan")
query_text = query
if query_plan and getattr(query_plan, "semantic_query", None):
query_text = query_plan.semantic_query
# 简单的文本相似度计算
query_words = set(query.lower().split())
memory_words = set(memory.text_content.lower().split())
query_words = set(query_text.lower().split())
memory_text = (memory.display or memory.text_content or "").lower()
memory_words = set(memory_text.split())
if not query_words or not memory_words:
return 0.0
@@ -443,10 +518,15 @@ class MultiStageRetrieval:
try:
score = 0.0
query_plan = context.get("query_plan")
# 检查记忆类型是否匹配上下文
if context.get("expected_memory_types"):
if memory.memory_type in context["expected_memory_types"]:
score += 0.3
elif query_plan and getattr(query_plan, "memory_types", None):
if memory.memory_type in query_plan.memory_types:
score += 0.3
# 检查关键词匹配
if context.get("keywords"):
@@ -456,6 +536,35 @@ class MultiStageRetrieval:
if overlap:
score += len(overlap) / max(len(context_keywords), 1) * 0.4
if query_plan:
# 主体匹配
subject_score = self._calculate_subject_overlap(memory, getattr(query_plan, "subject_includes", []))
score += subject_score * 0.3
# 对象/描述匹配
object_keywords = getattr(query_plan, "object_includes", []) or []
if object_keywords:
display_text = (memory.display or memory.text_content or "").lower()
hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
if hits:
score += min(0.3, hits * 0.1)
optional_keywords = getattr(query_plan, "optional_keywords", []) or []
if optional_keywords:
display_text = (memory.display or memory.text_content or "").lower()
hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
if hits:
score += min(0.2, hits * 0.05)
# 时间偏好
recency_pref = getattr(query_plan, "recency_preference", "")
if recency_pref:
memory_age = time.time() - memory.metadata.created_at
if recency_pref == "recent" and memory_age < 7 * 24 * 3600:
score += 0.2
elif recency_pref == "historical" and memory_age > 30 * 24 * 3600:
score += 0.1
# 检查时效性
if context.get("recent_only", False):
memory_age = time.time() - memory.metadata.created_at
@@ -471,6 +580,8 @@ class MultiStageRetrieval:
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
"""计算最终评分"""
try:
query_plan = context.get("query_plan")
# 语义相似度
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
@@ -482,13 +593,29 @@ class MultiStageRetrieval:
# 时效性评分
recency_score = self._calculate_recency_score(memory.metadata.created_at)
if query_plan:
recency_pref = getattr(query_plan, "recency_preference", "")
if recency_pref == "recent":
recency_score = max(recency_score, 0.8)
elif recency_pref == "historical":
recency_score = min(recency_score, 0.5)
# 权重组合
vector_weight = self.config.vector_weight
semantic_weight = self.config.semantic_weight
context_weight = self.config.context_weight
recency_weight = self.config.recency_weight
if query_plan and getattr(query_plan, "emphasis", None) == "precision":
semantic_weight += 0.05
elif query_plan and getattr(query_plan, "emphasis", None) == "recall":
context_weight += 0.05
final_score = (
semantic_score * self.config.semantic_weight +
vector_score * self.config.vector_weight +
context_score * self.config.context_weight +
recency_score * self.config.recency_weight
semantic_score * semantic_weight +
vector_score * vector_weight +
context_score * context_weight +
recency_score * recency_weight
)
# 加入记忆重要性权重
@@ -501,6 +628,31 @@ class MultiStageRetrieval:
logger.warning(f"计算最终评分失败: {e}")
return 0.0
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
if not required_subjects:
return 0.0
memory_subjects = {subject.lower() for subject in memory.subjects if isinstance(subject, str)}
if not memory_subjects:
return 0.0
hit = 0
total = 0
for subject in required_subjects:
if not isinstance(subject, str):
continue
total += 1
normalized = subject.strip().lower()
if not normalized:
continue
if any(normalized in mem_subject for mem_subject in memory_subjects):
hit += 1
if total == 0:
return 0.0
return hit / total
def _calculate_recency_score(self, timestamp: float) -> float:
"""计算时效性评分"""
try:
@@ -524,6 +676,10 @@ class MultiStageRetrieval:
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
"""从上下文中提取记忆类型"""
try:
query_plan = context.get("query_plan")
if query_plan and getattr(query_plan, "memory_types", None):
return query_plan.memory_types
if "expected_memory_types" in context:
return context["expected_memory_types"]
@@ -544,15 +700,30 @@ class MultiStageRetrieval:
except Exception:
return []
def _extract_keywords_from_query(self, query: str) -> List[str]:
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
"""从查询中提取关键词"""
try:
extracted: List[str] = []
if query_plan and getattr(query_plan, "required_keywords", None):
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
# 简单的关键词提取
words = query.lower().split()
# 过滤停用词
stopwords = {"", "", "", "", "", "", "", "", "", "", "", "", "", ""}
keywords = [word for word in words if len(word) > 1 and word not in stopwords]
return keywords[:10] # 最多返回10个关键词
extracted.extend(word for word in words if len(word) > 1 and word not in stopwords)
# 去重并保留顺序
seen = set()
deduplicated = []
for word in extracted:
if word in seen or not word:
continue
seen.add(word)
deduplicated.append(word)
return deduplicated[:10]
except Exception:
return []

View File

@@ -20,6 +20,7 @@ from pathlib import Path
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.common.config_helpers import resolve_embedding_dimension
from src.chat.memory_system.memory_chunk import MemoryChunk
logger = get_logger(__name__)
@@ -36,12 +37,12 @@ except ImportError:
@dataclass
class VectorStorageConfig:
"""向量存储配置"""
dimension: int = 768
dimension: int = 1024
similarity_threshold: float = 0.8
index_type: str = "flat" # flat, ivf, hnsw
max_index_size: int = 100000
storage_path: str = "data/memory_vectors"
auto_save_interval: int = 100 # 每N次操作自动保存
auto_save_interval: int = 10 # 每N次操作自动保存
enable_compression: bool = True
@@ -50,6 +51,15 @@ class VectorStorageManager:
def __init__(self, config: Optional[VectorStorageConfig] = None):
self.config = config or VectorStorageConfig()
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
if resolved_dimension and resolved_dimension != self.config.dimension:
logger.info(
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
resolved_dimension,
self.config.dimension,
)
self.config.dimension = resolved_dimension
self.storage_path = Path(self.config.storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
@@ -117,6 +127,32 @@ class VectorStorageManager:
)
logger.info("✅ 嵌入模型初始化完成")
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
"""生成查询向量,用于记忆召回"""
if not query_text:
return None
try:
await self.initialize_embedding_model()
embedding, _ = await self.embedding_model.get_embedding(query_text)
if not embedding:
return None
if len(embedding) != self.config.dimension:
logger.warning(
"查询向量维度不匹配: 期望 %d, 实际 %d",
self.config.dimension,
len(embedding)
)
return None
return self._normalize_vector(embedding)
except Exception as exc:
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
return None
async def store_memories(self, memories: List[MemoryChunk]):
"""存储记忆向量"""
if not memories:
@@ -213,7 +249,7 @@ class VectorStorageManager:
results[memory_id] = embedding
else:
logger.warning(
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)",
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
self.config.dimension,
len(embedding) if embedding else 0,
memory_id,
@@ -299,14 +335,32 @@ class VectorStorageManager:
async def search_similar_memories(
self,
query_vector: List[float],
query_vector: Optional[List[float]] = None,
*,
query_text: Optional[str] = None,
limit: int = 10,
user_id: Optional[str] = None
scope_id: Optional[str] = None
) -> List[Tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
try:
if query_vector is None:
if not query_text:
return []
query_vector = await self.generate_query_embedding(query_text)
if not query_vector:
return []
scope_filter: Optional[str] = None
if isinstance(scope_id, str):
normalized_scope = scope_id.strip().lower()
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
scope_filter = scope_id
elif scope_id:
scope_filter = str(scope_id)
# 规范化查询向量
query_vector = self._normalize_vector(query_vector)
@@ -341,10 +395,9 @@ class VectorStorageManager:
memory_id = self.index_to_memory_id.get(index)
if memory_id:
# 应用用户过滤
if user_id:
if scope_filter:
memory = self.memory_cache.get(memory_id)
if memory and memory.user_id != user_id:
if memory and str(memory.user_id) != scope_filter:
continue
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
@@ -481,8 +534,14 @@ class VectorStorageManager:
# 保存映射关系
mapping_file = self.storage_path / "id_mapping.json"
mapping_data = {
"memory_id_to_index": self.memory_id_to_index,
"index_to_memory_id": self.index_to_memory_id
"memory_id_to_index": {
str(memory_id): int(index)
for memory_id, index in self.memory_id_to_index.items()
},
"index_to_memory_id": {
str(index): memory_id
for index, memory_id in self.index_to_memory_id.items()
}
}
with open(mapping_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
@@ -529,8 +588,17 @@ class VectorStorageManager:
if mapping_file.exists():
with open(mapping_file, 'r', encoding='utf-8') as f:
mapping_data = orjson.loads(f.read())
self.memory_id_to_index = mapping_data.get("memory_id_to_index", {})
self.index_to_memory_id = mapping_data.get("index_to_memory_id", {})
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
str(memory_id): int(index)
for memory_id, index in raw_memory_to_index.items()
}
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
self.index_to_memory_id = {
int(index): memory_id
for index, memory_id in raw_index_to_memory.items()
}
# 加载FAISS索引如果可用
if FAISS_AVAILABLE:

View File

@@ -469,14 +469,14 @@ class ChatBot:
async def preprocess():
# 存储消息到数据库
from .storage import MessageStorage
try:
await MessageStorage.store_message(message, message.chat_stream)
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
# 使用消息管理器处理消息(保持原有功能)
from src.common.data_models.database_data_model import DatabaseMessages

View File

@@ -373,12 +373,12 @@ class Prompt:
# 性能优化 - 为不同任务设置不同的超时时间
task_timeouts = {
"memory_block": 5.0, # 记忆系统可能较慢,单独设置超时
"tool_info": 3.0, # 工具信息中等速度
"relation_info": 2.0, # 关系信息通常较快
"knowledge_info": 3.0, # 知识库查询中等速度
"cross_context": 2.0, # 上下文处理通常较快
"expression_habits": 1.5, # 表达习惯处理很快
"memory_block": 15.0, # 记忆系统
"tool_info": 15.0, # 工具信息
"relation_info": 10.0, # 关系信息
"knowledge_info": 10.0, # 知识库查询
"cross_context": 10.0, # 上下文处理
"expression_habits": 10.0, # 表达习惯
}
# 分别处理每个任务,避免慢任务影响快任务
@@ -558,12 +558,8 @@ class Prompt:
)
]
# 等待所有记忆查询完成最多3秒
try:
running_memories, instant_memory = await asyncio.wait_for(
asyncio.gather(*memory_tasks, return_exceptions=True),
timeout=3.0
)
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
# 处理可能的异常结果
if isinstance(running_memories, Exception):

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Union
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.config_helpers import resolve_embedding_dimension
from src.common.database.sqlalchemy_models import CacheEntries
from src.common.database.sqlalchemy_database_api import db_query, db_save
from src.common.vector_db import vector_db_service
@@ -40,7 +41,11 @@ class CacheManager:
# L1 缓存 (内存)
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
if not embedding_dim:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.embedding_dimension = embedding_dim
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {}
@@ -72,7 +77,7 @@ class CacheManager:
embedding_array = embedding_array.flatten()
# 检查维度是否符合预期
expected_dim = global_config.lpmm_knowledge.embedding_dimension
expected_dim = getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension
if embedding_array.shape[0] != expected_dim:
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
return None

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from typing import Optional
from src.config.config import global_config, model_config
def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]:
"""获取当前配置的嵌入向量维度。
优先顺序:
1. 模型配置中 `model_task_config.embedding.embedding_dimension`
2. 机器人配置中 `lpmm_knowledge.embedding_dimension`
3. 调用方提供的 fallback
"""
candidates: list[Optional[int]] = []
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
candidates.append(getattr(embedding_task, "embedding_dimension", None))
except Exception:
candidates.append(None)
try:
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
except Exception:
candidates.append(None)
candidates.append(fallback)
resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
if resolved and sync_global:
try:
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
except Exception:
pass
return resolved

View File

@@ -1,4 +1,4 @@
from typing import List, Dict, Any, Literal, Union
from typing import List, Dict, Any, Literal, Union, Optional
from pydantic import Field
from threading import Lock
@@ -105,6 +105,11 @@ class TaskConfig(ValidatedConfigBase):
max_tokens: int = Field(default=800, description="任务最大输出token数")
temperature: float = Field(default=0.7, description="模型温度")
concurrency_count: int = Field(default=1, description="并发请求数量")
embedding_dimension: Optional[int] = Field(
default=None,
description="嵌入模型输出向量维度,仅在嵌入任务中使用",
ge=1,
)
@classmethod
def validate_model_list(cls, v):

View File

@@ -443,21 +443,6 @@ class MemoryConfig(ValidatedConfigBase):
enable_memory: bool = Field(default=True, description="启用记忆")
memory_build_interval: int = Field(default=600, description="记忆构建间隔")
memory_build_distribution: list[float] = Field(
default_factory=lambda: [6.0, 3.0, 0.6, 32.0, 12.0, 0.4], description="记忆构建分布"
)
memory_build_sample_num: int = Field(default=8, description="记忆构建样本数量")
memory_build_sample_length: int = Field(default=40, description="记忆构建样本长度")
memory_compress_rate: float = Field(default=0.1, description="记忆压缩率")
forget_memory_interval: int = Field(default=1000, description="遗忘记忆间隔")
memory_forget_time: int = Field(default=24, description="记忆遗忘时间")
memory_forget_percentage: float = Field(default=0.01, description="记忆遗忘百分比")
consolidate_memory_interval: int = Field(default=1000, description="记忆巩固间隔")
consolidation_similarity_threshold: float = Field(default=0.7, description="巩固相似性阈值")
consolidate_memory_percentage: float = Field(default=0.01, description="巩固记忆百分比")
memory_ban_words: list[str] = Field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"], description="记忆禁用词"
)
enable_instant_memory: bool = Field(default=True, description="启用即时记忆")
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
@@ -472,8 +457,8 @@ class MemoryConfig(ValidatedConfigBase):
memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值")
# 向量存储配置
vector_dimension: int = Field(default=768, description="向量维度")
vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值")
semantic_similarity_threshold: float = Field(default=0.6, description="语义相似度阈值")
# 多阶段检索配置
metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量")

View File

@@ -27,9 +27,8 @@ from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.component_types import EventType
# from src.api.main import start_api_server
# 导入新的插件管理器和热重载管理器
# 导入新的插件管理器
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
# 导入消息API和traceback模块
from src.common.message import get_global_api
@@ -116,13 +115,7 @@ class MainSystem:
except Exception as e:
logger.error(f"停止消息重组器时出错: {e}")
try:
# 停止插件热重载系统
hot_reload_manager.stop()
logger.info("🛑 插件热重载系统已停止")
except Exception as e:
logger.error(f"停止热重载系统时出错: {e}")
try:
# 停止增强记忆系统
if global_config.memory.enable_memory:
@@ -229,9 +222,7 @@ MoFox_Bot(第三方修改版)
# 处理所有缓存的事件订阅(插件加载完成后)
event_manager.process_all_pending_subscriptions()
# 启动插件热重载系统
hot_reload_manager.start()
# 初始化表情管理器
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")

View File

@@ -8,12 +8,10 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
__all__ = [
"plugin_manager",
"component_registry",
"event_manager",
"global_announcement_manager",
"hot_reload_manager",
]

View File

@@ -1,512 +0,0 @@
"""
插件热重载模块
使用 Watchdog 监听插件目录变化,自动重载插件
"""
import os
import sys
import time
import importlib
from pathlib import Path
from threading import Thread
from typing import Dict, Set, List, Optional, Tuple
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from src.common.logger import get_logger
from .plugin_manager import plugin_manager
logger = get_logger("plugin_hot_reload")
class PluginFileHandler(FileSystemEventHandler):
"""插件文件变化处理器"""
def __init__(self, hot_reload_manager):
super().__init__()
self.hot_reload_manager = hot_reload_manager
self.pending_reloads: Set[str] = set() # 待重载的插件名称
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
self.debounce_delay = 2.0 # 增加防抖延迟到2秒确保文件写入完成
self.file_change_cache: Dict[str, float] = {} # 文件变化缓存
def on_modified(self, event):
"""文件修改事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "modified")
def on_created(self, event):
"""文件创建事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "created")
def on_deleted(self, event):
"""文件删除事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "deleted")
def _handle_file_change(self, file_path: str, change_type: str):
"""处理文件变化"""
try:
# 获取插件名称
plugin_info = self._get_plugin_info_from_path(file_path)
if not plugin_info:
return
plugin_name, source_type = plugin_info
current_time = time.time()
# 文件变化缓存,避免重复处理同一文件的快速连续变化
file_cache_key = f"{file_path}_{change_type}"
last_file_time = self.file_change_cache.get(file_cache_key, 0)
if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略
return
self.file_change_cache[file_cache_key] = current_time
# 插件级别的防抖处理
last_plugin_time = self.last_reload_time.get(plugin_name, 0)
if current_time - last_plugin_time < self.debounce_delay:
# 如果在防抖期内,更新待重载标记但不立即处理
self.pending_reloads.add(plugin_name)
return
file_name = Path(file_path).name
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type}) [{source_type}] -> {plugin_name}")
# 如果是删除事件,处理关键文件删除
if change_type == "deleted":
# 解析实际的插件名称
actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name)
if file_name == "plugin.py":
if actual_plugin_name in plugin_manager.loaded_plugins:
logger.info(
f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
self.hot_reload_manager._unload_plugin(actual_plugin_name)
else:
logger.info(
f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
return
elif file_name in ("manifest.toml", "_manifest.json"):
if actual_plugin_name in plugin_manager.loaded_plugins:
logger.info(
f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
self.hot_reload_manager._unload_plugin(actual_plugin_name)
else:
logger.info(
f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
return
# 对于修改和创建事件,都进行重载
# 添加到待重载列表
self.pending_reloads.add(plugin_name)
self.last_reload_time[plugin_name] = current_time
# 延迟重载,确保文件写入完成
reload_thread = Thread(
target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True
)
reload_thread.start()
except Exception as e:
logger.error(f"❌ 处理文件变化时发生错误: {e}", exc_info=True)
def _delayed_reload(self, plugin_name: str, source_type: str, trigger_time: float):
"""延迟重载插件"""
try:
# 等待文件写入完成
time.sleep(self.debounce_delay)
# 检查是否还需要重载(可能在等待期间有更新的变化)
if plugin_name not in self.pending_reloads:
return
# 检查是否有更新的重载请求
if self.last_reload_time.get(plugin_name, 0) > trigger_time:
return
self.pending_reloads.discard(plugin_name)
logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]")
# 执行深度重载
success = self.hot_reload_manager._deep_reload_plugin(plugin_name)
if success:
logger.info(f"✅ 插件重载成功: {plugin_name} [{source_type}]")
else:
logger.error(f"❌ 插件重载失败: {plugin_name} [{source_type}]")
except Exception as e:
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]:
"""从文件路径获取插件信息
Returns:
tuple[插件名称, 源类型] 或 None
"""
try:
path = Path(file_path)
# 检查是否在任何一个监听的插件目录中
for watch_dir in self.hot_reload_manager.watch_directories:
plugin_root = Path(watch_dir)
if path.is_relative_to(plugin_root):
# 确定源类型
if "src" in str(plugin_root):
source_type = "built-in"
else:
source_type = "external"
# 获取插件目录名(插件名)
relative_path = path.relative_to(plugin_root)
if len(relative_path.parts) == 0:
continue
plugin_name = relative_path.parts[0]
# 确认这是一个有效的插件目录
plugin_dir = plugin_root / plugin_name
if plugin_dir.is_dir():
# 检查是否有插件主文件或配置文件
has_plugin_py = (plugin_dir / "plugin.py").exists()
has_manifest = (plugin_dir / "manifest.toml").exists() or (
plugin_dir / "_manifest.json"
).exists()
if has_plugin_py or has_manifest:
return plugin_name, source_type
return None
except Exception:
return None
class PluginHotReloadManager:
"""插件热重载管理器"""
def __init__(self, watch_directories: Optional[List[str]] = None):
if watch_directories is None:
# 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录
self.watch_directories = [
os.path.join(os.getcwd(), "plugins"), # 外部插件目录
os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录
]
else:
self.watch_directories = watch_directories
self.observers = []
self.file_handlers = []
self.is_running = False
# 确保监听目录存在
for watch_dir in self.watch_directories:
if not os.path.exists(watch_dir):
os.makedirs(watch_dir, exist_ok=True)
logger.info(f"📁 创建插件监听目录: {watch_dir}")
def start(self):
"""启动热重载监听"""
if self.is_running:
logger.warning("插件热重载已经在运行中")
return
try:
# 为每个监听目录创建独立的观察者
for watch_dir in self.watch_directories:
observer = Observer()
file_handler = PluginFileHandler(self)
observer.schedule(file_handler, watch_dir, recursive=True)
observer.start()
self.observers.append(observer)
self.file_handlers.append(file_handler)
self.is_running = True
# 打印监听的目录信息
dir_info = []
for watch_dir in self.watch_directories:
if "src" in watch_dir:
dir_info.append(f"{watch_dir} (内置插件)")
else:
dir_info.append(f"{watch_dir} (外部插件)")
logger.info("🚀 插件热重载已启动,监听目录:")
for info in dir_info:
logger.info(f" 📂 {info}")
except Exception as e:
logger.error(f"❌ 启动插件热重载失败: {e}")
self.stop() # 清理已创建的观察者
self.is_running = False
def stop(self):
"""停止热重载监听"""
if not self.is_running and not self.observers:
return
# 停止所有观察者
for observer in self.observers:
try:
observer.stop()
observer.join()
except Exception as e:
logger.error(f"❌ 停止观察者时发生错误: {e}")
self.observers.clear()
self.file_handlers.clear()
self.is_running = False
logger.info("🛑 插件热重载已停止")
def _reload_plugin(self, plugin_name: str):
"""重载指定插件(简单重载)"""
try:
# 解析实际的插件名称
actual_plugin_name = self._resolve_plugin_name(plugin_name)
logger.info(f"🔄 开始简单重载插件: {plugin_name} -> {actual_plugin_name}")
if plugin_manager.reload_plugin(actual_plugin_name):
logger.info(f"✅ 插件简单重载成功: {actual_plugin_name}")
return True
else:
logger.error(f"❌ 插件简单重载失败: {actual_plugin_name}")
return False
except Exception as e:
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
@staticmethod
def _resolve_plugin_name(folder_name: str) -> str:
"""
将文件夹名称解析为实际的插件名称
通过检查插件管理器中的路径映射来找到对应的插件名
"""
# 首先检查是否直接匹配
if folder_name in plugin_manager.plugin_classes:
logger.debug(f"🔍 直接匹配插件名: {folder_name}")
return folder_name
# 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称
matched_plugins = []
for plugin_name, plugin_path in plugin_manager.plugin_paths.items():
# 检查路径是否包含该文件夹名
if folder_name in plugin_path:
matched_plugins.append((plugin_name, plugin_path))
# 在匹配的插件中,优先选择在插件类中存在的
for plugin_name, plugin_path in matched_plugins:
if plugin_name in plugin_manager.plugin_classes:
logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})")
return plugin_name
# 如果还是没找到在插件类中存在的,返回第一个匹配项
if matched_plugins:
plugin_name, plugin_path = matched_plugins[0]
logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在")
return plugin_name
# 如果还是没找到,返回原文件夹名
logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称")
return folder_name
def _deep_reload_plugin(self, plugin_name: str):
"""深度重载指定插件(清理模块缓存)"""
try:
# 解析实际的插件名称
actual_plugin_name = self._resolve_plugin_name(plugin_name)
logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}")
# 强制清理相关模块缓存
self._force_clear_plugin_modules(plugin_name)
# 使用插件管理器的强制重载功能
success = plugin_manager.force_reload_plugin(actual_plugin_name)
if success:
logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}")
return True
else:
logger.error(f"❌ 插件深度重载失败,尝试简单重载: {actual_plugin_name}")
# 如果深度重载失败,尝试简单重载
return self._reload_plugin(actual_plugin_name)
except Exception as e:
logger.error(f"❌ 深度重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
# 出错时尝试简单重载
return self._reload_plugin(plugin_name)
@staticmethod
def _force_clear_plugin_modules(plugin_name: str):
"""强制清理插件相关的模块缓存"""
# 找到所有相关的模块名
modules_to_remove = []
plugin_module_prefix = f"src.plugins.built_in.{plugin_name}"
for module_name in list(sys.modules.keys()):
if plugin_module_prefix in module_name:
modules_to_remove.append(module_name)
# 删除模块缓存
for module_name in modules_to_remove:
if module_name in sys.modules:
logger.debug(f"🗑️ 清理模块缓存: {module_name}")
del sys.modules[module_name]
@staticmethod
def _force_reimport_plugin(plugin_name: str):
"""强制重新导入插件(委托给插件管理器)"""
try:
# 使用插件管理器的重载功能
success = plugin_manager.reload_plugin(plugin_name)
return success
except Exception as e:
logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
@staticmethod
def _unload_plugin(plugin_name: str):
"""卸载指定插件"""
try:
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
if plugin_manager.unload_plugin(plugin_name):
logger.info(f"✅ 插件卸载成功: {plugin_name}")
return True
else:
logger.error(f"❌ 插件卸载失败: {plugin_name}")
return False
except Exception as e:
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
def reload_all_plugins(self):
"""重载所有插件"""
try:
logger.info("🔄 开始深度重载所有插件...")
# 获取当前已加载的插件列表
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
success_count = 0
fail_count = 0
for plugin_name in loaded_plugins:
logger.info(f"🔄 重载插件: {plugin_name}")
if self._deep_reload_plugin(plugin_name):
success_count += 1
else:
fail_count += 1
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count}")
# 清理全局缓存
importlib.invalidate_caches()
except Exception as e:
logger.error(f"❌ 重载所有插件时发生错误: {e}", exc_info=True)
def force_reload_plugin(self, plugin_name: str):
"""手动强制重载指定插件(委托给插件管理器)"""
try:
logger.info(f"🔄 手动强制重载插件: {plugin_name}")
# 清理待重载列表中的该插件(避免重复重载)
for handler in self.file_handlers:
handler.pending_reloads.discard(plugin_name)
# 使用插件管理器的强制重载功能
success = plugin_manager.force_reload_plugin(plugin_name)
if success:
logger.info(f"✅ 手动强制重载成功: {plugin_name}")
else:
logger.error(f"❌ 手动强制重载失败: {plugin_name}")
return success
except Exception as e:
logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
def add_watch_directory(self, directory: str):
"""添加新的监听目录"""
if directory in self.watch_directories:
logger.info(f"目录 {directory} 已在监听列表中")
return
# 确保目录存在
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"📁 创建插件监听目录: {directory}")
self.watch_directories.append(directory)
# 如果热重载正在运行,为新目录创建观察者
if self.is_running:
try:
observer = Observer()
file_handler = PluginFileHandler(self)
observer.schedule(file_handler, directory, recursive=True)
observer.start()
self.observers.append(observer)
self.file_handlers.append(file_handler)
logger.info(f"📂 已添加新的监听目录: {directory}")
except Exception as e:
logger.error(f"❌ 添加监听目录 {directory} 失败: {e}")
self.watch_directories.remove(directory)
def get_status(self) -> dict:
"""获取热重载状态"""
pending_reloads = set()
if self.file_handlers:
for handler in self.file_handlers:
pending_reloads.update(handler.pending_reloads)
return {
"is_running": self.is_running,
"watch_directories": self.watch_directories,
"active_observers": len(self.observers),
"loaded_plugins": len(plugin_manager.loaded_plugins),
"failed_plugins": len(plugin_manager.failed_plugins),
"pending_reloads": list(pending_reloads),
"debounce_delay": self.file_handlers[0].debounce_delay if self.file_handlers else 0,
}
@staticmethod
def clear_all_caches():
"""清理所有Python模块缓存"""
try:
logger.info("🧹 开始清理所有Python模块缓存...")
# 重新扫描所有插件目录,这会重新加载模块
plugin_manager.rescan_plugin_directory()
logger.info("✅ 模块缓存清理完成")
except Exception as e:
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
# 全局热重载管理器实例
hot_reload_manager = PluginHotReloadManager()

View File

@@ -91,7 +91,6 @@ INSTALL_NAME_TO_IMPORT_NAME = {
"pyusb": "usb", # USB访问
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
"psutil": "psutil", # 系统信息和进程管理
"watchdog": "watchdog", # 文件系统事件监控
"python-gnupg": "gnupg", # GnuPG的Python接口
# ============== 加密与安全 (Cryptography & Security) ==============
"pycrypto": "Crypto", # 加密库 (较旧)

View File

@@ -15,7 +15,6 @@ from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.utils.permission_decorators import require_permission
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
class ManagementCommand(PlusCommand):
@@ -78,10 +77,6 @@ class ManagementCommand(PlusCommand):
await self._force_reload_plugin(args[1])
elif action in ["add_dir", "添加目录"] and len(args) > 1:
await self._add_dir(args[1])
elif action in ["hotreload_status", "热重载状态"]:
await self._show_hotreload_status()
elif action in ["clear_cache", "清理缓存"]:
await self._clear_all_caches()
else:
await self.send_text("❌ 插件管理命令不合法\n使用 /pm plugin help 查看帮助")
return False, "命令不合法", True
@@ -179,14 +174,9 @@ class ManagementCommand(PlusCommand):
• `/pm plugin force_reload <插件名>` - 强制重载指定插件(深度清理)
• `/pm plugin add_dir <目录路径>` - 添加插件目录
<EFBFBD> 热重载管理:
• `/pm plugin hotreload_status` - 查看热重载状态
• `/pm plugin clear_cache` - 清理所有模块缓存
<EFBFBD>📝 示例:
• `/pm plugin load echo_example`
• `/pm plugin force_reload permission_manager_plugin`
• `/pm plugin clear_cache`"""
• `/pm plugin force_reload permission_manager_plugin`"""
elif target == "component":
help_msg = """🧩 组件管理命令帮助
@@ -262,7 +252,7 @@ class ManagementCommand(PlusCommand):
await self.send_text(f"🔄 开始强制重载插件: `{plugin_name}`...")
try:
success = hot_reload_manager.force_reload_plugin(plugin_name)
success = plugin_manage_api.force_reload_plugin(plugin_name)
if success:
await self.send_text(f"✅ 插件强制重载成功: `{plugin_name}`")
else:
@@ -270,44 +260,7 @@ class ManagementCommand(PlusCommand):
except Exception as e:
await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}")
async def _show_hotreload_status(self):
"""显示热重载状态"""
try:
status = hot_reload_manager.get_status()
status_text = f"""🔄 **热重载系统状态**
🟢 **运行状态:** {"运行中" if status["is_running"] else "已停止"}
📂 **监听目录:** {len(status["watch_directories"])}
👁️ **活跃观察者:** {status["active_observers"]}
📦 **已加载插件:** {status["loaded_plugins"]}
❌ **失败插件:** {status["failed_plugins"]}
⏱️ **防抖延迟:** {status.get("debounce_delay", 0)}
📋 **监听的目录:**"""
for i, watch_dir in enumerate(status["watch_directories"], 1):
dir_type = "(内置插件)" if "src" in watch_dir else "(外部插件)"
status_text += f"\n{i}. `{watch_dir}` {dir_type}"
if status.get("pending_reloads"):
status_text += f"\n\n⏳ **待重载插件:** {', '.join([f'`{p}`' for p in status['pending_reloads']])}"
await self.send_text(status_text)
except Exception as e:
await self.send_text(f"❌ 获取热重载状态时发生错误: {str(e)}")
async def _clear_all_caches(self):
"""清理所有模块缓存"""
await self.send_text("🧹 开始清理所有Python模块缓存...")
try:
hot_reload_manager.clear_all_caches()
await self.send_text("✅ 模块缓存清理完成!建议重载相关插件以确保生效。")
except Exception as e:
await self.send_text(f"❌ 清理缓存时发生错误: {str(e)}")
async def _add_dir(self, dir_path: str):
"""添加插件目录"""
await self.send_text(f"📁 正在添加插件目录: `{dir_path}`")

View File

@@ -255,25 +255,34 @@ max_context_emojis = 30 # 每次随机传递给LLM的表情包详细描述的最
[memory]
enable_memory = true # 是否启用记忆系统
memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低MoFox-Bot学习越多但是冗余信息也会增多
memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
memory_build_sample_length = 30 # 采样长度,数值越高一段记忆内容越丰富
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
forget_memory_interval = 3000 # 记忆遗忘间隔 单位秒 间隔越低MoFox-Bot遗忘越频繁记忆更精简但更难学习
memory_forget_time = 48 #多长时间后的记忆会被遗忘 单位小时
memory_forget_percentage = 0.008 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认
consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低MoFox-Bot整合越频繁记忆更精简
consolidation_similarity_threshold = 0.7 # 相似度阈值
consolidation_check_percentage = 0.05 # 检查节点比例
enable_instant_memory = false # 是否启用即时记忆,测试功能,可能存在未知问题
enable_instant_memory = true # 是否启用即时记忆
enable_llm_instant_memory = true # 是否启用基于LLM的瞬时记忆
enable_vector_instant_memory = true # 是否启用基于向量的瞬时记忆
enable_enhanced_memory = true # 是否启用增强记忆系统
enhanced_memory_auto_save = true # 是否自动保存增强记忆
#不希望记忆的词,已经记忆的不会受到影响,需要手动清理
memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
min_memory_length = 10 # 最小记忆长度
max_memory_length = 500 # 最大记忆长度
memory_value_threshold = 0.7 # 记忆价值阈值,低于该值的记忆会被丢弃
vector_similarity_threshold = 0.8 # 向量相似度阈值
semantic_similarity_threshold = 0.6 # 语义重排阶段的最低匹配阈值
metadata_filter_limit = 100 # 元数据过滤阶段返回数量上限
vector_search_limit = 50 # 向量搜索阶段返回数量上限
semantic_rerank_limit = 20 # 语义重排阶段返回数量上限
final_result_limit = 10 # 综合筛选后的最终返回数量
vector_weight = 0.4 # 综合评分中向量相似度的权重
semantic_weight = 0.3 # 综合评分中语义匹配的权重
context_weight = 0.2 # 综合评分中上下文关联的权重
recency_weight = 0.1 # 综合评分中时效性的权重
fusion_similarity_threshold = 0.85 # 记忆融合时的相似度阈值
deduplication_window_hours = 24 # 记忆去重窗口(小时)
enable_memory_cache = true # 是否启用本地记忆缓存
cache_ttl_seconds = 300 # 缓存有效期(秒)
max_cache_size = 1000 # 缓存中允许的最大记忆条数
[voice]
enable_asr = true # 是否启用语音识别启用后MoFox-Bot可以识别语音消息启用该功能需要配置语音识别模型[model.voice]

View File

@@ -1,5 +1,5 @@
[inner]
version = "1.3.5"
version = "1.3.6"
# 配置文件版本号迭代规则同bot_config.toml
@@ -203,6 +203,7 @@ max_tokens = 1000
#嵌入模型
[model_task_config.embedding]
model_list = ["bge-m3"]
embedding_dimension = 1024