diff --git a/pyproject.toml b/pyproject.toml index cf3c3a844..a67f28472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ dependencies = [ "tqdm>=4.67.1", "urllib3>=2.5.0", "uvicorn>=0.35.0", - "watchdog>=6.0.0", "websockets>=15.0.1", "aiomysql>=0.2.0", "aiosqlite>=0.21.0", diff --git a/requirements.txt b/requirements.txt index 91811fd28..a96737f69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,7 +50,6 @@ reportportal-client scikit-learn seaborn structlog -watchdog httpx requests beautifulsoup4 diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 5c70fa744..37a315197 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -12,6 +12,7 @@ from sqlalchemy import select from src.common.logger import get_logger from src.config.config import global_config +from src.common.config_helpers import resolve_embedding_dimension from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult logger = get_logger("bot_interest_manager") @@ -28,7 +29,9 @@ class BotInterestManager: # Embedding客户端配置 self.embedding_request = None self.embedding_config = None - self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度 + configured_dim = resolve_embedding_dimension() + self.embedding_dimension = int(configured_dim) if configured_dim else 0 + self._detected_embedding_dimension: Optional[int] = None @property def is_initialized(self) -> bool: @@ -82,8 +85,11 @@ class BotInterestManager: logger.info("📋 找到embedding模型配置") self.embedding_config = model_config.model_task_config.embedding - self.embedding_dimension = 1024 # BGE-M3的维度 - logger.info(f"📐 使用模型维度: {self.embedding_dimension}") + + if self.embedding_dimension: + logger.info(f"📐 配置的embedding维度: {self.embedding_dimension}") + else: + logger.info("📐 未在配置中检测到embedding维度,将根据首次返回的向量自动识别") # 创建LLMRequest实例用于embedding self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding") @@ -350,7 +356,27 @@ class BotInterestManager: if embedding and len(embedding) > 0: self.embedding_cache[text] = embedding - logger.debug(f"✅ Embedding获取成功,维度: {len(embedding)}, 模型: {model_name}") + + current_dim = len(embedding) + if self._detected_embedding_dimension is None: + self._detected_embedding_dimension = current_dim + if self.embedding_dimension and self.embedding_dimension != current_dim: + logger.warning( + "⚠️ 实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新", + current_dim, + self.embedding_dimension, + ) + else: + self.embedding_dimension = current_dim + logger.info(f"📏 检测到embedding维度: {current_dim}") + elif current_dim != self.embedding_dimension: + logger.warning( + "⚠️ 收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。", + self.embedding_dimension, + current_dim, + ) + + logger.debug(f"✅ Embedding获取成功,维度: {current_dim}, 模型: {model_name}") return embedding else: raise RuntimeError(f"❌ 返回的embedding为空: {embedding}") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 67296c0c9..162a00b7f 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -26,6 +26,7 @@ from rich.progress import ( TextColumn, ) from src.config.config import global_config +from src.common.config_helpers import resolve_embedding_dimension install(extra_lines=3) @@ -504,7 +505,10 @@ class EmbeddingStore: # L2归一化 faiss.normalize_L2(embeddings) # 构建索引 - self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension) + embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension) + if not embedding_dim: + embedding_dim = global_config.lpmm_knowledge.embedding_dimension + self.faiss_index = faiss.IndexFlatIP(embedding_dim) self.faiss_index.add(embeddings) def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]: diff --git a/src/chat/memory_system/enhanced_memory_adapter.py b/src/chat/memory_system/enhanced_memory_adapter.py index 0d73b11f5..b48a723fb 100644 --- a/src/chat/memory_system/enhanced_memory_adapter.py +++ b/src/chat/memory_system/enhanced_memory_adapter.py @@ -11,12 +11,27 @@ from dataclasses import dataclass from src.common.logger import get_logger from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode -from src.chat.memory_system.memory_chunk import MemoryChunk +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType from src.llm_models.utils_model import LLMRequest logger = get_logger(__name__) +MEMORY_TYPE_LABELS = { + MemoryType.PERSONAL_FACT: "个人事实", + MemoryType.EVENT: "事件", + MemoryType.PREFERENCE: "偏好", + MemoryType.OPINION: "观点", + MemoryType.RELATIONSHIP: "关系", + MemoryType.EMOTION: "情感", + MemoryType.KNOWLEDGE: "知识", + MemoryType.SKILL: "技能", + MemoryType.GOAL: "目标", + MemoryType.EXPERIENCE: "经验", + MemoryType.CONTEXTUAL: "上下文", +} + + @dataclass class AdapterConfig: """适配器配置""" @@ -85,12 +100,9 @@ class EnhancedMemoryAdapter: async def process_conversation_memory( self, - conversation_text: str, - context: Dict[str, Any], - user_id: str, - timestamp: Optional[float] = None + context: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - """处理对话记忆""" + """处理对话记忆,以上下文为唯一输入""" if not self._initialized or not self.config.enable_enhanced_memory: return {"success": False, "error": "Enhanced memory not available"} @@ -98,10 +110,30 @@ class EnhancedMemoryAdapter: self.adapter_stats["total_processed"] += 1 try: + payload_context: Dict[str, Any] = dict(context or {}) + + conversation_text = payload_context.get("conversation_text") + if not conversation_text: + conversation_candidate = ( + payload_context.get("message_content") + or payload_context.get("latest_message") + or payload_context.get("raw_text") + ) + if conversation_candidate is not None: + conversation_text = str(conversation_candidate) + payload_context["conversation_text"] = conversation_text + else: + conversation_text = "" + else: + conversation_text = str(conversation_text) + + if "timestamp" not in payload_context: + payload_context["timestamp"] = time.time() + + logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text)) + # 使用集成层处理对话 - result = await self.integration_layer.process_conversation( - conversation_text, context, user_id, timestamp - ) + result = await self.integration_layer.process_conversation(payload_context) # 更新统计 processing_time = time.time() - start_time @@ -132,7 +164,7 @@ class EnhancedMemoryAdapter: try: limit = limit or self.config.max_retrieval_results memories = await self.integration_layer.retrieve_relevant_memories( - query, user_id, context, limit + query, None, context, limit ) self.adapter_stats["memories_retrieved"] += len(memories) @@ -157,12 +189,15 @@ class EnhancedMemoryAdapter: if not memories: return "" - # 格式化记忆为提示词友好的格式 - memory_context_parts = [] - for memory in memories: - memory_context_parts.append(f"- {memory.text_content}") + # 格式化记忆为提示词友好的Markdown结构 + lines: List[str] = ["### 🧠 相关记忆 (Relevant Memories)", ""] - return "\n".join(memory_context_parts) + for memory in memories: + type_label = MEMORY_TYPE_LABELS.get(memory.memory_type, memory.memory_type.value) + display_text = memory.display or memory.text_content + lines.append(f"- **[{type_label}]** {display_text}") + + return "\n".join(lines) async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]: """获取增强记忆系统摘要""" @@ -270,13 +305,10 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest): async def process_conversation_with_enhanced_memory( - conversation_text: str, context: Dict[str, Any], - user_id: str, - timestamp: Optional[float] = None, llm_model: Optional[LLMRequest] = None ) -> Dict[str, Any]: - """使用增强记忆系统处理对话""" + """使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息""" if not llm_model: # 获取默认的LLM模型 from src.llm_models.utils_model import get_global_llm_model @@ -284,7 +316,18 @@ async def process_conversation_with_enhanced_memory( try: adapter = await get_enhanced_memory_adapter(llm_model) - return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp) + payload_context = dict(context or {}) + + if "conversation_text" not in payload_context: + conversation_candidate = ( + payload_context.get("message_content") + or payload_context.get("latest_message") + or payload_context.get("raw_text") + ) + if conversation_candidate is not None: + payload_context["conversation_text"] = str(conversation_candidate) + + return await adapter.process_conversation_memory(payload_context) except Exception as e: logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True) return {"success": False, "error": str(e)} diff --git a/src/chat/memory_system/enhanced_memory_core.py b/src/chat/memory_system/enhanced_memory_core.py index d8a9910f6..08bebe2f0 100644 --- a/src/chat/memory_system/enhanced_memory_core.py +++ b/src/chat/memory_system/enhanced_memory_core.py @@ -1,13 +1,15 @@ # -*- coding: utf-8 -*- """ 增强型精准记忆系统核心模块 -基于文档设计的高效记忆构建、存储与召回优化系统 +1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。 +2. 内置 LLM 查询规划器与嵌入维度自动解析机制,直接从模型配置推断向量存储参数。 """ import asyncio import time import orjson import re +import hashlib from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING from datetime import datetime, timedelta from dataclasses import dataclass, asdict @@ -22,12 +24,16 @@ from src.chat.memory_system.memory_fusion import MemoryFusionEngine from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig from src.chat.memory_system.metadata_index import MetadataIndexManager from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig +from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger(__name__) +# 全局记忆作用域(共享记忆库) +GLOBAL_MEMORY_SCOPE = "global" + class MemorySystemStatus(Enum): """记忆系统状态""" @@ -47,14 +53,20 @@ class MemorySystemConfig: memory_value_threshold: float = 0.7 min_build_interval_seconds: float = 300.0 - # 向量存储配置 - vector_dimension: int = 768 + # 向量存储配置(嵌入维度自动来自模型配置) + vector_dimension: int = 1024 similarity_threshold: float = 0.8 # 召回配置 coarse_recall_limit: int = 50 fine_recall_limit: int = 10 + semantic_rerank_limit: int = 20 final_recall_limit: int = 5 + semantic_similarity_threshold: float = 0.6 + vector_weight: float = 0.4 + semantic_weight: float = 0.3 + context_weight: float = 0.2 + recency_weight: float = 0.1 # 融合配置 fusion_similarity_threshold: float = 0.85 @@ -64,6 +76,23 @@ class MemorySystemConfig: def from_global_config(cls): """从全局配置创建配置实例""" + embedding_dimension = None + try: + embedding_task = getattr(model_config.model_task_config, "embedding", None) + if embedding_task is not None: + embedding_dimension = getattr(embedding_task, "embedding_dimension", None) + except Exception: + embedding_dimension = None + + if not embedding_dimension: + try: + embedding_dimension = getattr(global_config.lpmm_knowledge, "embedding_dimension", None) + except Exception: + embedding_dimension = None + + if not embedding_dimension: + embedding_dimension = 1024 + return cls( # 记忆构建配置 min_memory_length=global_config.memory.min_memory_length, @@ -72,13 +101,19 @@ class MemorySystemConfig: min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0), # 向量存储配置 - vector_dimension=global_config.memory.vector_dimension, + vector_dimension=int(embedding_dimension), similarity_threshold=global_config.memory.vector_similarity_threshold, # 召回配置 coarse_recall_limit=global_config.memory.metadata_filter_limit, - fine_recall_limit=global_config.memory.final_result_limit, + fine_recall_limit=global_config.memory.vector_search_limit, + semantic_rerank_limit=global_config.memory.semantic_rerank_limit, final_recall_limit=global_config.memory.final_result_limit, + semantic_similarity_threshold=getattr(global_config.memory, "semantic_similarity_threshold", 0.6), + vector_weight=global_config.memory.vector_weight, + semantic_weight=global_config.memory.semantic_weight, + context_weight=global_config.memory.context_weight, + recency_weight=global_config.memory.recency_weight, # 融合配置 fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold, @@ -104,6 +139,7 @@ class EnhancedMemorySystem: self.vector_storage: VectorStorageManager = None self.metadata_index: MetadataIndexManager = None self.retrieval_system: MultiStageRetrieval = None + self.query_planner: MemoryQueryPlanner = None # LLM模型 self.value_assessment_model: LLMRequest = None @@ -117,6 +153,9 @@ class EnhancedMemorySystem: # 构建节流记录 self._last_memory_build_times: Dict[str, float] = {} + # 记忆指纹缓存,用于快速检测重复记忆 + self._memory_fingerprints: Dict[str, str] = {} + logger.info("EnhancedMemorySystem 初始化开始") async def initialize(self): @@ -125,19 +164,29 @@ class EnhancedMemorySystem: logger.info("正在初始化增强型记忆系统...") # 初始化LLM模型 - task_config = ( - self.llm_model.model_for_task - if self.llm_model is not None - else model_config.model_task_config.utils_small - ) + fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None + + value_task_config = getattr(model_config.model_task_config, "utils_small", None) + extraction_task_config = getattr(model_config.model_task_config, "utils", None) + + if value_task_config is None: + logger.warning("未找到 utils_small 模型配置,回退到 utils 或外部提供的模型配置。") + value_task_config = extraction_task_config or fallback_task + + if extraction_task_config is None: + logger.warning("未找到 utils 模型配置,回退到 utils_small 或外部提供的模型配置。") + extraction_task_config = value_task_config or fallback_task + + if value_task_config is None or extraction_task_config is None: + raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。") self.value_assessment_model = LLMRequest( - model_set=task_config, + model_set=value_task_config, request_type="memory.value_assessment" ) self.memory_extraction_model = LLMRequest( - model_set=task_config, + model_set=extraction_task_config, request_type="memory.extraction" ) @@ -155,13 +204,36 @@ class EnhancedMemorySystem: retrieval_config = RetrievalConfig( metadata_filter_limit=self.config.coarse_recall_limit, vector_search_limit=self.config.fine_recall_limit, - final_result_limit=self.config.final_recall_limit + semantic_rerank_limit=self.config.semantic_rerank_limit, + final_result_limit=self.config.final_recall_limit, + vector_similarity_threshold=self.config.similarity_threshold, + semantic_similarity_threshold=self.config.semantic_similarity_threshold, + vector_weight=self.config.vector_weight, + semantic_weight=self.config.semantic_weight, + context_weight=self.config.context_weight, + recency_weight=self.config.recency_weight, ) self.retrieval_system = MultiStageRetrieval(retrieval_config) + planner_task_config = getattr(model_config.model_task_config, "planner", None) + planner_model: Optional[LLMRequest] = None + try: + planner_model = LLMRequest( + model_set=planner_task_config, + request_type="memory.query_planner" + ) + except Exception as planner_exc: + logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True) + + self.query_planner = MemoryQueryPlanner( + planner_model, + default_limit=self.config.final_recall_limit + ) + # 加载持久化数据 await self.vector_storage.load_storage() await self.metadata_index.load_index() + self._populate_memory_fingerprints() self.status = MemorySystemStatus.READY logger.info("✅ 增强型记忆系统初始化完成") @@ -174,7 +246,7 @@ class EnhancedMemorySystem: async def retrieve_memories_for_building( self, query_text: str, - user_id: str, + user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5 ) -> List[MemoryChunk]: @@ -182,7 +254,6 @@ class EnhancedMemorySystem: Args: query_text: 查询文本 - user_id: 用户ID context: 上下文信息 limit: 返回结果数量限制 @@ -201,7 +272,6 @@ class EnhancedMemorySystem: # 执行检索 memories = await self.vector_storage.search_similar_memories( query_text=query_text, - user_id=user_id, limit=limit ) @@ -218,23 +288,18 @@ class EnhancedMemorySystem: self, conversation_text: str, context: Dict[str, Any], - user_id: str, timestamp: Optional[float] = None ) -> List[MemoryChunk]: """从对话中构建记忆 Args: conversation_text: 对话文本 - context: 上下文信息(包括用户信息、群组信息等) - user_id: 用户ID + context: 上下文信息 timestamp: 时间戳,默认为当前时间 Returns: 构建的记忆块列表 """ - if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]: - raise RuntimeError("记忆系统未就绪") - original_status = self.status self.status = MemorySystemStatus.BUILDING start_time = time.time() @@ -243,9 +308,9 @@ class EnhancedMemorySystem: build_marker_time: Optional[float] = None try: - normalized_context = self._normalize_context(context, user_id, timestamp) + normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp) - build_scope_key = self._get_build_scope_key(normalized_context, user_id) + build_scope_key = self._get_build_scope_key(normalized_context, GLOBAL_MEMORY_SCOPE) min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0)) current_time = time.time() @@ -266,7 +331,7 @@ class EnhancedMemorySystem: conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context) - logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}") + logger.debug("开始构建记忆,文本长度: %d", len(conversation_text)) # 1. 信息价值评估 value_score = await self._assess_information_value(conversation_text, normalized_context) @@ -280,7 +345,7 @@ class EnhancedMemorySystem: memory_chunks = await self.memory_builder.build_memories( conversation_text, normalized_context, - user_id, + GLOBAL_MEMORY_SCOPE, timestamp or time.time() ) @@ -293,19 +358,24 @@ class EnhancedMemorySystem: fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks) # 4. 存储记忆 - await self._store_memories(fused_chunks) + stored_count = await self._store_memories(fused_chunks) # 4.1 控制台预览 self._log_memory_preview(fused_chunks) # 5. 更新统计 - self.total_memories += len(fused_chunks) + self.total_memories += stored_count self.last_build_time = time.time() if build_scope_key: self._last_memory_build_times[build_scope_key] = self.last_build_time build_time = time.time() - start_time - logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒") + logger.info( + "✅ 生成 %d 条记忆,成功入库 %d 条,耗时 %.2f秒", + len(fused_chunks), + stored_count, + build_time, + ) self.status = original_status return fused_chunks @@ -347,21 +417,34 @@ class EnhancedMemorySystem: async def process_conversation_memory( self, - conversation_text: str, - context: Dict[str, Any], - user_id: str, - timestamp: Optional[float] = None + context: Dict[str, Any] ) -> Dict[str, Any]: - """对外暴露的对话记忆处理接口,兼容旧调用方式""" + """对外暴露的对话记忆处理接口,仅依赖上下文信息""" start_time = time.time() try: - normalized_context = self._normalize_context(context, user_id, timestamp) + context = dict(context or {}) + + conversation_candidate = ( + context.get("conversation_text") + or context.get("message_content") + or context.get("latest_message") + or context.get("raw_text") + or "" + ) + + conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate) + + timestamp = context.get("timestamp") + if timestamp is None: + timestamp = time.time() + + normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp) + normalized_context.setdefault("conversation_text", conversation_text) memories = await self.build_memory_from_conversation( conversation_text=conversation_text, context=normalized_context, - user_id=user_id, timestamp=timestamp ) @@ -395,52 +478,77 @@ class EnhancedMemorySystem: **kwargs ) -> List[MemoryChunk]: """检索相关记忆,兼容 query/query_text 参数形式""" - if self.status != MemorySystemStatus.READY: - raise RuntimeError("记忆系统未就绪") - - query_text = query_text or kwargs.get("query") - if not query_text: + raw_query = query_text or kwargs.get("query") + if not raw_query: raise ValueError("query_text 或 query 参数不能为空") context = context or {} - user_id = user_id or kwargs.get("user_id") + resolved_user_id = GLOBAL_MEMORY_SCOPE + + if self.retrieval_system is None or self.metadata_index is None: + raise RuntimeError("检索组件未初始化") + + all_memories_cache = self.vector_storage.memory_cache + if not all_memories_cache: + logger.debug("记忆缓存为空,返回空结果") + self.last_retrieval_time = time.time() + self.status = MemorySystemStatus.READY + return [] self.status = MemorySystemStatus.RETRIEVING start_time = time.time() try: - normalized_context = self._normalize_context(context, user_id, None) + normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None) - candidate_memories = list(self.vector_storage.memory_cache.values()) - if user_id: - candidate_memories = [m for m in candidate_memories if m.user_id == user_id] + effective_limit = limit or self.config.final_recall_limit + query_plan = None + planner_ran = False + resolved_query_text = raw_query + if self.query_planner: + try: + planner_ran = True + query_plan = await self.query_planner.plan_query(raw_query, normalized_context) + normalized_context["query_plan"] = query_plan + effective_limit = min(effective_limit, query_plan.limit or effective_limit) + if getattr(query_plan, "semantic_query", None): + resolved_query_text = query_plan.semantic_query + logger.debug( + "查询规划: semantic='%s', types=%s, subjects=%s, limit=%d", + query_plan.semantic_query, + [mt.value for mt in query_plan.memory_types], + query_plan.subject_includes, + query_plan.limit, + ) + except Exception as plan_exc: + logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True) - if not candidate_memories: - self.status = MemorySystemStatus.READY - self.last_retrieval_time = time.time() - logger.debug(f"未找到用户 {user_id} 的候选记忆") - return [] + effective_limit = effective_limit or self.config.final_recall_limit + effective_limit = max(1, min(effective_limit, self.config.final_recall_limit)) + normalized_context["resolved_query_text"] = resolved_query_text - scored_memories = [] - for memory in candidate_memories: - score = self._compute_memory_score(query_text, memory, normalized_context) - if score > 0: - scored_memories.append((memory, score)) - - if not scored_memories: - # 如果所有分数为0,返回最近的记忆作为降级策略 + if normalized_context.get("__memory_building__"): + logger.debug("当前处于记忆构建流程,跳过查询规划并进行降级检索") + self.status = MemorySystemStatus.BUILDING + final_memories = [] + candidate_memories = list(all_memories_cache.values()) candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True) - scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]] + final_memories = candidate_memories[:effective_limit] else: - scored_memories.sort(key=lambda item: item[1], reverse=True) + retrieval_result = await self.retrieval_system.retrieve_memories( + query=resolved_query_text, + user_id=resolved_user_id, + context=normalized_context, + metadata_index=self.metadata_index, + vector_storage=self.vector_storage, + all_memories_cache=all_memories_cache, + limit=effective_limit, + ) - top_memories = [memory for memory, _ in scored_memories[:limit]] + final_memories = retrieval_result.final_memories - # 更新访问信息和缓存 - for memory, score in scored_memories[:limit]: + for memory in final_memories: memory.update_access() - memory.update_relevance(score) - cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id) if cache_entry is not None: cache_entry["last_accessed"] = memory.metadata.last_accessed @@ -448,14 +556,34 @@ class EnhancedMemorySystem: cache_entry["relevance_score"] = memory.metadata.relevance_score retrieval_time = time.time() - start_time - logger.info( - f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}秒" + plan_summary = "" + if planner_ran and query_plan: + plan_types = ",".join(mt.value for mt in query_plan.memory_types) or "-" + plan_subjects = ",".join(query_plan.subject_includes) or "-" + plan_summary = ( + f" | planner.semantic='{query_plan.semantic_query}'" + f" | planner.limit={query_plan.limit}" + f" | planner.types={plan_types}" + f" | planner.subjects={plan_subjects}" + ) + + log_message = ( + "✅ 记忆检索完成" + f" | user={resolved_user_id}" + f" | count={len(final_memories)}" + f" | duration={retrieval_time:.3f}s" + f" | applied_limit={effective_limit}" + f" | raw_query='{raw_query}'" + f" | semantic_query='{resolved_query_text}'" + f"{plan_summary}" ) + logger.info(log_message) + self.last_retrieval_time = time.time() self.status = MemorySystemStatus.READY - return top_memories + return final_memories except Exception as e: self.status = MemorySystemStatus.ERROR @@ -499,8 +627,8 @@ class EnhancedMemorySystem: except Exception: context = dict(raw_context or {}) - # 基础字段 - context["user_id"] = context.get("user_id") or user_id or "unknown" + # 基础字段(统一使用全局作用域) + context["user_id"] = GLOBAL_MEMORY_SCOPE context["timestamp"] = context.get("timestamp") or timestamp or time.time() context["message_type"] = context.get("message_type") or "normal" context["platform"] = context.get("platform") or context.get("source_platform") or "unknown" @@ -523,8 +651,8 @@ class EnhancedMemorySystem: if stream_id: context["stream_id"] = stream_id - # chat_id 兜底 - context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}" + # 全局记忆无需聊天隔离 + context["chat_id"] = context.get("chat_id") or "global_chat" # 历史窗口配置 window_candidate = ( @@ -616,18 +744,7 @@ class EnhancedMemorySystem: def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]: """确定用于节流控制的记忆构建作用域""" - stream_id = context.get("stream_id") - if stream_id: - return f"stream::{stream_id}" - - chat_id = context.get("chat_id") - if chat_id: - return f"chat::{chat_id}" - - if user_id: - return f"user::{user_id}" - - return None + return "global_scope" def _determine_history_limit(self, context: Dict[str, Any]) -> int: """确定历史消息获取数量,限制在30-50之间""" @@ -789,24 +906,134 @@ class EnhancedMemorySystem: logger.error(f"信息价值评估失败: {e}", exc_info=True) return 0.5 # 默认中等价值 - async def _store_memories(self, memory_chunks: List[MemoryChunk]): - """存储记忆块到各个存储系统""" + async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int: + """存储记忆块到各个存储系统,返回成功入库数量""" if not memory_chunks: - return + return 0 + + unique_memories: List[MemoryChunk] = [] + skipped_duplicates = 0 + + for memory in memory_chunks: + fingerprint = self._build_memory_fingerprint(memory) + key = self._fingerprint_key(memory.user_id, fingerprint) + + existing_id = self._memory_fingerprints.get(key) + if existing_id: + existing = self.vector_storage.memory_cache.get(existing_id) + if existing: + self._merge_existing_memory(existing, memory) + await self.metadata_index.update_memory_entry(existing) + skipped_duplicates += 1 + logger.debug( + "检测到重复记忆,已合并到现有记录 | memory_id=%s", + existing.memory_id, + ) + continue + else: + # 指纹存在但缓存缺失,视为新记忆并覆盖旧映射 + logger.debug("检测到过期指纹映射,重写现有条目") + + unique_memories.append(memory) + + if not unique_memories: + if skipped_duplicates: + logger.info("本次记忆全部与现有内容重复,跳过入库") + return 0 # 并行存储到向量数据库和元数据索引 - storage_tasks = [] + storage_tasks = [ + self.vector_storage.store_memories(unique_memories), + self.metadata_index.index_memories(unique_memories), + ] - # 向量存储 - storage_tasks.append(self.vector_storage.store_memories(memory_chunks)) - - # 元数据索引 - storage_tasks.append(self.metadata_index.index_memories(memory_chunks)) - - # 等待所有存储任务完成 await asyncio.gather(*storage_tasks, return_exceptions=True) - logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统") + self._register_memory_fingerprints(unique_memories) + + logger.debug( + "成功存储 %d 条记忆(跳过重复 %d 条)", + len(unique_memories), + skipped_duplicates, + ) + + return len(unique_memories) + + def _merge_existing_memory(self, existing: MemoryChunk, incoming: MemoryChunk) -> None: + """将新记忆的信息合并到已存在的记忆中""" + updated = False + + for keyword in incoming.keywords: + if keyword not in existing.keywords: + existing.add_keyword(keyword) + updated = True + + for tag in incoming.tags: + if tag not in existing.tags: + existing.add_tag(tag) + updated = True + + for category in incoming.categories: + if category not in existing.categories: + existing.add_category(category) + updated = True + + if incoming.metadata.source_context: + existing.metadata.source_context = incoming.metadata.source_context + + if incoming.metadata.importance.value > existing.metadata.importance.value: + existing.metadata.importance = incoming.metadata.importance + updated = True + + if incoming.metadata.confidence.value > existing.metadata.confidence.value: + existing.metadata.confidence = incoming.metadata.confidence + updated = True + + if incoming.metadata.relevance_score > existing.metadata.relevance_score: + existing.metadata.relevance_score = incoming.metadata.relevance_score + updated = True + + if updated: + existing.metadata.last_modified = time.time() + + def _populate_memory_fingerprints(self) -> None: + """基于当前缓存构建记忆指纹映射""" + self._memory_fingerprints.clear() + for memory in self.vector_storage.memory_cache.values(): + fingerprint = self._build_memory_fingerprint(memory) + key = self._fingerprint_key(memory.user_id, fingerprint) + self._memory_fingerprints[key] = memory.memory_id + + def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None: + for memory in memories: + fingerprint = self._build_memory_fingerprint(memory) + key = self._fingerprint_key(memory.user_id, fingerprint) + self._memory_fingerprints[key] = memory.memory_id + + def _build_memory_fingerprint(self, memory: MemoryChunk) -> str: + subjects = memory.subjects or [] + subject_part = "|".join(sorted(s.strip() for s in subjects if s)) + predicate_part = (memory.content.predicate or "").strip() + + obj = memory.content.object + if isinstance(obj, (dict, list)): + obj_part = orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8") + else: + obj_part = str(obj).strip() + + base = "|".join([ + str(memory.user_id or "unknown"), + memory.memory_type.value, + subject_part, + predicate_part, + obj_part, + ]) + + return hashlib.sha256(base.encode("utf-8")).hexdigest() + + @staticmethod + def _fingerprint_key(user_id: str, fingerprint: str) -> str: + return f"{str(user_id)}:{fingerprint}" def get_system_stats(self) -> Dict[str, Any]: """获取系统统计信息""" diff --git a/src/chat/memory_system/enhanced_memory_manager.py b/src/chat/memory_system/enhanced_memory_manager.py index a676c46c2..344f6a417 100644 --- a/src/chat/memory_system/enhanced_memory_manager.py +++ b/src/chat/memory_system/enhanced_memory_manager.py @@ -241,12 +241,12 @@ class EnhancedMemoryManager: return [] try: - result = await self.enhanced_system.process_conversation_memory( - conversation_text=conversation_text, - context=context, - user_id=user_id, - timestamp=timestamp - ) + payload_context = dict(context or {}) + payload_context.setdefault("conversation_text", conversation_text) + if timestamp is not None: + payload_context.setdefault("timestamp", timestamp) + + result = await self.enhanced_system.process_conversation_memory(payload_context) # 从结果中提取记忆块 memory_chunks = [] @@ -274,7 +274,7 @@ class EnhancedMemoryManager: try: relevant_memories = await self.enhanced_system.retrieve_relevant_memories( query=query_text, - user_id=user_id, + user_id=None, context=context or {}, limit=limit ) @@ -303,6 +303,9 @@ class EnhancedMemoryManager: def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]: """将记忆块转换为更易读的文本描述""" structure = memory.content.to_dict() + if memory.display: + return self._clean_text(memory.display), structure + subject = structure.get("subject") predicate = structure.get("predicate") or "" obj = structure.get("object") diff --git a/src/chat/memory_system/integration_layer.py b/src/chat/memory_system/integration_layer.py index db3be9d6d..3f8b7f1ce 100644 --- a/src/chat/memory_system/integration_layer.py +++ b/src/chat/memory_system/integration_layer.py @@ -114,12 +114,9 @@ class MemoryIntegrationLayer: async def process_conversation( self, - conversation_text: str, - context: Dict[str, Any], - user_id: str, - timestamp: Optional[float] = None + context: Dict[str, Any] ) -> Dict[str, Any]: - """处理对话记忆""" + """处理对话记忆,仅使用上下文信息""" if not self._initialized or not self.enhanced_memory: return {"success": False, "error": "Memory system not available"} @@ -128,13 +125,12 @@ class MemoryIntegrationLayer: self.integration_stats["enhanced_queries"] += 1 try: + payload_context = dict(context or {}) + conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or "" + logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text)) + # 直接使用增强记忆系统处理 - result = await self.enhanced_memory.process_conversation_memory( - conversation_text=conversation_text, - context=context, - user_id=user_id, - timestamp=timestamp - ) + result = await self.enhanced_memory.process_conversation_memory(payload_context) # 更新统计 processing_time = time.time() - start_time @@ -156,7 +152,7 @@ class MemoryIntegrationLayer: async def retrieve_relevant_memories( self, query: str, - user_id: str, + user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None ) -> List[MemoryChunk]: @@ -168,7 +164,7 @@ class MemoryIntegrationLayer: limit = limit or self.config.max_retrieval_results memories = await self.enhanced_memory.retrieve_relevant_memories( query=query, - user_id=user_id, + user_id=None, context=context or {}, limit=limit ) diff --git a/src/chat/memory_system/memory_builder.py b/src/chat/memory_system/memory_builder.py index 540a52ded..592172e33 100644 --- a/src/chat/memory_system/memory_builder.py +++ b/src/chat/memory_system/memory_builder.py @@ -2,21 +2,48 @@ """ 记忆构建模块 从对话流中提取高质量、结构化记忆单元 +输出格式要求: +{{ + "memories": [ + {{ + "type": "记忆类型", + "display": "用于直接展示和检索的自然语言描述", + "subject": ["主体1", "主体2"], + "predicate": "谓语(动作/状态)", + "object": "宾语(对象/属性或结构体)", + "keywords": ["关键词1", "关键词2"], + "importance": "重要性等级(1-4)", + "confidence": "置信度(1-4)", + "reasoning": "提取理由" + }} + ] +}} + +注意: +1. `subject` 可包含多个主体,请用数组表示;若主体不明确,请根据上下文给出最合理的称呼 +2. `display` 必须是一句完整流畅的中文描述,可直接用于用户展示和向量搜索 +3. 只提取确实值得记忆的信息,不要提取琐碎内容 +4. 确保信息准确、具体、有价值 +5. 重要性: 1=低, 2=一般, 3=高, 4=关键;置信度: 1=低, 2=中等, 3=高, 4=已验证 """ import re import time -import orjson -from typing import Dict, List, Optional, Any -from datetime import datetime from dataclasses import dataclass +from datetime import datetime from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Union, Type + +import orjson from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.chat.memory_system.memory_chunk import ( - MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel, - create_memory_chunk + MemoryChunk, + MemoryType, + ConfidenceLevel, + ImportanceLevel, + create_memory_chunk, ) logger = get_logger(__name__) @@ -24,6 +51,7 @@ logger = get_logger(__name__) class ExtractionStrategy(Enum): """提取策略""" + LLM_BASED = "llm_based" # 基于LLM的智能提取 RULE_BASED = "rule_based" # 基于规则的提取 HYBRID = "hybrid" # 混合策略 @@ -171,18 +199,18 @@ class MemoryBuilder: """使用规则提取记忆""" memories = [] - subject_display = self._resolve_user_display(context, user_id) + subjects = self._resolve_conversation_participants(context, user_id) # 规则1: 检测个人信息 - personal_info = self._extract_personal_info(text, user_id, timestamp, context, subject_display) + personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects) memories.extend(personal_info) # 规则2: 检测偏好信息 - preferences = self._extract_preferences(text, user_id, timestamp, context, subject_display) + preferences = self._extract_preferences(text, user_id, timestamp, context, subjects) memories.extend(preferences) # 规则3: 检测事件信息 - events = self._extract_events(text, user_id, timestamp, context, subject_display) + events = self._extract_events(text, user_id, timestamp, context, subjects) memories.extend(events) return memories @@ -258,10 +286,7 @@ class MemoryBuilder: 你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。 当前时间: {current_date} -聊天ID: {chat_id} 消息类型: {message_type} -目标用户ID: {target_user_id_display} -目标用户称呼: {target_user_name} ## 🤖 机器人身份(仅供参考,禁止写入记忆) - 机器人名称: {bot_name_display} @@ -272,7 +297,6 @@ class MemoryBuilder: 请务必遵守以下命名规范: - 当说话者是机器人时,请使用“{bot_name_display}”或其他明确称呼作为主语; -- 如果看到系统自动生成的长ID(类似 {target_user_id}),请改用“{target_user_name}”、机器人的称呼或“该用户”描述,不要把ID写入记忆; - 记录关键事实时,请准确标记主体是机器人还是用户,避免混淆。 对话内容: @@ -450,7 +474,7 @@ class MemoryBuilder: bot_identifiers = self._collect_bot_identifiers(context) system_identifiers = self._collect_system_identifiers(context) - default_subject = self._resolve_user_display(context, user_id) + default_subjects = self._resolve_conversation_participants(context, user_id) bot_display = None if context: @@ -481,19 +505,33 @@ class MemoryBuilder: for mem_data in memory_list: try: subject_value = mem_data.get("subject") - normalized_subject = self._normalize_subject( + normalized_subject = self._normalize_subjects( subject_value, bot_identifiers, system_identifiers, - default_subject, + default_subjects, bot_display ) - if normalized_subject is None: + if not normalized_subject: logger.debug("跳过疑似机器人自身信息的记忆: %s", mem_data) continue # 创建记忆块 + importance_level = self._parse_enum_value( + ImportanceLevel, + mem_data.get("importance"), + ImportanceLevel.NORMAL, + "importance" + ) + + confidence_level = self._parse_enum_value( + ConfidenceLevel, + mem_data.get("confidence"), + ConfidenceLevel.MEDIUM, + "confidence" + ) + memory = create_memory_chunk( user_id=user_id, subject=normalized_subject, @@ -502,8 +540,9 @@ class MemoryBuilder: memory_type=MemoryType(mem_data.get("type", "contextual")), chat_id=context.get("chat_id"), source_context=mem_data.get("reasoning", ""), - importance=ImportanceLevel(mem_data.get("importance", 2)), - confidence=ConfidenceLevel(mem_data.get("confidence", 2)) + importance=importance_level, + confidence=confidence_level, + display=mem_data.get("display") ) # 添加关键词 @@ -511,13 +550,6 @@ class MemoryBuilder: for keyword in keywords: memory.add_keyword(keyword) - subject_text = memory.content.subject.strip() if isinstance(memory.content.subject, str) else str(memory.content.subject) - if not subject_text: - memory.content.subject = default_subject - elif subject_text.lower() in system_identifiers or self._looks_like_system_identifier(subject_text): - logger.debug("将系统标识主语替换为默认用户名称: %s", subject_text) - memory.content.subject = default_subject - memories.append(memory) except Exception as e: @@ -526,6 +558,64 @@ class MemoryBuilder: return memories + def _parse_enum_value( + self, + enum_cls: Type[Enum], + raw_value: Any, + default: Enum, + field_name: str + ) -> Enum: + """解析枚举值,兼容数字/字符串表示""" + if isinstance(raw_value, enum_cls): + return raw_value + + if raw_value is None: + return default + + # 直接尝试整数转换 + if isinstance(raw_value, (int, float)): + int_value = int(raw_value) + try: + return enum_cls(int_value) + except ValueError: + logger.debug("%s=%s 无法解析为 %s", field_name, raw_value, enum_cls.__name__) + return default + + if isinstance(raw_value, str): + value_str = raw_value.strip() + if not value_str: + return default + + if value_str.isdigit(): + try: + return enum_cls(int(value_str)) + except ValueError: + logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__) + else: + normalized = value_str.replace("-", "_").replace(" ", "_").upper() + for member in enum_cls: + if member.name == normalized: + return member + for member in enum_cls: + if str(member.value).lower() == value_str.lower(): + return member + + try: + return enum_cls(value_str) + except ValueError: + logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__) + + try: + return enum_cls(raw_value) + except Exception: + logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s", + field_name, + raw_value, + type(raw_value).__name__, + enum_cls.__name__, + default.name) + return default + def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]: identifiers: set[str] = {"bot", "机器人", "ai助手"} if not context: @@ -580,6 +670,58 @@ class MemoryBuilder: return identifiers + def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]: + participants: List[str] = [] + + if context: + candidate_keys = [ + "participants", + "participant_names", + "speaker_names", + "members", + "member_names", + "mention_users", + "audiences" + ] + + for key in candidate_keys: + value = context.get(key) + if isinstance(value, (list, tuple, set)): + for item in value: + if isinstance(item, str): + cleaned = self._clean_subject_text(item) + if cleaned: + participants.append(cleaned) + elif isinstance(value, str): + for part in self._split_subject_string(value): + if part: + participants.append(part) + + fallback = self._resolve_user_display(context, user_id) + if fallback: + participants.append(fallback) + + if context: + bot_name = context.get("bot_name") or context.get("bot_identity") + if isinstance(bot_name, str): + cleaned = self._clean_subject_text(bot_name) + if cleaned: + participants.append(cleaned) + + if not participants: + participants = ["对话参与者"] + + deduplicated: List[str] = [] + seen = set() + for name in participants: + key = name.lower() + if key in seen: + continue + seen.add(key) + deduplicated.append(name) + + return deduplicated + def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str: candidate_keys = [ "user_display_name", @@ -626,51 +768,160 @@ class MemoryBuilder: return False - def _normalize_subject( + def _split_subject_string(self, value: str) -> List[str]: + if not value: + return [] + + replaced = re.sub(r"\band\b", "、", value, flags=re.IGNORECASE) + replaced = replaced.replace("和", "、").replace("与", "、").replace("及", "、") + replaced = replaced.replace("&", "、").replace("/", "、").replace("+", "、") + + tokens = [self._clean_subject_text(token) for token in re.split(r"[、,,;;]+", replaced)] + return [token for token in tokens if token] + + def _normalize_subjects( self, subject: Any, bot_identifiers: set[str], system_identifiers: set[str], - default_subject: str, + default_subjects: List[str], bot_display: Optional[str] = None - ) -> Optional[str]: - if subject is None: - return default_subject + ) -> List[str]: + defaults = default_subjects or ["对话参与者"] - subject_str = subject if isinstance(subject, str) else str(subject) - cleaned = self._clean_subject_text(subject_str) - if not cleaned: - return default_subject + raw_candidates: List[str] = [] + if isinstance(subject, list): + for item in subject: + if isinstance(item, str): + raw_candidates.extend(self._split_subject_string(item)) + elif item is not None: + raw_candidates.extend(self._split_subject_string(str(item))) + elif isinstance(subject, str): + raw_candidates.extend(self._split_subject_string(subject)) + elif subject is not None: + raw_candidates.extend(self._split_subject_string(str(subject))) - lowered = cleaned.lower() + normalized: List[str] = [] bot_primary = self._clean_subject_text(bot_display or "") - if lowered in bot_identifiers: - return bot_primary or cleaned + for candidate in raw_candidates: + if not candidate: + continue - if lowered in {"用户", "user", "the user", "对方", "对手"}: - return default_subject + lowered = candidate.lower() + if lowered in bot_identifiers: + normalized.append(bot_primary or candidate) + continue - prefix_match = re.match(r"^(用户|User|user|USER|成员|member|Member|target|Target|TARGET)[\s::\-\u2014_]*?(.*)$", cleaned) - if prefix_match: - remainder = self._clean_subject_text(prefix_match.group(2)) - if not remainder: - return default_subject - remainder_lower = remainder.lower() - if remainder_lower in bot_identifiers: - return bot_primary or remainder - if ( - remainder_lower in system_identifiers - or self._looks_like_system_identifier(remainder) - ): - return default_subject - cleaned = remainder - lowered = cleaned.lower() + if lowered in {"用户", "user", "the user", "对方", "对手"}: + normalized.extend(defaults) + continue - if lowered in system_identifiers or self._looks_like_system_identifier(cleaned): - return default_subject + if lowered in system_identifiers or self._looks_like_system_identifier(candidate): + continue - return cleaned + normalized.append(candidate) + + if not normalized: + normalized = list(defaults) + + deduplicated: List[str] = [] + seen = set() + for name in normalized: + key = name.lower() + if key in seen: + continue + seen.add(key) + deduplicated.append(name) + + return deduplicated + + def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]: + if isinstance(obj, dict): + for key in keys: + value = obj.get(key) + if value is None: + continue + if isinstance(value, list): + compact = "、".join(str(item) for item in value[:3]) + if compact: + return compact + else: + value_str = str(value).strip() + if value_str: + return value_str + elif isinstance(obj, list): + compact = "、".join(str(item) for item in obj[:3]) + return compact or None + elif isinstance(obj, str): + return obj.strip() or None + return None + + def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str: + subject_phrase = "、".join(subjects) if subjects else "对话参与者" + predicate = (predicate or "").strip() + + if predicate == "is_named": + name = self._extract_value_from_object(obj, ["name", "nickname"]) or "" + name = self._clean_subject_text(name) + if name: + quoted = name if (name.startswith("「") and name.endswith("」")) else f"「{name}」" + return f"{subject_phrase}的昵称是{quoted}" + elif predicate == "is_age": + age = self._extract_value_from_object(obj, ["age"]) or "" + age = self._clean_subject_text(age) + if age: + return f"{subject_phrase}今年{age}岁" + elif predicate == "is_profession": + profession = self._extract_value_from_object(obj, ["profession", "job"]) or "" + profession = self._clean_subject_text(profession) + if profession: + return f"{subject_phrase}的职业是{profession}" + elif predicate == "lives_in": + location = self._extract_value_from_object(obj, ["location", "city", "place"]) or "" + location = self._clean_subject_text(location) + if location: + return f"{subject_phrase}居住在{location}" + elif predicate == "has_phone": + phone = self._extract_value_from_object(obj, ["phone", "number"]) or "" + phone = self._clean_subject_text(phone) + if phone: + return f"{subject_phrase}的电话号码是{phone}" + elif predicate == "has_email": + email = self._extract_value_from_object(obj, ["email"]) or "" + email = self._clean_subject_text(email) + if email: + return f"{subject_phrase}的邮箱是{email}" + elif predicate in {"likes", "likes_food", "favorite_is"}: + liked = self._extract_value_from_object(obj, ["item", "value", "name"]) or "" + liked = self._clean_subject_text(liked) + if liked: + verb = "喜欢" if predicate != "likes_food" else "爱吃" + if predicate == "favorite_is": + verb = "最喜欢" + return f"{subject_phrase}{verb}{liked}" + elif predicate in {"dislikes", "hates"}: + disliked = self._extract_value_from_object(obj, ["item", "value", "name"]) or "" + disliked = self._clean_subject_text(disliked) + if disliked: + verb = "不喜欢" if predicate == "dislikes" else "讨厌" + return f"{subject_phrase}{verb}{disliked}" + elif predicate == "mentioned_event": + description = self._extract_value_from_object(obj, ["event_text", "description"]) or "" + description = self._clean_subject_text(description) + if description: + return f"{subject_phrase}提到了:{description}" + + obj_text = self._extract_value_from_object(obj, ["value", "detail", "content"]) or "" + obj_text = self._clean_subject_text(obj_text) + + if predicate and obj_text: + return f"{subject_phrase}{predicate}{obj_text}".strip() + if obj_text: + return f"{subject_phrase}{obj_text}".strip() + if predicate: + return f"{subject_phrase}{predicate}".strip() + return subject_phrase def _extract_personal_info( self, @@ -678,7 +929,7 @@ class MemoryBuilder: user_id: str, timestamp: float, context: Dict[str, Any], - subject_display: str + subjects: List[str] ) -> List[MemoryChunk]: """提取个人信息""" memories = [] @@ -702,13 +953,14 @@ class MemoryBuilder: memory = create_memory_chunk( user_id=user_id, - subject=subject_display, + subject=subjects, predicate=predicate, obj=obj, memory_type=MemoryType.PERSONAL_FACT, chat_id=context.get("chat_id"), importance=ImportanceLevel.HIGH, - confidence=ConfidenceLevel.HIGH + confidence=ConfidenceLevel.HIGH, + display=self._compose_display_text(subjects, predicate, obj) ) memories.append(memory) @@ -721,7 +973,7 @@ class MemoryBuilder: user_id: str, timestamp: float, context: Dict[str, Any], - subject_display: str + subjects: List[str] ) -> List[MemoryChunk]: """提取偏好信息""" memories = [] @@ -740,13 +992,14 @@ class MemoryBuilder: if match: memory = create_memory_chunk( user_id=user_id, - subject=subject_display, + subject=subjects, predicate=predicate, obj=match.group(1), memory_type=MemoryType.PREFERENCE, chat_id=context.get("chat_id"), importance=ImportanceLevel.NORMAL, - confidence=ConfidenceLevel.MEDIUM + confidence=ConfidenceLevel.MEDIUM, + display=self._compose_display_text(subjects, predicate, match.group(1)) ) memories.append(memory) @@ -759,7 +1012,7 @@ class MemoryBuilder: user_id: str, timestamp: float, context: Dict[str, Any], - subject_display: str + subjects: List[str] ) -> List[MemoryChunk]: """提取事件信息""" memories = [] @@ -770,13 +1023,14 @@ class MemoryBuilder: if any(keyword in text for keyword in event_keywords): memory = create_memory_chunk( user_id=user_id, - subject=subject_display, + subject=subjects, predicate="mentioned_event", obj={"event_text": text, "timestamp": timestamp}, memory_type=MemoryType.EVENT, chat_id=context.get("chat_id"), importance=ImportanceLevel.NORMAL, - confidence=ConfidenceLevel.MEDIUM + confidence=ConfidenceLevel.MEDIUM, + display=self._compose_display_text(subjects, "mentioned_event", text) ) memories.append(memory) diff --git a/src/chat/memory_system/memory_chunk.py b/src/chat/memory_system/memory_chunk.py index 0b9da0180..7e40ee55c 100644 --- a/src/chat/memory_system/memory_chunk.py +++ b/src/chat/memory_system/memory_chunk.py @@ -7,7 +7,7 @@ import time import uuid import orjson -from typing import Dict, List, Optional, Any, Union +from typing import Dict, List, Optional, Any, Union, Iterable from dataclasses import dataclass, field, asdict from datetime import datetime from enum import Enum @@ -52,17 +52,20 @@ class ImportanceLevel(Enum): @dataclass class ContentStructure: - """主谓宾三元组结构""" - subject: str # 主语(通常为用户) - predicate: str # 谓语(动作、状态、关系) - object: Union[str, Dict] # 宾语(对象、属性、值) + """主谓宾结构,包含自然语言描述""" + + subject: Union[str, List[str]] + predicate: str + object: Union[str, Dict] + display: str = "" def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" return { "subject": self.subject, "predicate": self.predicate, - "object": self.object + "object": self.object, + "display": self.display } @classmethod @@ -71,16 +74,25 @@ class ContentStructure: return cls( subject=data.get("subject", ""), predicate=data.get("predicate", ""), - object=data.get("object", "") + object=data.get("object", ""), + display=data.get("display", "") ) + def to_subject_list(self) -> List[str]: + """将主语转换为列表形式""" + if isinstance(self.subject, list): + return [s for s in self.subject if isinstance(s, str) and s.strip()] + if isinstance(self.subject, str) and self.subject.strip(): + return [self.subject.strip()] + return [] + def __str__(self) -> str: """字符串表示""" - if isinstance(self.object, dict): - object_str = str(self.object) - else: - object_str = str(self.object) - return f"{self.subject} {self.predicate} {object_str}" + if self.display: + return self.display + subjects = "、".join(self.to_subject_list()) or str(self.subject) + object_str = self.object if isinstance(self.object, str) else str(self.object) + return f"{subjects} {self.predicate} {object_str}".strip() @dataclass @@ -236,9 +248,19 @@ class MemoryChunk: @property def text_content(self) -> str: - """获取文本内容""" + """获取文本内容(优先使用display)""" return str(self.content) + @property + def display(self) -> str: + """获取展示文本""" + return self.content.display or str(self.content) + + @property + def subjects(self) -> List[str]: + """获取主语列表""" + return self.content.to_subject_list() + def update_access(self): """更新访问信息""" self.metadata.update_access() @@ -415,16 +437,42 @@ class MemoryChunk: confidence_icon = "●" * self.metadata.confidence.value importance_icon = "★" * self.metadata.importance.value - return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}" + return f"{emoji} [{self.memory_type.value}] {self.display} {confidence_icon} {importance_icon}" def __repr__(self) -> str: """调试表示""" return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})" +def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str: + """根据主谓宾生成自然语言描述""" + subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)] + subject_part = "、".join(subjects_clean) if subjects_clean else "对话参与者" + + if isinstance(obj, dict): + object_candidates = [] + for key, value in obj.items(): + if isinstance(value, (str, int, float)): + object_candidates.append(f"{key}:{value}") + elif isinstance(value, list): + compact = "、".join(str(item) for item in value[:3]) + object_candidates.append(f"{key}:{compact}") + object_part = ",".join(object_candidates) if object_candidates else str(obj) + else: + object_part = str(obj).strip() + + predicate_clean = predicate.strip() + if not predicate_clean: + return f"{subject_part} {object_part}".strip() + + if object_part: + return f"{subject_part}{predicate_clean}{object_part}".strip() + return f"{subject_part}{predicate_clean}".strip() + + def create_memory_chunk( user_id: str, - subject: str, + subject: Union[str, List[str]], predicate: str, obj: Union[str, Dict], memory_type: MemoryType, @@ -432,6 +480,7 @@ def create_memory_chunk( source_context: Optional[str] = None, importance: ImportanceLevel = ImportanceLevel.NORMAL, confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM, + display: Optional[str] = None, **kwargs ) -> MemoryChunk: """便捷的内存块创建函数""" @@ -447,10 +496,22 @@ def create_memory_chunk( source_context=source_context ) + subjects: List[str] + if isinstance(subject, list): + subjects = [s for s in subject if isinstance(s, str) and s.strip()] + subject_payload: Union[str, List[str]] = subjects + else: + cleaned = subject.strip() if isinstance(subject, str) else "" + subjects = [cleaned] if cleaned else [] + subject_payload = cleaned + + display_text = display or _build_display_text(subjects, predicate, obj) + content = ContentStructure( - subject=subject, + subject=subject_payload, predicate=predicate, - object=obj + object=obj, + display=display_text ) chunk = MemoryChunk( diff --git a/src/chat/memory_system/memory_fusion.py b/src/chat/memory_system/memory_fusion.py index 26790f318..bd47bb84c 100644 --- a/src/chat/memory_system/memory_fusion.py +++ b/src/chat/memory_system/memory_fusion.py @@ -266,8 +266,12 @@ class MemoryFusionEngine: consistency_score = 0.0 # 主语一致性 - if mem1.content.subject == mem2.content.subject: - consistency_score += 0.4 + subjects1 = set(mem1.subjects) + subjects2 = set(mem2.subjects) + if subjects1 or subjects2: + overlap = len(subjects1 & subjects2) + union_count = max(len(subjects1 | subjects2), 1) + consistency_score += (overlap / union_count) * 0.4 # 谓语相似性 predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate) diff --git a/src/chat/memory_system/memory_integration_hooks.py b/src/chat/memory_system/memory_integration_hooks.py index 6728613e4..2dab63b7a 100644 --- a/src/chat/memory_system/memory_integration_hooks.py +++ b/src/chat/memory_system/memory_integration_hooks.py @@ -282,9 +282,11 @@ class MemoryIntegrationHooks: } # 使用增强记忆系统处理对话 - result = await process_conversation_with_enhanced_memory( - conversation_text, context, user_id - ) + memory_context = dict(context) + memory_context["conversation_text"] = conversation_text + memory_context["user_id"] = user_id + + result = await process_conversation_with_enhanced_memory(memory_context) processing_time = time.time() - start_time self._update_hook_stats(processing_time) @@ -336,9 +338,11 @@ class MemoryIntegrationHooks: } # 使用增强记忆系统处理对话 - result = await process_conversation_with_enhanced_memory( - conversation_text, context, user_id - ) + memory_context = dict(context) + memory_context["conversation_text"] = conversation_text + memory_context["user_id"] = user_id + + result = await process_conversation_with_enhanced_memory(memory_context) processing_time = time.time() - start_time self._update_hook_stats(processing_time) diff --git a/src/chat/memory_system/memory_query_planner.py b/src/chat/memory_system/memory_query_planner.py new file mode 100644 index 000000000..d2e80a4a5 --- /dev/null +++ b/src/chat/memory_system/memory_query_planner.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +"""记忆检索查询规划器""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import orjson + +from src.chat.memory_system.memory_chunk import MemoryType +from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest + +logger = get_logger(__name__) + + +@dataclass +class MemoryQueryPlan: + """查询规划结果""" + + semantic_query: str + memory_types: List[MemoryType] = field(default_factory=list) + subject_includes: List[str] = field(default_factory=list) + object_includes: List[str] = field(default_factory=list) + required_keywords: List[str] = field(default_factory=list) + optional_keywords: List[str] = field(default_factory=list) + owner_filters: List[str] = field(default_factory=list) + recency_preference: str = "any" + limit: int = 10 + emphasis: Optional[str] = None + raw_plan: Dict[str, Any] = field(default_factory=dict) + + def ensure_defaults(self, fallback_query: str, default_limit: int) -> None: + if not self.semantic_query: + self.semantic_query = fallback_query + if self.limit <= 0: + self.limit = default_limit + self.recency_preference = (self.recency_preference or "any").lower() + if self.recency_preference not in {"any", "recent", "historical"}: + self.recency_preference = "any" + self.emphasis = (self.emphasis or "balanced").lower() + + +class MemoryQueryPlanner: + """基于小模型的记忆检索查询规划器""" + + def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10): + self.model = planner_model + self.default_limit = default_limit + + async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan: + if not self.model: + logger.debug("未提供查询规划模型,使用默认规划") + return self._default_plan(query_text) + + prompt = self._build_prompt(query_text, context) + + try: + response, _ = await self.model.generate_response_async(prompt, temperature=0.2) + payload = self._extract_json_payload(response) + if not payload: + logger.debug("查询规划模型未返回结构化结果,使用默认规划") + return self._default_plan(query_text) + + try: + data = orjson.loads(payload) + except orjson.JSONDecodeError as exc: + preview = payload[:200] + logger.warning("解析查询规划JSON失败: %s,片段: %s", exc, preview) + return self._default_plan(query_text) + + plan = self._parse_plan_dict(data, query_text) + plan.ensure_defaults(query_text, self.default_limit) + return plan + + except Exception as exc: + logger.error("查询规划模型调用失败: %s", exc, exc_info=True) + return self._default_plan(query_text) + + def _default_plan(self, query_text: str) -> MemoryQueryPlan: + return MemoryQueryPlan( + semantic_query=query_text, + limit=self.default_limit + ) + + def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan: + semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query + + def _collect_list(key: str) -> List[str]: + value = data.get(key) + if isinstance(value, str): + return [value] + if isinstance(value, list): + return [self._safe_str(item) for item in value if self._safe_str(item)] + return [] + + memory_type_values = _collect_list("memory_types") + memory_types: List[MemoryType] = [] + for item in memory_type_values: + if not item: + continue + try: + memory_types.append(MemoryType(item)) + except ValueError: + # 尝试匹配value值 + normalized = item.lower() + for mt in MemoryType: + if mt.value == normalized: + memory_types.append(mt) + break + + plan = MemoryQueryPlan( + semantic_query=semantic_query, + memory_types=memory_types, + subject_includes=_collect_list("subject_includes"), + object_includes=_collect_list("object_includes"), + required_keywords=_collect_list("required_keywords"), + optional_keywords=_collect_list("optional_keywords"), + owner_filters=_collect_list("owner_filters"), + recency_preference=self._safe_str(data.get("recency")) or "any", + limit=self._safe_int(data.get("limit"), self.default_limit), + emphasis=self._safe_str(data.get("emphasis")) or "balanced", + raw_plan=data + ) + return plan + + def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str: + participants = context.get("participants") or context.get("speaker_names") or [] + if isinstance(participants, str): + participants = [participants] + participants = [p for p in participants if isinstance(p, str) and p.strip()] + participant_preview = "、".join(participants[:5]) or "未知" + + persona = context.get("bot_personality") or context.get("bot_identity") or "未知" + + return f""" +你是一名记忆检索分析师,将根据对话查询生成结构化的检索计划。 +请结合提供的上下文,输出一个JSON对象,字段含义如下: +- semantic_query: 提供给向量检索的自然语言查询,要求清晰具体; +- memory_types: 建议检索的记忆类型数组,取值范围参见 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual); +- subject_includes: 需要出现在记忆主语中的人或角色列表; +- object_includes: 记忆中需要提到的重要对象或主题关键词列表; +- required_keywords: 检索时必须包含的关键词; +- optional_keywords: 可以提升相关性的附加关键词; +- owner_filters: 如果需要限制检索所属主体,请列出用户ID或其它标识; +- recency: 建议的时间偏好,可选 recent/any/historical; +- emphasis: 检索策略倾向,可选 precision/recall/balanced; +- limit: 推荐的最大返回数量(1-15之间); +- notes: 额外说明,可选。 + +当前查询: "{query_text}" +已知的对话参与者: {participant_preview} +机器人设定: {persona} + +请输出符合要求的JSON,禁止添加额外说明或Markdown代码块。 +""" + + def _extract_json_payload(self, response: str) -> Optional[str]: + if not response: + return None + + stripped = response.strip() + code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL) + if code_block_match: + candidate = code_block_match.group(1).strip() + if candidate: + return candidate + + start = stripped.find("{") + end = stripped.rfind("}") + if start != -1 and end != -1 and end > start: + return stripped[start:end + 1] + + return stripped if stripped.startswith("{") and stripped.endswith("}") else None + + @staticmethod + def _safe_str(value: Any) -> str: + if isinstance(value, str): + return value.strip() + if value is None: + return "" + return str(value).strip() + + @staticmethod + def _safe_int(value: Any, default: int) -> int: + try: + number = int(value) + if number <= 0: + return default + return number + except (TypeError, ValueError): + return default \ No newline at end of file diff --git a/src/chat/memory_system/metadata_index.py b/src/chat/memory_system/metadata_index.py index 77aa8c995..10f8ff266 100644 --- a/src/chat/memory_system/metadata_index.py +++ b/src/chat/memory_system/metadata_index.py @@ -25,6 +25,7 @@ class IndexType(Enum): """索引类型""" MEMORY_TYPE = "memory_type" # 记忆类型索引 USER_ID = "user_id" # 用户ID索引 + SUBJECT = "subject" # 主体索引 KEYWORD = "keyword" # 关键词索引 TAG = "tag" # 标签索引 CATEGORY = "category" # 分类索引 @@ -41,6 +42,7 @@ class IndexQuery: """索引查询条件""" user_ids: Optional[List[str]] = None memory_types: Optional[List[MemoryType]] = None + subjects: Optional[List[str]] = None keywords: Optional[List[str]] = None tags: Optional[List[str]] = None categories: Optional[List[str]] = None @@ -76,6 +78,7 @@ class MetadataIndexManager: self.indices = { IndexType.MEMORY_TYPE: defaultdict(set), IndexType.USER_ID: defaultdict(set), + IndexType.SUBJECT: defaultdict(set), IndexType.KEYWORD: defaultdict(set), IndexType.TAG: defaultdict(set), IndexType.CATEGORY: defaultdict(set), @@ -110,6 +113,41 @@ class MetadataIndexManager: self.auto_save_interval = 500 # 每500次操作自动保存 self._operation_count = 0 + @staticmethod + def _serialize_index_key(index_type: IndexType, key: Any) -> str: + """将索引键序列化为字符串以便存储""" + if isinstance(key, Enum): + value = key.value + else: + value = key + return str(value) + + @staticmethod + def _deserialize_index_key(index_type: IndexType, key: str) -> Any: + """根据索引类型反序列化索引键""" + try: + if index_type == IndexType.MEMORY_TYPE: + return MemoryType(key) + if index_type == IndexType.CONFIDENCE: + return ConfidenceLevel(int(key)) + if index_type == IndexType.IMPORTANCE: + return ImportanceLevel(int(key)) + # 其他索引键默认使用原始字符串(可能已经是lower后的字符串) + return key + except Exception: + logger.warning("无法反序列化索引键 %s 在索引 %s 中,使用原始字符串", key, index_type.value) + return key + + @staticmethod + def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]: + serialized = {} + for field_name, value in metadata.items(): + if isinstance(value, Enum): + serialized[field_name] = value.value + else: + serialized[field_name] = value + return serialized + async def index_memories(self, memories: List[MemoryChunk]): """为记忆建立索引""" if not memories: @@ -142,6 +180,68 @@ class MetadataIndexManager: except Exception as e: logger.error(f"❌ 元数据索引失败: {e}", exc_info=True) + async def update_memory_entry(self, memory: MemoryChunk): + """更新已存在记忆的索引信息""" + if not memory: + return + + with self._lock: + entry = self.memory_metadata_cache.get(memory.memory_id) + if entry is None: + # 若不存在则作为新记忆索引 + self._index_single_memory(memory) + return + + old_confidence = entry.get("confidence") + old_importance = entry.get("importance") + old_semantic_hash = entry.get("semantic_hash") + + entry.update( + { + "user_id": memory.user_id, + "memory_type": memory.memory_type, + "created_at": memory.metadata.created_at, + "last_accessed": memory.metadata.last_accessed, + "access_count": memory.metadata.access_count, + "confidence": memory.metadata.confidence, + "importance": memory.metadata.importance, + "relationship_score": memory.metadata.relationship_score, + "relevance_score": memory.metadata.relevance_score, + "semantic_hash": memory.semantic_hash, + "subjects": memory.subjects, + } + ) + + # 更新置信度/重要性索引 + if isinstance(old_confidence, ConfidenceLevel): + self.indices[IndexType.CONFIDENCE][old_confidence].discard(memory.memory_id) + if isinstance(old_importance, ImportanceLevel): + self.indices[IndexType.IMPORTANCE][old_importance].discard(memory.memory_id) + if isinstance(old_semantic_hash, str): + self.indices[IndexType.SEMANTIC_HASH][old_semantic_hash].discard(memory.memory_id) + + self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory.memory_id) + self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory.memory_id) + if memory.semantic_hash: + self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory.memory_id) + + # 同步关键词/标签/分类索引 + for keyword in memory.keywords: + if keyword: + self.indices[IndexType.KEYWORD][keyword.lower()].add(memory.memory_id) + + for tag in memory.tags: + if tag: + self.indices[IndexType.TAG][tag.lower()].add(memory.memory_id) + + for category in memory.categories: + if category: + self.indices[IndexType.CATEGORY][category.lower()].add(memory.memory_id) + + for subject in memory.subjects: + if subject: + self.indices[IndexType.SUBJECT][subject.strip().lower()].add(memory.memory_id) + def _index_single_memory(self, memory: MemoryChunk): """为单个记忆建立索引""" memory_id = memory.memory_id @@ -157,7 +257,8 @@ class MetadataIndexManager: "importance": memory.metadata.importance, "relationship_score": memory.metadata.relationship_score, "relevance_score": memory.metadata.relevance_score, - "semantic_hash": memory.semantic_hash + "semantic_hash": memory.semantic_hash, + "subjects": memory.subjects } # 记忆类型索引 @@ -166,6 +267,12 @@ class MetadataIndexManager: # 用户ID索引 self.indices[IndexType.USER_ID][memory.user_id].add(memory_id) + # 主体索引 + for subject in memory.subjects: + normalized = subject.strip().lower() + if normalized: + self.indices[IndexType.SUBJECT][normalized].add(memory_id) + # 关键词索引 for keyword in memory.keywords: self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id) @@ -282,13 +389,6 @@ class MetadataIndexManager: # 应用最严格的过滤条件 applied_filters = [] - if query.user_ids: - user_ids_set = set() - for user_id in query.user_ids: - user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set())) - candidate_ids.update(user_ids_set) - applied_filters.append("user_ids") - if query.memory_types: memory_types_set = set() for memory_type in query.memory_types: @@ -302,7 +402,7 @@ class MetadataIndexManager: if query.keywords: keywords_set = set() for keyword in query.keywords: - keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set())) + keywords_set.update(self._collect_index_matches(IndexType.KEYWORD, keyword)) if applied_filters: candidate_ids &= keywords_set else: @@ -329,12 +429,55 @@ class MetadataIndexManager: candidate_ids.update(categories_set) applied_filters.append("categories") + if query.subjects: + subjects_set = set() + for subject in query.subjects: + subjects_set.update(self._collect_index_matches(IndexType.SUBJECT, subject)) + if applied_filters: + candidate_ids &= subjects_set + else: + candidate_ids.update(subjects_set) + applied_filters.append("subjects") + # 如果没有应用任何过滤条件,返回所有记忆 if not applied_filters: return all_memory_ids return candidate_ids + def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]: + """根据给定token收集索引匹配,支持部分匹配""" + mapping = self.indices.get(index_type) + if mapping is None: + return set() + + key = "" + if isinstance(token, Enum): + key = str(token.value).strip().lower() + elif isinstance(token, str): + key = token.strip().lower() + elif token is not None: + key = str(token).strip().lower() + + if not key: + return set() + + matches: Set[str] = set(mapping.get(key, set())) + + if matches: + return set(matches) + + for existing_key, ids in mapping.items(): + if not existing_key or not isinstance(existing_key, str): + continue + normalized = existing_key.strip().lower() + if not normalized: + continue + if key in normalized or normalized in key: + matches.update(ids) + + return matches + def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]: """应用过滤条件""" filtered_ids = list(candidate_ids) @@ -440,10 +583,10 @@ class MetadataIndexManager: def _get_applied_filters(self, query: IndexQuery) -> List[str]: """获取应用的过滤器列表""" filters = [] - if query.user_ids: - filters.append("user_ids") if query.memory_types: filters.append("memory_types") + if query.subjects: + filters.append("subjects") if query.keywords: filters.append("keywords") if query.tags: @@ -502,6 +645,18 @@ class MetadataIndexManager: # 从各类索引中移除 self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id) self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id) + subjects = metadata.get("subjects") or [] + for subject in subjects: + if not isinstance(subject, str): + continue + normalized = subject.strip().lower() + if not normalized: + continue + subject_bucket = self.indices[IndexType.SUBJECT].get(normalized) + if subject_bucket is not None: + subject_bucket.discard(memory_id) + if not subject_bucket: + self.indices[IndexType.SUBJECT].pop(normalized, None) # 从时间索引中移除 self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id] @@ -625,11 +780,13 @@ class MetadataIndexManager: logger.info("正在保存元数据索引...") # 保存各类索引 - indices_data = {} + indices_data: Dict[str, Dict[str, List[str]]] = {} for index_type, index_data in self.indices.items(): - indices_data[index_type.value] = { - key: list(values) for key, values in index_data.items() - } + serialized_index = {} + for key, values in index_data.items(): + serialized_key = self._serialize_index_key(index_type, key) + serialized_index[serialized_key] = list(values) + indices_data[index_type.value] = serialized_index indices_file = self.index_path / "indices.json" with open(indices_file, 'w', encoding='utf-8') as f: @@ -652,8 +809,12 @@ class MetadataIndexManager: # 保存元数据缓存 metadata_cache_file = self.index_path / "metadata_cache.json" + metadata_serialized = { + memory_id: self._serialize_metadata_entry(metadata) + for memory_id, metadata in self.memory_metadata_cache.items() + } with open(metadata_cache_file, 'w', encoding='utf-8') as f: - f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8')) + f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8')) # 保存统计信息 stats_file = self.index_path / "index_stats.json" @@ -679,9 +840,11 @@ class MetadataIndexManager: for index_type_value, index_data in indices_data.items(): index_type = IndexType(index_type_value) - self.indices[index_type] = { - key: set(values) for key, values in index_data.items() - } + restored_index = defaultdict(set) + for key_str, values in index_data.items(): + restored_key = self._deserialize_index_key(index_type, key_str) + restored_index[restored_key] = set(values) + self.indices[index_type] = restored_index # 加载时间索引 time_index_file = self.index_path / "time_index.json" @@ -709,10 +872,38 @@ class MetadataIndexManager: # 转换置信度和重要性为枚举类型 for memory_id, metadata in cache_data.items(): - if isinstance(metadata["confidence"], str): - metadata["confidence"] = ConfidenceLevel(metadata["confidence"]) - if isinstance(metadata["importance"], str): - metadata["importance"] = ImportanceLevel(metadata["importance"]) + memory_type_value = metadata.get("memory_type") + if isinstance(memory_type_value, str): + try: + metadata["memory_type"] = MemoryType(memory_type_value) + except ValueError: + logger.warning("无法解析memory_type %s", memory_type_value) + + confidence_value = metadata.get("confidence") + if isinstance(confidence_value, (str, int)): + try: + metadata["confidence"] = ConfidenceLevel(int(confidence_value)) + except ValueError: + logger.warning("无法解析confidence %s", confidence_value) + + importance_value = metadata.get("importance") + if isinstance(importance_value, (str, int)): + try: + metadata["importance"] = ImportanceLevel(int(importance_value)) + except ValueError: + logger.warning("无法解析importance %s", importance_value) + + subjects_value = metadata.get("subjects") + if isinstance(subjects_value, str): + metadata["subjects"] = [subjects_value] + elif isinstance(subjects_value, list): + cleaned_subjects = [] + for item in subjects_value: + if isinstance(item, str) and item.strip(): + cleaned_subjects.append(item.strip()) + metadata["subjects"] = cleaned_subjects + else: + metadata["subjects"] = [] self.memory_metadata_cache = cache_data diff --git a/src/chat/memory_system/multi_stage_retrieval.py b/src/chat/memory_system/multi_stage_retrieval.py index d8e7afe7b..44bb5e62e 100644 --- a/src/chat/memory_system/multi_stage_retrieval.py +++ b/src/chat/memory_system/multi_stage_retrieval.py @@ -203,11 +203,17 @@ class MultiStageRetrieval: try: from .metadata_index import IndexQuery - # 构建索引查询 + query_plan = context.get("query_plan") + + memory_types = self._extract_memory_types_from_context(context) + keywords = self._extract_keywords_from_query(query, query_plan) + subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None + index_query = IndexQuery( - user_ids=[user_id], - memory_types=self._extract_memory_types_from_context(context), - keywords=self._extract_keywords_from_query(query), + user_ids=None, + memory_types=memory_types, + subjects=subjects, + keywords=keywords, limit=self.config.metadata_filter_limit, sort_by="last_accessed", sort_order="desc" @@ -215,13 +221,66 @@ class MultiStageRetrieval: # 执行查询 result = await metadata_index.query_memories(index_query) - filtered_count = result.total_count - len(result.memory_ids) + result_ids = list(result.memory_ids) + filtered_count = max(0, len(all_memories_cache) - len(result_ids)) - logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆") + # 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆 + if not result_ids: + sorted_ids = sorted( + (memory.memory_id for memory in all_memories_cache.values()), + key=lambda mid: all_memories_cache[mid].metadata.last_accessed if mid in all_memories_cache else 0, + reverse=True, + ) + if memory_types: + type_filtered = [ + mid for mid in sorted_ids + if all_memories_cache[mid].memory_type in memory_types + ] + sorted_ids = type_filtered or sorted_ids + if subjects: + subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()] + if subject_candidates: + subject_filtered = [ + mid for mid in sorted_ids + if any( + subj.strip().lower() in subject_candidates + for subj in all_memories_cache[mid].subjects + ) + ] + sorted_ids = subject_filtered or sorted_ids + + if keywords: + keyword_pool = {kw.lower() for kw in keywords if isinstance(kw, str) and kw.strip()} + if keyword_pool: + keyword_filtered = [] + for mid in sorted_ids: + memory_text = ( + (all_memories_cache[mid].display or "") + + "\n" + + (all_memories_cache[mid].text_content or "") + ).lower() + if any(kw in memory_text for kw in keyword_pool): + keyword_filtered.append(mid) + sorted_ids = keyword_filtered or sorted_ids + + result_ids = sorted_ids[: self.config.metadata_filter_limit] + filtered_count = max(0, len(all_memories_cache) - len(result_ids)) + logger.debug( + "元数据过滤未命中索引,使用近似回退: types=%s, subjects=%s, keywords=%s", + bool(memory_types), + bool(subjects), + bool(keywords), + ) + + logger.debug( + "元数据过滤:候选=%d, 返回=%d", + len(all_memories_cache), + len(result_ids), + ) return StageResult( stage=RetrievalStage.METADATA_FILTERING, - memory_ids=result.memory_ids, + memory_ids=result_ids, processing_time=time.time() - start_time, filtered_count=filtered_count, score_threshold=0.0 @@ -251,7 +310,7 @@ class MultiStageRetrieval: try: # 生成查询向量 - query_embedding = await self._generate_query_embedding(query, context) + query_embedding = await self._generate_query_embedding(query, context, vector_storage) if not query_embedding: return StageResult( @@ -263,22 +322,24 @@ class MultiStageRetrieval: ) # 执行向量搜索 - search_result = await vector_storage.search_similar( - query_embedding, + search_result = await vector_storage.search_similar_memories( + query_vector=query_embedding, limit=self.config.vector_search_limit ) + candidate_pool = candidate_ids or set(all_memories_cache.keys()) + # 过滤候选记忆 filtered_memories = [] for memory_id, similarity in search_result: - if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold: + if memory_id in candidate_pool and similarity >= self.config.vector_similarity_threshold: filtered_memories.append((memory_id, similarity)) # 按相似度排序 filtered_memories.sort(key=lambda x: x[1], reverse=True) result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]] - filtered_count = len(candidate_ids) - len(result_ids) + filtered_count = max(0, len(candidate_pool) - len(result_ids)) logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆") @@ -407,12 +468,20 @@ class MultiStageRetrieval: score_threshold=0.0 ) - async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]: + async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]: """生成查询向量""" try: - # 这里应该调用embedding模型 - # 由于我们可能没有直接的embedding模型,返回None或使用简单的方法 - # 在实际实现中,这里应该调用与记忆存储相同的embedding模型 + query_plan = context.get("query_plan") + query_text = query + if query_plan and getattr(query_plan, "semantic_query", None): + query_text = query_plan.semantic_query + + if not query_text: + return None + + if hasattr(vector_storage, "generate_query_embedding"): + return await vector_storage.generate_query_embedding(query_text) + return None except Exception as e: logger.warning(f"生成查询向量失败: {e}") @@ -421,9 +490,15 @@ class MultiStageRetrieval: async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: """计算语义相似度""" try: + query_plan = context.get("query_plan") + query_text = query + if query_plan and getattr(query_plan, "semantic_query", None): + query_text = query_plan.semantic_query + # 简单的文本相似度计算 - query_words = set(query.lower().split()) - memory_words = set(memory.text_content.lower().split()) + query_words = set(query_text.lower().split()) + memory_text = (memory.display or memory.text_content or "").lower() + memory_words = set(memory_text.split()) if not query_words or not memory_words: return 0.0 @@ -443,10 +518,15 @@ class MultiStageRetrieval: try: score = 0.0 + query_plan = context.get("query_plan") + # 检查记忆类型是否匹配上下文 if context.get("expected_memory_types"): if memory.memory_type in context["expected_memory_types"]: score += 0.3 + elif query_plan and getattr(query_plan, "memory_types", None): + if memory.memory_type in query_plan.memory_types: + score += 0.3 # 检查关键词匹配 if context.get("keywords"): @@ -456,6 +536,35 @@ class MultiStageRetrieval: if overlap: score += len(overlap) / max(len(context_keywords), 1) * 0.4 + if query_plan: + # 主体匹配 + subject_score = self._calculate_subject_overlap(memory, getattr(query_plan, "subject_includes", [])) + score += subject_score * 0.3 + + # 对象/描述匹配 + object_keywords = getattr(query_plan, "object_includes", []) or [] + if object_keywords: + display_text = (memory.display or memory.text_content or "").lower() + hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text) + if hits: + score += min(0.3, hits * 0.1) + + optional_keywords = getattr(query_plan, "optional_keywords", []) or [] + if optional_keywords: + display_text = (memory.display or memory.text_content or "").lower() + hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text) + if hits: + score += min(0.2, hits * 0.05) + + # 时间偏好 + recency_pref = getattr(query_plan, "recency_preference", "") + if recency_pref: + memory_age = time.time() - memory.metadata.created_at + if recency_pref == "recent" and memory_age < 7 * 24 * 3600: + score += 0.2 + elif recency_pref == "historical" and memory_age > 30 * 24 * 3600: + score += 0.1 + # 检查时效性 if context.get("recent_only", False): memory_age = time.time() - memory.metadata.created_at @@ -471,6 +580,8 @@ class MultiStageRetrieval: async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float: """计算最终评分""" try: + query_plan = context.get("query_plan") + # 语义相似度 semantic_score = await self._calculate_semantic_similarity(query, memory, context) @@ -482,13 +593,29 @@ class MultiStageRetrieval: # 时效性评分 recency_score = self._calculate_recency_score(memory.metadata.created_at) + if query_plan: + recency_pref = getattr(query_plan, "recency_preference", "") + if recency_pref == "recent": + recency_score = max(recency_score, 0.8) + elif recency_pref == "historical": + recency_score = min(recency_score, 0.5) # 权重组合 + vector_weight = self.config.vector_weight + semantic_weight = self.config.semantic_weight + context_weight = self.config.context_weight + recency_weight = self.config.recency_weight + + if query_plan and getattr(query_plan, "emphasis", None) == "precision": + semantic_weight += 0.05 + elif query_plan and getattr(query_plan, "emphasis", None) == "recall": + context_weight += 0.05 + final_score = ( - semantic_score * self.config.semantic_weight + - vector_score * self.config.vector_weight + - context_score * self.config.context_weight + - recency_score * self.config.recency_weight + semantic_score * semantic_weight + + vector_score * vector_weight + + context_score * context_weight + + recency_score * recency_weight ) # 加入记忆重要性权重 @@ -501,6 +628,31 @@ class MultiStageRetrieval: logger.warning(f"计算最终评分失败: {e}") return 0.0 + def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float: + if not required_subjects: + return 0.0 + + memory_subjects = {subject.lower() for subject in memory.subjects if isinstance(subject, str)} + if not memory_subjects: + return 0.0 + + hit = 0 + total = 0 + for subject in required_subjects: + if not isinstance(subject, str): + continue + total += 1 + normalized = subject.strip().lower() + if not normalized: + continue + if any(normalized in mem_subject for mem_subject in memory_subjects): + hit += 1 + + if total == 0: + return 0.0 + + return hit / total + def _calculate_recency_score(self, timestamp: float) -> float: """计算时效性评分""" try: @@ -524,6 +676,10 @@ class MultiStageRetrieval: def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]: """从上下文中提取记忆类型""" try: + query_plan = context.get("query_plan") + if query_plan and getattr(query_plan, "memory_types", None): + return query_plan.memory_types + if "expected_memory_types" in context: return context["expected_memory_types"] @@ -544,15 +700,30 @@ class MultiStageRetrieval: except Exception: return [] - def _extract_keywords_from_query(self, query: str) -> List[str]: + def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]: """从查询中提取关键词""" try: + extracted: List[str] = [] + + if query_plan and getattr(query_plan, "required_keywords", None): + extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)]) + # 简单的关键词提取 words = query.lower().split() # 过滤停用词 stopwords = {"的", "是", "在", "有", "我", "你", "他", "她", "它", "这", "那", "了", "吗", "呢"} - keywords = [word for word in words if len(word) > 1 and word not in stopwords] - return keywords[:10] # 最多返回10个关键词 + extracted.extend(word for word in words if len(word) > 1 and word not in stopwords) + + # 去重并保留顺序 + seen = set() + deduplicated = [] + for word in extracted: + if word in seen or not word: + continue + seen.add(word) + deduplicated.append(word) + + return deduplicated[:10] except Exception: return [] diff --git a/src/chat/memory_system/vector_storage.py b/src/chat/memory_system/vector_storage.py index aeafe13aa..5fd2357f3 100644 --- a/src/chat/memory_system/vector_storage.py +++ b/src/chat/memory_system/vector_storage.py @@ -20,6 +20,7 @@ from pathlib import Path from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config +from src.common.config_helpers import resolve_embedding_dimension from src.chat.memory_system.memory_chunk import MemoryChunk logger = get_logger(__name__) @@ -36,12 +37,12 @@ except ImportError: @dataclass class VectorStorageConfig: """向量存储配置""" - dimension: int = 768 + dimension: int = 1024 similarity_threshold: float = 0.8 index_type: str = "flat" # flat, ivf, hnsw max_index_size: int = 100000 storage_path: str = "data/memory_vectors" - auto_save_interval: int = 100 # 每N次操作自动保存 + auto_save_interval: int = 10 # 每N次操作自动保存 enable_compression: bool = True @@ -50,6 +51,15 @@ class VectorStorageManager: def __init__(self, config: Optional[VectorStorageConfig] = None): self.config = config or VectorStorageConfig() + + resolved_dimension = resolve_embedding_dimension(self.config.dimension) + if resolved_dimension and resolved_dimension != self.config.dimension: + logger.info( + "向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)", + resolved_dimension, + self.config.dimension, + ) + self.config.dimension = resolved_dimension self.storage_path = Path(self.config.storage_path) self.storage_path.mkdir(parents=True, exist_ok=True) @@ -117,6 +127,32 @@ class VectorStorageManager: ) logger.info("✅ 嵌入模型初始化完成") + async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]: + """生成查询向量,用于记忆召回""" + if not query_text: + return None + + try: + await self.initialize_embedding_model() + + embedding, _ = await self.embedding_model.get_embedding(query_text) + if not embedding: + return None + + if len(embedding) != self.config.dimension: + logger.warning( + "查询向量维度不匹配: 期望 %d, 实际 %d", + self.config.dimension, + len(embedding) + ) + return None + + return self._normalize_vector(embedding) + + except Exception as exc: + logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True) + return None + async def store_memories(self, memories: List[MemoryChunk]): """存储记忆向量""" if not memories: @@ -213,7 +249,7 @@ class VectorStorageManager: results[memory_id] = embedding else: logger.warning( - "嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)", + "嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。", self.config.dimension, len(embedding) if embedding else 0, memory_id, @@ -299,14 +335,32 @@ class VectorStorageManager: async def search_similar_memories( self, - query_vector: List[float], + query_vector: Optional[List[float]] = None, + *, + query_text: Optional[str] = None, limit: int = 10, - user_id: Optional[str] = None + scope_id: Optional[str] = None ) -> List[Tuple[str, float]]: """搜索相似记忆""" start_time = time.time() try: + if query_vector is None: + if not query_text: + return [] + + query_vector = await self.generate_query_embedding(query_text) + if not query_vector: + return [] + + scope_filter: Optional[str] = None + if isinstance(scope_id, str): + normalized_scope = scope_id.strip().lower() + if normalized_scope and normalized_scope not in {"global", "global_memory"}: + scope_filter = scope_id + elif scope_id: + scope_filter = str(scope_id) + # 规范化查询向量 query_vector = self._normalize_vector(query_vector) @@ -341,10 +395,9 @@ class VectorStorageManager: memory_id = self.index_to_memory_id.get(index) if memory_id: - # 应用用户过滤 - if user_id: + if scope_filter: memory = self.memory_cache.get(memory_id) - if memory and memory.user_id != user_id: + if memory and str(memory.user_id) != scope_filter: continue similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内 @@ -481,8 +534,14 @@ class VectorStorageManager: # 保存映射关系 mapping_file = self.storage_path / "id_mapping.json" mapping_data = { - "memory_id_to_index": self.memory_id_to_index, - "index_to_memory_id": self.index_to_memory_id + "memory_id_to_index": { + str(memory_id): int(index) + for memory_id, index in self.memory_id_to_index.items() + }, + "index_to_memory_id": { + str(index): memory_id + for index, memory_id in self.index_to_memory_id.items() + } } with open(mapping_file, 'w', encoding='utf-8') as f: f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8')) @@ -529,8 +588,17 @@ class VectorStorageManager: if mapping_file.exists(): with open(mapping_file, 'r', encoding='utf-8') as f: mapping_data = orjson.loads(f.read()) - self.memory_id_to_index = mapping_data.get("memory_id_to_index", {}) - self.index_to_memory_id = mapping_data.get("index_to_memory_id", {}) + raw_memory_to_index = mapping_data.get("memory_id_to_index", {}) + self.memory_id_to_index = { + str(memory_id): int(index) + for memory_id, index in raw_memory_to_index.items() + } + + raw_index_to_memory = mapping_data.get("index_to_memory_id", {}) + self.index_to_memory_id = { + int(index): memory_id + for index, memory_id in raw_index_to_memory.items() + } # 加载FAISS索引(如果可用) if FAISS_AVAILABLE: diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index d8b9c6857..d30600b70 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -469,14 +469,14 @@ class ChatBot: async def preprocess(): # 存储消息到数据库 from .storage import MessageStorage - + try: await MessageStorage.store_message(message, message.chat_stream) logger.debug(f"消息已存储到数据库: {message.message_info.message_id}") except Exception as e: logger.error(f"存储消息到数据库失败: {e}") traceback.print_exc() - + # 使用消息管理器处理消息(保持原有功能) from src.common.data_models.database_data_model import DatabaseMessages diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index e9f77f2c4..3a499766a 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -373,12 +373,12 @@ class Prompt: # 性能优化 - 为不同任务设置不同的超时时间 task_timeouts = { - "memory_block": 5.0, # 记忆系统可能较慢,单独设置超时 - "tool_info": 3.0, # 工具信息中等速度 - "relation_info": 2.0, # 关系信息通常较快 - "knowledge_info": 3.0, # 知识库查询中等速度 - "cross_context": 2.0, # 上下文处理通常较快 - "expression_habits": 1.5, # 表达习惯处理很快 + "memory_block": 15.0, # 记忆系统 + "tool_info": 15.0, # 工具信息 + "relation_info": 10.0, # 关系信息 + "knowledge_info": 10.0, # 知识库查询 + "cross_context": 10.0, # 上下文处理 + "expression_habits": 10.0, # 表达习惯 } # 分别处理每个任务,避免慢任务影响快任务 @@ -558,12 +558,8 @@ class Prompt: ) ] - # 等待所有记忆查询完成(最多3秒) try: - running_memories, instant_memory = await asyncio.wait_for( - asyncio.gather(*memory_tasks, return_exceptions=True), - timeout=3.0 - ) + running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True) # 处理可能的异常结果 if isinstance(running_memories, Exception): diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 6ecd599af..bbe489336 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -207,6 +207,9 @@ class VideoAnalyzer: """检查视频是否已经分析过""" try: async with get_db_session() as session: + if not session: + logger.warning("无法获取数据库会话,跳过视频存在性检查。") + return None # 明确刷新会话以确保看到其他事务的最新提交 await session.expire_all() stmt = select(Videos).where(Videos.video_hash == video_hash) @@ -227,6 +230,9 @@ class VideoAnalyzer: try: async with get_db_session() as session: + if not session: + logger.warning("无法获取数据库会话,跳过视频结果存储。") + return None # 只根据video_hash查找 stmt = select(Videos).where(Videos.video_hash == video_hash) result = await session.execute(stmt) @@ -540,11 +546,14 @@ class VideoAnalyzer: # logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片") # 获取模型信息和客户端 - model_info, api_provider, client = self.video_llm._select_model() + selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response") + if not selection_result: + raise RuntimeError("无法为视频分析选择可用模型。") + model_info, api_provider, client = selection_result # logger.info(f"使用模型: {model_info.name} 进行多帧分析") # 直接执行多图片请求 - api_response = await self.video_llm._execute_request( + api_response = await self.video_llm._executor.execute_request( api_provider=api_provider, client=client, request_type=RequestType.RESPONSE, diff --git a/src/chat/utils/utils_video_legacy.py b/src/chat/utils/utils_video_legacy.py index 4d8e06681..77ca88142 100644 --- a/src/chat/utils/utils_video_legacy.py +++ b/src/chat/utils/utils_video_legacy.py @@ -461,11 +461,14 @@ class LegacyVideoAnalyzer: # logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片") # 获取模型信息和客户端 - model_info, api_provider, client = self.video_llm._select_model() + selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response") + if not selection_result: + raise RuntimeError("无法为视频分析选择可用模型 (legacy)。") + model_info, api_provider, client = selection_result # logger.info(f"使用模型: {model_info.name} 进行多帧分析") # 直接执行多图片请求 - api_response = await self.video_llm._execute_request( + api_response = await self.video_llm._executor.execute_request( api_provider=api_provider, client=client, request_type=RequestType.RESPONSE, diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index ec940ec6c..6ebe5c3d8 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Union from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config +from src.common.config_helpers import resolve_embedding_dimension from src.common.database.sqlalchemy_models import CacheEntries from src.common.database.sqlalchemy_database_api import db_query, db_save from src.common.vector_db import vector_db_service @@ -40,7 +41,11 @@ class CacheManager: # L1 缓存 (内存) self.l1_kv_cache: Dict[str, Dict[str, Any]] = {} - embedding_dim = global_config.lpmm_knowledge.embedding_dimension + embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension) + if not embedding_dim: + embedding_dim = global_config.lpmm_knowledge.embedding_dimension + + self.embedding_dimension = embedding_dim self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) self.l1_vector_id_to_key: Dict[int, str] = {} @@ -72,7 +77,7 @@ class CacheManager: embedding_array = embedding_array.flatten() # 检查维度是否符合预期 - expected_dim = global_config.lpmm_knowledge.embedding_dimension + expected_dim = getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension if embedding_array.shape[0] != expected_dim: logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}") return None diff --git a/src/common/config_helpers.py b/src/common/config_helpers.py new file mode 100644 index 000000000..5a2134fe1 --- /dev/null +++ b/src/common/config_helpers.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import Optional + +from src.config.config import global_config, model_config + + +def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]: + """获取当前配置的嵌入向量维度。 + + 优先顺序: + 1. 模型配置中 `model_task_config.embedding.embedding_dimension` + 2. 机器人配置中 `lpmm_knowledge.embedding_dimension` + 3. 调用方提供的 fallback + """ + + candidates: list[Optional[int]] = [] + + try: + embedding_task = getattr(model_config.model_task_config, "embedding", None) + if embedding_task is not None: + candidates.append(getattr(embedding_task, "embedding_dimension", None)) + except Exception: + candidates.append(None) + + try: + candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None)) + except Exception: + candidates.append(None) + + candidates.append(fallback) + + resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None) + + if resolved and sync_global: + try: + if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved: + global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined] + except Exception: + pass + + return resolved diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index cf74eedb5..64c1fd66a 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -759,30 +759,38 @@ async def initialize_database(): @asynccontextmanager -async def get_db_session() -> AsyncGenerator[AsyncSession, None]: - """异步数据库会话上下文管理器""" +async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]: + """ + 异步数据库会话上下文管理器。 + 在初始化失败时会yield None,调用方需要检查会话是否为None。 + """ session: Optional[AsyncSession] = None + SessionLocal = None try: - engine, SessionLocal = await initialize_database() + _, SessionLocal = await initialize_database() if not SessionLocal: - raise RuntimeError("Database session not initialized") - session = SessionLocal() + logger.error("数据库会话工厂 (_SessionLocal) 未初始化。") + yield None + return + except Exception as e: + logger.error(f"数据库初始化失败,无法创建会话: {e}") + yield None + return + try: + session = SessionLocal() # 对于 SQLite,在会话开始时设置 PRAGMA from src.config.config import global_config if global_config.database.database_type == "sqlite": - try: - await session.execute(text("PRAGMA busy_timeout = 60000")) - await session.execute(text("PRAGMA foreign_keys = ON")) - except Exception as e: - logger.warning(f"[SQLite] 设置会话 PRAGMA 失败: {e}") + await session.execute(text("PRAGMA busy_timeout = 60000")) + await session.execute(text("PRAGMA foreign_keys = ON")) yield session except Exception as e: - logger.error(f"数据库会话错误: {e}") + logger.error(f"数据库会话期间发生错误: {e}") if session: await session.rollback() - raise + raise # 将会话期间的错误重新抛出给调用者 finally: if session: await session.close() diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index f7e9fe514..eb5d1a1f1 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Literal, Union +from typing import List, Dict, Any, Literal, Union, Optional from pydantic import Field from threading import Lock @@ -105,6 +105,11 @@ class TaskConfig(ValidatedConfigBase): max_tokens: int = Field(default=800, description="任务最大输出token数") temperature: float = Field(default=0.7, description="模型温度") concurrency_count: int = Field(default=1, description="并发请求数量") + embedding_dimension: Optional[int] = Field( + default=None, + description="嵌入模型输出向量维度,仅在嵌入任务中使用", + ge=1, + ) @classmethod def validate_model_list(cls, v): diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 3615f49a7..53aa5a6f5 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -443,21 +443,6 @@ class MemoryConfig(ValidatedConfigBase): enable_memory: bool = Field(default=True, description="启用记忆") memory_build_interval: int = Field(default=600, description="记忆构建间隔") - memory_build_distribution: list[float] = Field( - default_factory=lambda: [6.0, 3.0, 0.6, 32.0, 12.0, 0.4], description="记忆构建分布" - ) - memory_build_sample_num: int = Field(default=8, description="记忆构建样本数量") - memory_build_sample_length: int = Field(default=40, description="记忆构建样本长度") - memory_compress_rate: float = Field(default=0.1, description="记忆压缩率") - forget_memory_interval: int = Field(default=1000, description="遗忘记忆间隔") - memory_forget_time: int = Field(default=24, description="记忆遗忘时间") - memory_forget_percentage: float = Field(default=0.01, description="记忆遗忘百分比") - consolidate_memory_interval: int = Field(default=1000, description="记忆巩固间隔") - consolidation_similarity_threshold: float = Field(default=0.7, description="巩固相似性阈值") - consolidate_memory_percentage: float = Field(default=0.01, description="巩固记忆百分比") - memory_ban_words: list[str] = Field( - default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"], description="记忆禁用词" - ) enable_instant_memory: bool = Field(default=True, description="启用即时记忆") enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆") enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆") @@ -472,8 +457,8 @@ class MemoryConfig(ValidatedConfigBase): memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值") # 向量存储配置 - vector_dimension: int = Field(default=768, description="向量维度") vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值") + semantic_similarity_threshold: float = Field(default=0.6, description="语义相似度阈值") # 多阶段检索配置 metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量") diff --git a/src/main.py b/src/main.py index 9ae7a197d..cbb4ddc7b 100644 --- a/src/main.py +++ b/src/main.py @@ -27,9 +27,8 @@ from src.plugin_system.core.event_manager import event_manager from src.plugin_system.base.component_types import EventType # from src.api.main import start_api_server -# 导入新的插件管理器和热重载管理器 +# 导入新的插件管理器 from src.plugin_system.core.plugin_manager import plugin_manager -from src.plugin_system.core.plugin_hot_reload import hot_reload_manager # 导入消息API和traceback模块 from src.common.message import get_global_api @@ -116,13 +115,7 @@ class MainSystem: except Exception as e: logger.error(f"停止消息重组器时出错: {e}") - try: - # 停止插件热重载系统 - hot_reload_manager.stop() - logger.info("🛑 插件热重载系统已停止") - except Exception as e: - logger.error(f"停止热重载系统时出错: {e}") - + try: # 停止增强记忆系统 if global_config.memory.enable_memory: @@ -229,9 +222,7 @@ MoFox_Bot(第三方修改版) # 处理所有缓存的事件订阅(插件加载完成后) event_manager.process_all_pending_subscriptions() - # 启动插件热重载系统 - hot_reload_manager.start() - + # 初始化表情管理器 get_emoji_manager().initialize() logger.info("表情包管理器初始化成功") diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 725619adb..26b79d4df 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -2,7 +2,7 @@ import time import asyncio from abc import ABC, abstractmethod -from typing import Tuple, Optional +from typing import Tuple, Optional, List, Dict, Any from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream @@ -27,8 +27,21 @@ class BaseAction(ABC): - parallel_action: 是否允许并行执行 - random_activation_probability: 随机激活概率 - llm_judge_prompt: LLM判断提示词 + + 二步Action相关属性: + - is_two_step_action: 是否为二步Action + - step_one_description: 第一步的描述 + - sub_actions: 子Action列表 """ + # 二步Action相关类属性 + is_two_step_action: bool = False + """是否为二步Action。如果为True,Action将分两步执行:第一步选择操作,第二步执行具体操作""" + step_one_description: str = "" + """第一步的描述,用于向LLM展示Action的基本功能""" + sub_actions: List[Tuple[str, str, Dict[str, str]]] = [] + """子Action列表,格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用""" + def __init__( self, action_data: dict, @@ -93,6 +106,13 @@ class BaseAction(ABC): self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy() self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL) + # 二步Action相关实例属性 + self.is_two_step_action: bool = getattr(self.__class__, "is_two_step_action", False) + self.step_one_description: str = getattr(self.__class__, "step_one_description", "") + self.sub_actions: List[Tuple[str, str, Dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy() + self._selected_sub_action: Optional[str] = None + """当前选择的子Action名称,用于二步Action的状态管理""" + # ============================================================================= # 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解) # ============================================================================= @@ -412,23 +432,32 @@ class BaseAction(ABC): logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") return False, f"未找到Action组件信息: {action_name}" - plugin_config = component_registry.get_plugin_config(component_info.plugin_name) + # 确保获取的是Action组件 + if component_info.component_type != ComponentType.ACTION: + logger.error(f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action,而是一个 '{component_info.component_type.value}'") + return False, f"组件 '{action_name}' 不是一个有效的Action" + plugin_config = component_registry.get_plugin_config(component_info.plugin_name) # 3. 实例化被调用的Action - action_instance = action_class( - action_data=called_action_data, - reasoning=f"Called by {self.action_name}", - cycle_timers=self.cycle_timers, - thinking_id=self.thinking_id, - chat_stream=self.chat_stream, - log_prefix=log_prefix, - plugin_config=plugin_config, - action_message=self.action_message, - ) + action_params = { + "action_data": called_action_data, + "reasoning": f"Called by {self.action_name}", + "cycle_timers": self.cycle_timers, + "thinking_id": self.thinking_id, + "chat_stream": self.chat_stream, + "log_prefix": log_prefix, + "plugin_config": plugin_config, + "action_message": self.action_message, + } + action_instance = action_class(**action_params) # 4. 执行Action logger.debug(f"{log_prefix} 开始执行...") - result = await action_instance.execute() + execute_result = await action_instance.execute() + # 确保返回类型符合 (bool, str) 格式 + is_success = execute_result[0] if isinstance(execute_result, tuple) and len(execute_result) > 0 else False + message = execute_result[1] if isinstance(execute_result, tuple) and len(execute_result) > 1 else "" + result = (is_success, str(message)) logger.info(f"{log_prefix} 执行完成,结果: {result}") return result @@ -477,15 +506,73 @@ class BaseAction(ABC): action_require=getattr(cls, "action_require", []).copy(), associated_types=getattr(cls, "associated_types", []).copy(), chat_type_allow=getattr(cls, "chat_type_allow", ChatType.ALL), + # 二步Action相关属性 + is_two_step_action=getattr(cls, "is_two_step_action", False), + step_one_description=getattr(cls, "step_one_description", ""), + sub_actions=getattr(cls, "sub_actions", []).copy(), ) + async def handle_step_one(self) -> Tuple[bool, str]: + """处理二步Action的第一步 + + Returns: + Tuple[bool, str]: (是否执行成功, 回复文本) + """ + if not self.is_two_step_action: + return False, "此Action不是二步Action" + + # 检查action_data中是否包含选择的子Action + selected_action = self.action_data.get("selected_action") + if not selected_action: + # 第一步:展示可用的子Action + available_actions = [sub_action[0] for sub_action in self.sub_actions] + description = self.step_one_description or f"{self.action_name}支持以下操作" + + actions_list = "\n".join([f"- {action}: {desc}" for action, desc, _ in self.sub_actions]) + response = f"{description}\n\n可用操作:\n{actions_list}\n\n请选择要执行的操作。" + + return True, response + else: + # 验证选择的子Action是否有效 + valid_actions = [sub_action[0] for sub_action in self.sub_actions] + if selected_action not in valid_actions: + return False, f"无效的操作选择: {selected_action}。可用操作: {valid_actions}" + + # 保存选择的子Action + self._selected_sub_action = selected_action + + # 调用第二步执行 + return await self.execute_step_two(selected_action) + + async def execute_step_two(self, sub_action_name: str) -> Tuple[bool, str]: + """执行二步Action的第二步 + + Args: + sub_action_name: 子Action名称 + + Returns: + Tuple[bool, str]: (是否执行成功, 回复文本) + """ + if not self.is_two_step_action: + return False, "此Action不是二步Action" + + # 子类需要重写此方法来实现具体的第二步逻辑 + return False, f"二步Action必须实现execute_step_two方法来处理操作: {sub_action_name}" + @abstractmethod async def execute(self) -> Tuple[bool, str]: """执行Action的抽象方法,子类必须实现 + 对于二步Action,会自动处理第一步逻辑 + Returns: Tuple[bool, str]: (是否执行成功, 回复文本) """ + # 如果是二步Action,自动处理第一步 + if self.is_two_step_action: + return await self.handle_step_one() + + # 普通Action由子类实现 pass async def handle_action(self) -> Tuple[bool, str]: diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index b5022ea2a..84dc8b150 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -38,6 +38,14 @@ class BaseTool(ABC): semantic_cache_query_key: Optional[str] = None """用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索""" + # 二步工具调用相关属性 + is_two_step_tool: bool = False + """是否为二步工具。如果为True,工具将分两步调用:第一步展示工具信息,第二步执行具体操作""" + step_one_description: str = "" + """第一步的描述,用于向LLM展示工具的基本功能""" + sub_tools: List[Tuple[str, str, List[Tuple[str, ToolParamType, str, bool, List[str] | None]]]] = [] + """子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用""" + def __init__(self, plugin_config: Optional[dict] = None): self.plugin_config = plugin_config or {} # 直接存储插件配置字典 @@ -48,10 +56,64 @@ class BaseTool(ABC): Returns: dict: 工具定义字典 """ - if not cls.name or not cls.description or not cls.parameters: - raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") + if not cls.name or not cls.description: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性") - return {"name": cls.name, "description": cls.description, "parameters": cls.parameters} + # 如果是二步工具,第一步只返回基本信息 + if cls.is_two_step_tool: + return { + "name": cls.name, + "description": cls.step_one_description or cls.description, + "parameters": [("action", ToolParamType.STRING, "选择要执行的操作", True, [sub_tool[0] for sub_tool in cls.sub_tools])] + } + else: + # 普通工具需要parameters + if not cls.parameters: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 parameters 属性") + return {"name": cls.name, "description": cls.description, "parameters": cls.parameters} + + @classmethod + def get_step_two_tool_definition(cls, sub_tool_name: str) -> dict[str, Any]: + """获取二步工具的第二步定义 + + Args: + sub_tool_name: 子工具名称 + + Returns: + dict: 第二步工具定义字典 + """ + if not cls.is_two_step_tool: + raise ValueError(f"工具 {cls.name} 不是二步工具") + + # 查找对应的子工具 + for sub_name, sub_desc, sub_params in cls.sub_tools: + if sub_name == sub_tool_name: + return { + "name": f"{cls.name}_{sub_tool_name}", + "description": sub_desc, + "parameters": sub_params + } + + raise ValueError(f"未找到子工具: {sub_tool_name}") + + @classmethod + def get_all_sub_tool_definitions(cls) -> List[dict[str, Any]]: + """获取所有子工具的定义 + + Returns: + List[dict]: 所有子工具定义列表 + """ + if not cls.is_two_step_tool: + return [] + + definitions = [] + for sub_name, sub_desc, sub_params in cls.sub_tools: + definitions.append({ + "name": f"{cls.name}_{sub_name}", + "description": sub_desc, + "parameters": sub_params + }) + return definitions @classmethod def get_tool_info(cls) -> ToolInfo: @@ -79,8 +141,68 @@ class BaseTool(ABC): Returns: dict: 工具执行结果 """ + # 如果是二步工具,处理第一步调用 + if self.is_two_step_tool and "action" in function_args: + return await self._handle_step_one(function_args) + raise NotImplementedError("子类必须实现execute方法") + async def _handle_step_one(self, function_args: dict[str, Any]) -> dict[str, Any]: + """处理二步工具的第一步调用 + + Args: + function_args: 包含action参数的函数参数 + + Returns: + dict: 第一步执行结果,包含第二步的工具定义 + """ + action = function_args.get("action") + if not action: + return {"error": "缺少action参数"} + + # 查找对应的子工具 + sub_tool_found = None + for sub_name, sub_desc, sub_params in self.sub_tools: + if sub_name == action: + sub_tool_found = (sub_name, sub_desc, sub_params) + break + + if not sub_tool_found: + available_actions = [sub_tool[0] for sub_tool in self.sub_tools] + return {"error": f"未知的操作: {action}。可用操作: {available_actions}"} + + sub_name, sub_desc, sub_params = sub_tool_found + + # 返回第二步工具定义 + step_two_definition = { + "name": f"{self.name}_{sub_name}", + "description": sub_desc, + "parameters": sub_params + } + + return { + "type": "two_step_tool_step_one", + "content": f"已选择操作: {action}。请使用以下工具进行具体调用:", + "next_tool_definition": step_two_definition, + "selected_action": action + } + + async def execute_step_two(self, sub_tool_name: str, function_args: dict[str, Any]) -> dict[str, Any]: + """执行二步工具的第二步 + + Args: + sub_tool_name: 子工具名称 + function_args: 工具调用参数 + + Returns: + dict: 工具执行结果 + """ + if not self.is_two_step_tool: + raise ValueError(f"工具 {self.name} 不是二步工具") + + # 子类需要重写此方法来实现具体的第二步逻辑 + raise NotImplementedError("二步工具必须实现execute_step_two方法") + async def direct_execute(self, **kwargs: dict[str, Any]) -> dict[str, Any]: """直接执行工具函数(供插件调用) 通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数 diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 3fc943bd5..6d0590d43 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -142,6 +142,10 @@ class ActionInfo(ComponentInfo): mode_enable: ChatMode = ChatMode.ALL parallel_action: bool = False chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型 + # 二步Action相关属性 + is_two_step_action: bool = False # 是否为二步Action + step_one_description: str = "" # 第一步的描述 + sub_actions: List[Tuple[str, str, Dict[str, str]]] = field(default_factory=list) # 子Action列表 def __post_init__(self): super().__post_init__() @@ -153,6 +157,8 @@ class ActionInfo(ComponentInfo): self.action_require = [] if self.associated_types is None: self.associated_types = [] + if self.sub_actions is None: + self.sub_actions = [] self.component_type = ComponentType.ACTION diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index a73b99acb..46aa5a96c 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -8,12 +8,10 @@ from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager -from src.plugin_system.core.plugin_hot_reload import hot_reload_manager __all__ = [ "plugin_manager", "component_registry", "event_manager", "global_announcement_manager", - "hot_reload_manager", ] diff --git a/src/plugin_system/core/plugin_hot_reload.py b/src/plugin_system/core/plugin_hot_reload.py deleted file mode 100644 index 12c87a6ef..000000000 --- a/src/plugin_system/core/plugin_hot_reload.py +++ /dev/null @@ -1,512 +0,0 @@ -""" -插件热重载模块 - -使用 Watchdog 监听插件目录变化,自动重载插件 -""" - -import os -import sys -import time -import importlib -from pathlib import Path -from threading import Thread -from typing import Dict, Set, List, Optional, Tuple - -from watchdog.observers import Observer -from watchdog.events import FileSystemEventHandler - -from src.common.logger import get_logger -from .plugin_manager import plugin_manager - -logger = get_logger("plugin_hot_reload") - - -class PluginFileHandler(FileSystemEventHandler): - """插件文件变化处理器""" - - def __init__(self, hot_reload_manager): - super().__init__() - self.hot_reload_manager = hot_reload_manager - self.pending_reloads: Set[str] = set() # 待重载的插件名称 - self.last_reload_time: Dict[str, float] = {} # 上次重载时间 - self.debounce_delay = 2.0 # 增加防抖延迟到2秒,确保文件写入完成 - self.file_change_cache: Dict[str, float] = {} # 文件变化缓存 - - def on_modified(self, event): - """文件修改事件""" - if not event.is_directory: - file_path = str(event.src_path) - if file_path.endswith((".py", ".toml")): - self._handle_file_change(file_path, "modified") - - def on_created(self, event): - """文件创建事件""" - if not event.is_directory: - file_path = str(event.src_path) - if file_path.endswith((".py", ".toml")): - self._handle_file_change(file_path, "created") - - def on_deleted(self, event): - """文件删除事件""" - if not event.is_directory: - file_path = str(event.src_path) - if file_path.endswith((".py", ".toml")): - self._handle_file_change(file_path, "deleted") - - def _handle_file_change(self, file_path: str, change_type: str): - """处理文件变化""" - try: - # 获取插件名称 - plugin_info = self._get_plugin_info_from_path(file_path) - if not plugin_info: - return - - plugin_name, source_type = plugin_info - current_time = time.time() - - # 文件变化缓存,避免重复处理同一文件的快速连续变化 - file_cache_key = f"{file_path}_{change_type}" - last_file_time = self.file_change_cache.get(file_cache_key, 0) - if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略 - return - self.file_change_cache[file_cache_key] = current_time - - # 插件级别的防抖处理 - last_plugin_time = self.last_reload_time.get(plugin_name, 0) - if current_time - last_plugin_time < self.debounce_delay: - # 如果在防抖期内,更新待重载标记但不立即处理 - self.pending_reloads.add(plugin_name) - return - - file_name = Path(file_path).name - logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type}) [{source_type}] -> {plugin_name}") - - # 如果是删除事件,处理关键文件删除 - if change_type == "deleted": - # 解析实际的插件名称 - actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name) - - if file_name == "plugin.py": - if actual_plugin_name in plugin_manager.loaded_plugins: - logger.info( - f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]" - ) - self.hot_reload_manager._unload_plugin(actual_plugin_name) - else: - logger.info( - f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]" - ) - return - elif file_name in ("manifest.toml", "_manifest.json"): - if actual_plugin_name in plugin_manager.loaded_plugins: - logger.info( - f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]" - ) - self.hot_reload_manager._unload_plugin(actual_plugin_name) - else: - logger.info( - f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]" - ) - return - - # 对于修改和创建事件,都进行重载 - # 添加到待重载列表 - self.pending_reloads.add(plugin_name) - self.last_reload_time[plugin_name] = current_time - - # 延迟重载,确保文件写入完成 - reload_thread = Thread( - target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True - ) - reload_thread.start() - - except Exception as e: - logger.error(f"❌ 处理文件变化时发生错误: {e}", exc_info=True) - - def _delayed_reload(self, plugin_name: str, source_type: str, trigger_time: float): - """延迟重载插件""" - try: - # 等待文件写入完成 - time.sleep(self.debounce_delay) - - # 检查是否还需要重载(可能在等待期间有更新的变化) - if plugin_name not in self.pending_reloads: - return - - # 检查是否有更新的重载请求 - if self.last_reload_time.get(plugin_name, 0) > trigger_time: - return - - self.pending_reloads.discard(plugin_name) - logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]") - - # 执行深度重载 - success = self.hot_reload_manager._deep_reload_plugin(plugin_name) - if success: - logger.info(f"✅ 插件重载成功: {plugin_name} [{source_type}]") - else: - logger.error(f"❌ 插件重载失败: {plugin_name} [{source_type}]") - - except Exception as e: - logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}", exc_info=True) - - def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]: - """从文件路径获取插件信息 - - Returns: - tuple[插件名称, 源类型] 或 None - """ - try: - path = Path(file_path) - - # 检查是否在任何一个监听的插件目录中 - for watch_dir in self.hot_reload_manager.watch_directories: - plugin_root = Path(watch_dir) - if path.is_relative_to(plugin_root): - # 确定源类型 - if "src" in str(plugin_root): - source_type = "built-in" - else: - source_type = "external" - - # 获取插件目录名(插件名) - relative_path = path.relative_to(plugin_root) - if len(relative_path.parts) == 0: - continue - - plugin_name = relative_path.parts[0] - - # 确认这是一个有效的插件目录 - plugin_dir = plugin_root / plugin_name - if plugin_dir.is_dir(): - # 检查是否有插件主文件或配置文件 - has_plugin_py = (plugin_dir / "plugin.py").exists() - has_manifest = (plugin_dir / "manifest.toml").exists() or ( - plugin_dir / "_manifest.json" - ).exists() - - if has_plugin_py or has_manifest: - return plugin_name, source_type - - return None - - except Exception: - return None - - -class PluginHotReloadManager: - """插件热重载管理器""" - - def __init__(self, watch_directories: Optional[List[str]] = None): - if watch_directories is None: - # 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录 - self.watch_directories = [ - os.path.join(os.getcwd(), "plugins"), # 外部插件目录 - os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录 - ] - else: - self.watch_directories = watch_directories - - self.observers = [] - self.file_handlers = [] - self.is_running = False - - # 确保监听目录存在 - for watch_dir in self.watch_directories: - if not os.path.exists(watch_dir): - os.makedirs(watch_dir, exist_ok=True) - logger.info(f"📁 创建插件监听目录: {watch_dir}") - - def start(self): - """启动热重载监听""" - if self.is_running: - logger.warning("插件热重载已经在运行中") - return - - try: - # 为每个监听目录创建独立的观察者 - for watch_dir in self.watch_directories: - observer = Observer() - file_handler = PluginFileHandler(self) - - observer.schedule(file_handler, watch_dir, recursive=True) - - observer.start() - self.observers.append(observer) - self.file_handlers.append(file_handler) - - self.is_running = True - - # 打印监听的目录信息 - dir_info = [] - for watch_dir in self.watch_directories: - if "src" in watch_dir: - dir_info.append(f"{watch_dir} (内置插件)") - else: - dir_info.append(f"{watch_dir} (外部插件)") - - logger.info("🚀 插件热重载已启动,监听目录:") - for info in dir_info: - logger.info(f" 📂 {info}") - - except Exception as e: - logger.error(f"❌ 启动插件热重载失败: {e}") - self.stop() # 清理已创建的观察者 - self.is_running = False - - def stop(self): - """停止热重载监听""" - if not self.is_running and not self.observers: - return - - # 停止所有观察者 - for observer in self.observers: - try: - observer.stop() - observer.join() - except Exception as e: - logger.error(f"❌ 停止观察者时发生错误: {e}") - - self.observers.clear() - self.file_handlers.clear() - self.is_running = False - logger.info("🛑 插件热重载已停止") - - def _reload_plugin(self, plugin_name: str): - """重载指定插件(简单重载)""" - try: - # 解析实际的插件名称 - actual_plugin_name = self._resolve_plugin_name(plugin_name) - logger.info(f"🔄 开始简单重载插件: {plugin_name} -> {actual_plugin_name}") - - if plugin_manager.reload_plugin(actual_plugin_name): - logger.info(f"✅ 插件简单重载成功: {actual_plugin_name}") - return True - else: - logger.error(f"❌ 插件简单重载失败: {actual_plugin_name}") - return False - - except Exception as e: - logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}", exc_info=True) - return False - - @staticmethod - def _resolve_plugin_name(folder_name: str) -> str: - """ - 将文件夹名称解析为实际的插件名称 - 通过检查插件管理器中的路径映射来找到对应的插件名 - """ - # 首先检查是否直接匹配 - if folder_name in plugin_manager.plugin_classes: - logger.debug(f"🔍 直接匹配插件名: {folder_name}") - return folder_name - - # 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称 - matched_plugins = [] - for plugin_name, plugin_path in plugin_manager.plugin_paths.items(): - # 检查路径是否包含该文件夹名 - if folder_name in plugin_path: - matched_plugins.append((plugin_name, plugin_path)) - - # 在匹配的插件中,优先选择在插件类中存在的 - for plugin_name, plugin_path in matched_plugins: - if plugin_name in plugin_manager.plugin_classes: - logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})") - return plugin_name - - # 如果还是没找到在插件类中存在的,返回第一个匹配项 - if matched_plugins: - plugin_name, plugin_path = matched_plugins[0] - logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在") - return plugin_name - - # 如果还是没找到,返回原文件夹名 - logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称") - return folder_name - - def _deep_reload_plugin(self, plugin_name: str): - """深度重载指定插件(清理模块缓存)""" - try: - # 解析实际的插件名称 - actual_plugin_name = self._resolve_plugin_name(plugin_name) - logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}") - - # 强制清理相关模块缓存 - self._force_clear_plugin_modules(plugin_name) - - # 使用插件管理器的强制重载功能 - success = plugin_manager.force_reload_plugin(actual_plugin_name) - - if success: - logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}") - return True - else: - logger.error(f"❌ 插件深度重载失败,尝试简单重载: {actual_plugin_name}") - # 如果深度重载失败,尝试简单重载 - return self._reload_plugin(actual_plugin_name) - - except Exception as e: - logger.error(f"❌ 深度重载插件 {plugin_name} 时发生错误: {e}", exc_info=True) - # 出错时尝试简单重载 - return self._reload_plugin(plugin_name) - - @staticmethod - def _force_clear_plugin_modules(plugin_name: str): - """强制清理插件相关的模块缓存""" - - # 找到所有相关的模块名 - modules_to_remove = [] - plugin_module_prefix = f"src.plugins.built_in.{plugin_name}" - - for module_name in list(sys.modules.keys()): - if plugin_module_prefix in module_name: - modules_to_remove.append(module_name) - - # 删除模块缓存 - for module_name in modules_to_remove: - if module_name in sys.modules: - logger.debug(f"🗑️ 清理模块缓存: {module_name}") - del sys.modules[module_name] - - @staticmethod - def _force_reimport_plugin(plugin_name: str): - """强制重新导入插件(委托给插件管理器)""" - try: - # 使用插件管理器的重载功能 - success = plugin_manager.reload_plugin(plugin_name) - return success - - except Exception as e: - logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True) - return False - - @staticmethod - def _unload_plugin(plugin_name: str): - """卸载指定插件""" - try: - logger.info(f"🗑️ 开始卸载插件: {plugin_name}") - - if plugin_manager.unload_plugin(plugin_name): - logger.info(f"✅ 插件卸载成功: {plugin_name}") - return True - else: - logger.error(f"❌ 插件卸载失败: {plugin_name}") - return False - - except Exception as e: - logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}", exc_info=True) - return False - - def reload_all_plugins(self): - """重载所有插件""" - try: - logger.info("🔄 开始深度重载所有插件...") - - # 获取当前已加载的插件列表 - loaded_plugins = list(plugin_manager.loaded_plugins.keys()) - - success_count = 0 - fail_count = 0 - - for plugin_name in loaded_plugins: - logger.info(f"🔄 重载插件: {plugin_name}") - if self._deep_reload_plugin(plugin_name): - success_count += 1 - else: - fail_count += 1 - - logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个") - - # 清理全局缓存 - importlib.invalidate_caches() - - except Exception as e: - logger.error(f"❌ 重载所有插件时发生错误: {e}", exc_info=True) - - def force_reload_plugin(self, plugin_name: str): - """手动强制重载指定插件(委托给插件管理器)""" - try: - logger.info(f"🔄 手动强制重载插件: {plugin_name}") - - # 清理待重载列表中的该插件(避免重复重载) - for handler in self.file_handlers: - handler.pending_reloads.discard(plugin_name) - - # 使用插件管理器的强制重载功能 - success = plugin_manager.force_reload_plugin(plugin_name) - - if success: - logger.info(f"✅ 手动强制重载成功: {plugin_name}") - else: - logger.error(f"❌ 手动强制重载失败: {plugin_name}") - - return success - - except Exception as e: - logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True) - return False - - def add_watch_directory(self, directory: str): - """添加新的监听目录""" - if directory in self.watch_directories: - logger.info(f"目录 {directory} 已在监听列表中") - return - - # 确保目录存在 - if not os.path.exists(directory): - os.makedirs(directory, exist_ok=True) - logger.info(f"📁 创建插件监听目录: {directory}") - - self.watch_directories.append(directory) - - # 如果热重载正在运行,为新目录创建观察者 - if self.is_running: - try: - observer = Observer() - file_handler = PluginFileHandler(self) - - observer.schedule(file_handler, directory, recursive=True) - - observer.start() - self.observers.append(observer) - self.file_handlers.append(file_handler) - - logger.info(f"📂 已添加新的监听目录: {directory}") - - except Exception as e: - logger.error(f"❌ 添加监听目录 {directory} 失败: {e}") - self.watch_directories.remove(directory) - - def get_status(self) -> dict: - """获取热重载状态""" - pending_reloads = set() - if self.file_handlers: - for handler in self.file_handlers: - pending_reloads.update(handler.pending_reloads) - - return { - "is_running": self.is_running, - "watch_directories": self.watch_directories, - "active_observers": len(self.observers), - "loaded_plugins": len(plugin_manager.loaded_plugins), - "failed_plugins": len(plugin_manager.failed_plugins), - "pending_reloads": list(pending_reloads), - "debounce_delay": self.file_handlers[0].debounce_delay if self.file_handlers else 0, - } - - @staticmethod - def clear_all_caches(): - """清理所有Python模块缓存""" - try: - logger.info("🧹 开始清理所有Python模块缓存...") - - # 重新扫描所有插件目录,这会重新加载模块 - plugin_manager.rescan_plugin_directory() - logger.info("✅ 模块缓存清理完成") - - except Exception as e: - logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True) - - -# 全局热重载管理器实例 -hot_reload_manager = PluginHotReloadManager() diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 1b2618f43..daa8244cf 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -55,6 +55,10 @@ class ToolExecutor: self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") + # 二步工具调用状态管理 + self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {} + """待处理的第二步工具调用,格式为 {tool_name: step_two_definition}""" + logger.info(f"{self.log_prefix}工具执行器初始化完成") async def execute_from_chat_message( @@ -112,7 +116,18 @@ class ToolExecutor: def _get_tool_definitions(self) -> List[Dict[str, Any]]: all_tools = get_llm_available_tool_definitions() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) - return [definition for name, definition in all_tools if name not in user_disabled_tools] + + # 获取基础工具定义(包括二步工具的第一步) + tool_definitions = [definition for name, definition in all_tools if name not in user_disabled_tools] + + # 检查是否有待处理的二步工具第二步调用 + pending_step_two = getattr(self, '_pending_step_two_tools', {}) + if pending_step_two: + # 添加第二步工具定义 + for tool_name, step_two_def in pending_step_two.items(): + tool_definitions.append(step_two_def) + + return tool_definitions async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: """执行工具调用 @@ -251,6 +266,32 @@ class ToolExecutor: f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}" ) function_args["llm_called"] = True # 标记为LLM调用 + + # 检查是否是二步工具的第二步调用 + if "_" in function_name and function_name.count("_") >= 1: + # 可能是二步工具的第二步调用,格式为 "tool_name_sub_tool_name" + parts = function_name.split("_", 1) + if len(parts) == 2: + base_tool_name, sub_tool_name = parts + base_tool_instance = get_tool_instance(base_tool_name) + + if base_tool_instance and base_tool_instance.is_two_step_tool: + logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}") + result = await base_tool_instance.execute_step_two(sub_tool_name, function_args) + + # 清理待处理的第二步工具 + self._pending_step_two_tools.pop(base_tool_name, None) + + if result: + logger.debug(f"{self.log_prefix}二步工具第二步 {function_name} 执行成功") + return { + "tool_call_id": tool_call.call_id, + "role": "tool", + "name": function_name, + "type": "function", + "content": result.get("content", ""), + } + # 获取对应工具实例 tool_instance = tool_instance or get_tool_instance(function_name) if not tool_instance: @@ -260,6 +301,16 @@ class ToolExecutor: # 执行工具并记录日志 logger.debug(f"{self.log_prefix}执行工具 {function_name},参数: {function_args}") result = await tool_instance.execute(function_args) + + # 检查是否是二步工具的第一步结果 + if result and result.get("type") == "two_step_tool_step_one": + logger.info(f"{self.log_prefix}二步工具第一步完成: {function_name}") + # 保存第二步工具定义 + next_tool_def = result.get("next_tool_definition") + if next_tool_def: + self._pending_step_two_tools[function_name] = next_tool_def + logger.debug(f"{self.log_prefix}已保存第二步工具定义: {next_tool_def['name']}") + if result: logger.debug(f"{self.log_prefix}工具 {function_name} 执行成功,结果: {result}") return { diff --git a/src/plugin_system/utils/dependency_alias.py b/src/plugin_system/utils/dependency_alias.py index b5bf669e1..7a2aa1d80 100644 --- a/src/plugin_system/utils/dependency_alias.py +++ b/src/plugin_system/utils/dependency_alias.py @@ -91,7 +91,6 @@ INSTALL_NAME_TO_IMPORT_NAME = { "pyusb": "usb", # USB访问 "pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异) "psutil": "psutil", # 系统信息和进程管理 - "watchdog": "watchdog", # 文件系统事件监控 "python-gnupg": "gnupg", # GnuPG的Python接口 # ============== 加密与安全 (Cryptography & Security) ============== "pycrypto": "Crypto", # 加密库 (较旧) diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index de644c31b..4ea543f68 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -88,25 +88,27 @@ class MaiZoneRefactoredPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def on_plugin_loaded(self): + """插件加载完成后的回调,初始化服务并启动后台任务""" + # --- 注册权限节点 --- await permission_api.register_permission_node( "plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False ) await permission_api.register_permission_node( "plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True ) - # 创建所有服务实例 + + # --- 创建并注册所有服务实例 --- content_service = ContentService(self.get_config) image_service = ImageService(self.get_config) cookie_service = CookieService(self.get_config) reply_tracker_service = ReplyTrackerService() - # 使用已创建的 reply_tracker_service 实例 qzone_service = QZoneService( self.get_config, content_service, image_service, cookie_service, - reply_tracker_service, # 传入已创建的实例 + reply_tracker_service, ) scheduler_service = SchedulerService(self.get_config, qzone_service) monitor_service = MonitorService(self.get_config, qzone_service) @@ -115,18 +117,12 @@ class MaiZoneRefactoredPlugin(BasePlugin): register_service("reply_tracker", reply_tracker_service) register_service("get_config", self.get_config) - # 保存服务引用以便后续启动 - self.scheduler_service = scheduler_service - self.monitor_service = monitor_service + logger.info("MaiZone重构版插件服务已注册。") - logger.info("MaiZone重构版插件已加载,服务已注册。") - - async def on_plugin_loaded(self): - """插件加载完成后的回调,启动异步服务""" - if hasattr(self, "scheduler_service") and hasattr(self, "monitor_service"): - asyncio.create_task(self.scheduler_service.start()) - asyncio.create_task(self.monitor_service.start()) - logger.info("MaiZone后台任务已启动。") + # --- 启动后台任务 --- + asyncio.create_task(scheduler_service.start()) + asyncio.create_task(monitor_service.start()) + logger.info("MaiZone后台监控和定时任务已启动。") def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: return [ diff --git a/src/plugins/built_in/maizone_refactored/services/cookie_service.py b/src/plugins/built_in/maizone_refactored/services/cookie_service.py index b4aedf322..e3692f883 100644 --- a/src/plugins/built_in/maizone_refactored/services/cookie_service.py +++ b/src/plugins/built_in/maizone_refactored/services/cookie_service.py @@ -113,31 +113,32 @@ class CookieService: async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]: """ 获取Cookie,按以下顺序尝试: - 1. Adapter API - 2. HTTP备用端点 - 3. 本地文件缓存 + 1. HTTP备用端点 (更稳定) + 2. 本地文件缓存 + 3. Adapter API (作为最后手段) """ - # 1. 尝试从Adapter获取 - cookies = await self._get_cookies_from_adapter(stream_id) - if cookies: - logger.info("成功从Adapter获取Cookie。") - self._save_cookies_to_file(qq_account, cookies) - return cookies - - # 2. 尝试从HTTP备用端点获取 - logger.warning("从Adapter获取Cookie失败,尝试使用HTTP备用地址。") + # 1. 尝试从HTTP备用端点获取 + logger.info(f"开始尝试从HTTP备用地址获取 {qq_account} 的Cookie...") cookies = await self._get_cookies_from_http() if cookies: - logger.info("成功从HTTP备用地址获取Cookie。") + logger.info(f"成功从HTTP备用地址为 {qq_account} 获取Cookie。") self._save_cookies_to_file(qq_account, cookies) return cookies - # 3. 尝试从本地文件加载 - logger.warning("从HTTP备用地址获取Cookie失败,尝试加载本地缓存。") + # 2. 尝试从本地文件加载 + logger.warning(f"从HTTP备用地址获取 {qq_account} 的Cookie失败,尝试加载本地缓存。") cookies = self._load_cookies_from_file(qq_account) if cookies: - logger.info("成功从本地文件加载缓存的Cookie。") + logger.info(f"成功从本地文件为 {qq_account} 加载缓存的Cookie。") return cookies - logger.error("所有Cookie获取方法均失败。") + # 3. 尝试从Adapter获取 (作为最后的备用方案) + logger.warning(f"从本地缓存加载 {qq_account} 的Cookie失败,最后尝试使用Adapter API。") + cookies = await self._get_cookies_from_adapter(stream_id) + if cookies: + logger.info(f"成功从Adapter API为 {qq_account} 获取Cookie。") + self._save_cookies_to_file(qq_account, cookies) + return cookies + + logger.error(f"为 {qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作,或存在有效的本地Cookie文件。") return None diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 09e6f5e53..40c0d424a 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -409,8 +409,9 @@ class QZoneService: cookie_dir.mkdir(exist_ok=True) cookie_file_path = cookie_dir / f"cookies-{qq_account}.json" + # 优先尝试通过Napcat HTTP服务获取最新的Cookie try: - # 使用HTTP服务器方式获取Cookie + logger.info("尝试通过Napcat HTTP服务获取Cookie...") host = self.get_config("cookie.http_fallback_host", "172.20.130.55") port = self.get_config("cookie.http_fallback_port", "9999") napcat_token = self.get_config("cookie.napcat_token", "") @@ -421,23 +422,43 @@ class QZoneService: parsed_cookies = { k.strip(): v.strip() for k, v in (p.split("=", 1) for p in cookie_str.split("; ") if "=" in p) } - with open(cookie_file_path, "wb") as f: - f.write(orjson.dumps(parsed_cookies)) - logger.info(f"Cookie已更新并保存至: {cookie_file_path}") + # 成功获取后,异步写入本地文件作为备份 + try: + with open(cookie_file_path, "wb") as f: + f.write(orjson.dumps(parsed_cookies)) + logger.info(f"通过Napcat服务成功更新Cookie,并已保存至: {cookie_file_path}") + except Exception as e: + logger.warning(f"保存Cookie到文件时出错: {e}") return parsed_cookies + else: + logger.warning("通过Napcat服务未能获取有效Cookie。") - # 如果HTTP获取失败,尝试读取本地文件 - if cookie_file_path.exists(): - with open(cookie_file_path, "rb") as f: - return orjson.loads(f.read()) - return None except Exception as e: - logger.error(f"更新或加载Cookie时发生异常: {e}") - return None + logger.warning(f"通过Napcat HTTP服务获取Cookie时发生异常: {e}。将尝试从本地文件加载。") - async def _fetch_cookies_http(self, host: str, port: str, napcat_token: str) -> Optional[Dict]: + # 如果通过服务获取失败,则尝试从本地文件加载 + logger.info("尝试从本地Cookie文件加载...") + if cookie_file_path.exists(): + try: + with open(cookie_file_path, "rb") as f: + cookies = orjson.loads(f.read()) + logger.info(f"成功从本地文件加载Cookie: {cookie_file_path}") + return cookies + except Exception as e: + logger.error(f"从本地文件 {cookie_file_path} 读取或解析Cookie失败: {e}") + else: + logger.warning(f"本地Cookie文件不存在: {cookie_file_path}") + + logger.error("所有获取Cookie的方式均失败。") + return None + + async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> Optional[Dict]: """通过HTTP服务器获取Cookie""" - url = f"http://{host}:{port}/get_cookies" + # 从配置中读取主机和端口,如果未提供则使用传入的参数 + final_host = self.get_config("cookie.http_fallback_host", host) + final_port = self.get_config("cookie.http_fallback_port", port) + url = f"http://{final_host}:{final_port}/get_cookies" + max_retries = 5 retry_delay = 1 @@ -481,14 +502,19 @@ class QZoneService: async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]: cookies = await self.cookie_service.get_cookies(qq_account, stream_id) if not cookies: + logger.error("获取API客户端失败:未能获取到Cookie。请检查Napcat连接是否正常,或是否存在有效的本地Cookie文件。") return None p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper()) if not p_skey: + logger.error(f"获取API客户端失败:Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}") return None gtk = self._generate_gtk(p_skey) uin = cookies.get("uin", "").lstrip("o") + if not uin: + logger.error(f"获取API客户端失败:Cookie中缺少关键的 'uin'。Cookie内容: {cookies}") + return None async def _request(method, url, params=None, data=None, headers=None): final_headers = {"referer": f"https://user.qzone.qq.com/{uin}", "origin": "https://user.qzone.qq.com"} diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index 4e4aa3e10..c8bc267af 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -185,9 +185,13 @@ class SendHandler: logger.info(f"执行适配器命令: {action}") - # 直接向Napcat发送命令并获取响应 - response_task = asyncio.create_task(self.send_message_to_napcat(action, params)) - response = await response_task + # 根据action决定处理方式 + if action == "get_cookies": + # 对于get_cookies,我们需要一个更长的超时时间 + response = await self.send_message_to_napcat(action, params, timeout=40.0) + else: + # 对于其他命令,使用默认超时 + response = await self.send_message_to_napcat(action, params) # 发送响应回MaiBot await self.send_adapter_command_response(raw_message_base, response, request_id) @@ -196,6 +200,8 @@ class SendHandler: logger.info(f"适配器命令 {action} 执行成功") else: logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{str(response)}") + # 无论成功失败,都记录下完整的响应内容以供调试 + logger.debug(f"适配器命令 {action} 的完整响应: {response}") except Exception as e: logger.error(f"处理适配器命令时发生错误: {e}") @@ -583,7 +589,7 @@ class SendHandler: }, ) - async def send_message_to_napcat(self, action: str, params: dict) -> dict: + async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict: request_uuid = str(uuid.uuid4()) payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) @@ -595,9 +601,9 @@ class SendHandler: try: await connection.send(payload) - response = await get_response(request_uuid) + response = await get_response(request_uuid, timeout=timeout) # 使用传入的超时时间 except TimeoutError: - logger.error("发送消息超时,未收到响应") + logger.error(f"发送消息超时({timeout}秒),未收到响应: action={action}, params={params}") return {"status": "error", "message": "timeout"} except Exception as e: logger.error(f"发送消息失败: {e}") diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index c9550500b..0f575e13f 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -15,7 +15,6 @@ from src.plugin_system.base.command_args import CommandArgs from src.plugin_system.base.component_types import PlusCommandInfo, ChatType from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.utils.permission_decorators import require_permission -from src.plugin_system.core.plugin_hot_reload import hot_reload_manager class ManagementCommand(PlusCommand): @@ -78,10 +77,6 @@ class ManagementCommand(PlusCommand): await self._force_reload_plugin(args[1]) elif action in ["add_dir", "添加目录"] and len(args) > 1: await self._add_dir(args[1]) - elif action in ["hotreload_status", "热重载状态"]: - await self._show_hotreload_status() - elif action in ["clear_cache", "清理缓存"]: - await self._clear_all_caches() else: await self.send_text("❌ 插件管理命令不合法\n使用 /pm plugin help 查看帮助") return False, "命令不合法", True @@ -179,14 +174,9 @@ class ManagementCommand(PlusCommand): • `/pm plugin force_reload <插件名>` - 强制重载指定插件(深度清理) • `/pm plugin add_dir <目录路径>` - 添加插件目录 -� 热重载管理: -• `/pm plugin hotreload_status` - 查看热重载状态 -• `/pm plugin clear_cache` - 清理所有模块缓存 - �📝 示例: • `/pm plugin load echo_example` -• `/pm plugin force_reload permission_manager_plugin` -• `/pm plugin clear_cache`""" +• `/pm plugin force_reload permission_manager_plugin`""" elif target == "component": help_msg = """🧩 组件管理命令帮助 @@ -262,7 +252,7 @@ class ManagementCommand(PlusCommand): await self.send_text(f"🔄 开始强制重载插件: `{plugin_name}`...") try: - success = hot_reload_manager.force_reload_plugin(plugin_name) + success = plugin_manage_api.force_reload_plugin(plugin_name) if success: await self.send_text(f"✅ 插件强制重载成功: `{plugin_name}`") else: @@ -270,44 +260,7 @@ class ManagementCommand(PlusCommand): except Exception as e: await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}") - async def _show_hotreload_status(self): - """显示热重载状态""" - try: - status = hot_reload_manager.get_status() - - status_text = f"""🔄 **热重载系统状态** - -🟢 **运行状态:** {"运行中" if status["is_running"] else "已停止"} -📂 **监听目录:** {len(status["watch_directories"])} 个 -👁️ **活跃观察者:** {status["active_observers"]} 个 -📦 **已加载插件:** {status["loaded_plugins"]} 个 -❌ **失败插件:** {status["failed_plugins"]} 个 -⏱️ **防抖延迟:** {status.get("debounce_delay", 0)} 秒 - -📋 **监听的目录:**""" - - for i, watch_dir in enumerate(status["watch_directories"], 1): - dir_type = "(内置插件)" if "src" in watch_dir else "(外部插件)" - status_text += f"\n{i}. `{watch_dir}` {dir_type}" - - if status.get("pending_reloads"): - status_text += f"\n\n⏳ **待重载插件:** {', '.join([f'`{p}`' for p in status['pending_reloads']])}" - - await self.send_text(status_text) - - except Exception as e: - await self.send_text(f"❌ 获取热重载状态时发生错误: {str(e)}") - - async def _clear_all_caches(self): - """清理所有模块缓存""" - await self.send_text("🧹 开始清理所有Python模块缓存...") - - try: - hot_reload_manager.clear_all_caches() - await self.send_text("✅ 模块缓存清理完成!建议重载相关插件以确保生效。") - except Exception as e: - await self.send_text(f"❌ 清理缓存时发生错误: {str(e)}") - + async def _add_dir(self, dir_path: str): """添加插件目录""" await self.send_text(f"📁 正在添加插件目录: `{dir_path}`") diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 53eb6dd0f..7bca6ab08 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -60,10 +60,12 @@ class ReminderTask(AsyncTask): logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒") extra_info = f"现在是提醒时间,请你以一种符合你人设的、俏皮的方式提醒 {self.target_user_name}。\n提醒内容: {self.event_details}\n设置提醒的人: {self.creator_name}" + last_message = self.chat_stream.context_manager.context.get_last_message() + reply_message_dict = last_message.flatten() if last_message else None success, reply_set, _ = await generator_api.generate_reply( chat_stream=self.chat_stream, extra_info=extra_info, - reply_message=self.chat_stream.context_manager.context.get_last_message().to_dict(), + reply_message=reply_message_dict, request_type="plugin.reminder.remind_message", ) @@ -150,9 +152,11 @@ class PokeAction(BaseAction): action_require = ["当需要戳某个用户时使用", "当你想提醒特定用户时使用"] llm_judge_prompt = """ 判定是否需要使用戳一戳动作的条件: - 1. 用户明确要求使用戳一戳。 - 2. 你想以一种有趣的方式提醒或与某人互动。 - 3. 上下文明确需要你戳一个或多个人。 + 1. **关键**: 这是一个高消耗的动作,请仅在绝对必要时使用,例如用户明确要求或作为提醒的关键部分。请极其谨慎地使用。 + 2. **用户请求**: 用户明确要求使用戳一戳。 + 3. **互动提醒**: 你想以一种有趣的方式提醒或与某人互动,但请确保这是对话的自然延伸,而不是无故打扰。 + 4. **上下文需求**: 上下文明确需要你戳一个或多个人。 + 5. **频率限制**: 如果最近已经戳过,或者用户情绪不高,请绝对不要使用。 请回答"是"或"否"。 """ @@ -217,7 +221,6 @@ class SetEmojiLikeAction(BaseAction): emoji_options.append(match.group(1)) action_parameters = { - "emoji": f"要回应的表情,必须从以下表情中选择: {', '.join(emoji_options)}", "set": "是否设置回应 (True/False)", } action_require = [ @@ -238,6 +241,7 @@ class SetEmojiLikeAction(BaseAction): async def execute(self) -> Tuple[bool, str]: """执行设置表情回应的动作""" message_id = None + set_like = self.action_data.get("set", True) if self.has_action_message: logger.debug(str(self.action_message)) if isinstance(self.action_message, dict): @@ -251,24 +255,49 @@ class SetEmojiLikeAction(BaseAction): action_done=False, ) return False, "未提供消息ID" + available_models = llm_api.get_available_models() + if "utils_small" not in available_models: + logger.error("未找到 'utils_small' 模型配置,无法选择表情") + return False, "表情选择功能配置错误" - emoji_input = self.action_data.get("emoji") - set_like = self.action_data.get("set", True) - - if not emoji_input: - logger.error("未提供表情") - return False, "未提供表情" - logger.info(f"设置表情回应: {emoji_input}, 是否设置: {set_like}") - - emoji_id = get_emoji_id(emoji_input) - if not emoji_id: - logger.error(f"找不到表情: '{emoji_input}'。请从可用列表中选择。") - await self.store_action_info( - action_build_into_prompt=True, - action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: 找不到表情: '{emoji_input}'", - action_done=False, + model_to_use = available_models["utils_small"] + + # 获取最近的对话历史作为上下文 + context_text = "" + if self.action_message: + context_text = self.action_message.get("processed_plain_text", "") + else: + logger.error("无法找到动作选择的原始消息") + return False, "无法找到动作选择的原始消息" + + prompt = ( + f"根据以下这条消息,从列表中选择一个最合适的表情名称来回应这条消息。\n" + f"消息内容: '{context_text}'\n" + f"可用表情列表: {', '.join(self.emoji_options)}\n" + f"你的任务是:只输出你选择的表情的名称,不要包含任何其他文字或标点。\n" + f"例如,如果觉得应该用'赞',就只输出'赞'。" ) - return False, f"找不到表情: '{emoji_input}'。请从可用列表中选择。" + + success, response, _, _ = await llm_api.generate_with_model( + prompt, model_config=model_to_use, request_type="plugin.set_emoji_like.select_emoji" + ) + + if not success or not response: + logger.error("二级LLM未能选择有效的表情。") + return False, "无法找到合适的表情。" + + chosen_emoji_name = response.strip() + logger.info(f"二级LLM选择的表情是: '{chosen_emoji_name}'") + emoji_id = get_emoji_id(chosen_emoji_name) + + if not emoji_id: + logger.error(f"二级LLM选择的表情 '{chosen_emoji_name}' 仍然无法匹配到有效的表情ID。") + await self.store_action_info( + action_build_into_prompt=True, + action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: 找不到表情: '{chosen_emoji_name}'", + action_done=False, + ) + return False, f"找不到表情: '{chosen_emoji_name}'。" # 4. 使用适配器API发送命令 if not message_id: @@ -291,7 +320,7 @@ class SetEmojiLikeAction(BaseAction): logger.info("设置表情回应成功") await self.store_action_info( action_build_into_prompt=True, - action_prompt_display=f"执行了set_emoji_like动作,{emoji_input},设置表情回应: {emoji_id}, 是否设置: {set_like}", + action_prompt_display=f"执行了set_emoji_like动作,{chosen_emoji_name},设置表情回应: {emoji_id}, 是否设置: {set_like}", action_done=True, ) return True, "成功设置表情回应" diff --git a/src/schedule/plan_manager.py b/src/schedule/plan_manager.py index d72f55275..513a907d5 100644 --- a/src/schedule/plan_manager.py +++ b/src/schedule/plan_manager.py @@ -28,20 +28,20 @@ class PlanManager: if target_month is None: target_month = datetime.now().strftime("%Y-%m") - if not has_active_plans(target_month): + if not await has_active_plans(target_month): logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。") generation_successful = await self._generate_monthly_plans_logic(target_month) return generation_successful else: logger.info(f"{target_month} 已存在有效的月度计划。") - plans = get_active_plans_for_month(target_month) + plans = await get_active_plans_for_month(target_month) max_plans = global_config.planning_system.max_plans_per_month if len(plans) > max_plans: logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。") plans_to_delete = plans[: len(plans) - max_plans] delete_ids = [p.id for p in plans_to_delete] - delete_plans_by_ids(delete_ids) # type: ignore - plans = get_active_plans_for_month(target_month) + await delete_plans_by_ids(delete_ids) # type: ignore + plans = await get_active_plans_for_month(target_month) if plans: plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)]) @@ -64,11 +64,11 @@ class PlanManager: return False last_month = self._get_previous_month(target_month) - archived_plans = get_archived_plans_for_month(last_month) + archived_plans = await get_archived_plans_for_month(last_month) plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans) if plans: - add_new_plans(plans, target_month) + await add_new_plans(plans, target_month) logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。") return True else: @@ -95,11 +95,11 @@ class PlanManager: if target_month is None: target_month = datetime.now().strftime("%Y-%m") logger.info(f" 开始归档 {target_month} 的活跃月度计划...") - archived_count = archive_active_plans_for_month(target_month) + archived_count = await archive_active_plans_for_month(target_month) logger.info(f" 成功归档了 {archived_count} 条 {target_month} 的月度计划。") except Exception as e: logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}") - def get_plans_for_schedule(self, month: str, max_count: int) -> List: + async def get_plans_for_schedule(self, month: str, max_count: int) -> List: avoid_days = global_config.planning_system.avoid_repetition_days - return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) + return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index d9d94a06d..a8755c09e 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -255,25 +255,34 @@ max_context_emojis = 30 # 每次随机传递给LLM的表情包详细描述的最 [memory] enable_memory = true # 是否启用记忆系统 memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低,MoFox-Bot学习越多,但是冗余信息也会增多 -memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重 -memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多 -memory_build_sample_length = 30 # 采样长度,数值越高一段记忆内容越丰富 -memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多 - -forget_memory_interval = 3000 # 记忆遗忘间隔 单位秒 间隔越低,MoFox-Bot遗忘越频繁,记忆更精简,但更难学习 -memory_forget_time = 48 #多长时间后的记忆会被遗忘 单位小时 -memory_forget_percentage = 0.008 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认 - -consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低,MoFox-Bot整合越频繁,记忆更精简 -consolidation_similarity_threshold = 0.7 # 相似度阈值 -consolidation_check_percentage = 0.05 # 检查节点比例 - -enable_instant_memory = false # 是否启用即时记忆,测试功能,可能存在未知问题 +enable_instant_memory = true # 是否启用即时记忆 enable_llm_instant_memory = true # 是否启用基于LLM的瞬时记忆 enable_vector_instant_memory = true # 是否启用基于向量的瞬时记忆 +enable_enhanced_memory = true # 是否启用增强记忆系统 +enhanced_memory_auto_save = true # 是否自动保存增强记忆 -#不希望记忆的词,已经记忆的不会受到影响,需要手动清理 -memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ] +min_memory_length = 10 # 最小记忆长度 +max_memory_length = 500 # 最大记忆长度 +memory_value_threshold = 0.7 # 记忆价值阈值,低于该值的记忆会被丢弃 +vector_similarity_threshold = 0.8 # 向量相似度阈值 +semantic_similarity_threshold = 0.6 # 语义重排阶段的最低匹配阈值 + +metadata_filter_limit = 100 # 元数据过滤阶段返回数量上限 +vector_search_limit = 50 # 向量搜索阶段返回数量上限 +semantic_rerank_limit = 20 # 语义重排阶段返回数量上限 +final_result_limit = 10 # 综合筛选后的最终返回数量 + +vector_weight = 0.4 # 综合评分中向量相似度的权重 +semantic_weight = 0.3 # 综合评分中语义匹配的权重 +context_weight = 0.2 # 综合评分中上下文关联的权重 +recency_weight = 0.1 # 综合评分中时效性的权重 + +fusion_similarity_threshold = 0.85 # 记忆融合时的相似度阈值 +deduplication_window_hours = 24 # 记忆去重窗口(小时) + +enable_memory_cache = true # 是否启用本地记忆缓存 +cache_ttl_seconds = 300 # 缓存有效期(秒) +max_cache_size = 1000 # 缓存中允许的最大记忆条数 [voice] enable_asr = true # 是否启用语音识别,启用后MoFox-Bot可以识别语音消息,启用该功能需要配置语音识别模型[model.voice] diff --git a/template/model_config_template.toml b/template/model_config_template.toml index 9ee2d5e98..69e992a96 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -203,6 +203,7 @@ max_tokens = 1000 #嵌入模型 [model_task_config.embedding] model_list = ["bge-m3"] +embedding_dimension = 1024