Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -70,7 +70,6 @@ dependencies = [
|
|||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
"urllib3>=2.5.0",
|
"urllib3>=2.5.0",
|
||||||
"uvicorn>=0.35.0",
|
"uvicorn>=0.35.0",
|
||||||
"watchdog>=6.0.0",
|
|
||||||
"websockets>=15.0.1",
|
"websockets>=15.0.1",
|
||||||
"aiomysql>=0.2.0",
|
"aiomysql>=0.2.0",
|
||||||
"aiosqlite>=0.21.0",
|
"aiosqlite>=0.21.0",
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ reportportal-client
|
|||||||
scikit-learn
|
scikit-learn
|
||||||
seaborn
|
seaborn
|
||||||
structlog
|
structlog
|
||||||
watchdog
|
|
||||||
httpx
|
httpx
|
||||||
requests
|
requests
|
||||||
beautifulsoup4
|
beautifulsoup4
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.common.config_helpers import resolve_embedding_dimension
|
||||||
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||||
|
|
||||||
logger = get_logger("bot_interest_manager")
|
logger = get_logger("bot_interest_manager")
|
||||||
@@ -28,7 +29,9 @@ class BotInterestManager:
|
|||||||
# Embedding客户端配置
|
# Embedding客户端配置
|
||||||
self.embedding_request = None
|
self.embedding_request = None
|
||||||
self.embedding_config = None
|
self.embedding_config = None
|
||||||
self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度
|
configured_dim = resolve_embedding_dimension()
|
||||||
|
self.embedding_dimension = int(configured_dim) if configured_dim else 0
|
||||||
|
self._detected_embedding_dimension: Optional[int] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_initialized(self) -> bool:
|
def is_initialized(self) -> bool:
|
||||||
@@ -82,8 +85,11 @@ class BotInterestManager:
|
|||||||
|
|
||||||
logger.info("📋 找到embedding模型配置")
|
logger.info("📋 找到embedding模型配置")
|
||||||
self.embedding_config = model_config.model_task_config.embedding
|
self.embedding_config = model_config.model_task_config.embedding
|
||||||
self.embedding_dimension = 1024 # BGE-M3的维度
|
|
||||||
logger.info(f"📐 使用模型维度: {self.embedding_dimension}")
|
if self.embedding_dimension:
|
||||||
|
logger.info(f"📐 配置的embedding维度: {self.embedding_dimension}")
|
||||||
|
else:
|
||||||
|
logger.info("📐 未在配置中检测到embedding维度,将根据首次返回的向量自动识别")
|
||||||
|
|
||||||
# 创建LLMRequest实例用于embedding
|
# 创建LLMRequest实例用于embedding
|
||||||
self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding")
|
self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding")
|
||||||
@@ -350,7 +356,27 @@ class BotInterestManager:
|
|||||||
|
|
||||||
if embedding and len(embedding) > 0:
|
if embedding and len(embedding) > 0:
|
||||||
self.embedding_cache[text] = embedding
|
self.embedding_cache[text] = embedding
|
||||||
logger.debug(f"✅ Embedding获取成功,维度: {len(embedding)}, 模型: {model_name}")
|
|
||||||
|
current_dim = len(embedding)
|
||||||
|
if self._detected_embedding_dimension is None:
|
||||||
|
self._detected_embedding_dimension = current_dim
|
||||||
|
if self.embedding_dimension and self.embedding_dimension != current_dim:
|
||||||
|
logger.warning(
|
||||||
|
"⚠️ 实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新",
|
||||||
|
current_dim,
|
||||||
|
self.embedding_dimension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embedding_dimension = current_dim
|
||||||
|
logger.info(f"📏 检测到embedding维度: {current_dim}")
|
||||||
|
elif current_dim != self.embedding_dimension:
|
||||||
|
logger.warning(
|
||||||
|
"⚠️ 收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。",
|
||||||
|
self.embedding_dimension,
|
||||||
|
current_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"✅ Embedding获取成功,维度: {current_dim}, 模型: {model_name}")
|
||||||
return embedding
|
return embedding
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from rich.progress import (
|
|||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.common.config_helpers import resolve_embedding_dimension
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -504,7 +505,10 @@ class EmbeddingStore:
|
|||||||
# L2归一化
|
# L2归一化
|
||||||
faiss.normalize_L2(embeddings)
|
faiss.normalize_L2(embeddings)
|
||||||
# 构建索引
|
# 构建索引
|
||||||
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
|
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
||||||
|
if not embedding_dim:
|
||||||
|
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||||
|
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||||
self.faiss_index.add(embeddings)
|
self.faiss_index.add(embeddings)
|
||||||
|
|
||||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||||
|
|||||||
@@ -11,12 +11,27 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
|
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
|
||||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
MEMORY_TYPE_LABELS = {
|
||||||
|
MemoryType.PERSONAL_FACT: "个人事实",
|
||||||
|
MemoryType.EVENT: "事件",
|
||||||
|
MemoryType.PREFERENCE: "偏好",
|
||||||
|
MemoryType.OPINION: "观点",
|
||||||
|
MemoryType.RELATIONSHIP: "关系",
|
||||||
|
MemoryType.EMOTION: "情感",
|
||||||
|
MemoryType.KNOWLEDGE: "知识",
|
||||||
|
MemoryType.SKILL: "技能",
|
||||||
|
MemoryType.GOAL: "目标",
|
||||||
|
MemoryType.EXPERIENCE: "经验",
|
||||||
|
MemoryType.CONTEXTUAL: "上下文",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdapterConfig:
|
class AdapterConfig:
|
||||||
"""适配器配置"""
|
"""适配器配置"""
|
||||||
@@ -85,12 +100,9 @@ class EnhancedMemoryAdapter:
|
|||||||
|
|
||||||
async def process_conversation_memory(
|
async def process_conversation_memory(
|
||||||
self,
|
self,
|
||||||
conversation_text: str,
|
context: Optional[Dict[str, Any]] = None
|
||||||
context: Dict[str, Any],
|
|
||||||
user_id: str,
|
|
||||||
timestamp: Optional[float] = None
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""处理对话记忆"""
|
"""处理对话记忆,以上下文为唯一输入"""
|
||||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||||
return {"success": False, "error": "Enhanced memory not available"}
|
return {"success": False, "error": "Enhanced memory not available"}
|
||||||
|
|
||||||
@@ -98,10 +110,30 @@ class EnhancedMemoryAdapter:
|
|||||||
self.adapter_stats["total_processed"] += 1
|
self.adapter_stats["total_processed"] += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
payload_context: Dict[str, Any] = dict(context or {})
|
||||||
|
|
||||||
|
conversation_text = payload_context.get("conversation_text")
|
||||||
|
if not conversation_text:
|
||||||
|
conversation_candidate = (
|
||||||
|
payload_context.get("message_content")
|
||||||
|
or payload_context.get("latest_message")
|
||||||
|
or payload_context.get("raw_text")
|
||||||
|
)
|
||||||
|
if conversation_candidate is not None:
|
||||||
|
conversation_text = str(conversation_candidate)
|
||||||
|
payload_context["conversation_text"] = conversation_text
|
||||||
|
else:
|
||||||
|
conversation_text = ""
|
||||||
|
else:
|
||||||
|
conversation_text = str(conversation_text)
|
||||||
|
|
||||||
|
if "timestamp" not in payload_context:
|
||||||
|
payload_context["timestamp"] = time.time()
|
||||||
|
|
||||||
|
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
|
||||||
|
|
||||||
# 使用集成层处理对话
|
# 使用集成层处理对话
|
||||||
result = await self.integration_layer.process_conversation(
|
result = await self.integration_layer.process_conversation(payload_context)
|
||||||
conversation_text, context, user_id, timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新统计
|
# 更新统计
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
@@ -132,7 +164,7 @@ class EnhancedMemoryAdapter:
|
|||||||
try:
|
try:
|
||||||
limit = limit or self.config.max_retrieval_results
|
limit = limit or self.config.max_retrieval_results
|
||||||
memories = await self.integration_layer.retrieve_relevant_memories(
|
memories = await self.integration_layer.retrieve_relevant_memories(
|
||||||
query, user_id, context, limit
|
query, None, context, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
self.adapter_stats["memories_retrieved"] += len(memories)
|
self.adapter_stats["memories_retrieved"] += len(memories)
|
||||||
@@ -157,12 +189,15 @@ class EnhancedMemoryAdapter:
|
|||||||
if not memories:
|
if not memories:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 格式化记忆为提示词友好的格式
|
# 格式化记忆为提示词友好的Markdown结构
|
||||||
memory_context_parts = []
|
lines: List[str] = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
||||||
for memory in memories:
|
|
||||||
memory_context_parts.append(f"- {memory.text_content}")
|
|
||||||
|
|
||||||
return "\n".join(memory_context_parts)
|
for memory in memories:
|
||||||
|
type_label = MEMORY_TYPE_LABELS.get(memory.memory_type, memory.memory_type.value)
|
||||||
|
display_text = memory.display or memory.text_content
|
||||||
|
lines.append(f"- **[{type_label}]** {display_text}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
|
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
|
||||||
"""获取增强记忆系统摘要"""
|
"""获取增强记忆系统摘要"""
|
||||||
@@ -270,13 +305,10 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
|||||||
|
|
||||||
|
|
||||||
async def process_conversation_with_enhanced_memory(
|
async def process_conversation_with_enhanced_memory(
|
||||||
conversation_text: str,
|
|
||||||
context: Dict[str, Any],
|
context: Dict[str, Any],
|
||||||
user_id: str,
|
|
||||||
timestamp: Optional[float] = None,
|
|
||||||
llm_model: Optional[LLMRequest] = None
|
llm_model: Optional[LLMRequest] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""使用增强记忆系统处理对话"""
|
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
|
||||||
if not llm_model:
|
if not llm_model:
|
||||||
# 获取默认的LLM模型
|
# 获取默认的LLM模型
|
||||||
from src.llm_models.utils_model import get_global_llm_model
|
from src.llm_models.utils_model import get_global_llm_model
|
||||||
@@ -284,7 +316,18 @@ async def process_conversation_with_enhanced_memory(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
adapter = await get_enhanced_memory_adapter(llm_model)
|
adapter = await get_enhanced_memory_adapter(llm_model)
|
||||||
return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp)
|
payload_context = dict(context or {})
|
||||||
|
|
||||||
|
if "conversation_text" not in payload_context:
|
||||||
|
conversation_candidate = (
|
||||||
|
payload_context.get("message_content")
|
||||||
|
or payload_context.get("latest_message")
|
||||||
|
or payload_context.get("raw_text")
|
||||||
|
)
|
||||||
|
if conversation_candidate is not None:
|
||||||
|
payload_context["conversation_text"] = str(conversation_candidate)
|
||||||
|
|
||||||
|
return await adapter.process_conversation_memory(payload_context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
|
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
增强型精准记忆系统核心模块
|
增强型精准记忆系统核心模块
|
||||||
基于文档设计的高效记忆构建、存储与召回优化系统
|
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
|
||||||
|
2. 内置 LLM 查询规划器与嵌入维度自动解析机制,直接从模型配置推断向量存储参数。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import orjson
|
import orjson
|
||||||
import re
|
import re
|
||||||
|
import hashlib
|
||||||
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
@@ -22,12 +24,16 @@ from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
|||||||
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
||||||
from src.chat.memory_system.metadata_index import MetadataIndexManager
|
from src.chat.memory_system.metadata_index import MetadataIndexManager
|
||||||
from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig
|
from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig
|
||||||
|
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# 全局记忆作用域(共享记忆库)
|
||||||
|
GLOBAL_MEMORY_SCOPE = "global"
|
||||||
|
|
||||||
|
|
||||||
class MemorySystemStatus(Enum):
|
class MemorySystemStatus(Enum):
|
||||||
"""记忆系统状态"""
|
"""记忆系统状态"""
|
||||||
@@ -47,14 +53,20 @@ class MemorySystemConfig:
|
|||||||
memory_value_threshold: float = 0.7
|
memory_value_threshold: float = 0.7
|
||||||
min_build_interval_seconds: float = 300.0
|
min_build_interval_seconds: float = 300.0
|
||||||
|
|
||||||
# 向量存储配置
|
# 向量存储配置(嵌入维度自动来自模型配置)
|
||||||
vector_dimension: int = 768
|
vector_dimension: int = 1024
|
||||||
similarity_threshold: float = 0.8
|
similarity_threshold: float = 0.8
|
||||||
|
|
||||||
# 召回配置
|
# 召回配置
|
||||||
coarse_recall_limit: int = 50
|
coarse_recall_limit: int = 50
|
||||||
fine_recall_limit: int = 10
|
fine_recall_limit: int = 10
|
||||||
|
semantic_rerank_limit: int = 20
|
||||||
final_recall_limit: int = 5
|
final_recall_limit: int = 5
|
||||||
|
semantic_similarity_threshold: float = 0.6
|
||||||
|
vector_weight: float = 0.4
|
||||||
|
semantic_weight: float = 0.3
|
||||||
|
context_weight: float = 0.2
|
||||||
|
recency_weight: float = 0.1
|
||||||
|
|
||||||
# 融合配置
|
# 融合配置
|
||||||
fusion_similarity_threshold: float = 0.85
|
fusion_similarity_threshold: float = 0.85
|
||||||
@@ -64,6 +76,23 @@ class MemorySystemConfig:
|
|||||||
def from_global_config(cls):
|
def from_global_config(cls):
|
||||||
"""从全局配置创建配置实例"""
|
"""从全局配置创建配置实例"""
|
||||||
|
|
||||||
|
embedding_dimension = None
|
||||||
|
try:
|
||||||
|
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||||
|
if embedding_task is not None:
|
||||||
|
embedding_dimension = getattr(embedding_task, "embedding_dimension", None)
|
||||||
|
except Exception:
|
||||||
|
embedding_dimension = None
|
||||||
|
|
||||||
|
if not embedding_dimension:
|
||||||
|
try:
|
||||||
|
embedding_dimension = getattr(global_config.lpmm_knowledge, "embedding_dimension", None)
|
||||||
|
except Exception:
|
||||||
|
embedding_dimension = None
|
||||||
|
|
||||||
|
if not embedding_dimension:
|
||||||
|
embedding_dimension = 1024
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
# 记忆构建配置
|
# 记忆构建配置
|
||||||
min_memory_length=global_config.memory.min_memory_length,
|
min_memory_length=global_config.memory.min_memory_length,
|
||||||
@@ -72,13 +101,19 @@ class MemorySystemConfig:
|
|||||||
min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0),
|
min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0),
|
||||||
|
|
||||||
# 向量存储配置
|
# 向量存储配置
|
||||||
vector_dimension=global_config.memory.vector_dimension,
|
vector_dimension=int(embedding_dimension),
|
||||||
similarity_threshold=global_config.memory.vector_similarity_threshold,
|
similarity_threshold=global_config.memory.vector_similarity_threshold,
|
||||||
|
|
||||||
# 召回配置
|
# 召回配置
|
||||||
coarse_recall_limit=global_config.memory.metadata_filter_limit,
|
coarse_recall_limit=global_config.memory.metadata_filter_limit,
|
||||||
fine_recall_limit=global_config.memory.final_result_limit,
|
fine_recall_limit=global_config.memory.vector_search_limit,
|
||||||
|
semantic_rerank_limit=global_config.memory.semantic_rerank_limit,
|
||||||
final_recall_limit=global_config.memory.final_result_limit,
|
final_recall_limit=global_config.memory.final_result_limit,
|
||||||
|
semantic_similarity_threshold=getattr(global_config.memory, "semantic_similarity_threshold", 0.6),
|
||||||
|
vector_weight=global_config.memory.vector_weight,
|
||||||
|
semantic_weight=global_config.memory.semantic_weight,
|
||||||
|
context_weight=global_config.memory.context_weight,
|
||||||
|
recency_weight=global_config.memory.recency_weight,
|
||||||
|
|
||||||
# 融合配置
|
# 融合配置
|
||||||
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
|
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
|
||||||
@@ -104,6 +139,7 @@ class EnhancedMemorySystem:
|
|||||||
self.vector_storage: VectorStorageManager = None
|
self.vector_storage: VectorStorageManager = None
|
||||||
self.metadata_index: MetadataIndexManager = None
|
self.metadata_index: MetadataIndexManager = None
|
||||||
self.retrieval_system: MultiStageRetrieval = None
|
self.retrieval_system: MultiStageRetrieval = None
|
||||||
|
self.query_planner: MemoryQueryPlanner = None
|
||||||
|
|
||||||
# LLM模型
|
# LLM模型
|
||||||
self.value_assessment_model: LLMRequest = None
|
self.value_assessment_model: LLMRequest = None
|
||||||
@@ -117,6 +153,9 @@ class EnhancedMemorySystem:
|
|||||||
# 构建节流记录
|
# 构建节流记录
|
||||||
self._last_memory_build_times: Dict[str, float] = {}
|
self._last_memory_build_times: Dict[str, float] = {}
|
||||||
|
|
||||||
|
# 记忆指纹缓存,用于快速检测重复记忆
|
||||||
|
self._memory_fingerprints: Dict[str, str] = {}
|
||||||
|
|
||||||
logger.info("EnhancedMemorySystem 初始化开始")
|
logger.info("EnhancedMemorySystem 初始化开始")
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
@@ -125,19 +164,29 @@ class EnhancedMemorySystem:
|
|||||||
logger.info("正在初始化增强型记忆系统...")
|
logger.info("正在初始化增强型记忆系统...")
|
||||||
|
|
||||||
# 初始化LLM模型
|
# 初始化LLM模型
|
||||||
task_config = (
|
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||||
self.llm_model.model_for_task
|
|
||||||
if self.llm_model is not None
|
value_task_config = getattr(model_config.model_task_config, "utils_small", None)
|
||||||
else model_config.model_task_config.utils_small
|
extraction_task_config = getattr(model_config.model_task_config, "utils", None)
|
||||||
)
|
|
||||||
|
if value_task_config is None:
|
||||||
|
logger.warning("未找到 utils_small 模型配置,回退到 utils 或外部提供的模型配置。")
|
||||||
|
value_task_config = extraction_task_config or fallback_task
|
||||||
|
|
||||||
|
if extraction_task_config is None:
|
||||||
|
logger.warning("未找到 utils 模型配置,回退到 utils_small 或外部提供的模型配置。")
|
||||||
|
extraction_task_config = value_task_config or fallback_task
|
||||||
|
|
||||||
|
if value_task_config is None or extraction_task_config is None:
|
||||||
|
raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。")
|
||||||
|
|
||||||
self.value_assessment_model = LLMRequest(
|
self.value_assessment_model = LLMRequest(
|
||||||
model_set=task_config,
|
model_set=value_task_config,
|
||||||
request_type="memory.value_assessment"
|
request_type="memory.value_assessment"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.memory_extraction_model = LLMRequest(
|
self.memory_extraction_model = LLMRequest(
|
||||||
model_set=task_config,
|
model_set=extraction_task_config,
|
||||||
request_type="memory.extraction"
|
request_type="memory.extraction"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -155,13 +204,36 @@ class EnhancedMemorySystem:
|
|||||||
retrieval_config = RetrievalConfig(
|
retrieval_config = RetrievalConfig(
|
||||||
metadata_filter_limit=self.config.coarse_recall_limit,
|
metadata_filter_limit=self.config.coarse_recall_limit,
|
||||||
vector_search_limit=self.config.fine_recall_limit,
|
vector_search_limit=self.config.fine_recall_limit,
|
||||||
final_result_limit=self.config.final_recall_limit
|
semantic_rerank_limit=self.config.semantic_rerank_limit,
|
||||||
|
final_result_limit=self.config.final_recall_limit,
|
||||||
|
vector_similarity_threshold=self.config.similarity_threshold,
|
||||||
|
semantic_similarity_threshold=self.config.semantic_similarity_threshold,
|
||||||
|
vector_weight=self.config.vector_weight,
|
||||||
|
semantic_weight=self.config.semantic_weight,
|
||||||
|
context_weight=self.config.context_weight,
|
||||||
|
recency_weight=self.config.recency_weight,
|
||||||
)
|
)
|
||||||
self.retrieval_system = MultiStageRetrieval(retrieval_config)
|
self.retrieval_system = MultiStageRetrieval(retrieval_config)
|
||||||
|
|
||||||
|
planner_task_config = getattr(model_config.model_task_config, "planner", None)
|
||||||
|
planner_model: Optional[LLMRequest] = None
|
||||||
|
try:
|
||||||
|
planner_model = LLMRequest(
|
||||||
|
model_set=planner_task_config,
|
||||||
|
request_type="memory.query_planner"
|
||||||
|
)
|
||||||
|
except Exception as planner_exc:
|
||||||
|
logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True)
|
||||||
|
|
||||||
|
self.query_planner = MemoryQueryPlanner(
|
||||||
|
planner_model,
|
||||||
|
default_limit=self.config.final_recall_limit
|
||||||
|
)
|
||||||
|
|
||||||
# 加载持久化数据
|
# 加载持久化数据
|
||||||
await self.vector_storage.load_storage()
|
await self.vector_storage.load_storage()
|
||||||
await self.metadata_index.load_index()
|
await self.metadata_index.load_index()
|
||||||
|
self._populate_memory_fingerprints()
|
||||||
|
|
||||||
self.status = MemorySystemStatus.READY
|
self.status = MemorySystemStatus.READY
|
||||||
logger.info("✅ 增强型记忆系统初始化完成")
|
logger.info("✅ 增强型记忆系统初始化完成")
|
||||||
@@ -174,7 +246,7 @@ class EnhancedMemorySystem:
|
|||||||
async def retrieve_memories_for_building(
|
async def retrieve_memories_for_building(
|
||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
user_id: str,
|
user_id: Optional[str] = None,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
limit: int = 5
|
limit: int = 5
|
||||||
) -> List[MemoryChunk]:
|
) -> List[MemoryChunk]:
|
||||||
@@ -182,7 +254,6 @@ class EnhancedMemorySystem:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
user_id: 用户ID
|
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
limit: 返回结果数量限制
|
limit: 返回结果数量限制
|
||||||
|
|
||||||
@@ -201,7 +272,6 @@ class EnhancedMemorySystem:
|
|||||||
# 执行检索
|
# 执行检索
|
||||||
memories = await self.vector_storage.search_similar_memories(
|
memories = await self.vector_storage.search_similar_memories(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
user_id=user_id,
|
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -218,23 +288,18 @@ class EnhancedMemorySystem:
|
|||||||
self,
|
self,
|
||||||
conversation_text: str,
|
conversation_text: str,
|
||||||
context: Dict[str, Any],
|
context: Dict[str, Any],
|
||||||
user_id: str,
|
|
||||||
timestamp: Optional[float] = None
|
timestamp: Optional[float] = None
|
||||||
) -> List[MemoryChunk]:
|
) -> List[MemoryChunk]:
|
||||||
"""从对话中构建记忆
|
"""从对话中构建记忆
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation_text: 对话文本
|
conversation_text: 对话文本
|
||||||
context: 上下文信息(包括用户信息、群组信息等)
|
context: 上下文信息
|
||||||
user_id: 用户ID
|
|
||||||
timestamp: 时间戳,默认为当前时间
|
timestamp: 时间戳,默认为当前时间
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
构建的记忆块列表
|
构建的记忆块列表
|
||||||
"""
|
"""
|
||||||
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
|
|
||||||
raise RuntimeError("记忆系统未就绪")
|
|
||||||
|
|
||||||
original_status = self.status
|
original_status = self.status
|
||||||
self.status = MemorySystemStatus.BUILDING
|
self.status = MemorySystemStatus.BUILDING
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -243,9 +308,9 @@ class EnhancedMemorySystem:
|
|||||||
build_marker_time: Optional[float] = None
|
build_marker_time: Optional[float] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
normalized_context = self._normalize_context(context, user_id, timestamp)
|
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
|
||||||
|
|
||||||
build_scope_key = self._get_build_scope_key(normalized_context, user_id)
|
build_scope_key = self._get_build_scope_key(normalized_context, GLOBAL_MEMORY_SCOPE)
|
||||||
min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0))
|
min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0))
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
@@ -266,7 +331,7 @@ class EnhancedMemorySystem:
|
|||||||
|
|
||||||
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
|
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
|
||||||
|
|
||||||
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
|
logger.debug("开始构建记忆,文本长度: %d", len(conversation_text))
|
||||||
|
|
||||||
# 1. 信息价值评估
|
# 1. 信息价值评估
|
||||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||||
@@ -280,7 +345,7 @@ class EnhancedMemorySystem:
|
|||||||
memory_chunks = await self.memory_builder.build_memories(
|
memory_chunks = await self.memory_builder.build_memories(
|
||||||
conversation_text,
|
conversation_text,
|
||||||
normalized_context,
|
normalized_context,
|
||||||
user_id,
|
GLOBAL_MEMORY_SCOPE,
|
||||||
timestamp or time.time()
|
timestamp or time.time()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -293,19 +358,24 @@ class EnhancedMemorySystem:
|
|||||||
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks)
|
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks)
|
||||||
|
|
||||||
# 4. 存储记忆
|
# 4. 存储记忆
|
||||||
await self._store_memories(fused_chunks)
|
stored_count = await self._store_memories(fused_chunks)
|
||||||
|
|
||||||
# 4.1 控制台预览
|
# 4.1 控制台预览
|
||||||
self._log_memory_preview(fused_chunks)
|
self._log_memory_preview(fused_chunks)
|
||||||
|
|
||||||
# 5. 更新统计
|
# 5. 更新统计
|
||||||
self.total_memories += len(fused_chunks)
|
self.total_memories += stored_count
|
||||||
self.last_build_time = time.time()
|
self.last_build_time = time.time()
|
||||||
if build_scope_key:
|
if build_scope_key:
|
||||||
self._last_memory_build_times[build_scope_key] = self.last_build_time
|
self._last_memory_build_times[build_scope_key] = self.last_build_time
|
||||||
|
|
||||||
build_time = time.time() - start_time
|
build_time = time.time() - start_time
|
||||||
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒")
|
logger.info(
|
||||||
|
"✅ 生成 %d 条记忆,成功入库 %d 条,耗时 %.2f秒",
|
||||||
|
len(fused_chunks),
|
||||||
|
stored_count,
|
||||||
|
build_time,
|
||||||
|
)
|
||||||
|
|
||||||
self.status = original_status
|
self.status = original_status
|
||||||
return fused_chunks
|
return fused_chunks
|
||||||
@@ -347,21 +417,34 @@ class EnhancedMemorySystem:
|
|||||||
|
|
||||||
async def process_conversation_memory(
|
async def process_conversation_memory(
|
||||||
self,
|
self,
|
||||||
conversation_text: str,
|
context: Dict[str, Any]
|
||||||
context: Dict[str, Any],
|
|
||||||
user_id: str,
|
|
||||||
timestamp: Optional[float] = None
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""对外暴露的对话记忆处理接口,兼容旧调用方式"""
|
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
normalized_context = self._normalize_context(context, user_id, timestamp)
|
context = dict(context or {})
|
||||||
|
|
||||||
|
conversation_candidate = (
|
||||||
|
context.get("conversation_text")
|
||||||
|
or context.get("message_content")
|
||||||
|
or context.get("latest_message")
|
||||||
|
or context.get("raw_text")
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
|
||||||
|
|
||||||
|
timestamp = context.get("timestamp")
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
|
||||||
|
normalized_context.setdefault("conversation_text", conversation_text)
|
||||||
|
|
||||||
memories = await self.build_memory_from_conversation(
|
memories = await self.build_memory_from_conversation(
|
||||||
conversation_text=conversation_text,
|
conversation_text=conversation_text,
|
||||||
context=normalized_context,
|
context=normalized_context,
|
||||||
user_id=user_id,
|
|
||||||
timestamp=timestamp
|
timestamp=timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -395,52 +478,77 @@ class EnhancedMemorySystem:
|
|||||||
**kwargs
|
**kwargs
|
||||||
) -> List[MemoryChunk]:
|
) -> List[MemoryChunk]:
|
||||||
"""检索相关记忆,兼容 query/query_text 参数形式"""
|
"""检索相关记忆,兼容 query/query_text 参数形式"""
|
||||||
if self.status != MemorySystemStatus.READY:
|
raw_query = query_text or kwargs.get("query")
|
||||||
raise RuntimeError("记忆系统未就绪")
|
if not raw_query:
|
||||||
|
|
||||||
query_text = query_text or kwargs.get("query")
|
|
||||||
if not query_text:
|
|
||||||
raise ValueError("query_text 或 query 参数不能为空")
|
raise ValueError("query_text 或 query 参数不能为空")
|
||||||
|
|
||||||
context = context or {}
|
context = context or {}
|
||||||
user_id = user_id or kwargs.get("user_id")
|
resolved_user_id = GLOBAL_MEMORY_SCOPE
|
||||||
|
|
||||||
|
if self.retrieval_system is None or self.metadata_index is None:
|
||||||
|
raise RuntimeError("检索组件未初始化")
|
||||||
|
|
||||||
|
all_memories_cache = self.vector_storage.memory_cache
|
||||||
|
if not all_memories_cache:
|
||||||
|
logger.debug("记忆缓存为空,返回空结果")
|
||||||
|
self.last_retrieval_time = time.time()
|
||||||
|
self.status = MemorySystemStatus.READY
|
||||||
|
return []
|
||||||
|
|
||||||
self.status = MemorySystemStatus.RETRIEVING
|
self.status = MemorySystemStatus.RETRIEVING
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
normalized_context = self._normalize_context(context, user_id, None)
|
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
|
||||||
|
|
||||||
candidate_memories = list(self.vector_storage.memory_cache.values())
|
effective_limit = limit or self.config.final_recall_limit
|
||||||
if user_id:
|
query_plan = None
|
||||||
candidate_memories = [m for m in candidate_memories if m.user_id == user_id]
|
planner_ran = False
|
||||||
|
resolved_query_text = raw_query
|
||||||
|
if self.query_planner:
|
||||||
|
try:
|
||||||
|
planner_ran = True
|
||||||
|
query_plan = await self.query_planner.plan_query(raw_query, normalized_context)
|
||||||
|
normalized_context["query_plan"] = query_plan
|
||||||
|
effective_limit = min(effective_limit, query_plan.limit or effective_limit)
|
||||||
|
if getattr(query_plan, "semantic_query", None):
|
||||||
|
resolved_query_text = query_plan.semantic_query
|
||||||
|
logger.debug(
|
||||||
|
"查询规划: semantic='%s', types=%s, subjects=%s, limit=%d",
|
||||||
|
query_plan.semantic_query,
|
||||||
|
[mt.value for mt in query_plan.memory_types],
|
||||||
|
query_plan.subject_includes,
|
||||||
|
query_plan.limit,
|
||||||
|
)
|
||||||
|
except Exception as plan_exc:
|
||||||
|
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
|
||||||
|
|
||||||
if not candidate_memories:
|
effective_limit = effective_limit or self.config.final_recall_limit
|
||||||
self.status = MemorySystemStatus.READY
|
effective_limit = max(1, min(effective_limit, self.config.final_recall_limit))
|
||||||
self.last_retrieval_time = time.time()
|
normalized_context["resolved_query_text"] = resolved_query_text
|
||||||
logger.debug(f"未找到用户 {user_id} 的候选记忆")
|
|
||||||
return []
|
|
||||||
|
|
||||||
scored_memories = []
|
if normalized_context.get("__memory_building__"):
|
||||||
for memory in candidate_memories:
|
logger.debug("当前处于记忆构建流程,跳过查询规划并进行降级检索")
|
||||||
score = self._compute_memory_score(query_text, memory, normalized_context)
|
self.status = MemorySystemStatus.BUILDING
|
||||||
if score > 0:
|
final_memories = []
|
||||||
scored_memories.append((memory, score))
|
candidate_memories = list(all_memories_cache.values())
|
||||||
|
|
||||||
if not scored_memories:
|
|
||||||
# 如果所有分数为0,返回最近的记忆作为降级策略
|
|
||||||
candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True)
|
candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True)
|
||||||
scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]]
|
final_memories = candidate_memories[:effective_limit]
|
||||||
else:
|
else:
|
||||||
scored_memories.sort(key=lambda item: item[1], reverse=True)
|
retrieval_result = await self.retrieval_system.retrieve_memories(
|
||||||
|
query=resolved_query_text,
|
||||||
|
user_id=resolved_user_id,
|
||||||
|
context=normalized_context,
|
||||||
|
metadata_index=self.metadata_index,
|
||||||
|
vector_storage=self.vector_storage,
|
||||||
|
all_memories_cache=all_memories_cache,
|
||||||
|
limit=effective_limit,
|
||||||
|
)
|
||||||
|
|
||||||
top_memories = [memory for memory, _ in scored_memories[:limit]]
|
final_memories = retrieval_result.final_memories
|
||||||
|
|
||||||
# 更新访问信息和缓存
|
for memory in final_memories:
|
||||||
for memory, score in scored_memories[:limit]:
|
|
||||||
memory.update_access()
|
memory.update_access()
|
||||||
memory.update_relevance(score)
|
|
||||||
|
|
||||||
cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id)
|
cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id)
|
||||||
if cache_entry is not None:
|
if cache_entry is not None:
|
||||||
cache_entry["last_accessed"] = memory.metadata.last_accessed
|
cache_entry["last_accessed"] = memory.metadata.last_accessed
|
||||||
@@ -448,14 +556,34 @@ class EnhancedMemorySystem:
|
|||||||
cache_entry["relevance_score"] = memory.metadata.relevance_score
|
cache_entry["relevance_score"] = memory.metadata.relevance_score
|
||||||
|
|
||||||
retrieval_time = time.time() - start_time
|
retrieval_time = time.time() - start_time
|
||||||
logger.info(
|
plan_summary = ""
|
||||||
f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}秒"
|
if planner_ran and query_plan:
|
||||||
|
plan_types = ",".join(mt.value for mt in query_plan.memory_types) or "-"
|
||||||
|
plan_subjects = ",".join(query_plan.subject_includes) or "-"
|
||||||
|
plan_summary = (
|
||||||
|
f" | planner.semantic='{query_plan.semantic_query}'"
|
||||||
|
f" | planner.limit={query_plan.limit}"
|
||||||
|
f" | planner.types={plan_types}"
|
||||||
|
f" | planner.subjects={plan_subjects}"
|
||||||
|
)
|
||||||
|
|
||||||
|
log_message = (
|
||||||
|
"✅ 记忆检索完成"
|
||||||
|
f" | user={resolved_user_id}"
|
||||||
|
f" | count={len(final_memories)}"
|
||||||
|
f" | duration={retrieval_time:.3f}s"
|
||||||
|
f" | applied_limit={effective_limit}"
|
||||||
|
f" | raw_query='{raw_query}'"
|
||||||
|
f" | semantic_query='{resolved_query_text}'"
|
||||||
|
f"{plan_summary}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(log_message)
|
||||||
|
|
||||||
self.last_retrieval_time = time.time()
|
self.last_retrieval_time = time.time()
|
||||||
self.status = MemorySystemStatus.READY
|
self.status = MemorySystemStatus.READY
|
||||||
|
|
||||||
return top_memories
|
return final_memories
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.status = MemorySystemStatus.ERROR
|
self.status = MemorySystemStatus.ERROR
|
||||||
@@ -499,8 +627,8 @@ class EnhancedMemorySystem:
|
|||||||
except Exception:
|
except Exception:
|
||||||
context = dict(raw_context or {})
|
context = dict(raw_context or {})
|
||||||
|
|
||||||
# 基础字段
|
# 基础字段(统一使用全局作用域)
|
||||||
context["user_id"] = context.get("user_id") or user_id or "unknown"
|
context["user_id"] = GLOBAL_MEMORY_SCOPE
|
||||||
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
|
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
|
||||||
context["message_type"] = context.get("message_type") or "normal"
|
context["message_type"] = context.get("message_type") or "normal"
|
||||||
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
|
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
|
||||||
@@ -523,8 +651,8 @@ class EnhancedMemorySystem:
|
|||||||
if stream_id:
|
if stream_id:
|
||||||
context["stream_id"] = stream_id
|
context["stream_id"] = stream_id
|
||||||
|
|
||||||
# chat_id 兜底
|
# 全局记忆无需聊天隔离
|
||||||
context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}"
|
context["chat_id"] = context.get("chat_id") or "global_chat"
|
||||||
|
|
||||||
# 历史窗口配置
|
# 历史窗口配置
|
||||||
window_candidate = (
|
window_candidate = (
|
||||||
@@ -616,18 +744,7 @@ class EnhancedMemorySystem:
|
|||||||
|
|
||||||
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
||||||
"""确定用于节流控制的记忆构建作用域"""
|
"""确定用于节流控制的记忆构建作用域"""
|
||||||
stream_id = context.get("stream_id")
|
return "global_scope"
|
||||||
if stream_id:
|
|
||||||
return f"stream::{stream_id}"
|
|
||||||
|
|
||||||
chat_id = context.get("chat_id")
|
|
||||||
if chat_id:
|
|
||||||
return f"chat::{chat_id}"
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
return f"user::{user_id}"
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
|
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
|
||||||
"""确定历史消息获取数量,限制在30-50之间"""
|
"""确定历史消息获取数量,限制在30-50之间"""
|
||||||
@@ -789,24 +906,134 @@ class EnhancedMemorySystem:
|
|||||||
logger.error(f"信息价值评估失败: {e}", exc_info=True)
|
logger.error(f"信息价值评估失败: {e}", exc_info=True)
|
||||||
return 0.5 # 默认中等价值
|
return 0.5 # 默认中等价值
|
||||||
|
|
||||||
async def _store_memories(self, memory_chunks: List[MemoryChunk]):
|
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
|
||||||
"""存储记忆块到各个存储系统"""
|
"""存储记忆块到各个存储系统,返回成功入库数量"""
|
||||||
if not memory_chunks:
|
if not memory_chunks:
|
||||||
return
|
return 0
|
||||||
|
|
||||||
|
unique_memories: List[MemoryChunk] = []
|
||||||
|
skipped_duplicates = 0
|
||||||
|
|
||||||
|
for memory in memory_chunks:
|
||||||
|
fingerprint = self._build_memory_fingerprint(memory)
|
||||||
|
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||||
|
|
||||||
|
existing_id = self._memory_fingerprints.get(key)
|
||||||
|
if existing_id:
|
||||||
|
existing = self.vector_storage.memory_cache.get(existing_id)
|
||||||
|
if existing:
|
||||||
|
self._merge_existing_memory(existing, memory)
|
||||||
|
await self.metadata_index.update_memory_entry(existing)
|
||||||
|
skipped_duplicates += 1
|
||||||
|
logger.debug(
|
||||||
|
"检测到重复记忆,已合并到现有记录 | memory_id=%s",
|
||||||
|
existing.memory_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 指纹存在但缓存缺失,视为新记忆并覆盖旧映射
|
||||||
|
logger.debug("检测到过期指纹映射,重写现有条目")
|
||||||
|
|
||||||
|
unique_memories.append(memory)
|
||||||
|
|
||||||
|
if not unique_memories:
|
||||||
|
if skipped_duplicates:
|
||||||
|
logger.info("本次记忆全部与现有内容重复,跳过入库")
|
||||||
|
return 0
|
||||||
|
|
||||||
# 并行存储到向量数据库和元数据索引
|
# 并行存储到向量数据库和元数据索引
|
||||||
storage_tasks = []
|
storage_tasks = [
|
||||||
|
self.vector_storage.store_memories(unique_memories),
|
||||||
|
self.metadata_index.index_memories(unique_memories),
|
||||||
|
]
|
||||||
|
|
||||||
# 向量存储
|
|
||||||
storage_tasks.append(self.vector_storage.store_memories(memory_chunks))
|
|
||||||
|
|
||||||
# 元数据索引
|
|
||||||
storage_tasks.append(self.metadata_index.index_memories(memory_chunks))
|
|
||||||
|
|
||||||
# 等待所有存储任务完成
|
|
||||||
await asyncio.gather(*storage_tasks, return_exceptions=True)
|
await asyncio.gather(*storage_tasks, return_exceptions=True)
|
||||||
|
|
||||||
logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统")
|
self._register_memory_fingerprints(unique_memories)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"成功存储 %d 条记忆(跳过重复 %d 条)",
|
||||||
|
len(unique_memories),
|
||||||
|
skipped_duplicates,
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(unique_memories)
|
||||||
|
|
||||||
|
def _merge_existing_memory(self, existing: MemoryChunk, incoming: MemoryChunk) -> None:
|
||||||
|
"""将新记忆的信息合并到已存在的记忆中"""
|
||||||
|
updated = False
|
||||||
|
|
||||||
|
for keyword in incoming.keywords:
|
||||||
|
if keyword not in existing.keywords:
|
||||||
|
existing.add_keyword(keyword)
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
for tag in incoming.tags:
|
||||||
|
if tag not in existing.tags:
|
||||||
|
existing.add_tag(tag)
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
for category in incoming.categories:
|
||||||
|
if category not in existing.categories:
|
||||||
|
existing.add_category(category)
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
if incoming.metadata.source_context:
|
||||||
|
existing.metadata.source_context = incoming.metadata.source_context
|
||||||
|
|
||||||
|
if incoming.metadata.importance.value > existing.metadata.importance.value:
|
||||||
|
existing.metadata.importance = incoming.metadata.importance
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
if incoming.metadata.confidence.value > existing.metadata.confidence.value:
|
||||||
|
existing.metadata.confidence = incoming.metadata.confidence
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
if incoming.metadata.relevance_score > existing.metadata.relevance_score:
|
||||||
|
existing.metadata.relevance_score = incoming.metadata.relevance_score
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
existing.metadata.last_modified = time.time()
|
||||||
|
|
||||||
|
def _populate_memory_fingerprints(self) -> None:
|
||||||
|
"""基于当前缓存构建记忆指纹映射"""
|
||||||
|
self._memory_fingerprints.clear()
|
||||||
|
for memory in self.vector_storage.memory_cache.values():
|
||||||
|
fingerprint = self._build_memory_fingerprint(memory)
|
||||||
|
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||||
|
self._memory_fingerprints[key] = memory.memory_id
|
||||||
|
|
||||||
|
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
|
||||||
|
for memory in memories:
|
||||||
|
fingerprint = self._build_memory_fingerprint(memory)
|
||||||
|
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||||
|
self._memory_fingerprints[key] = memory.memory_id
|
||||||
|
|
||||||
|
def _build_memory_fingerprint(self, memory: MemoryChunk) -> str:
|
||||||
|
subjects = memory.subjects or []
|
||||||
|
subject_part = "|".join(sorted(s.strip() for s in subjects if s))
|
||||||
|
predicate_part = (memory.content.predicate or "").strip()
|
||||||
|
|
||||||
|
obj = memory.content.object
|
||||||
|
if isinstance(obj, (dict, list)):
|
||||||
|
obj_part = orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||||
|
else:
|
||||||
|
obj_part = str(obj).strip()
|
||||||
|
|
||||||
|
base = "|".join([
|
||||||
|
str(memory.user_id or "unknown"),
|
||||||
|
memory.memory_type.value,
|
||||||
|
subject_part,
|
||||||
|
predicate_part,
|
||||||
|
obj_part,
|
||||||
|
])
|
||||||
|
|
||||||
|
return hashlib.sha256(base.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
|
||||||
|
return f"{str(user_id)}:{fingerprint}"
|
||||||
|
|
||||||
def get_system_stats(self) -> Dict[str, Any]:
|
def get_system_stats(self) -> Dict[str, Any]:
|
||||||
"""获取系统统计信息"""
|
"""获取系统统计信息"""
|
||||||
|
|||||||
@@ -241,12 +241,12 @@ class EnhancedMemoryManager:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self.enhanced_system.process_conversation_memory(
|
payload_context = dict(context or {})
|
||||||
conversation_text=conversation_text,
|
payload_context.setdefault("conversation_text", conversation_text)
|
||||||
context=context,
|
if timestamp is not None:
|
||||||
user_id=user_id,
|
payload_context.setdefault("timestamp", timestamp)
|
||||||
timestamp=timestamp
|
|
||||||
)
|
result = await self.enhanced_system.process_conversation_memory(payload_context)
|
||||||
|
|
||||||
# 从结果中提取记忆块
|
# 从结果中提取记忆块
|
||||||
memory_chunks = []
|
memory_chunks = []
|
||||||
@@ -274,7 +274,7 @@ class EnhancedMemoryManager:
|
|||||||
try:
|
try:
|
||||||
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
|
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
|
||||||
query=query_text,
|
query=query_text,
|
||||||
user_id=user_id,
|
user_id=None,
|
||||||
context=context or {},
|
context=context or {},
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
@@ -303,6 +303,9 @@ class EnhancedMemoryManager:
|
|||||||
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
|
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
|
||||||
"""将记忆块转换为更易读的文本描述"""
|
"""将记忆块转换为更易读的文本描述"""
|
||||||
structure = memory.content.to_dict()
|
structure = memory.content.to_dict()
|
||||||
|
if memory.display:
|
||||||
|
return self._clean_text(memory.display), structure
|
||||||
|
|
||||||
subject = structure.get("subject")
|
subject = structure.get("subject")
|
||||||
predicate = structure.get("predicate") or ""
|
predicate = structure.get("predicate") or ""
|
||||||
obj = structure.get("object")
|
obj = structure.get("object")
|
||||||
|
|||||||
@@ -114,12 +114,9 @@ class MemoryIntegrationLayer:
|
|||||||
|
|
||||||
async def process_conversation(
|
async def process_conversation(
|
||||||
self,
|
self,
|
||||||
conversation_text: str,
|
context: Dict[str, Any]
|
||||||
context: Dict[str, Any],
|
|
||||||
user_id: str,
|
|
||||||
timestamp: Optional[float] = None
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""处理对话记忆"""
|
"""处理对话记忆,仅使用上下文信息"""
|
||||||
if not self._initialized or not self.enhanced_memory:
|
if not self._initialized or not self.enhanced_memory:
|
||||||
return {"success": False, "error": "Memory system not available"}
|
return {"success": False, "error": "Memory system not available"}
|
||||||
|
|
||||||
@@ -128,13 +125,12 @@ class MemoryIntegrationLayer:
|
|||||||
self.integration_stats["enhanced_queries"] += 1
|
self.integration_stats["enhanced_queries"] += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
payload_context = dict(context or {})
|
||||||
|
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
|
||||||
|
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
|
||||||
|
|
||||||
# 直接使用增强记忆系统处理
|
# 直接使用增强记忆系统处理
|
||||||
result = await self.enhanced_memory.process_conversation_memory(
|
result = await self.enhanced_memory.process_conversation_memory(payload_context)
|
||||||
conversation_text=conversation_text,
|
|
||||||
context=context,
|
|
||||||
user_id=user_id,
|
|
||||||
timestamp=timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新统计
|
# 更新统计
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
@@ -156,7 +152,7 @@ class MemoryIntegrationLayer:
|
|||||||
async def retrieve_relevant_memories(
|
async def retrieve_relevant_memories(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
user_id: str,
|
user_id: Optional[str] = None,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
limit: Optional[int] = None
|
limit: Optional[int] = None
|
||||||
) -> List[MemoryChunk]:
|
) -> List[MemoryChunk]:
|
||||||
@@ -168,7 +164,7 @@ class MemoryIntegrationLayer:
|
|||||||
limit = limit or self.config.max_retrieval_results
|
limit = limit or self.config.max_retrieval_results
|
||||||
memories = await self.enhanced_memory.retrieve_relevant_memories(
|
memories = await self.enhanced_memory.retrieve_relevant_memories(
|
||||||
query=query,
|
query=query,
|
||||||
user_id=user_id,
|
user_id=None,
|
||||||
context=context or {},
|
context=context or {},
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,21 +2,48 @@
|
|||||||
"""
|
"""
|
||||||
记忆构建模块
|
记忆构建模块
|
||||||
从对话流中提取高质量、结构化记忆单元
|
从对话流中提取高质量、结构化记忆单元
|
||||||
|
输出格式要求:
|
||||||
|
{{
|
||||||
|
"memories": [
|
||||||
|
{{
|
||||||
|
"type": "记忆类型",
|
||||||
|
"display": "用于直接展示和检索的自然语言描述",
|
||||||
|
"subject": ["主体1", "主体2"],
|
||||||
|
"predicate": "谓语(动作/状态)",
|
||||||
|
"object": "宾语(对象/属性或结构体)",
|
||||||
|
"keywords": ["关键词1", "关键词2"],
|
||||||
|
"importance": "重要性等级(1-4)",
|
||||||
|
"confidence": "置信度(1-4)",
|
||||||
|
"reasoning": "提取理由"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
注意:
|
||||||
|
1. `subject` 可包含多个主体,请用数组表示;若主体不明确,请根据上下文给出最合理的称呼
|
||||||
|
2. `display` 必须是一句完整流畅的中文描述,可直接用于用户展示和向量搜索
|
||||||
|
3. 只提取确实值得记忆的信息,不要提取琐碎内容
|
||||||
|
4. 确保信息准确、具体、有价值
|
||||||
|
5. 重要性: 1=低, 2=一般, 3=高, 4=关键;置信度: 1=低, 2=中等, 3=高, 4=已验证
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import orjson
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
from datetime import datetime
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Union, Type
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.memory_system.memory_chunk import (
|
from src.chat.memory_system.memory_chunk import (
|
||||||
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
|
MemoryChunk,
|
||||||
create_memory_chunk
|
MemoryType,
|
||||||
|
ConfidenceLevel,
|
||||||
|
ImportanceLevel,
|
||||||
|
create_memory_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -24,6 +51,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
class ExtractionStrategy(Enum):
|
class ExtractionStrategy(Enum):
|
||||||
"""提取策略"""
|
"""提取策略"""
|
||||||
|
|
||||||
LLM_BASED = "llm_based" # 基于LLM的智能提取
|
LLM_BASED = "llm_based" # 基于LLM的智能提取
|
||||||
RULE_BASED = "rule_based" # 基于规则的提取
|
RULE_BASED = "rule_based" # 基于规则的提取
|
||||||
HYBRID = "hybrid" # 混合策略
|
HYBRID = "hybrid" # 混合策略
|
||||||
@@ -171,18 +199,18 @@ class MemoryBuilder:
|
|||||||
"""使用规则提取记忆"""
|
"""使用规则提取记忆"""
|
||||||
memories = []
|
memories = []
|
||||||
|
|
||||||
subject_display = self._resolve_user_display(context, user_id)
|
subjects = self._resolve_conversation_participants(context, user_id)
|
||||||
|
|
||||||
# 规则1: 检测个人信息
|
# 规则1: 检测个人信息
|
||||||
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subject_display)
|
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects)
|
||||||
memories.extend(personal_info)
|
memories.extend(personal_info)
|
||||||
|
|
||||||
# 规则2: 检测偏好信息
|
# 规则2: 检测偏好信息
|
||||||
preferences = self._extract_preferences(text, user_id, timestamp, context, subject_display)
|
preferences = self._extract_preferences(text, user_id, timestamp, context, subjects)
|
||||||
memories.extend(preferences)
|
memories.extend(preferences)
|
||||||
|
|
||||||
# 规则3: 检测事件信息
|
# 规则3: 检测事件信息
|
||||||
events = self._extract_events(text, user_id, timestamp, context, subject_display)
|
events = self._extract_events(text, user_id, timestamp, context, subjects)
|
||||||
memories.extend(events)
|
memories.extend(events)
|
||||||
|
|
||||||
return memories
|
return memories
|
||||||
@@ -258,10 +286,7 @@ class MemoryBuilder:
|
|||||||
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
|
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
|
||||||
|
|
||||||
当前时间: {current_date}
|
当前时间: {current_date}
|
||||||
聊天ID: {chat_id}
|
|
||||||
消息类型: {message_type}
|
消息类型: {message_type}
|
||||||
目标用户ID: {target_user_id_display}
|
|
||||||
目标用户称呼: {target_user_name}
|
|
||||||
|
|
||||||
## 🤖 机器人身份(仅供参考,禁止写入记忆)
|
## 🤖 机器人身份(仅供参考,禁止写入记忆)
|
||||||
- 机器人名称: {bot_name_display}
|
- 机器人名称: {bot_name_display}
|
||||||
@@ -272,7 +297,6 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
请务必遵守以下命名规范:
|
请务必遵守以下命名规范:
|
||||||
- 当说话者是机器人时,请使用“{bot_name_display}”或其他明确称呼作为主语;
|
- 当说话者是机器人时,请使用“{bot_name_display}”或其他明确称呼作为主语;
|
||||||
- 如果看到系统自动生成的长ID(类似 {target_user_id}),请改用“{target_user_name}”、机器人的称呼或“该用户”描述,不要把ID写入记忆;
|
|
||||||
- 记录关键事实时,请准确标记主体是机器人还是用户,避免混淆。
|
- 记录关键事实时,请准确标记主体是机器人还是用户,避免混淆。
|
||||||
|
|
||||||
对话内容:
|
对话内容:
|
||||||
@@ -450,7 +474,7 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
bot_identifiers = self._collect_bot_identifiers(context)
|
bot_identifiers = self._collect_bot_identifiers(context)
|
||||||
system_identifiers = self._collect_system_identifiers(context)
|
system_identifiers = self._collect_system_identifiers(context)
|
||||||
default_subject = self._resolve_user_display(context, user_id)
|
default_subjects = self._resolve_conversation_participants(context, user_id)
|
||||||
|
|
||||||
bot_display = None
|
bot_display = None
|
||||||
if context:
|
if context:
|
||||||
@@ -481,19 +505,33 @@ class MemoryBuilder:
|
|||||||
for mem_data in memory_list:
|
for mem_data in memory_list:
|
||||||
try:
|
try:
|
||||||
subject_value = mem_data.get("subject")
|
subject_value = mem_data.get("subject")
|
||||||
normalized_subject = self._normalize_subject(
|
normalized_subject = self._normalize_subjects(
|
||||||
subject_value,
|
subject_value,
|
||||||
bot_identifiers,
|
bot_identifiers,
|
||||||
system_identifiers,
|
system_identifiers,
|
||||||
default_subject,
|
default_subjects,
|
||||||
bot_display
|
bot_display
|
||||||
)
|
)
|
||||||
|
|
||||||
if normalized_subject is None:
|
if not normalized_subject:
|
||||||
logger.debug("跳过疑似机器人自身信息的记忆: %s", mem_data)
|
logger.debug("跳过疑似机器人自身信息的记忆: %s", mem_data)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 创建记忆块
|
# 创建记忆块
|
||||||
|
importance_level = self._parse_enum_value(
|
||||||
|
ImportanceLevel,
|
||||||
|
mem_data.get("importance"),
|
||||||
|
ImportanceLevel.NORMAL,
|
||||||
|
"importance"
|
||||||
|
)
|
||||||
|
|
||||||
|
confidence_level = self._parse_enum_value(
|
||||||
|
ConfidenceLevel,
|
||||||
|
mem_data.get("confidence"),
|
||||||
|
ConfidenceLevel.MEDIUM,
|
||||||
|
"confidence"
|
||||||
|
)
|
||||||
|
|
||||||
memory = create_memory_chunk(
|
memory = create_memory_chunk(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
subject=normalized_subject,
|
subject=normalized_subject,
|
||||||
@@ -502,8 +540,9 @@ class MemoryBuilder:
|
|||||||
memory_type=MemoryType(mem_data.get("type", "contextual")),
|
memory_type=MemoryType(mem_data.get("type", "contextual")),
|
||||||
chat_id=context.get("chat_id"),
|
chat_id=context.get("chat_id"),
|
||||||
source_context=mem_data.get("reasoning", ""),
|
source_context=mem_data.get("reasoning", ""),
|
||||||
importance=ImportanceLevel(mem_data.get("importance", 2)),
|
importance=importance_level,
|
||||||
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
|
confidence=confidence_level,
|
||||||
|
display=mem_data.get("display")
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加关键词
|
# 添加关键词
|
||||||
@@ -511,13 +550,6 @@ class MemoryBuilder:
|
|||||||
for keyword in keywords:
|
for keyword in keywords:
|
||||||
memory.add_keyword(keyword)
|
memory.add_keyword(keyword)
|
||||||
|
|
||||||
subject_text = memory.content.subject.strip() if isinstance(memory.content.subject, str) else str(memory.content.subject)
|
|
||||||
if not subject_text:
|
|
||||||
memory.content.subject = default_subject
|
|
||||||
elif subject_text.lower() in system_identifiers or self._looks_like_system_identifier(subject_text):
|
|
||||||
logger.debug("将系统标识主语替换为默认用户名称: %s", subject_text)
|
|
||||||
memory.content.subject = default_subject
|
|
||||||
|
|
||||||
memories.append(memory)
|
memories.append(memory)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -526,6 +558,64 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
return memories
|
return memories
|
||||||
|
|
||||||
|
def _parse_enum_value(
|
||||||
|
self,
|
||||||
|
enum_cls: Type[Enum],
|
||||||
|
raw_value: Any,
|
||||||
|
default: Enum,
|
||||||
|
field_name: str
|
||||||
|
) -> Enum:
|
||||||
|
"""解析枚举值,兼容数字/字符串表示"""
|
||||||
|
if isinstance(raw_value, enum_cls):
|
||||||
|
return raw_value
|
||||||
|
|
||||||
|
if raw_value is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# 直接尝试整数转换
|
||||||
|
if isinstance(raw_value, (int, float)):
|
||||||
|
int_value = int(raw_value)
|
||||||
|
try:
|
||||||
|
return enum_cls(int_value)
|
||||||
|
except ValueError:
|
||||||
|
logger.debug("%s=%s 无法解析为 %s", field_name, raw_value, enum_cls.__name__)
|
||||||
|
return default
|
||||||
|
|
||||||
|
if isinstance(raw_value, str):
|
||||||
|
value_str = raw_value.strip()
|
||||||
|
if not value_str:
|
||||||
|
return default
|
||||||
|
|
||||||
|
if value_str.isdigit():
|
||||||
|
try:
|
||||||
|
return enum_cls(int(value_str))
|
||||||
|
except ValueError:
|
||||||
|
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
|
||||||
|
else:
|
||||||
|
normalized = value_str.replace("-", "_").replace(" ", "_").upper()
|
||||||
|
for member in enum_cls:
|
||||||
|
if member.name == normalized:
|
||||||
|
return member
|
||||||
|
for member in enum_cls:
|
||||||
|
if str(member.value).lower() == value_str.lower():
|
||||||
|
return member
|
||||||
|
|
||||||
|
try:
|
||||||
|
return enum_cls(value_str)
|
||||||
|
except ValueError:
|
||||||
|
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return enum_cls(raw_value)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
|
||||||
|
field_name,
|
||||||
|
raw_value,
|
||||||
|
type(raw_value).__name__,
|
||||||
|
enum_cls.__name__,
|
||||||
|
default.name)
|
||||||
|
return default
|
||||||
|
|
||||||
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||||
identifiers: set[str] = {"bot", "机器人", "ai助手"}
|
identifiers: set[str] = {"bot", "机器人", "ai助手"}
|
||||||
if not context:
|
if not context:
|
||||||
@@ -580,6 +670,58 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
return identifiers
|
return identifiers
|
||||||
|
|
||||||
|
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
|
||||||
|
participants: List[str] = []
|
||||||
|
|
||||||
|
if context:
|
||||||
|
candidate_keys = [
|
||||||
|
"participants",
|
||||||
|
"participant_names",
|
||||||
|
"speaker_names",
|
||||||
|
"members",
|
||||||
|
"member_names",
|
||||||
|
"mention_users",
|
||||||
|
"audiences"
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in candidate_keys:
|
||||||
|
value = context.get(key)
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, str):
|
||||||
|
cleaned = self._clean_subject_text(item)
|
||||||
|
if cleaned:
|
||||||
|
participants.append(cleaned)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
for part in self._split_subject_string(value):
|
||||||
|
if part:
|
||||||
|
participants.append(part)
|
||||||
|
|
||||||
|
fallback = self._resolve_user_display(context, user_id)
|
||||||
|
if fallback:
|
||||||
|
participants.append(fallback)
|
||||||
|
|
||||||
|
if context:
|
||||||
|
bot_name = context.get("bot_name") or context.get("bot_identity")
|
||||||
|
if isinstance(bot_name, str):
|
||||||
|
cleaned = self._clean_subject_text(bot_name)
|
||||||
|
if cleaned:
|
||||||
|
participants.append(cleaned)
|
||||||
|
|
||||||
|
if not participants:
|
||||||
|
participants = ["对话参与者"]
|
||||||
|
|
||||||
|
deduplicated: List[str] = []
|
||||||
|
seen = set()
|
||||||
|
for name in participants:
|
||||||
|
key = name.lower()
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
deduplicated.append(name)
|
||||||
|
|
||||||
|
return deduplicated
|
||||||
|
|
||||||
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
|
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
|
||||||
candidate_keys = [
|
candidate_keys = [
|
||||||
"user_display_name",
|
"user_display_name",
|
||||||
@@ -626,51 +768,160 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _normalize_subject(
|
def _split_subject_string(self, value: str) -> List[str]:
|
||||||
|
if not value:
|
||||||
|
return []
|
||||||
|
|
||||||
|
replaced = re.sub(r"\band\b", "、", value, flags=re.IGNORECASE)
|
||||||
|
replaced = replaced.replace("和", "、").replace("与", "、").replace("及", "、")
|
||||||
|
replaced = replaced.replace("&", "、").replace("/", "、").replace("+", "、")
|
||||||
|
|
||||||
|
tokens = [self._clean_subject_text(token) for token in re.split(r"[、,,;;]+", replaced)]
|
||||||
|
return [token for token in tokens if token]
|
||||||
|
|
||||||
|
def _normalize_subjects(
|
||||||
self,
|
self,
|
||||||
subject: Any,
|
subject: Any,
|
||||||
bot_identifiers: set[str],
|
bot_identifiers: set[str],
|
||||||
system_identifiers: set[str],
|
system_identifiers: set[str],
|
||||||
default_subject: str,
|
default_subjects: List[str],
|
||||||
bot_display: Optional[str] = None
|
bot_display: Optional[str] = None
|
||||||
) -> Optional[str]:
|
) -> List[str]:
|
||||||
if subject is None:
|
defaults = default_subjects or ["对话参与者"]
|
||||||
return default_subject
|
|
||||||
|
|
||||||
subject_str = subject if isinstance(subject, str) else str(subject)
|
raw_candidates: List[str] = []
|
||||||
cleaned = self._clean_subject_text(subject_str)
|
if isinstance(subject, list):
|
||||||
if not cleaned:
|
for item in subject:
|
||||||
return default_subject
|
if isinstance(item, str):
|
||||||
|
raw_candidates.extend(self._split_subject_string(item))
|
||||||
|
elif item is not None:
|
||||||
|
raw_candidates.extend(self._split_subject_string(str(item)))
|
||||||
|
elif isinstance(subject, str):
|
||||||
|
raw_candidates.extend(self._split_subject_string(subject))
|
||||||
|
elif subject is not None:
|
||||||
|
raw_candidates.extend(self._split_subject_string(str(subject)))
|
||||||
|
|
||||||
lowered = cleaned.lower()
|
normalized: List[str] = []
|
||||||
bot_primary = self._clean_subject_text(bot_display or "")
|
bot_primary = self._clean_subject_text(bot_display or "")
|
||||||
|
|
||||||
if lowered in bot_identifiers:
|
for candidate in raw_candidates:
|
||||||
return bot_primary or cleaned
|
if not candidate:
|
||||||
|
continue
|
||||||
|
|
||||||
if lowered in {"用户", "user", "the user", "对方", "对手"}:
|
lowered = candidate.lower()
|
||||||
return default_subject
|
if lowered in bot_identifiers:
|
||||||
|
normalized.append(bot_primary or candidate)
|
||||||
|
continue
|
||||||
|
|
||||||
prefix_match = re.match(r"^(用户|User|user|USER|成员|member|Member|target|Target|TARGET)[\s::\-\u2014_]*?(.*)$", cleaned)
|
if lowered in {"用户", "user", "the user", "对方", "对手"}:
|
||||||
if prefix_match:
|
normalized.extend(defaults)
|
||||||
remainder = self._clean_subject_text(prefix_match.group(2))
|
continue
|
||||||
if not remainder:
|
|
||||||
return default_subject
|
|
||||||
remainder_lower = remainder.lower()
|
|
||||||
if remainder_lower in bot_identifiers:
|
|
||||||
return bot_primary or remainder
|
|
||||||
if (
|
|
||||||
remainder_lower in system_identifiers
|
|
||||||
or self._looks_like_system_identifier(remainder)
|
|
||||||
):
|
|
||||||
return default_subject
|
|
||||||
cleaned = remainder
|
|
||||||
lowered = cleaned.lower()
|
|
||||||
|
|
||||||
if lowered in system_identifiers or self._looks_like_system_identifier(cleaned):
|
if lowered in system_identifiers or self._looks_like_system_identifier(candidate):
|
||||||
return default_subject
|
continue
|
||||||
|
|
||||||
return cleaned
|
normalized.append(candidate)
|
||||||
|
|
||||||
|
if not normalized:
|
||||||
|
normalized = list(defaults)
|
||||||
|
|
||||||
|
deduplicated: List[str] = []
|
||||||
|
seen = set()
|
||||||
|
for name in normalized:
|
||||||
|
key = name.lower()
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
deduplicated.append(name)
|
||||||
|
|
||||||
|
return deduplicated
|
||||||
|
|
||||||
|
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
for key in keys:
|
||||||
|
value = obj.get(key)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if isinstance(value, list):
|
||||||
|
compact = "、".join(str(item) for item in value[:3])
|
||||||
|
if compact:
|
||||||
|
return compact
|
||||||
|
else:
|
||||||
|
value_str = str(value).strip()
|
||||||
|
if value_str:
|
||||||
|
return value_str
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
compact = "、".join(str(item) for item in obj[:3])
|
||||||
|
return compact or None
|
||||||
|
elif isinstance(obj, str):
|
||||||
|
return obj.strip() or None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str:
|
||||||
|
subject_phrase = "、".join(subjects) if subjects else "对话参与者"
|
||||||
|
predicate = (predicate or "").strip()
|
||||||
|
|
||||||
|
if predicate == "is_named":
|
||||||
|
name = self._extract_value_from_object(obj, ["name", "nickname"]) or ""
|
||||||
|
name = self._clean_subject_text(name)
|
||||||
|
if name:
|
||||||
|
quoted = name if (name.startswith("「") and name.endswith("」")) else f"「{name}」"
|
||||||
|
return f"{subject_phrase}的昵称是{quoted}"
|
||||||
|
elif predicate == "is_age":
|
||||||
|
age = self._extract_value_from_object(obj, ["age"]) or ""
|
||||||
|
age = self._clean_subject_text(age)
|
||||||
|
if age:
|
||||||
|
return f"{subject_phrase}今年{age}岁"
|
||||||
|
elif predicate == "is_profession":
|
||||||
|
profession = self._extract_value_from_object(obj, ["profession", "job"]) or ""
|
||||||
|
profession = self._clean_subject_text(profession)
|
||||||
|
if profession:
|
||||||
|
return f"{subject_phrase}的职业是{profession}"
|
||||||
|
elif predicate == "lives_in":
|
||||||
|
location = self._extract_value_from_object(obj, ["location", "city", "place"]) or ""
|
||||||
|
location = self._clean_subject_text(location)
|
||||||
|
if location:
|
||||||
|
return f"{subject_phrase}居住在{location}"
|
||||||
|
elif predicate == "has_phone":
|
||||||
|
phone = self._extract_value_from_object(obj, ["phone", "number"]) or ""
|
||||||
|
phone = self._clean_subject_text(phone)
|
||||||
|
if phone:
|
||||||
|
return f"{subject_phrase}的电话号码是{phone}"
|
||||||
|
elif predicate == "has_email":
|
||||||
|
email = self._extract_value_from_object(obj, ["email"]) or ""
|
||||||
|
email = self._clean_subject_text(email)
|
||||||
|
if email:
|
||||||
|
return f"{subject_phrase}的邮箱是{email}"
|
||||||
|
elif predicate in {"likes", "likes_food", "favorite_is"}:
|
||||||
|
liked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
|
||||||
|
liked = self._clean_subject_text(liked)
|
||||||
|
if liked:
|
||||||
|
verb = "喜欢" if predicate != "likes_food" else "爱吃"
|
||||||
|
if predicate == "favorite_is":
|
||||||
|
verb = "最喜欢"
|
||||||
|
return f"{subject_phrase}{verb}{liked}"
|
||||||
|
elif predicate in {"dislikes", "hates"}:
|
||||||
|
disliked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
|
||||||
|
disliked = self._clean_subject_text(disliked)
|
||||||
|
if disliked:
|
||||||
|
verb = "不喜欢" if predicate == "dislikes" else "讨厌"
|
||||||
|
return f"{subject_phrase}{verb}{disliked}"
|
||||||
|
elif predicate == "mentioned_event":
|
||||||
|
description = self._extract_value_from_object(obj, ["event_text", "description"]) or ""
|
||||||
|
description = self._clean_subject_text(description)
|
||||||
|
if description:
|
||||||
|
return f"{subject_phrase}提到了:{description}"
|
||||||
|
|
||||||
|
obj_text = self._extract_value_from_object(obj, ["value", "detail", "content"]) or ""
|
||||||
|
obj_text = self._clean_subject_text(obj_text)
|
||||||
|
|
||||||
|
if predicate and obj_text:
|
||||||
|
return f"{subject_phrase}{predicate}{obj_text}".strip()
|
||||||
|
if obj_text:
|
||||||
|
return f"{subject_phrase}{obj_text}".strip()
|
||||||
|
if predicate:
|
||||||
|
return f"{subject_phrase}{predicate}".strip()
|
||||||
|
return subject_phrase
|
||||||
|
|
||||||
def _extract_personal_info(
|
def _extract_personal_info(
|
||||||
self,
|
self,
|
||||||
@@ -678,7 +929,7 @@ class MemoryBuilder:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
timestamp: float,
|
timestamp: float,
|
||||||
context: Dict[str, Any],
|
context: Dict[str, Any],
|
||||||
subject_display: str
|
subjects: List[str]
|
||||||
) -> List[MemoryChunk]:
|
) -> List[MemoryChunk]:
|
||||||
"""提取个人信息"""
|
"""提取个人信息"""
|
||||||
memories = []
|
memories = []
|
||||||
@@ -702,13 +953,14 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
memory = create_memory_chunk(
|
memory = create_memory_chunk(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
subject=subject_display,
|
subject=subjects,
|
||||||
predicate=predicate,
|
predicate=predicate,
|
||||||
obj=obj,
|
obj=obj,
|
||||||
memory_type=MemoryType.PERSONAL_FACT,
|
memory_type=MemoryType.PERSONAL_FACT,
|
||||||
chat_id=context.get("chat_id"),
|
chat_id=context.get("chat_id"),
|
||||||
importance=ImportanceLevel.HIGH,
|
importance=ImportanceLevel.HIGH,
|
||||||
confidence=ConfidenceLevel.HIGH
|
confidence=ConfidenceLevel.HIGH,
|
||||||
|
display=self._compose_display_text(subjects, predicate, obj)
|
||||||
)
|
)
|
||||||
|
|
||||||
memories.append(memory)
|
memories.append(memory)
|
||||||
@@ -721,7 +973,7 @@ class MemoryBuilder:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
timestamp: float,
|
timestamp: float,
|
||||||
context: Dict[str, Any],
|
context: Dict[str, Any],
|
||||||
subject_display: str
|
subjects: List[str]
|
||||||
) -> List[MemoryChunk]:
|
) -> List[MemoryChunk]:
|
||||||
"""提取偏好信息"""
|
"""提取偏好信息"""
|
||||||
memories = []
|
memories = []
|
||||||
@@ -740,13 +992,14 @@ class MemoryBuilder:
|
|||||||
if match:
|
if match:
|
||||||
memory = create_memory_chunk(
|
memory = create_memory_chunk(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
subject=subject_display,
|
subject=subjects,
|
||||||
predicate=predicate,
|
predicate=predicate,
|
||||||
obj=match.group(1),
|
obj=match.group(1),
|
||||||
memory_type=MemoryType.PREFERENCE,
|
memory_type=MemoryType.PREFERENCE,
|
||||||
chat_id=context.get("chat_id"),
|
chat_id=context.get("chat_id"),
|
||||||
importance=ImportanceLevel.NORMAL,
|
importance=ImportanceLevel.NORMAL,
|
||||||
confidence=ConfidenceLevel.MEDIUM
|
confidence=ConfidenceLevel.MEDIUM,
|
||||||
|
display=self._compose_display_text(subjects, predicate, match.group(1))
|
||||||
)
|
)
|
||||||
|
|
||||||
memories.append(memory)
|
memories.append(memory)
|
||||||
@@ -759,7 +1012,7 @@ class MemoryBuilder:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
timestamp: float,
|
timestamp: float,
|
||||||
context: Dict[str, Any],
|
context: Dict[str, Any],
|
||||||
subject_display: str
|
subjects: List[str]
|
||||||
) -> List[MemoryChunk]:
|
) -> List[MemoryChunk]:
|
||||||
"""提取事件信息"""
|
"""提取事件信息"""
|
||||||
memories = []
|
memories = []
|
||||||
@@ -770,13 +1023,14 @@ class MemoryBuilder:
|
|||||||
if any(keyword in text for keyword in event_keywords):
|
if any(keyword in text for keyword in event_keywords):
|
||||||
memory = create_memory_chunk(
|
memory = create_memory_chunk(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
subject=subject_display,
|
subject=subjects,
|
||||||
predicate="mentioned_event",
|
predicate="mentioned_event",
|
||||||
obj={"event_text": text, "timestamp": timestamp},
|
obj={"event_text": text, "timestamp": timestamp},
|
||||||
memory_type=MemoryType.EVENT,
|
memory_type=MemoryType.EVENT,
|
||||||
chat_id=context.get("chat_id"),
|
chat_id=context.get("chat_id"),
|
||||||
importance=ImportanceLevel.NORMAL,
|
importance=ImportanceLevel.NORMAL,
|
||||||
confidence=ConfidenceLevel.MEDIUM
|
confidence=ConfidenceLevel.MEDIUM,
|
||||||
|
display=self._compose_display_text(subjects, "mentioned_event", text)
|
||||||
)
|
)
|
||||||
|
|
||||||
memories.append(memory)
|
memories.append(memory)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import orjson
|
import orjson
|
||||||
from typing import Dict, List, Optional, Any, Union
|
from typing import Dict, List, Optional, Any, Union, Iterable
|
||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -52,17 +52,20 @@ class ImportanceLevel(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ContentStructure:
|
class ContentStructure:
|
||||||
"""主谓宾三元组结构"""
|
"""主谓宾结构,包含自然语言描述"""
|
||||||
subject: str # 主语(通常为用户)
|
|
||||||
predicate: str # 谓语(动作、状态、关系)
|
subject: Union[str, List[str]]
|
||||||
object: Union[str, Dict] # 宾语(对象、属性、值)
|
predicate: str
|
||||||
|
object: Union[str, Dict]
|
||||||
|
display: str = ""
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
return {
|
return {
|
||||||
"subject": self.subject,
|
"subject": self.subject,
|
||||||
"predicate": self.predicate,
|
"predicate": self.predicate,
|
||||||
"object": self.object
|
"object": self.object,
|
||||||
|
"display": self.display
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -71,16 +74,25 @@ class ContentStructure:
|
|||||||
return cls(
|
return cls(
|
||||||
subject=data.get("subject", ""),
|
subject=data.get("subject", ""),
|
||||||
predicate=data.get("predicate", ""),
|
predicate=data.get("predicate", ""),
|
||||||
object=data.get("object", "")
|
object=data.get("object", ""),
|
||||||
|
display=data.get("display", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_subject_list(self) -> List[str]:
|
||||||
|
"""将主语转换为列表形式"""
|
||||||
|
if isinstance(self.subject, list):
|
||||||
|
return [s for s in self.subject if isinstance(s, str) and s.strip()]
|
||||||
|
if isinstance(self.subject, str) and self.subject.strip():
|
||||||
|
return [self.subject.strip()]
|
||||||
|
return []
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""字符串表示"""
|
"""字符串表示"""
|
||||||
if isinstance(self.object, dict):
|
if self.display:
|
||||||
object_str = str(self.object)
|
return self.display
|
||||||
else:
|
subjects = "、".join(self.to_subject_list()) or str(self.subject)
|
||||||
object_str = str(self.object)
|
object_str = self.object if isinstance(self.object, str) else str(self.object)
|
||||||
return f"{self.subject} {self.predicate} {object_str}"
|
return f"{subjects} {self.predicate} {object_str}".strip()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -236,9 +248,19 @@ class MemoryChunk:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def text_content(self) -> str:
|
def text_content(self) -> str:
|
||||||
"""获取文本内容"""
|
"""获取文本内容(优先使用display)"""
|
||||||
return str(self.content)
|
return str(self.content)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display(self) -> str:
|
||||||
|
"""获取展示文本"""
|
||||||
|
return self.content.display or str(self.content)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subjects(self) -> List[str]:
|
||||||
|
"""获取主语列表"""
|
||||||
|
return self.content.to_subject_list()
|
||||||
|
|
||||||
def update_access(self):
|
def update_access(self):
|
||||||
"""更新访问信息"""
|
"""更新访问信息"""
|
||||||
self.metadata.update_access()
|
self.metadata.update_access()
|
||||||
@@ -415,16 +437,42 @@ class MemoryChunk:
|
|||||||
confidence_icon = "●" * self.metadata.confidence.value
|
confidence_icon = "●" * self.metadata.confidence.value
|
||||||
importance_icon = "★" * self.metadata.importance.value
|
importance_icon = "★" * self.metadata.importance.value
|
||||||
|
|
||||||
return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}"
|
return f"{emoji} [{self.memory_type.value}] {self.display} {confidence_icon} {importance_icon}"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""调试表示"""
|
"""调试表示"""
|
||||||
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
|
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
|
||||||
|
"""根据主谓宾生成自然语言描述"""
|
||||||
|
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
|
||||||
|
subject_part = "、".join(subjects_clean) if subjects_clean else "对话参与者"
|
||||||
|
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
object_candidates = []
|
||||||
|
for key, value in obj.items():
|
||||||
|
if isinstance(value, (str, int, float)):
|
||||||
|
object_candidates.append(f"{key}:{value}")
|
||||||
|
elif isinstance(value, list):
|
||||||
|
compact = "、".join(str(item) for item in value[:3])
|
||||||
|
object_candidates.append(f"{key}:{compact}")
|
||||||
|
object_part = ",".join(object_candidates) if object_candidates else str(obj)
|
||||||
|
else:
|
||||||
|
object_part = str(obj).strip()
|
||||||
|
|
||||||
|
predicate_clean = predicate.strip()
|
||||||
|
if not predicate_clean:
|
||||||
|
return f"{subject_part} {object_part}".strip()
|
||||||
|
|
||||||
|
if object_part:
|
||||||
|
return f"{subject_part}{predicate_clean}{object_part}".strip()
|
||||||
|
return f"{subject_part}{predicate_clean}".strip()
|
||||||
|
|
||||||
|
|
||||||
def create_memory_chunk(
|
def create_memory_chunk(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
subject: str,
|
subject: Union[str, List[str]],
|
||||||
predicate: str,
|
predicate: str,
|
||||||
obj: Union[str, Dict],
|
obj: Union[str, Dict],
|
||||||
memory_type: MemoryType,
|
memory_type: MemoryType,
|
||||||
@@ -432,6 +480,7 @@ def create_memory_chunk(
|
|||||||
source_context: Optional[str] = None,
|
source_context: Optional[str] = None,
|
||||||
importance: ImportanceLevel = ImportanceLevel.NORMAL,
|
importance: ImportanceLevel = ImportanceLevel.NORMAL,
|
||||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
|
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
|
||||||
|
display: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> MemoryChunk:
|
) -> MemoryChunk:
|
||||||
"""便捷的内存块创建函数"""
|
"""便捷的内存块创建函数"""
|
||||||
@@ -447,10 +496,22 @@ def create_memory_chunk(
|
|||||||
source_context=source_context
|
source_context=source_context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
subjects: List[str]
|
||||||
|
if isinstance(subject, list):
|
||||||
|
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
|
||||||
|
subject_payload: Union[str, List[str]] = subjects
|
||||||
|
else:
|
||||||
|
cleaned = subject.strip() if isinstance(subject, str) else ""
|
||||||
|
subjects = [cleaned] if cleaned else []
|
||||||
|
subject_payload = cleaned
|
||||||
|
|
||||||
|
display_text = display or _build_display_text(subjects, predicate, obj)
|
||||||
|
|
||||||
content = ContentStructure(
|
content = ContentStructure(
|
||||||
subject=subject,
|
subject=subject_payload,
|
||||||
predicate=predicate,
|
predicate=predicate,
|
||||||
object=obj
|
object=obj,
|
||||||
|
display=display_text
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk = MemoryChunk(
|
chunk = MemoryChunk(
|
||||||
|
|||||||
@@ -266,8 +266,12 @@ class MemoryFusionEngine:
|
|||||||
consistency_score = 0.0
|
consistency_score = 0.0
|
||||||
|
|
||||||
# 主语一致性
|
# 主语一致性
|
||||||
if mem1.content.subject == mem2.content.subject:
|
subjects1 = set(mem1.subjects)
|
||||||
consistency_score += 0.4
|
subjects2 = set(mem2.subjects)
|
||||||
|
if subjects1 or subjects2:
|
||||||
|
overlap = len(subjects1 & subjects2)
|
||||||
|
union_count = max(len(subjects1 | subjects2), 1)
|
||||||
|
consistency_score += (overlap / union_count) * 0.4
|
||||||
|
|
||||||
# 谓语相似性
|
# 谓语相似性
|
||||||
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)
|
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)
|
||||||
|
|||||||
@@ -282,9 +282,11 @@ class MemoryIntegrationHooks:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 使用增强记忆系统处理对话
|
# 使用增强记忆系统处理对话
|
||||||
result = await process_conversation_with_enhanced_memory(
|
memory_context = dict(context)
|
||||||
conversation_text, context, user_id
|
memory_context["conversation_text"] = conversation_text
|
||||||
)
|
memory_context["user_id"] = user_id
|
||||||
|
|
||||||
|
result = await process_conversation_with_enhanced_memory(memory_context)
|
||||||
|
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
self._update_hook_stats(processing_time)
|
self._update_hook_stats(processing_time)
|
||||||
@@ -336,9 +338,11 @@ class MemoryIntegrationHooks:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 使用增强记忆系统处理对话
|
# 使用增强记忆系统处理对话
|
||||||
result = await process_conversation_with_enhanced_memory(
|
memory_context = dict(context)
|
||||||
conversation_text, context, user_id
|
memory_context["conversation_text"] = conversation_text
|
||||||
)
|
memory_context["user_id"] = user_id
|
||||||
|
|
||||||
|
result = await process_conversation_with_enhanced_memory(memory_context)
|
||||||
|
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
self._update_hook_stats(processing_time)
|
self._update_hook_stats(processing_time)
|
||||||
|
|||||||
194
src/chat/memory_system/memory_query_planner.py
Normal file
194
src/chat/memory_system/memory_query_planner.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""记忆检索查询规划器"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
from src.chat.memory_system.memory_chunk import MemoryType
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryQueryPlan:
|
||||||
|
"""查询规划结果"""
|
||||||
|
|
||||||
|
semantic_query: str
|
||||||
|
memory_types: List[MemoryType] = field(default_factory=list)
|
||||||
|
subject_includes: List[str] = field(default_factory=list)
|
||||||
|
object_includes: List[str] = field(default_factory=list)
|
||||||
|
required_keywords: List[str] = field(default_factory=list)
|
||||||
|
optional_keywords: List[str] = field(default_factory=list)
|
||||||
|
owner_filters: List[str] = field(default_factory=list)
|
||||||
|
recency_preference: str = "any"
|
||||||
|
limit: int = 10
|
||||||
|
emphasis: Optional[str] = None
|
||||||
|
raw_plan: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
|
||||||
|
if not self.semantic_query:
|
||||||
|
self.semantic_query = fallback_query
|
||||||
|
if self.limit <= 0:
|
||||||
|
self.limit = default_limit
|
||||||
|
self.recency_preference = (self.recency_preference or "any").lower()
|
||||||
|
if self.recency_preference not in {"any", "recent", "historical"}:
|
||||||
|
self.recency_preference = "any"
|
||||||
|
self.emphasis = (self.emphasis or "balanced").lower()
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryQueryPlanner:
|
||||||
|
"""基于小模型的记忆检索查询规划器"""
|
||||||
|
|
||||||
|
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
|
||||||
|
self.model = planner_model
|
||||||
|
self.default_limit = default_limit
|
||||||
|
|
||||||
|
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
|
||||||
|
if not self.model:
|
||||||
|
logger.debug("未提供查询规划模型,使用默认规划")
|
||||||
|
return self._default_plan(query_text)
|
||||||
|
|
||||||
|
prompt = self._build_prompt(query_text, context)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response, _ = await self.model.generate_response_async(prompt, temperature=0.2)
|
||||||
|
payload = self._extract_json_payload(response)
|
||||||
|
if not payload:
|
||||||
|
logger.debug("查询规划模型未返回结构化结果,使用默认规划")
|
||||||
|
return self._default_plan(query_text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = orjson.loads(payload)
|
||||||
|
except orjson.JSONDecodeError as exc:
|
||||||
|
preview = payload[:200]
|
||||||
|
logger.warning("解析查询规划JSON失败: %s,片段: %s", exc, preview)
|
||||||
|
return self._default_plan(query_text)
|
||||||
|
|
||||||
|
plan = self._parse_plan_dict(data, query_text)
|
||||||
|
plan.ensure_defaults(query_text, self.default_limit)
|
||||||
|
return plan
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("查询规划模型调用失败: %s", exc, exc_info=True)
|
||||||
|
return self._default_plan(query_text)
|
||||||
|
|
||||||
|
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
|
||||||
|
return MemoryQueryPlan(
|
||||||
|
semantic_query=query_text,
|
||||||
|
limit=self.default_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||||
|
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
|
||||||
|
|
||||||
|
def _collect_list(key: str) -> List[str]:
|
||||||
|
value = data.get(key)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return [value]
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [self._safe_str(item) for item in value if self._safe_str(item)]
|
||||||
|
return []
|
||||||
|
|
||||||
|
memory_type_values = _collect_list("memory_types")
|
||||||
|
memory_types: List[MemoryType] = []
|
||||||
|
for item in memory_type_values:
|
||||||
|
if not item:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
memory_types.append(MemoryType(item))
|
||||||
|
except ValueError:
|
||||||
|
# 尝试匹配value值
|
||||||
|
normalized = item.lower()
|
||||||
|
for mt in MemoryType:
|
||||||
|
if mt.value == normalized:
|
||||||
|
memory_types.append(mt)
|
||||||
|
break
|
||||||
|
|
||||||
|
plan = MemoryQueryPlan(
|
||||||
|
semantic_query=semantic_query,
|
||||||
|
memory_types=memory_types,
|
||||||
|
subject_includes=_collect_list("subject_includes"),
|
||||||
|
object_includes=_collect_list("object_includes"),
|
||||||
|
required_keywords=_collect_list("required_keywords"),
|
||||||
|
optional_keywords=_collect_list("optional_keywords"),
|
||||||
|
owner_filters=_collect_list("owner_filters"),
|
||||||
|
recency_preference=self._safe_str(data.get("recency")) or "any",
|
||||||
|
limit=self._safe_int(data.get("limit"), self.default_limit),
|
||||||
|
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
|
||||||
|
raw_plan=data
|
||||||
|
)
|
||||||
|
return plan
|
||||||
|
|
||||||
|
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
|
||||||
|
participants = context.get("participants") or context.get("speaker_names") or []
|
||||||
|
if isinstance(participants, str):
|
||||||
|
participants = [participants]
|
||||||
|
participants = [p for p in participants if isinstance(p, str) and p.strip()]
|
||||||
|
participant_preview = "、".join(participants[:5]) or "未知"
|
||||||
|
|
||||||
|
persona = context.get("bot_personality") or context.get("bot_identity") or "未知"
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
你是一名记忆检索分析师,将根据对话查询生成结构化的检索计划。
|
||||||
|
请结合提供的上下文,输出一个JSON对象,字段含义如下:
|
||||||
|
- semantic_query: 提供给向量检索的自然语言查询,要求清晰具体;
|
||||||
|
- memory_types: 建议检索的记忆类型数组,取值范围参见 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual);
|
||||||
|
- subject_includes: 需要出现在记忆主语中的人或角色列表;
|
||||||
|
- object_includes: 记忆中需要提到的重要对象或主题关键词列表;
|
||||||
|
- required_keywords: 检索时必须包含的关键词;
|
||||||
|
- optional_keywords: 可以提升相关性的附加关键词;
|
||||||
|
- owner_filters: 如果需要限制检索所属主体,请列出用户ID或其它标识;
|
||||||
|
- recency: 建议的时间偏好,可选 recent/any/historical;
|
||||||
|
- emphasis: 检索策略倾向,可选 precision/recall/balanced;
|
||||||
|
- limit: 推荐的最大返回数量(1-15之间);
|
||||||
|
- notes: 额外说明,可选。
|
||||||
|
|
||||||
|
当前查询: "{query_text}"
|
||||||
|
已知的对话参与者: {participant_preview}
|
||||||
|
机器人设定: {persona}
|
||||||
|
|
||||||
|
请输出符合要求的JSON,禁止添加额外说明或Markdown代码块。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||||
|
if not response:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stripped = response.strip()
|
||||||
|
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
|
||||||
|
if code_block_match:
|
||||||
|
candidate = code_block_match.group(1).strip()
|
||||||
|
if candidate:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
start = stripped.find("{")
|
||||||
|
end = stripped.rfind("}")
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
return stripped[start:end + 1]
|
||||||
|
|
||||||
|
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_str(value: Any) -> str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.strip()
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
return str(value).strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_int(value: Any, default: int) -> int:
|
||||||
|
try:
|
||||||
|
number = int(value)
|
||||||
|
if number <= 0:
|
||||||
|
return default
|
||||||
|
return number
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
@@ -25,6 +25,7 @@ class IndexType(Enum):
|
|||||||
"""索引类型"""
|
"""索引类型"""
|
||||||
MEMORY_TYPE = "memory_type" # 记忆类型索引
|
MEMORY_TYPE = "memory_type" # 记忆类型索引
|
||||||
USER_ID = "user_id" # 用户ID索引
|
USER_ID = "user_id" # 用户ID索引
|
||||||
|
SUBJECT = "subject" # 主体索引
|
||||||
KEYWORD = "keyword" # 关键词索引
|
KEYWORD = "keyword" # 关键词索引
|
||||||
TAG = "tag" # 标签索引
|
TAG = "tag" # 标签索引
|
||||||
CATEGORY = "category" # 分类索引
|
CATEGORY = "category" # 分类索引
|
||||||
@@ -41,6 +42,7 @@ class IndexQuery:
|
|||||||
"""索引查询条件"""
|
"""索引查询条件"""
|
||||||
user_ids: Optional[List[str]] = None
|
user_ids: Optional[List[str]] = None
|
||||||
memory_types: Optional[List[MemoryType]] = None
|
memory_types: Optional[List[MemoryType]] = None
|
||||||
|
subjects: Optional[List[str]] = None
|
||||||
keywords: Optional[List[str]] = None
|
keywords: Optional[List[str]] = None
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
categories: Optional[List[str]] = None
|
categories: Optional[List[str]] = None
|
||||||
@@ -76,6 +78,7 @@ class MetadataIndexManager:
|
|||||||
self.indices = {
|
self.indices = {
|
||||||
IndexType.MEMORY_TYPE: defaultdict(set),
|
IndexType.MEMORY_TYPE: defaultdict(set),
|
||||||
IndexType.USER_ID: defaultdict(set),
|
IndexType.USER_ID: defaultdict(set),
|
||||||
|
IndexType.SUBJECT: defaultdict(set),
|
||||||
IndexType.KEYWORD: defaultdict(set),
|
IndexType.KEYWORD: defaultdict(set),
|
||||||
IndexType.TAG: defaultdict(set),
|
IndexType.TAG: defaultdict(set),
|
||||||
IndexType.CATEGORY: defaultdict(set),
|
IndexType.CATEGORY: defaultdict(set),
|
||||||
@@ -110,6 +113,41 @@ class MetadataIndexManager:
|
|||||||
self.auto_save_interval = 500 # 每500次操作自动保存
|
self.auto_save_interval = 500 # 每500次操作自动保存
|
||||||
self._operation_count = 0
|
self._operation_count = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_index_key(index_type: IndexType, key: Any) -> str:
|
||||||
|
"""将索引键序列化为字符串以便存储"""
|
||||||
|
if isinstance(key, Enum):
|
||||||
|
value = key.value
|
||||||
|
else:
|
||||||
|
value = key
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deserialize_index_key(index_type: IndexType, key: str) -> Any:
|
||||||
|
"""根据索引类型反序列化索引键"""
|
||||||
|
try:
|
||||||
|
if index_type == IndexType.MEMORY_TYPE:
|
||||||
|
return MemoryType(key)
|
||||||
|
if index_type == IndexType.CONFIDENCE:
|
||||||
|
return ConfidenceLevel(int(key))
|
||||||
|
if index_type == IndexType.IMPORTANCE:
|
||||||
|
return ImportanceLevel(int(key))
|
||||||
|
# 其他索引键默认使用原始字符串(可能已经是lower后的字符串)
|
||||||
|
return key
|
||||||
|
except Exception:
|
||||||
|
logger.warning("无法反序列化索引键 %s 在索引 %s 中,使用原始字符串", key, index_type.value)
|
||||||
|
return key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
serialized = {}
|
||||||
|
for field_name, value in metadata.items():
|
||||||
|
if isinstance(value, Enum):
|
||||||
|
serialized[field_name] = value.value
|
||||||
|
else:
|
||||||
|
serialized[field_name] = value
|
||||||
|
return serialized
|
||||||
|
|
||||||
async def index_memories(self, memories: List[MemoryChunk]):
|
async def index_memories(self, memories: List[MemoryChunk]):
|
||||||
"""为记忆建立索引"""
|
"""为记忆建立索引"""
|
||||||
if not memories:
|
if not memories:
|
||||||
@@ -142,6 +180,68 @@ class MetadataIndexManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 元数据索引失败: {e}", exc_info=True)
|
logger.error(f"❌ 元数据索引失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def update_memory_entry(self, memory: MemoryChunk):
|
||||||
|
"""更新已存在记忆的索引信息"""
|
||||||
|
if not memory:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
entry = self.memory_metadata_cache.get(memory.memory_id)
|
||||||
|
if entry is None:
|
||||||
|
# 若不存在则作为新记忆索引
|
||||||
|
self._index_single_memory(memory)
|
||||||
|
return
|
||||||
|
|
||||||
|
old_confidence = entry.get("confidence")
|
||||||
|
old_importance = entry.get("importance")
|
||||||
|
old_semantic_hash = entry.get("semantic_hash")
|
||||||
|
|
||||||
|
entry.update(
|
||||||
|
{
|
||||||
|
"user_id": memory.user_id,
|
||||||
|
"memory_type": memory.memory_type,
|
||||||
|
"created_at": memory.metadata.created_at,
|
||||||
|
"last_accessed": memory.metadata.last_accessed,
|
||||||
|
"access_count": memory.metadata.access_count,
|
||||||
|
"confidence": memory.metadata.confidence,
|
||||||
|
"importance": memory.metadata.importance,
|
||||||
|
"relationship_score": memory.metadata.relationship_score,
|
||||||
|
"relevance_score": memory.metadata.relevance_score,
|
||||||
|
"semantic_hash": memory.semantic_hash,
|
||||||
|
"subjects": memory.subjects,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新置信度/重要性索引
|
||||||
|
if isinstance(old_confidence, ConfidenceLevel):
|
||||||
|
self.indices[IndexType.CONFIDENCE][old_confidence].discard(memory.memory_id)
|
||||||
|
if isinstance(old_importance, ImportanceLevel):
|
||||||
|
self.indices[IndexType.IMPORTANCE][old_importance].discard(memory.memory_id)
|
||||||
|
if isinstance(old_semantic_hash, str):
|
||||||
|
self.indices[IndexType.SEMANTIC_HASH][old_semantic_hash].discard(memory.memory_id)
|
||||||
|
|
||||||
|
self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory.memory_id)
|
||||||
|
self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory.memory_id)
|
||||||
|
if memory.semantic_hash:
|
||||||
|
self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory.memory_id)
|
||||||
|
|
||||||
|
# 同步关键词/标签/分类索引
|
||||||
|
for keyword in memory.keywords:
|
||||||
|
if keyword:
|
||||||
|
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory.memory_id)
|
||||||
|
|
||||||
|
for tag in memory.tags:
|
||||||
|
if tag:
|
||||||
|
self.indices[IndexType.TAG][tag.lower()].add(memory.memory_id)
|
||||||
|
|
||||||
|
for category in memory.categories:
|
||||||
|
if category:
|
||||||
|
self.indices[IndexType.CATEGORY][category.lower()].add(memory.memory_id)
|
||||||
|
|
||||||
|
for subject in memory.subjects:
|
||||||
|
if subject:
|
||||||
|
self.indices[IndexType.SUBJECT][subject.strip().lower()].add(memory.memory_id)
|
||||||
|
|
||||||
def _index_single_memory(self, memory: MemoryChunk):
|
def _index_single_memory(self, memory: MemoryChunk):
|
||||||
"""为单个记忆建立索引"""
|
"""为单个记忆建立索引"""
|
||||||
memory_id = memory.memory_id
|
memory_id = memory.memory_id
|
||||||
@@ -157,7 +257,8 @@ class MetadataIndexManager:
|
|||||||
"importance": memory.metadata.importance,
|
"importance": memory.metadata.importance,
|
||||||
"relationship_score": memory.metadata.relationship_score,
|
"relationship_score": memory.metadata.relationship_score,
|
||||||
"relevance_score": memory.metadata.relevance_score,
|
"relevance_score": memory.metadata.relevance_score,
|
||||||
"semantic_hash": memory.semantic_hash
|
"semantic_hash": memory.semantic_hash,
|
||||||
|
"subjects": memory.subjects
|
||||||
}
|
}
|
||||||
|
|
||||||
# 记忆类型索引
|
# 记忆类型索引
|
||||||
@@ -166,6 +267,12 @@ class MetadataIndexManager:
|
|||||||
# 用户ID索引
|
# 用户ID索引
|
||||||
self.indices[IndexType.USER_ID][memory.user_id].add(memory_id)
|
self.indices[IndexType.USER_ID][memory.user_id].add(memory_id)
|
||||||
|
|
||||||
|
# 主体索引
|
||||||
|
for subject in memory.subjects:
|
||||||
|
normalized = subject.strip().lower()
|
||||||
|
if normalized:
|
||||||
|
self.indices[IndexType.SUBJECT][normalized].add(memory_id)
|
||||||
|
|
||||||
# 关键词索引
|
# 关键词索引
|
||||||
for keyword in memory.keywords:
|
for keyword in memory.keywords:
|
||||||
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id)
|
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id)
|
||||||
@@ -282,13 +389,6 @@ class MetadataIndexManager:
|
|||||||
# 应用最严格的过滤条件
|
# 应用最严格的过滤条件
|
||||||
applied_filters = []
|
applied_filters = []
|
||||||
|
|
||||||
if query.user_ids:
|
|
||||||
user_ids_set = set()
|
|
||||||
for user_id in query.user_ids:
|
|
||||||
user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set()))
|
|
||||||
candidate_ids.update(user_ids_set)
|
|
||||||
applied_filters.append("user_ids")
|
|
||||||
|
|
||||||
if query.memory_types:
|
if query.memory_types:
|
||||||
memory_types_set = set()
|
memory_types_set = set()
|
||||||
for memory_type in query.memory_types:
|
for memory_type in query.memory_types:
|
||||||
@@ -302,7 +402,7 @@ class MetadataIndexManager:
|
|||||||
if query.keywords:
|
if query.keywords:
|
||||||
keywords_set = set()
|
keywords_set = set()
|
||||||
for keyword in query.keywords:
|
for keyword in query.keywords:
|
||||||
keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set()))
|
keywords_set.update(self._collect_index_matches(IndexType.KEYWORD, keyword))
|
||||||
if applied_filters:
|
if applied_filters:
|
||||||
candidate_ids &= keywords_set
|
candidate_ids &= keywords_set
|
||||||
else:
|
else:
|
||||||
@@ -329,12 +429,55 @@ class MetadataIndexManager:
|
|||||||
candidate_ids.update(categories_set)
|
candidate_ids.update(categories_set)
|
||||||
applied_filters.append("categories")
|
applied_filters.append("categories")
|
||||||
|
|
||||||
|
if query.subjects:
|
||||||
|
subjects_set = set()
|
||||||
|
for subject in query.subjects:
|
||||||
|
subjects_set.update(self._collect_index_matches(IndexType.SUBJECT, subject))
|
||||||
|
if applied_filters:
|
||||||
|
candidate_ids &= subjects_set
|
||||||
|
else:
|
||||||
|
candidate_ids.update(subjects_set)
|
||||||
|
applied_filters.append("subjects")
|
||||||
|
|
||||||
# 如果没有应用任何过滤条件,返回所有记忆
|
# 如果没有应用任何过滤条件,返回所有记忆
|
||||||
if not applied_filters:
|
if not applied_filters:
|
||||||
return all_memory_ids
|
return all_memory_ids
|
||||||
|
|
||||||
return candidate_ids
|
return candidate_ids
|
||||||
|
|
||||||
|
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
|
||||||
|
"""根据给定token收集索引匹配,支持部分匹配"""
|
||||||
|
mapping = self.indices.get(index_type)
|
||||||
|
if mapping is None:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
key = ""
|
||||||
|
if isinstance(token, Enum):
|
||||||
|
key = str(token.value).strip().lower()
|
||||||
|
elif isinstance(token, str):
|
||||||
|
key = token.strip().lower()
|
||||||
|
elif token is not None:
|
||||||
|
key = str(token).strip().lower()
|
||||||
|
|
||||||
|
if not key:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
matches: Set[str] = set(mapping.get(key, set()))
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
return set(matches)
|
||||||
|
|
||||||
|
for existing_key, ids in mapping.items():
|
||||||
|
if not existing_key or not isinstance(existing_key, str):
|
||||||
|
continue
|
||||||
|
normalized = existing_key.strip().lower()
|
||||||
|
if not normalized:
|
||||||
|
continue
|
||||||
|
if key in normalized or normalized in key:
|
||||||
|
matches.update(ids)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
|
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
|
||||||
"""应用过滤条件"""
|
"""应用过滤条件"""
|
||||||
filtered_ids = list(candidate_ids)
|
filtered_ids = list(candidate_ids)
|
||||||
@@ -440,10 +583,10 @@ class MetadataIndexManager:
|
|||||||
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
|
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
|
||||||
"""获取应用的过滤器列表"""
|
"""获取应用的过滤器列表"""
|
||||||
filters = []
|
filters = []
|
||||||
if query.user_ids:
|
|
||||||
filters.append("user_ids")
|
|
||||||
if query.memory_types:
|
if query.memory_types:
|
||||||
filters.append("memory_types")
|
filters.append("memory_types")
|
||||||
|
if query.subjects:
|
||||||
|
filters.append("subjects")
|
||||||
if query.keywords:
|
if query.keywords:
|
||||||
filters.append("keywords")
|
filters.append("keywords")
|
||||||
if query.tags:
|
if query.tags:
|
||||||
@@ -502,6 +645,18 @@ class MetadataIndexManager:
|
|||||||
# 从各类索引中移除
|
# 从各类索引中移除
|
||||||
self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id)
|
self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id)
|
||||||
self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id)
|
self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id)
|
||||||
|
subjects = metadata.get("subjects") or []
|
||||||
|
for subject in subjects:
|
||||||
|
if not isinstance(subject, str):
|
||||||
|
continue
|
||||||
|
normalized = subject.strip().lower()
|
||||||
|
if not normalized:
|
||||||
|
continue
|
||||||
|
subject_bucket = self.indices[IndexType.SUBJECT].get(normalized)
|
||||||
|
if subject_bucket is not None:
|
||||||
|
subject_bucket.discard(memory_id)
|
||||||
|
if not subject_bucket:
|
||||||
|
self.indices[IndexType.SUBJECT].pop(normalized, None)
|
||||||
|
|
||||||
# 从时间索引中移除
|
# 从时间索引中移除
|
||||||
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id]
|
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id]
|
||||||
@@ -625,11 +780,13 @@ class MetadataIndexManager:
|
|||||||
logger.info("正在保存元数据索引...")
|
logger.info("正在保存元数据索引...")
|
||||||
|
|
||||||
# 保存各类索引
|
# 保存各类索引
|
||||||
indices_data = {}
|
indices_data: Dict[str, Dict[str, List[str]]] = {}
|
||||||
for index_type, index_data in self.indices.items():
|
for index_type, index_data in self.indices.items():
|
||||||
indices_data[index_type.value] = {
|
serialized_index = {}
|
||||||
key: list(values) for key, values in index_data.items()
|
for key, values in index_data.items():
|
||||||
}
|
serialized_key = self._serialize_index_key(index_type, key)
|
||||||
|
serialized_index[serialized_key] = list(values)
|
||||||
|
indices_data[index_type.value] = serialized_index
|
||||||
|
|
||||||
indices_file = self.index_path / "indices.json"
|
indices_file = self.index_path / "indices.json"
|
||||||
with open(indices_file, 'w', encoding='utf-8') as f:
|
with open(indices_file, 'w', encoding='utf-8') as f:
|
||||||
@@ -652,8 +809,12 @@ class MetadataIndexManager:
|
|||||||
|
|
||||||
# 保存元数据缓存
|
# 保存元数据缓存
|
||||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||||
|
metadata_serialized = {
|
||||||
|
memory_id: self._serialize_metadata_entry(metadata)
|
||||||
|
for memory_id, metadata in self.memory_metadata_cache.items()
|
||||||
|
}
|
||||||
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
|
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||||
|
|
||||||
# 保存统计信息
|
# 保存统计信息
|
||||||
stats_file = self.index_path / "index_stats.json"
|
stats_file = self.index_path / "index_stats.json"
|
||||||
@@ -679,9 +840,11 @@ class MetadataIndexManager:
|
|||||||
|
|
||||||
for index_type_value, index_data in indices_data.items():
|
for index_type_value, index_data in indices_data.items():
|
||||||
index_type = IndexType(index_type_value)
|
index_type = IndexType(index_type_value)
|
||||||
self.indices[index_type] = {
|
restored_index = defaultdict(set)
|
||||||
key: set(values) for key, values in index_data.items()
|
for key_str, values in index_data.items():
|
||||||
}
|
restored_key = self._deserialize_index_key(index_type, key_str)
|
||||||
|
restored_index[restored_key] = set(values)
|
||||||
|
self.indices[index_type] = restored_index
|
||||||
|
|
||||||
# 加载时间索引
|
# 加载时间索引
|
||||||
time_index_file = self.index_path / "time_index.json"
|
time_index_file = self.index_path / "time_index.json"
|
||||||
@@ -709,10 +872,38 @@ class MetadataIndexManager:
|
|||||||
|
|
||||||
# 转换置信度和重要性为枚举类型
|
# 转换置信度和重要性为枚举类型
|
||||||
for memory_id, metadata in cache_data.items():
|
for memory_id, metadata in cache_data.items():
|
||||||
if isinstance(metadata["confidence"], str):
|
memory_type_value = metadata.get("memory_type")
|
||||||
metadata["confidence"] = ConfidenceLevel(metadata["confidence"])
|
if isinstance(memory_type_value, str):
|
||||||
if isinstance(metadata["importance"], str):
|
try:
|
||||||
metadata["importance"] = ImportanceLevel(metadata["importance"])
|
metadata["memory_type"] = MemoryType(memory_type_value)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("无法解析memory_type %s", memory_type_value)
|
||||||
|
|
||||||
|
confidence_value = metadata.get("confidence")
|
||||||
|
if isinstance(confidence_value, (str, int)):
|
||||||
|
try:
|
||||||
|
metadata["confidence"] = ConfidenceLevel(int(confidence_value))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("无法解析confidence %s", confidence_value)
|
||||||
|
|
||||||
|
importance_value = metadata.get("importance")
|
||||||
|
if isinstance(importance_value, (str, int)):
|
||||||
|
try:
|
||||||
|
metadata["importance"] = ImportanceLevel(int(importance_value))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("无法解析importance %s", importance_value)
|
||||||
|
|
||||||
|
subjects_value = metadata.get("subjects")
|
||||||
|
if isinstance(subjects_value, str):
|
||||||
|
metadata["subjects"] = [subjects_value]
|
||||||
|
elif isinstance(subjects_value, list):
|
||||||
|
cleaned_subjects = []
|
||||||
|
for item in subjects_value:
|
||||||
|
if isinstance(item, str) and item.strip():
|
||||||
|
cleaned_subjects.append(item.strip())
|
||||||
|
metadata["subjects"] = cleaned_subjects
|
||||||
|
else:
|
||||||
|
metadata["subjects"] = []
|
||||||
|
|
||||||
self.memory_metadata_cache = cache_data
|
self.memory_metadata_cache = cache_data
|
||||||
|
|
||||||
|
|||||||
@@ -203,11 +203,17 @@ class MultiStageRetrieval:
|
|||||||
try:
|
try:
|
||||||
from .metadata_index import IndexQuery
|
from .metadata_index import IndexQuery
|
||||||
|
|
||||||
# 构建索引查询
|
query_plan = context.get("query_plan")
|
||||||
|
|
||||||
|
memory_types = self._extract_memory_types_from_context(context)
|
||||||
|
keywords = self._extract_keywords_from_query(query, query_plan)
|
||||||
|
subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
|
||||||
|
|
||||||
index_query = IndexQuery(
|
index_query = IndexQuery(
|
||||||
user_ids=[user_id],
|
user_ids=None,
|
||||||
memory_types=self._extract_memory_types_from_context(context),
|
memory_types=memory_types,
|
||||||
keywords=self._extract_keywords_from_query(query),
|
subjects=subjects,
|
||||||
|
keywords=keywords,
|
||||||
limit=self.config.metadata_filter_limit,
|
limit=self.config.metadata_filter_limit,
|
||||||
sort_by="last_accessed",
|
sort_by="last_accessed",
|
||||||
sort_order="desc"
|
sort_order="desc"
|
||||||
@@ -215,13 +221,66 @@ class MultiStageRetrieval:
|
|||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
result = await metadata_index.query_memories(index_query)
|
result = await metadata_index.query_memories(index_query)
|
||||||
filtered_count = result.total_count - len(result.memory_ids)
|
result_ids = list(result.memory_ids)
|
||||||
|
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
|
||||||
|
|
||||||
logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆")
|
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
|
||||||
|
if not result_ids:
|
||||||
|
sorted_ids = sorted(
|
||||||
|
(memory.memory_id for memory in all_memories_cache.values()),
|
||||||
|
key=lambda mid: all_memories_cache[mid].metadata.last_accessed if mid in all_memories_cache else 0,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
if memory_types:
|
||||||
|
type_filtered = [
|
||||||
|
mid for mid in sorted_ids
|
||||||
|
if all_memories_cache[mid].memory_type in memory_types
|
||||||
|
]
|
||||||
|
sorted_ids = type_filtered or sorted_ids
|
||||||
|
if subjects:
|
||||||
|
subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()]
|
||||||
|
if subject_candidates:
|
||||||
|
subject_filtered = [
|
||||||
|
mid for mid in sorted_ids
|
||||||
|
if any(
|
||||||
|
subj.strip().lower() in subject_candidates
|
||||||
|
for subj in all_memories_cache[mid].subjects
|
||||||
|
)
|
||||||
|
]
|
||||||
|
sorted_ids = subject_filtered or sorted_ids
|
||||||
|
|
||||||
|
if keywords:
|
||||||
|
keyword_pool = {kw.lower() for kw in keywords if isinstance(kw, str) and kw.strip()}
|
||||||
|
if keyword_pool:
|
||||||
|
keyword_filtered = []
|
||||||
|
for mid in sorted_ids:
|
||||||
|
memory_text = (
|
||||||
|
(all_memories_cache[mid].display or "")
|
||||||
|
+ "\n"
|
||||||
|
+ (all_memories_cache[mid].text_content or "")
|
||||||
|
).lower()
|
||||||
|
if any(kw in memory_text for kw in keyword_pool):
|
||||||
|
keyword_filtered.append(mid)
|
||||||
|
sorted_ids = keyword_filtered or sorted_ids
|
||||||
|
|
||||||
|
result_ids = sorted_ids[: self.config.metadata_filter_limit]
|
||||||
|
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
|
||||||
|
logger.debug(
|
||||||
|
"元数据过滤未命中索引,使用近似回退: types=%s, subjects=%s, keywords=%s",
|
||||||
|
bool(memory_types),
|
||||||
|
bool(subjects),
|
||||||
|
bool(keywords),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"元数据过滤:候选=%d, 返回=%d",
|
||||||
|
len(all_memories_cache),
|
||||||
|
len(result_ids),
|
||||||
|
)
|
||||||
|
|
||||||
return StageResult(
|
return StageResult(
|
||||||
stage=RetrievalStage.METADATA_FILTERING,
|
stage=RetrievalStage.METADATA_FILTERING,
|
||||||
memory_ids=result.memory_ids,
|
memory_ids=result_ids,
|
||||||
processing_time=time.time() - start_time,
|
processing_time=time.time() - start_time,
|
||||||
filtered_count=filtered_count,
|
filtered_count=filtered_count,
|
||||||
score_threshold=0.0
|
score_threshold=0.0
|
||||||
@@ -251,7 +310,7 @@ class MultiStageRetrieval:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 生成查询向量
|
# 生成查询向量
|
||||||
query_embedding = await self._generate_query_embedding(query, context)
|
query_embedding = await self._generate_query_embedding(query, context, vector_storage)
|
||||||
|
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return StageResult(
|
return StageResult(
|
||||||
@@ -263,22 +322,24 @@ class MultiStageRetrieval:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 执行向量搜索
|
# 执行向量搜索
|
||||||
search_result = await vector_storage.search_similar(
|
search_result = await vector_storage.search_similar_memories(
|
||||||
query_embedding,
|
query_vector=query_embedding,
|
||||||
limit=self.config.vector_search_limit
|
limit=self.config.vector_search_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
candidate_pool = candidate_ids or set(all_memories_cache.keys())
|
||||||
|
|
||||||
# 过滤候选记忆
|
# 过滤候选记忆
|
||||||
filtered_memories = []
|
filtered_memories = []
|
||||||
for memory_id, similarity in search_result:
|
for memory_id, similarity in search_result:
|
||||||
if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold:
|
if memory_id in candidate_pool and similarity >= self.config.vector_similarity_threshold:
|
||||||
filtered_memories.append((memory_id, similarity))
|
filtered_memories.append((memory_id, similarity))
|
||||||
|
|
||||||
# 按相似度排序
|
# 按相似度排序
|
||||||
filtered_memories.sort(key=lambda x: x[1], reverse=True)
|
filtered_memories.sort(key=lambda x: x[1], reverse=True)
|
||||||
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
|
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
|
||||||
|
|
||||||
filtered_count = len(candidate_ids) - len(result_ids)
|
filtered_count = max(0, len(candidate_pool) - len(result_ids))
|
||||||
|
|
||||||
logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
|
logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
|
||||||
|
|
||||||
@@ -407,12 +468,20 @@ class MultiStageRetrieval:
|
|||||||
score_threshold=0.0
|
score_threshold=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]:
|
async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]:
|
||||||
"""生成查询向量"""
|
"""生成查询向量"""
|
||||||
try:
|
try:
|
||||||
# 这里应该调用embedding模型
|
query_plan = context.get("query_plan")
|
||||||
# 由于我们可能没有直接的embedding模型,返回None或使用简单的方法
|
query_text = query
|
||||||
# 在实际实现中,这里应该调用与记忆存储相同的embedding模型
|
if query_plan and getattr(query_plan, "semantic_query", None):
|
||||||
|
query_text = query_plan.semantic_query
|
||||||
|
|
||||||
|
if not query_text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if hasattr(vector_storage, "generate_query_embedding"):
|
||||||
|
return await vector_storage.generate_query_embedding(query_text)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"生成查询向量失败: {e}")
|
logger.warning(f"生成查询向量失败: {e}")
|
||||||
@@ -421,9 +490,15 @@ class MultiStageRetrieval:
|
|||||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||||
"""计算语义相似度"""
|
"""计算语义相似度"""
|
||||||
try:
|
try:
|
||||||
|
query_plan = context.get("query_plan")
|
||||||
|
query_text = query
|
||||||
|
if query_plan and getattr(query_plan, "semantic_query", None):
|
||||||
|
query_text = query_plan.semantic_query
|
||||||
|
|
||||||
# 简单的文本相似度计算
|
# 简单的文本相似度计算
|
||||||
query_words = set(query.lower().split())
|
query_words = set(query_text.lower().split())
|
||||||
memory_words = set(memory.text_content.lower().split())
|
memory_text = (memory.display or memory.text_content or "").lower()
|
||||||
|
memory_words = set(memory_text.split())
|
||||||
|
|
||||||
if not query_words or not memory_words:
|
if not query_words or not memory_words:
|
||||||
return 0.0
|
return 0.0
|
||||||
@@ -443,10 +518,15 @@ class MultiStageRetrieval:
|
|||||||
try:
|
try:
|
||||||
score = 0.0
|
score = 0.0
|
||||||
|
|
||||||
|
query_plan = context.get("query_plan")
|
||||||
|
|
||||||
# 检查记忆类型是否匹配上下文
|
# 检查记忆类型是否匹配上下文
|
||||||
if context.get("expected_memory_types"):
|
if context.get("expected_memory_types"):
|
||||||
if memory.memory_type in context["expected_memory_types"]:
|
if memory.memory_type in context["expected_memory_types"]:
|
||||||
score += 0.3
|
score += 0.3
|
||||||
|
elif query_plan and getattr(query_plan, "memory_types", None):
|
||||||
|
if memory.memory_type in query_plan.memory_types:
|
||||||
|
score += 0.3
|
||||||
|
|
||||||
# 检查关键词匹配
|
# 检查关键词匹配
|
||||||
if context.get("keywords"):
|
if context.get("keywords"):
|
||||||
@@ -456,6 +536,35 @@ class MultiStageRetrieval:
|
|||||||
if overlap:
|
if overlap:
|
||||||
score += len(overlap) / max(len(context_keywords), 1) * 0.4
|
score += len(overlap) / max(len(context_keywords), 1) * 0.4
|
||||||
|
|
||||||
|
if query_plan:
|
||||||
|
# 主体匹配
|
||||||
|
subject_score = self._calculate_subject_overlap(memory, getattr(query_plan, "subject_includes", []))
|
||||||
|
score += subject_score * 0.3
|
||||||
|
|
||||||
|
# 对象/描述匹配
|
||||||
|
object_keywords = getattr(query_plan, "object_includes", []) or []
|
||||||
|
if object_keywords:
|
||||||
|
display_text = (memory.display or memory.text_content or "").lower()
|
||||||
|
hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
|
||||||
|
if hits:
|
||||||
|
score += min(0.3, hits * 0.1)
|
||||||
|
|
||||||
|
optional_keywords = getattr(query_plan, "optional_keywords", []) or []
|
||||||
|
if optional_keywords:
|
||||||
|
display_text = (memory.display or memory.text_content or "").lower()
|
||||||
|
hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
|
||||||
|
if hits:
|
||||||
|
score += min(0.2, hits * 0.05)
|
||||||
|
|
||||||
|
# 时间偏好
|
||||||
|
recency_pref = getattr(query_plan, "recency_preference", "")
|
||||||
|
if recency_pref:
|
||||||
|
memory_age = time.time() - memory.metadata.created_at
|
||||||
|
if recency_pref == "recent" and memory_age < 7 * 24 * 3600:
|
||||||
|
score += 0.2
|
||||||
|
elif recency_pref == "historical" and memory_age > 30 * 24 * 3600:
|
||||||
|
score += 0.1
|
||||||
|
|
||||||
# 检查时效性
|
# 检查时效性
|
||||||
if context.get("recent_only", False):
|
if context.get("recent_only", False):
|
||||||
memory_age = time.time() - memory.metadata.created_at
|
memory_age = time.time() - memory.metadata.created_at
|
||||||
@@ -471,6 +580,8 @@ class MultiStageRetrieval:
|
|||||||
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
|
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
|
||||||
"""计算最终评分"""
|
"""计算最终评分"""
|
||||||
try:
|
try:
|
||||||
|
query_plan = context.get("query_plan")
|
||||||
|
|
||||||
# 语义相似度
|
# 语义相似度
|
||||||
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
|
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
|
||||||
|
|
||||||
@@ -482,13 +593,29 @@ class MultiStageRetrieval:
|
|||||||
|
|
||||||
# 时效性评分
|
# 时效性评分
|
||||||
recency_score = self._calculate_recency_score(memory.metadata.created_at)
|
recency_score = self._calculate_recency_score(memory.metadata.created_at)
|
||||||
|
if query_plan:
|
||||||
|
recency_pref = getattr(query_plan, "recency_preference", "")
|
||||||
|
if recency_pref == "recent":
|
||||||
|
recency_score = max(recency_score, 0.8)
|
||||||
|
elif recency_pref == "historical":
|
||||||
|
recency_score = min(recency_score, 0.5)
|
||||||
|
|
||||||
# 权重组合
|
# 权重组合
|
||||||
|
vector_weight = self.config.vector_weight
|
||||||
|
semantic_weight = self.config.semantic_weight
|
||||||
|
context_weight = self.config.context_weight
|
||||||
|
recency_weight = self.config.recency_weight
|
||||||
|
|
||||||
|
if query_plan and getattr(query_plan, "emphasis", None) == "precision":
|
||||||
|
semantic_weight += 0.05
|
||||||
|
elif query_plan and getattr(query_plan, "emphasis", None) == "recall":
|
||||||
|
context_weight += 0.05
|
||||||
|
|
||||||
final_score = (
|
final_score = (
|
||||||
semantic_score * self.config.semantic_weight +
|
semantic_score * semantic_weight +
|
||||||
vector_score * self.config.vector_weight +
|
vector_score * vector_weight +
|
||||||
context_score * self.config.context_weight +
|
context_score * context_weight +
|
||||||
recency_score * self.config.recency_weight
|
recency_score * recency_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
# 加入记忆重要性权重
|
# 加入记忆重要性权重
|
||||||
@@ -501,6 +628,31 @@ class MultiStageRetrieval:
|
|||||||
logger.warning(f"计算最终评分失败: {e}")
|
logger.warning(f"计算最终评分失败: {e}")
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
|
||||||
|
if not required_subjects:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
memory_subjects = {subject.lower() for subject in memory.subjects if isinstance(subject, str)}
|
||||||
|
if not memory_subjects:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
hit = 0
|
||||||
|
total = 0
|
||||||
|
for subject in required_subjects:
|
||||||
|
if not isinstance(subject, str):
|
||||||
|
continue
|
||||||
|
total += 1
|
||||||
|
normalized = subject.strip().lower()
|
||||||
|
if not normalized:
|
||||||
|
continue
|
||||||
|
if any(normalized in mem_subject for mem_subject in memory_subjects):
|
||||||
|
hit += 1
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return hit / total
|
||||||
|
|
||||||
def _calculate_recency_score(self, timestamp: float) -> float:
|
def _calculate_recency_score(self, timestamp: float) -> float:
|
||||||
"""计算时效性评分"""
|
"""计算时效性评分"""
|
||||||
try:
|
try:
|
||||||
@@ -524,6 +676,10 @@ class MultiStageRetrieval:
|
|||||||
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
|
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
|
||||||
"""从上下文中提取记忆类型"""
|
"""从上下文中提取记忆类型"""
|
||||||
try:
|
try:
|
||||||
|
query_plan = context.get("query_plan")
|
||||||
|
if query_plan and getattr(query_plan, "memory_types", None):
|
||||||
|
return query_plan.memory_types
|
||||||
|
|
||||||
if "expected_memory_types" in context:
|
if "expected_memory_types" in context:
|
||||||
return context["expected_memory_types"]
|
return context["expected_memory_types"]
|
||||||
|
|
||||||
@@ -544,15 +700,30 @@ class MultiStageRetrieval:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _extract_keywords_from_query(self, query: str) -> List[str]:
|
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
|
||||||
"""从查询中提取关键词"""
|
"""从查询中提取关键词"""
|
||||||
try:
|
try:
|
||||||
|
extracted: List[str] = []
|
||||||
|
|
||||||
|
if query_plan and getattr(query_plan, "required_keywords", None):
|
||||||
|
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
|
||||||
|
|
||||||
# 简单的关键词提取
|
# 简单的关键词提取
|
||||||
words = query.lower().split()
|
words = query.lower().split()
|
||||||
# 过滤停用词
|
# 过滤停用词
|
||||||
stopwords = {"的", "是", "在", "有", "我", "你", "他", "她", "它", "这", "那", "了", "吗", "呢"}
|
stopwords = {"的", "是", "在", "有", "我", "你", "他", "她", "它", "这", "那", "了", "吗", "呢"}
|
||||||
keywords = [word for word in words if len(word) > 1 and word not in stopwords]
|
extracted.extend(word for word in words if len(word) > 1 and word not in stopwords)
|
||||||
return keywords[:10] # 最多返回10个关键词
|
|
||||||
|
# 去重并保留顺序
|
||||||
|
seen = set()
|
||||||
|
deduplicated = []
|
||||||
|
for word in extracted:
|
||||||
|
if word in seen or not word:
|
||||||
|
continue
|
||||||
|
seen.add(word)
|
||||||
|
deduplicated.append(word)
|
||||||
|
|
||||||
|
return deduplicated[:10]
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from pathlib import Path
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config, global_config
|
||||||
|
from src.common.config_helpers import resolve_embedding_dimension
|
||||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -36,12 +37,12 @@ except ImportError:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class VectorStorageConfig:
|
class VectorStorageConfig:
|
||||||
"""向量存储配置"""
|
"""向量存储配置"""
|
||||||
dimension: int = 768
|
dimension: int = 1024
|
||||||
similarity_threshold: float = 0.8
|
similarity_threshold: float = 0.8
|
||||||
index_type: str = "flat" # flat, ivf, hnsw
|
index_type: str = "flat" # flat, ivf, hnsw
|
||||||
max_index_size: int = 100000
|
max_index_size: int = 100000
|
||||||
storage_path: str = "data/memory_vectors"
|
storage_path: str = "data/memory_vectors"
|
||||||
auto_save_interval: int = 100 # 每N次操作自动保存
|
auto_save_interval: int = 10 # 每N次操作自动保存
|
||||||
enable_compression: bool = True
|
enable_compression: bool = True
|
||||||
|
|
||||||
|
|
||||||
@@ -50,6 +51,15 @@ class VectorStorageManager:
|
|||||||
|
|
||||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||||
self.config = config or VectorStorageConfig()
|
self.config = config or VectorStorageConfig()
|
||||||
|
|
||||||
|
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
|
||||||
|
if resolved_dimension and resolved_dimension != self.config.dimension:
|
||||||
|
logger.info(
|
||||||
|
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
|
||||||
|
resolved_dimension,
|
||||||
|
self.config.dimension,
|
||||||
|
)
|
||||||
|
self.config.dimension = resolved_dimension
|
||||||
self.storage_path = Path(self.config.storage_path)
|
self.storage_path = Path(self.config.storage_path)
|
||||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -117,6 +127,32 @@ class VectorStorageManager:
|
|||||||
)
|
)
|
||||||
logger.info("✅ 嵌入模型初始化完成")
|
logger.info("✅ 嵌入模型初始化完成")
|
||||||
|
|
||||||
|
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
|
||||||
|
"""生成查询向量,用于记忆召回"""
|
||||||
|
if not query_text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.initialize_embedding_model()
|
||||||
|
|
||||||
|
embedding, _ = await self.embedding_model.get_embedding(query_text)
|
||||||
|
if not embedding:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(embedding) != self.config.dimension:
|
||||||
|
logger.warning(
|
||||||
|
"查询向量维度不匹配: 期望 %d, 实际 %d",
|
||||||
|
self.config.dimension,
|
||||||
|
len(embedding)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._normalize_vector(embedding)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
async def store_memories(self, memories: List[MemoryChunk]):
|
async def store_memories(self, memories: List[MemoryChunk]):
|
||||||
"""存储记忆向量"""
|
"""存储记忆向量"""
|
||||||
if not memories:
|
if not memories:
|
||||||
@@ -213,7 +249,7 @@ class VectorStorageManager:
|
|||||||
results[memory_id] = embedding
|
results[memory_id] = embedding
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)",
|
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
|
||||||
self.config.dimension,
|
self.config.dimension,
|
||||||
len(embedding) if embedding else 0,
|
len(embedding) if embedding else 0,
|
||||||
memory_id,
|
memory_id,
|
||||||
@@ -299,14 +335,32 @@ class VectorStorageManager:
|
|||||||
|
|
||||||
async def search_similar_memories(
|
async def search_similar_memories(
|
||||||
self,
|
self,
|
||||||
query_vector: List[float],
|
query_vector: Optional[List[float]] = None,
|
||||||
|
*,
|
||||||
|
query_text: Optional[str] = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
user_id: Optional[str] = None
|
scope_id: Optional[str] = None
|
||||||
) -> List[Tuple[str, float]]:
|
) -> List[Tuple[str, float]]:
|
||||||
"""搜索相似记忆"""
|
"""搜索相似记忆"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if query_vector is None:
|
||||||
|
if not query_text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_vector = await self.generate_query_embedding(query_text)
|
||||||
|
if not query_vector:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scope_filter: Optional[str] = None
|
||||||
|
if isinstance(scope_id, str):
|
||||||
|
normalized_scope = scope_id.strip().lower()
|
||||||
|
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
|
||||||
|
scope_filter = scope_id
|
||||||
|
elif scope_id:
|
||||||
|
scope_filter = str(scope_id)
|
||||||
|
|
||||||
# 规范化查询向量
|
# 规范化查询向量
|
||||||
query_vector = self._normalize_vector(query_vector)
|
query_vector = self._normalize_vector(query_vector)
|
||||||
|
|
||||||
@@ -341,10 +395,9 @@ class VectorStorageManager:
|
|||||||
|
|
||||||
memory_id = self.index_to_memory_id.get(index)
|
memory_id = self.index_to_memory_id.get(index)
|
||||||
if memory_id:
|
if memory_id:
|
||||||
# 应用用户过滤
|
if scope_filter:
|
||||||
if user_id:
|
|
||||||
memory = self.memory_cache.get(memory_id)
|
memory = self.memory_cache.get(memory_id)
|
||||||
if memory and memory.user_id != user_id:
|
if memory and str(memory.user_id) != scope_filter:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
|
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
|
||||||
@@ -481,8 +534,14 @@ class VectorStorageManager:
|
|||||||
# 保存映射关系
|
# 保存映射关系
|
||||||
mapping_file = self.storage_path / "id_mapping.json"
|
mapping_file = self.storage_path / "id_mapping.json"
|
||||||
mapping_data = {
|
mapping_data = {
|
||||||
"memory_id_to_index": self.memory_id_to_index,
|
"memory_id_to_index": {
|
||||||
"index_to_memory_id": self.index_to_memory_id
|
str(memory_id): int(index)
|
||||||
|
for memory_id, index in self.memory_id_to_index.items()
|
||||||
|
},
|
||||||
|
"index_to_memory_id": {
|
||||||
|
str(index): memory_id
|
||||||
|
for index, memory_id in self.index_to_memory_id.items()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
with open(mapping_file, 'w', encoding='utf-8') as f:
|
with open(mapping_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||||
@@ -529,8 +588,17 @@ class VectorStorageManager:
|
|||||||
if mapping_file.exists():
|
if mapping_file.exists():
|
||||||
with open(mapping_file, 'r', encoding='utf-8') as f:
|
with open(mapping_file, 'r', encoding='utf-8') as f:
|
||||||
mapping_data = orjson.loads(f.read())
|
mapping_data = orjson.loads(f.read())
|
||||||
self.memory_id_to_index = mapping_data.get("memory_id_to_index", {})
|
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
|
||||||
self.index_to_memory_id = mapping_data.get("index_to_memory_id", {})
|
self.memory_id_to_index = {
|
||||||
|
str(memory_id): int(index)
|
||||||
|
for memory_id, index in raw_memory_to_index.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
|
||||||
|
self.index_to_memory_id = {
|
||||||
|
int(index): memory_id
|
||||||
|
for index, memory_id in raw_index_to_memory.items()
|
||||||
|
}
|
||||||
|
|
||||||
# 加载FAISS索引(如果可用)
|
# 加载FAISS索引(如果可用)
|
||||||
if FAISS_AVAILABLE:
|
if FAISS_AVAILABLE:
|
||||||
|
|||||||
@@ -469,14 +469,14 @@ class ChatBot:
|
|||||||
async def preprocess():
|
async def preprocess():
|
||||||
# 存储消息到数据库
|
# 存储消息到数据库
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await MessageStorage.store_message(message, message.chat_stream)
|
await MessageStorage.store_message(message, message.chat_stream)
|
||||||
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
|
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储消息到数据库失败: {e}")
|
logger.error(f"存储消息到数据库失败: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# 使用消息管理器处理消息(保持原有功能)
|
# 使用消息管理器处理消息(保持原有功能)
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
|||||||
@@ -373,12 +373,12 @@ class Prompt:
|
|||||||
|
|
||||||
# 性能优化 - 为不同任务设置不同的超时时间
|
# 性能优化 - 为不同任务设置不同的超时时间
|
||||||
task_timeouts = {
|
task_timeouts = {
|
||||||
"memory_block": 5.0, # 记忆系统可能较慢,单独设置超时
|
"memory_block": 15.0, # 记忆系统
|
||||||
"tool_info": 3.0, # 工具信息中等速度
|
"tool_info": 15.0, # 工具信息
|
||||||
"relation_info": 2.0, # 关系信息通常较快
|
"relation_info": 10.0, # 关系信息
|
||||||
"knowledge_info": 3.0, # 知识库查询中等速度
|
"knowledge_info": 10.0, # 知识库查询
|
||||||
"cross_context": 2.0, # 上下文处理通常较快
|
"cross_context": 10.0, # 上下文处理
|
||||||
"expression_habits": 1.5, # 表达习惯处理很快
|
"expression_habits": 10.0, # 表达习惯
|
||||||
}
|
}
|
||||||
|
|
||||||
# 分别处理每个任务,避免慢任务影响快任务
|
# 分别处理每个任务,避免慢任务影响快任务
|
||||||
@@ -558,12 +558,8 @@ class Prompt:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 等待所有记忆查询完成(最多3秒)
|
|
||||||
try:
|
try:
|
||||||
running_memories, instant_memory = await asyncio.wait_for(
|
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
|
||||||
asyncio.gather(*memory_tasks, return_exceptions=True),
|
|
||||||
timeout=3.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理可能的异常结果
|
# 处理可能的异常结果
|
||||||
if isinstance(running_memories, Exception):
|
if isinstance(running_memories, Exception):
|
||||||
|
|||||||
@@ -207,6 +207,9 @@ class VideoAnalyzer:
|
|||||||
"""检查视频是否已经分析过"""
|
"""检查视频是否已经分析过"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
|
if not session:
|
||||||
|
logger.warning("无法获取数据库会话,跳过视频存在性检查。")
|
||||||
|
return None
|
||||||
# 明确刷新会话以确保看到其他事务的最新提交
|
# 明确刷新会话以确保看到其他事务的最新提交
|
||||||
await session.expire_all()
|
await session.expire_all()
|
||||||
stmt = select(Videos).where(Videos.video_hash == video_hash)
|
stmt = select(Videos).where(Videos.video_hash == video_hash)
|
||||||
@@ -227,6 +230,9 @@ class VideoAnalyzer:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
|
if not session:
|
||||||
|
logger.warning("无法获取数据库会话,跳过视频结果存储。")
|
||||||
|
return None
|
||||||
# 只根据video_hash查找
|
# 只根据video_hash查找
|
||||||
stmt = select(Videos).where(Videos.video_hash == video_hash)
|
stmt = select(Videos).where(Videos.video_hash == video_hash)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -540,11 +546,14 @@ class VideoAnalyzer:
|
|||||||
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
||||||
|
|
||||||
# 获取模型信息和客户端
|
# 获取模型信息和客户端
|
||||||
model_info, api_provider, client = self.video_llm._select_model()
|
selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response")
|
||||||
|
if not selection_result:
|
||||||
|
raise RuntimeError("无法为视频分析选择可用模型。")
|
||||||
|
model_info, api_provider, client = selection_result
|
||||||
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
||||||
|
|
||||||
# 直接执行多图片请求
|
# 直接执行多图片请求
|
||||||
api_response = await self.video_llm._execute_request(
|
api_response = await self.video_llm._executor.execute_request(
|
||||||
api_provider=api_provider,
|
api_provider=api_provider,
|
||||||
client=client,
|
client=client,
|
||||||
request_type=RequestType.RESPONSE,
|
request_type=RequestType.RESPONSE,
|
||||||
|
|||||||
@@ -461,11 +461,14 @@ class LegacyVideoAnalyzer:
|
|||||||
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
||||||
|
|
||||||
# 获取模型信息和客户端
|
# 获取模型信息和客户端
|
||||||
model_info, api_provider, client = self.video_llm._select_model()
|
selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response")
|
||||||
|
if not selection_result:
|
||||||
|
raise RuntimeError("无法为视频分析选择可用模型 (legacy)。")
|
||||||
|
model_info, api_provider, client = selection_result
|
||||||
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
||||||
|
|
||||||
# 直接执行多图片请求
|
# 直接执行多图片请求
|
||||||
api_response = await self.video_llm._execute_request(
|
api_response = await self.video_llm._executor.execute_request(
|
||||||
api_provider=api_provider,
|
api_provider=api_provider,
|
||||||
client=client,
|
client=client,
|
||||||
request_type=RequestType.RESPONSE,
|
request_type=RequestType.RESPONSE,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
|
from src.common.config_helpers import resolve_embedding_dimension
|
||||||
from src.common.database.sqlalchemy_models import CacheEntries
|
from src.common.database.sqlalchemy_models import CacheEntries
|
||||||
from src.common.database.sqlalchemy_database_api import db_query, db_save
|
from src.common.database.sqlalchemy_database_api import db_query, db_save
|
||||||
from src.common.vector_db import vector_db_service
|
from src.common.vector_db import vector_db_service
|
||||||
@@ -40,7 +41,11 @@ class CacheManager:
|
|||||||
|
|
||||||
# L1 缓存 (内存)
|
# L1 缓存 (内存)
|
||||||
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
||||||
|
if not embedding_dim:
|
||||||
|
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||||
|
|
||||||
|
self.embedding_dimension = embedding_dim
|
||||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||||
|
|
||||||
@@ -72,7 +77,7 @@ class CacheManager:
|
|||||||
embedding_array = embedding_array.flatten()
|
embedding_array = embedding_array.flatten()
|
||||||
|
|
||||||
# 检查维度是否符合预期
|
# 检查维度是否符合预期
|
||||||
expected_dim = global_config.lpmm_knowledge.embedding_dimension
|
expected_dim = getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension
|
||||||
if embedding_array.shape[0] != expected_dim:
|
if embedding_array.shape[0] != expected_dim:
|
||||||
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
|
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
42
src/common/config_helpers.py
Normal file
42
src/common/config_helpers.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.config.config import global_config, model_config
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]:
|
||||||
|
"""获取当前配置的嵌入向量维度。
|
||||||
|
|
||||||
|
优先顺序:
|
||||||
|
1. 模型配置中 `model_task_config.embedding.embedding_dimension`
|
||||||
|
2. 机器人配置中 `lpmm_knowledge.embedding_dimension`
|
||||||
|
3. 调用方提供的 fallback
|
||||||
|
"""
|
||||||
|
|
||||||
|
candidates: list[Optional[int]] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||||
|
if embedding_task is not None:
|
||||||
|
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||||
|
except Exception:
|
||||||
|
candidates.append(None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||||
|
except Exception:
|
||||||
|
candidates.append(None)
|
||||||
|
|
||||||
|
candidates.append(fallback)
|
||||||
|
|
||||||
|
resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||||
|
|
||||||
|
if resolved and sync_global:
|
||||||
|
try:
|
||||||
|
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
|
||||||
|
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return resolved
|
||||||
@@ -759,30 +759,38 @@ async def initialize_database():
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]:
|
||||||
"""异步数据库会话上下文管理器"""
|
"""
|
||||||
|
异步数据库会话上下文管理器。
|
||||||
|
在初始化失败时会yield None,调用方需要检查会话是否为None。
|
||||||
|
"""
|
||||||
session: Optional[AsyncSession] = None
|
session: Optional[AsyncSession] = None
|
||||||
|
SessionLocal = None
|
||||||
try:
|
try:
|
||||||
engine, SessionLocal = await initialize_database()
|
_, SessionLocal = await initialize_database()
|
||||||
if not SessionLocal:
|
if not SessionLocal:
|
||||||
raise RuntimeError("Database session not initialized")
|
logger.error("数据库会话工厂 (_SessionLocal) 未初始化。")
|
||||||
session = SessionLocal()
|
yield None
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库初始化失败,无法创建会话: {e}")
|
||||||
|
yield None
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = SessionLocal()
|
||||||
# 对于 SQLite,在会话开始时设置 PRAGMA
|
# 对于 SQLite,在会话开始时设置 PRAGMA
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
if global_config.database.database_type == "sqlite":
|
if global_config.database.database_type == "sqlite":
|
||||||
try:
|
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
await session.execute(text("PRAGMA foreign_keys = ON"))
|
||||||
await session.execute(text("PRAGMA foreign_keys = ON"))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SQLite] 设置会话 PRAGMA 失败: {e}")
|
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"数据库会话错误: {e}")
|
logger.error(f"数据库会话期间发生错误: {e}")
|
||||||
if session:
|
if session:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise
|
raise # 将会话期间的错误重新抛出给调用者
|
||||||
finally:
|
finally:
|
||||||
if session:
|
if session:
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Dict, Any, Literal, Union
|
from typing import List, Dict, Any, Literal, Union, Optional
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
@@ -105,6 +105,11 @@ class TaskConfig(ValidatedConfigBase):
|
|||||||
max_tokens: int = Field(default=800, description="任务最大输出token数")
|
max_tokens: int = Field(default=800, description="任务最大输出token数")
|
||||||
temperature: float = Field(default=0.7, description="模型温度")
|
temperature: float = Field(default=0.7, description="模型温度")
|
||||||
concurrency_count: int = Field(default=1, description="并发请求数量")
|
concurrency_count: int = Field(default=1, description="并发请求数量")
|
||||||
|
embedding_dimension: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="嵌入模型输出向量维度,仅在嵌入任务中使用",
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model_list(cls, v):
|
def validate_model_list(cls, v):
|
||||||
|
|||||||
@@ -443,21 +443,6 @@ class MemoryConfig(ValidatedConfigBase):
|
|||||||
|
|
||||||
enable_memory: bool = Field(default=True, description="启用记忆")
|
enable_memory: bool = Field(default=True, description="启用记忆")
|
||||||
memory_build_interval: int = Field(default=600, description="记忆构建间隔")
|
memory_build_interval: int = Field(default=600, description="记忆构建间隔")
|
||||||
memory_build_distribution: list[float] = Field(
|
|
||||||
default_factory=lambda: [6.0, 3.0, 0.6, 32.0, 12.0, 0.4], description="记忆构建分布"
|
|
||||||
)
|
|
||||||
memory_build_sample_num: int = Field(default=8, description="记忆构建样本数量")
|
|
||||||
memory_build_sample_length: int = Field(default=40, description="记忆构建样本长度")
|
|
||||||
memory_compress_rate: float = Field(default=0.1, description="记忆压缩率")
|
|
||||||
forget_memory_interval: int = Field(default=1000, description="遗忘记忆间隔")
|
|
||||||
memory_forget_time: int = Field(default=24, description="记忆遗忘时间")
|
|
||||||
memory_forget_percentage: float = Field(default=0.01, description="记忆遗忘百分比")
|
|
||||||
consolidate_memory_interval: int = Field(default=1000, description="记忆巩固间隔")
|
|
||||||
consolidation_similarity_threshold: float = Field(default=0.7, description="巩固相似性阈值")
|
|
||||||
consolidate_memory_percentage: float = Field(default=0.01, description="巩固记忆百分比")
|
|
||||||
memory_ban_words: list[str] = Field(
|
|
||||||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"], description="记忆禁用词"
|
|
||||||
)
|
|
||||||
enable_instant_memory: bool = Field(default=True, description="启用即时记忆")
|
enable_instant_memory: bool = Field(default=True, description="启用即时记忆")
|
||||||
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
|
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
|
||||||
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
|
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
|
||||||
@@ -472,8 +457,8 @@ class MemoryConfig(ValidatedConfigBase):
|
|||||||
memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值")
|
memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值")
|
||||||
|
|
||||||
# 向量存储配置
|
# 向量存储配置
|
||||||
vector_dimension: int = Field(default=768, description="向量维度")
|
|
||||||
vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值")
|
vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值")
|
||||||
|
semantic_similarity_threshold: float = Field(default=0.6, description="语义相似度阈值")
|
||||||
|
|
||||||
# 多阶段检索配置
|
# 多阶段检索配置
|
||||||
metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量")
|
metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量")
|
||||||
|
|||||||
15
src/main.py
15
src/main.py
@@ -27,9 +27,8 @@ from src.plugin_system.core.event_manager import event_manager
|
|||||||
from src.plugin_system.base.component_types import EventType
|
from src.plugin_system.base.component_types import EventType
|
||||||
# from src.api.main import start_api_server
|
# from src.api.main import start_api_server
|
||||||
|
|
||||||
# 导入新的插件管理器和热重载管理器
|
# 导入新的插件管理器
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
|
||||||
|
|
||||||
# 导入消息API和traceback模块
|
# 导入消息API和traceback模块
|
||||||
from src.common.message import get_global_api
|
from src.common.message import get_global_api
|
||||||
@@ -116,13 +115,7 @@ class MainSystem:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"停止消息重组器时出错: {e}")
|
logger.error(f"停止消息重组器时出错: {e}")
|
||||||
|
|
||||||
try:
|
|
||||||
# 停止插件热重载系统
|
|
||||||
hot_reload_manager.stop()
|
|
||||||
logger.info("🛑 插件热重载系统已停止")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"停止热重载系统时出错: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 停止增强记忆系统
|
# 停止增强记忆系统
|
||||||
if global_config.memory.enable_memory:
|
if global_config.memory.enable_memory:
|
||||||
@@ -229,9 +222,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 处理所有缓存的事件订阅(插件加载完成后)
|
# 处理所有缓存的事件订阅(插件加载完成后)
|
||||||
event_manager.process_all_pending_subscriptions()
|
event_manager.process_all_pending_subscriptions()
|
||||||
|
|
||||||
# 启动插件热重载系统
|
|
||||||
hot_reload_manager.start()
|
|
||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional, List, Dict, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
@@ -27,8 +27,21 @@ class BaseAction(ABC):
|
|||||||
- parallel_action: 是否允许并行执行
|
- parallel_action: 是否允许并行执行
|
||||||
- random_activation_probability: 随机激活概率
|
- random_activation_probability: 随机激活概率
|
||||||
- llm_judge_prompt: LLM判断提示词
|
- llm_judge_prompt: LLM判断提示词
|
||||||
|
|
||||||
|
二步Action相关属性:
|
||||||
|
- is_two_step_action: 是否为二步Action
|
||||||
|
- step_one_description: 第一步的描述
|
||||||
|
- sub_actions: 子Action列表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 二步Action相关类属性
|
||||||
|
is_two_step_action: bool = False
|
||||||
|
"""是否为二步Action。如果为True,Action将分两步执行:第一步选择操作,第二步执行具体操作"""
|
||||||
|
step_one_description: str = ""
|
||||||
|
"""第一步的描述,用于向LLM展示Action的基本功能"""
|
||||||
|
sub_actions: List[Tuple[str, str, Dict[str, str]]] = []
|
||||||
|
"""子Action列表,格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
action_data: dict,
|
action_data: dict,
|
||||||
@@ -93,6 +106,13 @@ class BaseAction(ABC):
|
|||||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||||
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||||
|
|
||||||
|
# 二步Action相关实例属性
|
||||||
|
self.is_two_step_action: bool = getattr(self.__class__, "is_two_step_action", False)
|
||||||
|
self.step_one_description: str = getattr(self.__class__, "step_one_description", "")
|
||||||
|
self.sub_actions: List[Tuple[str, str, Dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy()
|
||||||
|
self._selected_sub_action: Optional[str] = None
|
||||||
|
"""当前选择的子Action名称,用于二步Action的状态管理"""
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -412,23 +432,32 @@ class BaseAction(ABC):
|
|||||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||||
return False, f"未找到Action组件信息: {action_name}"
|
return False, f"未找到Action组件信息: {action_name}"
|
||||||
|
|
||||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
# 确保获取的是Action组件
|
||||||
|
if component_info.component_type != ComponentType.ACTION:
|
||||||
|
logger.error(f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action,而是一个 '{component_info.component_type.value}'")
|
||||||
|
return False, f"组件 '{action_name}' 不是一个有效的Action"
|
||||||
|
|
||||||
|
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
||||||
# 3. 实例化被调用的Action
|
# 3. 实例化被调用的Action
|
||||||
action_instance = action_class(
|
action_params = {
|
||||||
action_data=called_action_data,
|
"action_data": called_action_data,
|
||||||
reasoning=f"Called by {self.action_name}",
|
"reasoning": f"Called by {self.action_name}",
|
||||||
cycle_timers=self.cycle_timers,
|
"cycle_timers": self.cycle_timers,
|
||||||
thinking_id=self.thinking_id,
|
"thinking_id": self.thinking_id,
|
||||||
chat_stream=self.chat_stream,
|
"chat_stream": self.chat_stream,
|
||||||
log_prefix=log_prefix,
|
"log_prefix": log_prefix,
|
||||||
plugin_config=plugin_config,
|
"plugin_config": plugin_config,
|
||||||
action_message=self.action_message,
|
"action_message": self.action_message,
|
||||||
)
|
}
|
||||||
|
action_instance = action_class(**action_params)
|
||||||
|
|
||||||
# 4. 执行Action
|
# 4. 执行Action
|
||||||
logger.debug(f"{log_prefix} 开始执行...")
|
logger.debug(f"{log_prefix} 开始执行...")
|
||||||
result = await action_instance.execute()
|
execute_result = await action_instance.execute()
|
||||||
|
# 确保返回类型符合 (bool, str) 格式
|
||||||
|
is_success = execute_result[0] if isinstance(execute_result, tuple) and len(execute_result) > 0 else False
|
||||||
|
message = execute_result[1] if isinstance(execute_result, tuple) and len(execute_result) > 1 else ""
|
||||||
|
result = (is_success, str(message))
|
||||||
logger.info(f"{log_prefix} 执行完成,结果: {result}")
|
logger.info(f"{log_prefix} 执行完成,结果: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -477,15 +506,73 @@ class BaseAction(ABC):
|
|||||||
action_require=getattr(cls, "action_require", []).copy(),
|
action_require=getattr(cls, "action_require", []).copy(),
|
||||||
associated_types=getattr(cls, "associated_types", []).copy(),
|
associated_types=getattr(cls, "associated_types", []).copy(),
|
||||||
chat_type_allow=getattr(cls, "chat_type_allow", ChatType.ALL),
|
chat_type_allow=getattr(cls, "chat_type_allow", ChatType.ALL),
|
||||||
|
# 二步Action相关属性
|
||||||
|
is_two_step_action=getattr(cls, "is_two_step_action", False),
|
||||||
|
step_one_description=getattr(cls, "step_one_description", ""),
|
||||||
|
sub_actions=getattr(cls, "sub_actions", []).copy(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def handle_step_one(self) -> Tuple[bool, str]:
|
||||||
|
"""处理二步Action的第一步
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||||
|
"""
|
||||||
|
if not self.is_two_step_action:
|
||||||
|
return False, "此Action不是二步Action"
|
||||||
|
|
||||||
|
# 检查action_data中是否包含选择的子Action
|
||||||
|
selected_action = self.action_data.get("selected_action")
|
||||||
|
if not selected_action:
|
||||||
|
# 第一步:展示可用的子Action
|
||||||
|
available_actions = [sub_action[0] for sub_action in self.sub_actions]
|
||||||
|
description = self.step_one_description or f"{self.action_name}支持以下操作"
|
||||||
|
|
||||||
|
actions_list = "\n".join([f"- {action}: {desc}" for action, desc, _ in self.sub_actions])
|
||||||
|
response = f"{description}\n\n可用操作:\n{actions_list}\n\n请选择要执行的操作。"
|
||||||
|
|
||||||
|
return True, response
|
||||||
|
else:
|
||||||
|
# 验证选择的子Action是否有效
|
||||||
|
valid_actions = [sub_action[0] for sub_action in self.sub_actions]
|
||||||
|
if selected_action not in valid_actions:
|
||||||
|
return False, f"无效的操作选择: {selected_action}。可用操作: {valid_actions}"
|
||||||
|
|
||||||
|
# 保存选择的子Action
|
||||||
|
self._selected_sub_action = selected_action
|
||||||
|
|
||||||
|
# 调用第二步执行
|
||||||
|
return await self.execute_step_two(selected_action)
|
||||||
|
|
||||||
|
async def execute_step_two(self, sub_action_name: str) -> Tuple[bool, str]:
|
||||||
|
"""执行二步Action的第二步
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sub_action_name: 子Action名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||||
|
"""
|
||||||
|
if not self.is_two_step_action:
|
||||||
|
return False, "此Action不是二步Action"
|
||||||
|
|
||||||
|
# 子类需要重写此方法来实现具体的第二步逻辑
|
||||||
|
return False, f"二步Action必须实现execute_step_two方法来处理操作: {sub_action_name}"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
"""执行Action的抽象方法,子类必须实现
|
"""执行Action的抽象方法,子类必须实现
|
||||||
|
|
||||||
|
对于二步Action,会自动处理第一步逻辑
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||||
"""
|
"""
|
||||||
|
# 如果是二步Action,自动处理第一步
|
||||||
|
if self.is_two_step_action:
|
||||||
|
return await self.handle_step_one()
|
||||||
|
|
||||||
|
# 普通Action由子类实现
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def handle_action(self) -> Tuple[bool, str]:
|
async def handle_action(self) -> Tuple[bool, str]:
|
||||||
|
|||||||
@@ -38,6 +38,14 @@ class BaseTool(ABC):
|
|||||||
semantic_cache_query_key: Optional[str] = None
|
semantic_cache_query_key: Optional[str] = None
|
||||||
"""用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索"""
|
"""用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索"""
|
||||||
|
|
||||||
|
# 二步工具调用相关属性
|
||||||
|
is_two_step_tool: bool = False
|
||||||
|
"""是否为二步工具。如果为True,工具将分两步调用:第一步展示工具信息,第二步执行具体操作"""
|
||||||
|
step_one_description: str = ""
|
||||||
|
"""第一步的描述,用于向LLM展示工具的基本功能"""
|
||||||
|
sub_tools: List[Tuple[str, str, List[Tuple[str, ToolParamType, str, bool, List[str] | None]]]] = []
|
||||||
|
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
|
||||||
|
|
||||||
def __init__(self, plugin_config: Optional[dict] = None):
|
def __init__(self, plugin_config: Optional[dict] = None):
|
||||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||||
|
|
||||||
@@ -48,10 +56,64 @@ class BaseTool(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: 工具定义字典
|
dict: 工具定义字典
|
||||||
"""
|
"""
|
||||||
if not cls.name or not cls.description or not cls.parameters:
|
if not cls.name or not cls.description:
|
||||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性")
|
||||||
|
|
||||||
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
# 如果是二步工具,第一步只返回基本信息
|
||||||
|
if cls.is_two_step_tool:
|
||||||
|
return {
|
||||||
|
"name": cls.name,
|
||||||
|
"description": cls.step_one_description or cls.description,
|
||||||
|
"parameters": [("action", ToolParamType.STRING, "选择要执行的操作", True, [sub_tool[0] for sub_tool in cls.sub_tools])]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 普通工具需要parameters
|
||||||
|
if not cls.parameters:
|
||||||
|
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 parameters 属性")
|
||||||
|
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_step_two_tool_definition(cls, sub_tool_name: str) -> dict[str, Any]:
|
||||||
|
"""获取二步工具的第二步定义
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sub_tool_name: 子工具名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 第二步工具定义字典
|
||||||
|
"""
|
||||||
|
if not cls.is_two_step_tool:
|
||||||
|
raise ValueError(f"工具 {cls.name} 不是二步工具")
|
||||||
|
|
||||||
|
# 查找对应的子工具
|
||||||
|
for sub_name, sub_desc, sub_params in cls.sub_tools:
|
||||||
|
if sub_name == sub_tool_name:
|
||||||
|
return {
|
||||||
|
"name": f"{cls.name}_{sub_tool_name}",
|
||||||
|
"description": sub_desc,
|
||||||
|
"parameters": sub_params
|
||||||
|
}
|
||||||
|
|
||||||
|
raise ValueError(f"未找到子工具: {sub_tool_name}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_sub_tool_definitions(cls) -> List[dict[str, Any]]:
|
||||||
|
"""获取所有子工具的定义
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: 所有子工具定义列表
|
||||||
|
"""
|
||||||
|
if not cls.is_two_step_tool:
|
||||||
|
return []
|
||||||
|
|
||||||
|
definitions = []
|
||||||
|
for sub_name, sub_desc, sub_params in cls.sub_tools:
|
||||||
|
definitions.append({
|
||||||
|
"name": f"{cls.name}_{sub_name}",
|
||||||
|
"description": sub_desc,
|
||||||
|
"parameters": sub_params
|
||||||
|
})
|
||||||
|
return definitions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_info(cls) -> ToolInfo:
|
def get_tool_info(cls) -> ToolInfo:
|
||||||
@@ -79,8 +141,68 @@ class BaseTool(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: 工具执行结果
|
dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
|
# 如果是二步工具,处理第一步调用
|
||||||
|
if self.is_two_step_tool and "action" in function_args:
|
||||||
|
return await self._handle_step_one(function_args)
|
||||||
|
|
||||||
raise NotImplementedError("子类必须实现execute方法")
|
raise NotImplementedError("子类必须实现execute方法")
|
||||||
|
|
||||||
|
async def _handle_step_one(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""处理二步工具的第一步调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function_args: 包含action参数的函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 第一步执行结果,包含第二步的工具定义
|
||||||
|
"""
|
||||||
|
action = function_args.get("action")
|
||||||
|
if not action:
|
||||||
|
return {"error": "缺少action参数"}
|
||||||
|
|
||||||
|
# 查找对应的子工具
|
||||||
|
sub_tool_found = None
|
||||||
|
for sub_name, sub_desc, sub_params in self.sub_tools:
|
||||||
|
if sub_name == action:
|
||||||
|
sub_tool_found = (sub_name, sub_desc, sub_params)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not sub_tool_found:
|
||||||
|
available_actions = [sub_tool[0] for sub_tool in self.sub_tools]
|
||||||
|
return {"error": f"未知的操作: {action}。可用操作: {available_actions}"}
|
||||||
|
|
||||||
|
sub_name, sub_desc, sub_params = sub_tool_found
|
||||||
|
|
||||||
|
# 返回第二步工具定义
|
||||||
|
step_two_definition = {
|
||||||
|
"name": f"{self.name}_{sub_name}",
|
||||||
|
"description": sub_desc,
|
||||||
|
"parameters": sub_params
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "two_step_tool_step_one",
|
||||||
|
"content": f"已选择操作: {action}。请使用以下工具进行具体调用:",
|
||||||
|
"next_tool_definition": step_two_definition,
|
||||||
|
"selected_action": action
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute_step_two(self, sub_tool_name: str, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""执行二步工具的第二步
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sub_tool_name: 子工具名称
|
||||||
|
function_args: 工具调用参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 工具执行结果
|
||||||
|
"""
|
||||||
|
if not self.is_two_step_tool:
|
||||||
|
raise ValueError(f"工具 {self.name} 不是二步工具")
|
||||||
|
|
||||||
|
# 子类需要重写此方法来实现具体的第二步逻辑
|
||||||
|
raise NotImplementedError("二步工具必须实现execute_step_two方法")
|
||||||
|
|
||||||
async def direct_execute(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
|
async def direct_execute(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""直接执行工具函数(供插件调用)
|
"""直接执行工具函数(供插件调用)
|
||||||
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
|
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
|
||||||
|
|||||||
@@ -142,6 +142,10 @@ class ActionInfo(ComponentInfo):
|
|||||||
mode_enable: ChatMode = ChatMode.ALL
|
mode_enable: ChatMode = ChatMode.ALL
|
||||||
parallel_action: bool = False
|
parallel_action: bool = False
|
||||||
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
|
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
|
||||||
|
# 二步Action相关属性
|
||||||
|
is_two_step_action: bool = False # 是否为二步Action
|
||||||
|
step_one_description: str = "" # 第一步的描述
|
||||||
|
sub_actions: List[Tuple[str, str, Dict[str, str]]] = field(default_factory=list) # 子Action列表
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
@@ -153,6 +157,8 @@ class ActionInfo(ComponentInfo):
|
|||||||
self.action_require = []
|
self.action_require = []
|
||||||
if self.associated_types is None:
|
if self.associated_types is None:
|
||||||
self.associated_types = []
|
self.associated_types = []
|
||||||
|
if self.sub_actions is None:
|
||||||
|
self.sub_actions = []
|
||||||
self.component_type = ComponentType.ACTION
|
self.component_type = ComponentType.ACTION
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,12 +8,10 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
|||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plugin_manager",
|
"plugin_manager",
|
||||||
"component_registry",
|
"component_registry",
|
||||||
"event_manager",
|
"event_manager",
|
||||||
"global_announcement_manager",
|
"global_announcement_manager",
|
||||||
"hot_reload_manager",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,512 +0,0 @@
|
|||||||
"""
|
|
||||||
插件热重载模块
|
|
||||||
|
|
||||||
使用 Watchdog 监听插件目录变化,自动重载插件
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import importlib
|
|
||||||
from pathlib import Path
|
|
||||||
from threading import Thread
|
|
||||||
from typing import Dict, Set, List, Optional, Tuple
|
|
||||||
|
|
||||||
from watchdog.observers import Observer
|
|
||||||
from watchdog.events import FileSystemEventHandler
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from .plugin_manager import plugin_manager
|
|
||||||
|
|
||||||
logger = get_logger("plugin_hot_reload")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginFileHandler(FileSystemEventHandler):
|
|
||||||
"""插件文件变化处理器"""
|
|
||||||
|
|
||||||
def __init__(self, hot_reload_manager):
|
|
||||||
super().__init__()
|
|
||||||
self.hot_reload_manager = hot_reload_manager
|
|
||||||
self.pending_reloads: Set[str] = set() # 待重载的插件名称
|
|
||||||
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
|
|
||||||
self.debounce_delay = 2.0 # 增加防抖延迟到2秒,确保文件写入完成
|
|
||||||
self.file_change_cache: Dict[str, float] = {} # 文件变化缓存
|
|
||||||
|
|
||||||
def on_modified(self, event):
|
|
||||||
"""文件修改事件"""
|
|
||||||
if not event.is_directory:
|
|
||||||
file_path = str(event.src_path)
|
|
||||||
if file_path.endswith((".py", ".toml")):
|
|
||||||
self._handle_file_change(file_path, "modified")
|
|
||||||
|
|
||||||
def on_created(self, event):
|
|
||||||
"""文件创建事件"""
|
|
||||||
if not event.is_directory:
|
|
||||||
file_path = str(event.src_path)
|
|
||||||
if file_path.endswith((".py", ".toml")):
|
|
||||||
self._handle_file_change(file_path, "created")
|
|
||||||
|
|
||||||
def on_deleted(self, event):
|
|
||||||
"""文件删除事件"""
|
|
||||||
if not event.is_directory:
|
|
||||||
file_path = str(event.src_path)
|
|
||||||
if file_path.endswith((".py", ".toml")):
|
|
||||||
self._handle_file_change(file_path, "deleted")
|
|
||||||
|
|
||||||
def _handle_file_change(self, file_path: str, change_type: str):
|
|
||||||
"""处理文件变化"""
|
|
||||||
try:
|
|
||||||
# 获取插件名称
|
|
||||||
plugin_info = self._get_plugin_info_from_path(file_path)
|
|
||||||
if not plugin_info:
|
|
||||||
return
|
|
||||||
|
|
||||||
plugin_name, source_type = plugin_info
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# 文件变化缓存,避免重复处理同一文件的快速连续变化
|
|
||||||
file_cache_key = f"{file_path}_{change_type}"
|
|
||||||
last_file_time = self.file_change_cache.get(file_cache_key, 0)
|
|
||||||
if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略
|
|
||||||
return
|
|
||||||
self.file_change_cache[file_cache_key] = current_time
|
|
||||||
|
|
||||||
# 插件级别的防抖处理
|
|
||||||
last_plugin_time = self.last_reload_time.get(plugin_name, 0)
|
|
||||||
if current_time - last_plugin_time < self.debounce_delay:
|
|
||||||
# 如果在防抖期内,更新待重载标记但不立即处理
|
|
||||||
self.pending_reloads.add(plugin_name)
|
|
||||||
return
|
|
||||||
|
|
||||||
file_name = Path(file_path).name
|
|
||||||
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type}) [{source_type}] -> {plugin_name}")
|
|
||||||
|
|
||||||
# 如果是删除事件,处理关键文件删除
|
|
||||||
if change_type == "deleted":
|
|
||||||
# 解析实际的插件名称
|
|
||||||
actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name)
|
|
||||||
|
|
||||||
if file_name == "plugin.py":
|
|
||||||
if actual_plugin_name in plugin_manager.loaded_plugins:
|
|
||||||
logger.info(
|
|
||||||
f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
|
||||||
)
|
|
||||||
self.hot_reload_manager._unload_plugin(actual_plugin_name)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
elif file_name in ("manifest.toml", "_manifest.json"):
|
|
||||||
if actual_plugin_name in plugin_manager.loaded_plugins:
|
|
||||||
logger.info(
|
|
||||||
f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
|
||||||
)
|
|
||||||
self.hot_reload_manager._unload_plugin(actual_plugin_name)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 对于修改和创建事件,都进行重载
|
|
||||||
# 添加到待重载列表
|
|
||||||
self.pending_reloads.add(plugin_name)
|
|
||||||
self.last_reload_time[plugin_name] = current_time
|
|
||||||
|
|
||||||
# 延迟重载,确保文件写入完成
|
|
||||||
reload_thread = Thread(
|
|
||||||
target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True
|
|
||||||
)
|
|
||||||
reload_thread.start()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 处理文件变化时发生错误: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def _delayed_reload(self, plugin_name: str, source_type: str, trigger_time: float):
|
|
||||||
"""延迟重载插件"""
|
|
||||||
try:
|
|
||||||
# 等待文件写入完成
|
|
||||||
time.sleep(self.debounce_delay)
|
|
||||||
|
|
||||||
# 检查是否还需要重载(可能在等待期间有更新的变化)
|
|
||||||
if plugin_name not in self.pending_reloads:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 检查是否有更新的重载请求
|
|
||||||
if self.last_reload_time.get(plugin_name, 0) > trigger_time:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.pending_reloads.discard(plugin_name)
|
|
||||||
logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]")
|
|
||||||
|
|
||||||
# 执行深度重载
|
|
||||||
success = self.hot_reload_manager._deep_reload_plugin(plugin_name)
|
|
||||||
if success:
|
|
||||||
logger.info(f"✅ 插件重载成功: {plugin_name} [{source_type}]")
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 插件重载失败: {plugin_name} [{source_type}]")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]:
|
|
||||||
"""从文件路径获取插件信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[插件名称, 源类型] 或 None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
path = Path(file_path)
|
|
||||||
|
|
||||||
# 检查是否在任何一个监听的插件目录中
|
|
||||||
for watch_dir in self.hot_reload_manager.watch_directories:
|
|
||||||
plugin_root = Path(watch_dir)
|
|
||||||
if path.is_relative_to(plugin_root):
|
|
||||||
# 确定源类型
|
|
||||||
if "src" in str(plugin_root):
|
|
||||||
source_type = "built-in"
|
|
||||||
else:
|
|
||||||
source_type = "external"
|
|
||||||
|
|
||||||
# 获取插件目录名(插件名)
|
|
||||||
relative_path = path.relative_to(plugin_root)
|
|
||||||
if len(relative_path.parts) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
plugin_name = relative_path.parts[0]
|
|
||||||
|
|
||||||
# 确认这是一个有效的插件目录
|
|
||||||
plugin_dir = plugin_root / plugin_name
|
|
||||||
if plugin_dir.is_dir():
|
|
||||||
# 检查是否有插件主文件或配置文件
|
|
||||||
has_plugin_py = (plugin_dir / "plugin.py").exists()
|
|
||||||
has_manifest = (plugin_dir / "manifest.toml").exists() or (
|
|
||||||
plugin_dir / "_manifest.json"
|
|
||||||
).exists()
|
|
||||||
|
|
||||||
if has_plugin_py or has_manifest:
|
|
||||||
return plugin_name, source_type
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class PluginHotReloadManager:
|
|
||||||
"""插件热重载管理器"""
|
|
||||||
|
|
||||||
def __init__(self, watch_directories: Optional[List[str]] = None):
|
|
||||||
if watch_directories is None:
|
|
||||||
# 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录
|
|
||||||
self.watch_directories = [
|
|
||||||
os.path.join(os.getcwd(), "plugins"), # 外部插件目录
|
|
||||||
os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.watch_directories = watch_directories
|
|
||||||
|
|
||||||
self.observers = []
|
|
||||||
self.file_handlers = []
|
|
||||||
self.is_running = False
|
|
||||||
|
|
||||||
# 确保监听目录存在
|
|
||||||
for watch_dir in self.watch_directories:
|
|
||||||
if not os.path.exists(watch_dir):
|
|
||||||
os.makedirs(watch_dir, exist_ok=True)
|
|
||||||
logger.info(f"📁 创建插件监听目录: {watch_dir}")
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""启动热重载监听"""
|
|
||||||
if self.is_running:
|
|
||||||
logger.warning("插件热重载已经在运行中")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 为每个监听目录创建独立的观察者
|
|
||||||
for watch_dir in self.watch_directories:
|
|
||||||
observer = Observer()
|
|
||||||
file_handler = PluginFileHandler(self)
|
|
||||||
|
|
||||||
observer.schedule(file_handler, watch_dir, recursive=True)
|
|
||||||
|
|
||||||
observer.start()
|
|
||||||
self.observers.append(observer)
|
|
||||||
self.file_handlers.append(file_handler)
|
|
||||||
|
|
||||||
self.is_running = True
|
|
||||||
|
|
||||||
# 打印监听的目录信息
|
|
||||||
dir_info = []
|
|
||||||
for watch_dir in self.watch_directories:
|
|
||||||
if "src" in watch_dir:
|
|
||||||
dir_info.append(f"{watch_dir} (内置插件)")
|
|
||||||
else:
|
|
||||||
dir_info.append(f"{watch_dir} (外部插件)")
|
|
||||||
|
|
||||||
logger.info("🚀 插件热重载已启动,监听目录:")
|
|
||||||
for info in dir_info:
|
|
||||||
logger.info(f" 📂 {info}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 启动插件热重载失败: {e}")
|
|
||||||
self.stop() # 清理已创建的观察者
|
|
||||||
self.is_running = False
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""停止热重载监听"""
|
|
||||||
if not self.is_running and not self.observers:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 停止所有观察者
|
|
||||||
for observer in self.observers:
|
|
||||||
try:
|
|
||||||
observer.stop()
|
|
||||||
observer.join()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 停止观察者时发生错误: {e}")
|
|
||||||
|
|
||||||
self.observers.clear()
|
|
||||||
self.file_handlers.clear()
|
|
||||||
self.is_running = False
|
|
||||||
logger.info("🛑 插件热重载已停止")
|
|
||||||
|
|
||||||
def _reload_plugin(self, plugin_name: str):
|
|
||||||
"""重载指定插件(简单重载)"""
|
|
||||||
try:
|
|
||||||
# 解析实际的插件名称
|
|
||||||
actual_plugin_name = self._resolve_plugin_name(plugin_name)
|
|
||||||
logger.info(f"🔄 开始简单重载插件: {plugin_name} -> {actual_plugin_name}")
|
|
||||||
|
|
||||||
if plugin_manager.reload_plugin(actual_plugin_name):
|
|
||||||
logger.info(f"✅ 插件简单重载成功: {actual_plugin_name}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 插件简单重载失败: {actual_plugin_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resolve_plugin_name(folder_name: str) -> str:
|
|
||||||
"""
|
|
||||||
将文件夹名称解析为实际的插件名称
|
|
||||||
通过检查插件管理器中的路径映射来找到对应的插件名
|
|
||||||
"""
|
|
||||||
# 首先检查是否直接匹配
|
|
||||||
if folder_name in plugin_manager.plugin_classes:
|
|
||||||
logger.debug(f"🔍 直接匹配插件名: {folder_name}")
|
|
||||||
return folder_name
|
|
||||||
|
|
||||||
# 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称
|
|
||||||
matched_plugins = []
|
|
||||||
for plugin_name, plugin_path in plugin_manager.plugin_paths.items():
|
|
||||||
# 检查路径是否包含该文件夹名
|
|
||||||
if folder_name in plugin_path:
|
|
||||||
matched_plugins.append((plugin_name, plugin_path))
|
|
||||||
|
|
||||||
# 在匹配的插件中,优先选择在插件类中存在的
|
|
||||||
for plugin_name, plugin_path in matched_plugins:
|
|
||||||
if plugin_name in plugin_manager.plugin_classes:
|
|
||||||
logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})")
|
|
||||||
return plugin_name
|
|
||||||
|
|
||||||
# 如果还是没找到在插件类中存在的,返回第一个匹配项
|
|
||||||
if matched_plugins:
|
|
||||||
plugin_name, plugin_path = matched_plugins[0]
|
|
||||||
logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在")
|
|
||||||
return plugin_name
|
|
||||||
|
|
||||||
# 如果还是没找到,返回原文件夹名
|
|
||||||
logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称")
|
|
||||||
return folder_name
|
|
||||||
|
|
||||||
def _deep_reload_plugin(self, plugin_name: str):
|
|
||||||
"""深度重载指定插件(清理模块缓存)"""
|
|
||||||
try:
|
|
||||||
# 解析实际的插件名称
|
|
||||||
actual_plugin_name = self._resolve_plugin_name(plugin_name)
|
|
||||||
logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}")
|
|
||||||
|
|
||||||
# 强制清理相关模块缓存
|
|
||||||
self._force_clear_plugin_modules(plugin_name)
|
|
||||||
|
|
||||||
# 使用插件管理器的强制重载功能
|
|
||||||
success = plugin_manager.force_reload_plugin(actual_plugin_name)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 插件深度重载失败,尝试简单重载: {actual_plugin_name}")
|
|
||||||
# 如果深度重载失败,尝试简单重载
|
|
||||||
return self._reload_plugin(actual_plugin_name)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 深度重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
|
||||||
# 出错时尝试简单重载
|
|
||||||
return self._reload_plugin(plugin_name)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _force_clear_plugin_modules(plugin_name: str):
|
|
||||||
"""强制清理插件相关的模块缓存"""
|
|
||||||
|
|
||||||
# 找到所有相关的模块名
|
|
||||||
modules_to_remove = []
|
|
||||||
plugin_module_prefix = f"src.plugins.built_in.{plugin_name}"
|
|
||||||
|
|
||||||
for module_name in list(sys.modules.keys()):
|
|
||||||
if plugin_module_prefix in module_name:
|
|
||||||
modules_to_remove.append(module_name)
|
|
||||||
|
|
||||||
# 删除模块缓存
|
|
||||||
for module_name in modules_to_remove:
|
|
||||||
if module_name in sys.modules:
|
|
||||||
logger.debug(f"🗑️ 清理模块缓存: {module_name}")
|
|
||||||
del sys.modules[module_name]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _force_reimport_plugin(plugin_name: str):
|
|
||||||
"""强制重新导入插件(委托给插件管理器)"""
|
|
||||||
try:
|
|
||||||
# 使用插件管理器的重载功能
|
|
||||||
success = plugin_manager.reload_plugin(plugin_name)
|
|
||||||
return success
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unload_plugin(plugin_name: str):
|
|
||||||
"""卸载指定插件"""
|
|
||||||
try:
|
|
||||||
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
|
|
||||||
|
|
||||||
if plugin_manager.unload_plugin(plugin_name):
|
|
||||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 插件卸载失败: {plugin_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def reload_all_plugins(self):
|
|
||||||
"""重载所有插件"""
|
|
||||||
try:
|
|
||||||
logger.info("🔄 开始深度重载所有插件...")
|
|
||||||
|
|
||||||
# 获取当前已加载的插件列表
|
|
||||||
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
|
|
||||||
|
|
||||||
success_count = 0
|
|
||||||
fail_count = 0
|
|
||||||
|
|
||||||
for plugin_name in loaded_plugins:
|
|
||||||
logger.info(f"🔄 重载插件: {plugin_name}")
|
|
||||||
if self._deep_reload_plugin(plugin_name):
|
|
||||||
success_count += 1
|
|
||||||
else:
|
|
||||||
fail_count += 1
|
|
||||||
|
|
||||||
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个")
|
|
||||||
|
|
||||||
# 清理全局缓存
|
|
||||||
importlib.invalidate_caches()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 重载所有插件时发生错误: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def force_reload_plugin(self, plugin_name: str):
|
|
||||||
"""手动强制重载指定插件(委托给插件管理器)"""
|
|
||||||
try:
|
|
||||||
logger.info(f"🔄 手动强制重载插件: {plugin_name}")
|
|
||||||
|
|
||||||
# 清理待重载列表中的该插件(避免重复重载)
|
|
||||||
for handler in self.file_handlers:
|
|
||||||
handler.pending_reloads.discard(plugin_name)
|
|
||||||
|
|
||||||
# 使用插件管理器的强制重载功能
|
|
||||||
success = plugin_manager.force_reload_plugin(plugin_name)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info(f"✅ 手动强制重载成功: {plugin_name}")
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 手动强制重载失败: {plugin_name}")
|
|
||||||
|
|
||||||
return success
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_watch_directory(self, directory: str):
|
|
||||||
"""添加新的监听目录"""
|
|
||||||
if directory in self.watch_directories:
|
|
||||||
logger.info(f"目录 {directory} 已在监听列表中")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 确保目录存在
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
os.makedirs(directory, exist_ok=True)
|
|
||||||
logger.info(f"📁 创建插件监听目录: {directory}")
|
|
||||||
|
|
||||||
self.watch_directories.append(directory)
|
|
||||||
|
|
||||||
# 如果热重载正在运行,为新目录创建观察者
|
|
||||||
if self.is_running:
|
|
||||||
try:
|
|
||||||
observer = Observer()
|
|
||||||
file_handler = PluginFileHandler(self)
|
|
||||||
|
|
||||||
observer.schedule(file_handler, directory, recursive=True)
|
|
||||||
|
|
||||||
observer.start()
|
|
||||||
self.observers.append(observer)
|
|
||||||
self.file_handlers.append(file_handler)
|
|
||||||
|
|
||||||
logger.info(f"📂 已添加新的监听目录: {directory}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 添加监听目录 {directory} 失败: {e}")
|
|
||||||
self.watch_directories.remove(directory)
|
|
||||||
|
|
||||||
def get_status(self) -> dict:
|
|
||||||
"""获取热重载状态"""
|
|
||||||
pending_reloads = set()
|
|
||||||
if self.file_handlers:
|
|
||||||
for handler in self.file_handlers:
|
|
||||||
pending_reloads.update(handler.pending_reloads)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"is_running": self.is_running,
|
|
||||||
"watch_directories": self.watch_directories,
|
|
||||||
"active_observers": len(self.observers),
|
|
||||||
"loaded_plugins": len(plugin_manager.loaded_plugins),
|
|
||||||
"failed_plugins": len(plugin_manager.failed_plugins),
|
|
||||||
"pending_reloads": list(pending_reloads),
|
|
||||||
"debounce_delay": self.file_handlers[0].debounce_delay if self.file_handlers else 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def clear_all_caches():
|
|
||||||
"""清理所有Python模块缓存"""
|
|
||||||
try:
|
|
||||||
logger.info("🧹 开始清理所有Python模块缓存...")
|
|
||||||
|
|
||||||
# 重新扫描所有插件目录,这会重新加载模块
|
|
||||||
plugin_manager.rescan_plugin_directory()
|
|
||||||
logger.info("✅ 模块缓存清理完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
# 全局热重载管理器实例
|
|
||||||
hot_reload_manager = PluginHotReloadManager()
|
|
||||||
@@ -55,6 +55,10 @@ class ToolExecutor:
|
|||||||
|
|
||||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||||
|
|
||||||
|
# 二步工具调用状态管理
|
||||||
|
self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {}
|
||||||
|
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix}工具执行器初始化完成")
|
logger.info(f"{self.log_prefix}工具执行器初始化完成")
|
||||||
|
|
||||||
async def execute_from_chat_message(
|
async def execute_from_chat_message(
|
||||||
@@ -112,7 +116,18 @@ class ToolExecutor:
|
|||||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||||
all_tools = get_llm_available_tool_definitions()
|
all_tools = get_llm_available_tool_definitions()
|
||||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||||
return [definition for name, definition in all_tools if name not in user_disabled_tools]
|
|
||||||
|
# 获取基础工具定义(包括二步工具的第一步)
|
||||||
|
tool_definitions = [definition for name, definition in all_tools if name not in user_disabled_tools]
|
||||||
|
|
||||||
|
# 检查是否有待处理的二步工具第二步调用
|
||||||
|
pending_step_two = getattr(self, '_pending_step_two_tools', {})
|
||||||
|
if pending_step_two:
|
||||||
|
# 添加第二步工具定义
|
||||||
|
for tool_name, step_two_def in pending_step_two.items():
|
||||||
|
tool_definitions.append(step_two_def)
|
||||||
|
|
||||||
|
return tool_definitions
|
||||||
|
|
||||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||||
"""执行工具调用
|
"""执行工具调用
|
||||||
@@ -251,6 +266,32 @@ class ToolExecutor:
|
|||||||
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
|
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
|
||||||
)
|
)
|
||||||
function_args["llm_called"] = True # 标记为LLM调用
|
function_args["llm_called"] = True # 标记为LLM调用
|
||||||
|
|
||||||
|
# 检查是否是二步工具的第二步调用
|
||||||
|
if "_" in function_name and function_name.count("_") >= 1:
|
||||||
|
# 可能是二步工具的第二步调用,格式为 "tool_name_sub_tool_name"
|
||||||
|
parts = function_name.split("_", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
base_tool_name, sub_tool_name = parts
|
||||||
|
base_tool_instance = get_tool_instance(base_tool_name)
|
||||||
|
|
||||||
|
if base_tool_instance and base_tool_instance.is_two_step_tool:
|
||||||
|
logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}")
|
||||||
|
result = await base_tool_instance.execute_step_two(sub_tool_name, function_args)
|
||||||
|
|
||||||
|
# 清理待处理的第二步工具
|
||||||
|
self._pending_step_two_tools.pop(base_tool_name, None)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.debug(f"{self.log_prefix}二步工具第二步 {function_name} 执行成功")
|
||||||
|
return {
|
||||||
|
"tool_call_id": tool_call.call_id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": function_name,
|
||||||
|
"type": "function",
|
||||||
|
"content": result.get("content", ""),
|
||||||
|
}
|
||||||
|
|
||||||
# 获取对应工具实例
|
# 获取对应工具实例
|
||||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||||
if not tool_instance:
|
if not tool_instance:
|
||||||
@@ -260,6 +301,16 @@ class ToolExecutor:
|
|||||||
# 执行工具并记录日志
|
# 执行工具并记录日志
|
||||||
logger.debug(f"{self.log_prefix}执行工具 {function_name},参数: {function_args}")
|
logger.debug(f"{self.log_prefix}执行工具 {function_name},参数: {function_args}")
|
||||||
result = await tool_instance.execute(function_args)
|
result = await tool_instance.execute(function_args)
|
||||||
|
|
||||||
|
# 检查是否是二步工具的第一步结果
|
||||||
|
if result and result.get("type") == "two_step_tool_step_one":
|
||||||
|
logger.info(f"{self.log_prefix}二步工具第一步完成: {function_name}")
|
||||||
|
# 保存第二步工具定义
|
||||||
|
next_tool_def = result.get("next_tool_definition")
|
||||||
|
if next_tool_def:
|
||||||
|
self._pending_step_two_tools[function_name] = next_tool_def
|
||||||
|
logger.debug(f"{self.log_prefix}已保存第二步工具定义: {next_tool_def['name']}")
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
logger.debug(f"{self.log_prefix}工具 {function_name} 执行成功,结果: {result}")
|
logger.debug(f"{self.log_prefix}工具 {function_name} 执行成功,结果: {result}")
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ INSTALL_NAME_TO_IMPORT_NAME = {
|
|||||||
"pyusb": "usb", # USB访问
|
"pyusb": "usb", # USB访问
|
||||||
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
|
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
|
||||||
"psutil": "psutil", # 系统信息和进程管理
|
"psutil": "psutil", # 系统信息和进程管理
|
||||||
"watchdog": "watchdog", # 文件系统事件监控
|
|
||||||
"python-gnupg": "gnupg", # GnuPG的Python接口
|
"python-gnupg": "gnupg", # GnuPG的Python接口
|
||||||
# ============== 加密与安全 (Cryptography & Security) ==============
|
# ============== 加密与安全 (Cryptography & Security) ==============
|
||||||
"pycrypto": "Crypto", # 加密库 (较旧)
|
"pycrypto": "Crypto", # 加密库 (较旧)
|
||||||
|
|||||||
@@ -88,25 +88,27 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
async def on_plugin_loaded(self):
|
async def on_plugin_loaded(self):
|
||||||
|
"""插件加载完成后的回调,初始化服务并启动后台任务"""
|
||||||
|
# --- 注册权限节点 ---
|
||||||
await permission_api.register_permission_node(
|
await permission_api.register_permission_node(
|
||||||
"plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False
|
"plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False
|
||||||
)
|
)
|
||||||
await permission_api.register_permission_node(
|
await permission_api.register_permission_node(
|
||||||
"plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True
|
"plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True
|
||||||
)
|
)
|
||||||
# 创建所有服务实例
|
|
||||||
|
# --- 创建并注册所有服务实例 ---
|
||||||
content_service = ContentService(self.get_config)
|
content_service = ContentService(self.get_config)
|
||||||
image_service = ImageService(self.get_config)
|
image_service = ImageService(self.get_config)
|
||||||
cookie_service = CookieService(self.get_config)
|
cookie_service = CookieService(self.get_config)
|
||||||
reply_tracker_service = ReplyTrackerService()
|
reply_tracker_service = ReplyTrackerService()
|
||||||
|
|
||||||
# 使用已创建的 reply_tracker_service 实例
|
|
||||||
qzone_service = QZoneService(
|
qzone_service = QZoneService(
|
||||||
self.get_config,
|
self.get_config,
|
||||||
content_service,
|
content_service,
|
||||||
image_service,
|
image_service,
|
||||||
cookie_service,
|
cookie_service,
|
||||||
reply_tracker_service, # 传入已创建的实例
|
reply_tracker_service,
|
||||||
)
|
)
|
||||||
scheduler_service = SchedulerService(self.get_config, qzone_service)
|
scheduler_service = SchedulerService(self.get_config, qzone_service)
|
||||||
monitor_service = MonitorService(self.get_config, qzone_service)
|
monitor_service = MonitorService(self.get_config, qzone_service)
|
||||||
@@ -115,18 +117,12 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
|||||||
register_service("reply_tracker", reply_tracker_service)
|
register_service("reply_tracker", reply_tracker_service)
|
||||||
register_service("get_config", self.get_config)
|
register_service("get_config", self.get_config)
|
||||||
|
|
||||||
# 保存服务引用以便后续启动
|
logger.info("MaiZone重构版插件服务已注册。")
|
||||||
self.scheduler_service = scheduler_service
|
|
||||||
self.monitor_service = monitor_service
|
|
||||||
|
|
||||||
logger.info("MaiZone重构版插件已加载,服务已注册。")
|
# --- 启动后台任务 ---
|
||||||
|
asyncio.create_task(scheduler_service.start())
|
||||||
async def on_plugin_loaded(self):
|
asyncio.create_task(monitor_service.start())
|
||||||
"""插件加载完成后的回调,启动异步服务"""
|
logger.info("MaiZone后台监控和定时任务已启动。")
|
||||||
if hasattr(self, "scheduler_service") and hasattr(self, "monitor_service"):
|
|
||||||
asyncio.create_task(self.scheduler_service.start())
|
|
||||||
asyncio.create_task(self.monitor_service.start())
|
|
||||||
logger.info("MaiZone后台任务已启动。")
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -113,31 +113,32 @@ class CookieService:
|
|||||||
async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
获取Cookie,按以下顺序尝试:
|
获取Cookie,按以下顺序尝试:
|
||||||
1. Adapter API
|
1. HTTP备用端点 (更稳定)
|
||||||
2. HTTP备用端点
|
2. 本地文件缓存
|
||||||
3. 本地文件缓存
|
3. Adapter API (作为最后手段)
|
||||||
"""
|
"""
|
||||||
# 1. 尝试从Adapter获取
|
# 1. 尝试从HTTP备用端点获取
|
||||||
cookies = await self._get_cookies_from_adapter(stream_id)
|
logger.info(f"开始尝试从HTTP备用地址获取 {qq_account} 的Cookie...")
|
||||||
if cookies:
|
|
||||||
logger.info("成功从Adapter获取Cookie。")
|
|
||||||
self._save_cookies_to_file(qq_account, cookies)
|
|
||||||
return cookies
|
|
||||||
|
|
||||||
# 2. 尝试从HTTP备用端点获取
|
|
||||||
logger.warning("从Adapter获取Cookie失败,尝试使用HTTP备用地址。")
|
|
||||||
cookies = await self._get_cookies_from_http()
|
cookies = await self._get_cookies_from_http()
|
||||||
if cookies:
|
if cookies:
|
||||||
logger.info("成功从HTTP备用地址获取Cookie。")
|
logger.info(f"成功从HTTP备用地址为 {qq_account} 获取Cookie。")
|
||||||
self._save_cookies_to_file(qq_account, cookies)
|
self._save_cookies_to_file(qq_account, cookies)
|
||||||
return cookies
|
return cookies
|
||||||
|
|
||||||
# 3. 尝试从本地文件加载
|
# 2. 尝试从本地文件加载
|
||||||
logger.warning("从HTTP备用地址获取Cookie失败,尝试加载本地缓存。")
|
logger.warning(f"从HTTP备用地址获取 {qq_account} 的Cookie失败,尝试加载本地缓存。")
|
||||||
cookies = self._load_cookies_from_file(qq_account)
|
cookies = self._load_cookies_from_file(qq_account)
|
||||||
if cookies:
|
if cookies:
|
||||||
logger.info("成功从本地文件加载缓存的Cookie。")
|
logger.info(f"成功从本地文件为 {qq_account} 加载缓存的Cookie。")
|
||||||
return cookies
|
return cookies
|
||||||
|
|
||||||
logger.error("所有Cookie获取方法均失败。")
|
# 3. 尝试从Adapter获取 (作为最后的备用方案)
|
||||||
|
logger.warning(f"从本地缓存加载 {qq_account} 的Cookie失败,最后尝试使用Adapter API。")
|
||||||
|
cookies = await self._get_cookies_from_adapter(stream_id)
|
||||||
|
if cookies:
|
||||||
|
logger.info(f"成功从Adapter API为 {qq_account} 获取Cookie。")
|
||||||
|
self._save_cookies_to_file(qq_account, cookies)
|
||||||
|
return cookies
|
||||||
|
|
||||||
|
logger.error(f"为 {qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作,或存在有效的本地Cookie文件。")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -409,8 +409,9 @@ class QZoneService:
|
|||||||
cookie_dir.mkdir(exist_ok=True)
|
cookie_dir.mkdir(exist_ok=True)
|
||||||
cookie_file_path = cookie_dir / f"cookies-{qq_account}.json"
|
cookie_file_path = cookie_dir / f"cookies-{qq_account}.json"
|
||||||
|
|
||||||
|
# 优先尝试通过Napcat HTTP服务获取最新的Cookie
|
||||||
try:
|
try:
|
||||||
# 使用HTTP服务器方式获取Cookie
|
logger.info("尝试通过Napcat HTTP服务获取Cookie...")
|
||||||
host = self.get_config("cookie.http_fallback_host", "172.20.130.55")
|
host = self.get_config("cookie.http_fallback_host", "172.20.130.55")
|
||||||
port = self.get_config("cookie.http_fallback_port", "9999")
|
port = self.get_config("cookie.http_fallback_port", "9999")
|
||||||
napcat_token = self.get_config("cookie.napcat_token", "")
|
napcat_token = self.get_config("cookie.napcat_token", "")
|
||||||
@@ -421,23 +422,43 @@ class QZoneService:
|
|||||||
parsed_cookies = {
|
parsed_cookies = {
|
||||||
k.strip(): v.strip() for k, v in (p.split("=", 1) for p in cookie_str.split("; ") if "=" in p)
|
k.strip(): v.strip() for k, v in (p.split("=", 1) for p in cookie_str.split("; ") if "=" in p)
|
||||||
}
|
}
|
||||||
with open(cookie_file_path, "wb") as f:
|
# 成功获取后,异步写入本地文件作为备份
|
||||||
f.write(orjson.dumps(parsed_cookies))
|
try:
|
||||||
logger.info(f"Cookie已更新并保存至: {cookie_file_path}")
|
with open(cookie_file_path, "wb") as f:
|
||||||
|
f.write(orjson.dumps(parsed_cookies))
|
||||||
|
logger.info(f"通过Napcat服务成功更新Cookie,并已保存至: {cookie_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存Cookie到文件时出错: {e}")
|
||||||
return parsed_cookies
|
return parsed_cookies
|
||||||
|
else:
|
||||||
|
logger.warning("通过Napcat服务未能获取有效Cookie。")
|
||||||
|
|
||||||
# 如果HTTP获取失败,尝试读取本地文件
|
|
||||||
if cookie_file_path.exists():
|
|
||||||
with open(cookie_file_path, "rb") as f:
|
|
||||||
return orjson.loads(f.read())
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新或加载Cookie时发生异常: {e}")
|
logger.warning(f"通过Napcat HTTP服务获取Cookie时发生异常: {e}。将尝试从本地文件加载。")
|
||||||
return None
|
|
||||||
|
|
||||||
async def _fetch_cookies_http(self, host: str, port: str, napcat_token: str) -> Optional[Dict]:
|
# 如果通过服务获取失败,则尝试从本地文件加载
|
||||||
|
logger.info("尝试从本地Cookie文件加载...")
|
||||||
|
if cookie_file_path.exists():
|
||||||
|
try:
|
||||||
|
with open(cookie_file_path, "rb") as f:
|
||||||
|
cookies = orjson.loads(f.read())
|
||||||
|
logger.info(f"成功从本地文件加载Cookie: {cookie_file_path}")
|
||||||
|
return cookies
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从本地文件 {cookie_file_path} 读取或解析Cookie失败: {e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"本地Cookie文件不存在: {cookie_file_path}")
|
||||||
|
|
||||||
|
logger.error("所有获取Cookie的方式均失败。")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> Optional[Dict]:
|
||||||
"""通过HTTP服务器获取Cookie"""
|
"""通过HTTP服务器获取Cookie"""
|
||||||
url = f"http://{host}:{port}/get_cookies"
|
# 从配置中读取主机和端口,如果未提供则使用传入的参数
|
||||||
|
final_host = self.get_config("cookie.http_fallback_host", host)
|
||||||
|
final_port = self.get_config("cookie.http_fallback_port", port)
|
||||||
|
url = f"http://{final_host}:{final_port}/get_cookies"
|
||||||
|
|
||||||
max_retries = 5
|
max_retries = 5
|
||||||
retry_delay = 1
|
retry_delay = 1
|
||||||
|
|
||||||
@@ -481,14 +502,19 @@ class QZoneService:
|
|||||||
async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]:
|
async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]:
|
||||||
cookies = await self.cookie_service.get_cookies(qq_account, stream_id)
|
cookies = await self.cookie_service.get_cookies(qq_account, stream_id)
|
||||||
if not cookies:
|
if not cookies:
|
||||||
|
logger.error("获取API客户端失败:未能获取到Cookie。请检查Napcat连接是否正常,或是否存在有效的本地Cookie文件。")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper())
|
p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper())
|
||||||
if not p_skey:
|
if not p_skey:
|
||||||
|
logger.error(f"获取API客户端失败:Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
gtk = self._generate_gtk(p_skey)
|
gtk = self._generate_gtk(p_skey)
|
||||||
uin = cookies.get("uin", "").lstrip("o")
|
uin = cookies.get("uin", "").lstrip("o")
|
||||||
|
if not uin:
|
||||||
|
logger.error(f"获取API客户端失败:Cookie中缺少关键的 'uin'。Cookie内容: {cookies}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def _request(method, url, params=None, data=None, headers=None):
|
async def _request(method, url, params=None, data=None, headers=None):
|
||||||
final_headers = {"referer": f"https://user.qzone.qq.com/{uin}", "origin": "https://user.qzone.qq.com"}
|
final_headers = {"referer": f"https://user.qzone.qq.com/{uin}", "origin": "https://user.qzone.qq.com"}
|
||||||
|
|||||||
@@ -185,9 +185,13 @@ class SendHandler:
|
|||||||
|
|
||||||
logger.info(f"执行适配器命令: {action}")
|
logger.info(f"执行适配器命令: {action}")
|
||||||
|
|
||||||
# 直接向Napcat发送命令并获取响应
|
# 根据action决定处理方式
|
||||||
response_task = asyncio.create_task(self.send_message_to_napcat(action, params))
|
if action == "get_cookies":
|
||||||
response = await response_task
|
# 对于get_cookies,我们需要一个更长的超时时间
|
||||||
|
response = await self.send_message_to_napcat(action, params, timeout=40.0)
|
||||||
|
else:
|
||||||
|
# 对于其他命令,使用默认超时
|
||||||
|
response = await self.send_message_to_napcat(action, params)
|
||||||
|
|
||||||
# 发送响应回MaiBot
|
# 发送响应回MaiBot
|
||||||
await self.send_adapter_command_response(raw_message_base, response, request_id)
|
await self.send_adapter_command_response(raw_message_base, response, request_id)
|
||||||
@@ -196,6 +200,8 @@ class SendHandler:
|
|||||||
logger.info(f"适配器命令 {action} 执行成功")
|
logger.info(f"适配器命令 {action} 执行成功")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{str(response)}")
|
logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{str(response)}")
|
||||||
|
# 无论成功失败,都记录下完整的响应内容以供调试
|
||||||
|
logger.debug(f"适配器命令 {action} 的完整响应: {response}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理适配器命令时发生错误: {e}")
|
logger.error(f"处理适配器命令时发生错误: {e}")
|
||||||
@@ -583,7 +589,7 @@ class SendHandler:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
|
async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict:
|
||||||
request_uuid = str(uuid.uuid4())
|
request_uuid = str(uuid.uuid4())
|
||||||
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
|
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
|
||||||
|
|
||||||
@@ -595,9 +601,9 @@ class SendHandler:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await connection.send(payload)
|
await connection.send(payload)
|
||||||
response = await get_response(request_uuid)
|
response = await get_response(request_uuid, timeout=timeout) # 使用传入的超时时间
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
logger.error("发送消息超时,未收到响应")
|
logger.error(f"发送消息超时({timeout}秒),未收到响应: action={action}, params={params}")
|
||||||
return {"status": "error", "message": "timeout"}
|
return {"status": "error", "message": "timeout"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送消息失败: {e}")
|
logger.error(f"发送消息失败: {e}")
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from src.plugin_system.base.command_args import CommandArgs
|
|||||||
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
||||||
from src.plugin_system.apis.permission_api import permission_api
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
from src.plugin_system.utils.permission_decorators import require_permission
|
from src.plugin_system.utils.permission_decorators import require_permission
|
||||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
|
||||||
|
|
||||||
|
|
||||||
class ManagementCommand(PlusCommand):
|
class ManagementCommand(PlusCommand):
|
||||||
@@ -78,10 +77,6 @@ class ManagementCommand(PlusCommand):
|
|||||||
await self._force_reload_plugin(args[1])
|
await self._force_reload_plugin(args[1])
|
||||||
elif action in ["add_dir", "添加目录"] and len(args) > 1:
|
elif action in ["add_dir", "添加目录"] and len(args) > 1:
|
||||||
await self._add_dir(args[1])
|
await self._add_dir(args[1])
|
||||||
elif action in ["hotreload_status", "热重载状态"]:
|
|
||||||
await self._show_hotreload_status()
|
|
||||||
elif action in ["clear_cache", "清理缓存"]:
|
|
||||||
await self._clear_all_caches()
|
|
||||||
else:
|
else:
|
||||||
await self.send_text("❌ 插件管理命令不合法\n使用 /pm plugin help 查看帮助")
|
await self.send_text("❌ 插件管理命令不合法\n使用 /pm plugin help 查看帮助")
|
||||||
return False, "命令不合法", True
|
return False, "命令不合法", True
|
||||||
@@ -179,14 +174,9 @@ class ManagementCommand(PlusCommand):
|
|||||||
• `/pm plugin force_reload <插件名>` - 强制重载指定插件(深度清理)
|
• `/pm plugin force_reload <插件名>` - 强制重载指定插件(深度清理)
|
||||||
• `/pm plugin add_dir <目录路径>` - 添加插件目录
|
• `/pm plugin add_dir <目录路径>` - 添加插件目录
|
||||||
|
|
||||||
<EFBFBD> 热重载管理:
|
|
||||||
• `/pm plugin hotreload_status` - 查看热重载状态
|
|
||||||
• `/pm plugin clear_cache` - 清理所有模块缓存
|
|
||||||
|
|
||||||
<EFBFBD>📝 示例:
|
<EFBFBD>📝 示例:
|
||||||
• `/pm plugin load echo_example`
|
• `/pm plugin load echo_example`
|
||||||
• `/pm plugin force_reload permission_manager_plugin`
|
• `/pm plugin force_reload permission_manager_plugin`"""
|
||||||
• `/pm plugin clear_cache`"""
|
|
||||||
elif target == "component":
|
elif target == "component":
|
||||||
help_msg = """🧩 组件管理命令帮助
|
help_msg = """🧩 组件管理命令帮助
|
||||||
|
|
||||||
@@ -262,7 +252,7 @@ class ManagementCommand(PlusCommand):
|
|||||||
await self.send_text(f"🔄 开始强制重载插件: `{plugin_name}`...")
|
await self.send_text(f"🔄 开始强制重载插件: `{plugin_name}`...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = hot_reload_manager.force_reload_plugin(plugin_name)
|
success = plugin_manage_api.force_reload_plugin(plugin_name)
|
||||||
if success:
|
if success:
|
||||||
await self.send_text(f"✅ 插件强制重载成功: `{plugin_name}`")
|
await self.send_text(f"✅ 插件强制重载成功: `{plugin_name}`")
|
||||||
else:
|
else:
|
||||||
@@ -270,44 +260,7 @@ class ManagementCommand(PlusCommand):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}")
|
await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}")
|
||||||
|
|
||||||
async def _show_hotreload_status(self):
|
|
||||||
"""显示热重载状态"""
|
|
||||||
try:
|
|
||||||
status = hot_reload_manager.get_status()
|
|
||||||
|
|
||||||
status_text = f"""🔄 **热重载系统状态**
|
|
||||||
|
|
||||||
🟢 **运行状态:** {"运行中" if status["is_running"] else "已停止"}
|
|
||||||
📂 **监听目录:** {len(status["watch_directories"])} 个
|
|
||||||
👁️ **活跃观察者:** {status["active_observers"]} 个
|
|
||||||
📦 **已加载插件:** {status["loaded_plugins"]} 个
|
|
||||||
❌ **失败插件:** {status["failed_plugins"]} 个
|
|
||||||
⏱️ **防抖延迟:** {status.get("debounce_delay", 0)} 秒
|
|
||||||
|
|
||||||
📋 **监听的目录:**"""
|
|
||||||
|
|
||||||
for i, watch_dir in enumerate(status["watch_directories"], 1):
|
|
||||||
dir_type = "(内置插件)" if "src" in watch_dir else "(外部插件)"
|
|
||||||
status_text += f"\n{i}. `{watch_dir}` {dir_type}"
|
|
||||||
|
|
||||||
if status.get("pending_reloads"):
|
|
||||||
status_text += f"\n\n⏳ **待重载插件:** {', '.join([f'`{p}`' for p in status['pending_reloads']])}"
|
|
||||||
|
|
||||||
await self.send_text(status_text)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await self.send_text(f"❌ 获取热重载状态时发生错误: {str(e)}")
|
|
||||||
|
|
||||||
async def _clear_all_caches(self):
|
|
||||||
"""清理所有模块缓存"""
|
|
||||||
await self.send_text("🧹 开始清理所有Python模块缓存...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
hot_reload_manager.clear_all_caches()
|
|
||||||
await self.send_text("✅ 模块缓存清理完成!建议重载相关插件以确保生效。")
|
|
||||||
except Exception as e:
|
|
||||||
await self.send_text(f"❌ 清理缓存时发生错误: {str(e)}")
|
|
||||||
|
|
||||||
async def _add_dir(self, dir_path: str):
|
async def _add_dir(self, dir_path: str):
|
||||||
"""添加插件目录"""
|
"""添加插件目录"""
|
||||||
await self.send_text(f"📁 正在添加插件目录: `{dir_path}`")
|
await self.send_text(f"📁 正在添加插件目录: `{dir_path}`")
|
||||||
|
|||||||
@@ -60,10 +60,12 @@ class ReminderTask(AsyncTask):
|
|||||||
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
|
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
|
||||||
|
|
||||||
extra_info = f"现在是提醒时间,请你以一种符合你人设的、俏皮的方式提醒 {self.target_user_name}。\n提醒内容: {self.event_details}\n设置提醒的人: {self.creator_name}"
|
extra_info = f"现在是提醒时间,请你以一种符合你人设的、俏皮的方式提醒 {self.target_user_name}。\n提醒内容: {self.event_details}\n设置提醒的人: {self.creator_name}"
|
||||||
|
last_message = self.chat_stream.context_manager.context.get_last_message()
|
||||||
|
reply_message_dict = last_message.flatten() if last_message else None
|
||||||
success, reply_set, _ = await generator_api.generate_reply(
|
success, reply_set, _ = await generator_api.generate_reply(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
reply_message=self.chat_stream.context_manager.context.get_last_message().to_dict(),
|
reply_message=reply_message_dict,
|
||||||
request_type="plugin.reminder.remind_message",
|
request_type="plugin.reminder.remind_message",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,9 +152,11 @@ class PokeAction(BaseAction):
|
|||||||
action_require = ["当需要戳某个用户时使用", "当你想提醒特定用户时使用"]
|
action_require = ["当需要戳某个用户时使用", "当你想提醒特定用户时使用"]
|
||||||
llm_judge_prompt = """
|
llm_judge_prompt = """
|
||||||
判定是否需要使用戳一戳动作的条件:
|
判定是否需要使用戳一戳动作的条件:
|
||||||
1. 用户明确要求使用戳一戳。
|
1. **关键**: 这是一个高消耗的动作,请仅在绝对必要时使用,例如用户明确要求或作为提醒的关键部分。请极其谨慎地使用。
|
||||||
2. 你想以一种有趣的方式提醒或与某人互动。
|
2. **用户请求**: 用户明确要求使用戳一戳。
|
||||||
3. 上下文明确需要你戳一个或多个人。
|
3. **互动提醒**: 你想以一种有趣的方式提醒或与某人互动,但请确保这是对话的自然延伸,而不是无故打扰。
|
||||||
|
4. **上下文需求**: 上下文明确需要你戳一个或多个人。
|
||||||
|
5. **频率限制**: 如果最近已经戳过,或者用户情绪不高,请绝对不要使用。
|
||||||
|
|
||||||
请回答"是"或"否"。
|
请回答"是"或"否"。
|
||||||
"""
|
"""
|
||||||
@@ -217,7 +221,6 @@ class SetEmojiLikeAction(BaseAction):
|
|||||||
emoji_options.append(match.group(1))
|
emoji_options.append(match.group(1))
|
||||||
|
|
||||||
action_parameters = {
|
action_parameters = {
|
||||||
"emoji": f"要回应的表情,必须从以下表情中选择: {', '.join(emoji_options)}",
|
|
||||||
"set": "是否设置回应 (True/False)",
|
"set": "是否设置回应 (True/False)",
|
||||||
}
|
}
|
||||||
action_require = [
|
action_require = [
|
||||||
@@ -238,6 +241,7 @@ class SetEmojiLikeAction(BaseAction):
|
|||||||
async def execute(self) -> Tuple[bool, str]:
|
async def execute(self) -> Tuple[bool, str]:
|
||||||
"""执行设置表情回应的动作"""
|
"""执行设置表情回应的动作"""
|
||||||
message_id = None
|
message_id = None
|
||||||
|
set_like = self.action_data.get("set", True)
|
||||||
if self.has_action_message:
|
if self.has_action_message:
|
||||||
logger.debug(str(self.action_message))
|
logger.debug(str(self.action_message))
|
||||||
if isinstance(self.action_message, dict):
|
if isinstance(self.action_message, dict):
|
||||||
@@ -251,24 +255,49 @@ class SetEmojiLikeAction(BaseAction):
|
|||||||
action_done=False,
|
action_done=False,
|
||||||
)
|
)
|
||||||
return False, "未提供消息ID"
|
return False, "未提供消息ID"
|
||||||
|
available_models = llm_api.get_available_models()
|
||||||
|
if "utils_small" not in available_models:
|
||||||
|
logger.error("未找到 'utils_small' 模型配置,无法选择表情")
|
||||||
|
return False, "表情选择功能配置错误"
|
||||||
|
|
||||||
emoji_input = self.action_data.get("emoji")
|
model_to_use = available_models["utils_small"]
|
||||||
set_like = self.action_data.get("set", True)
|
|
||||||
|
# 获取最近的对话历史作为上下文
|
||||||
if not emoji_input:
|
context_text = ""
|
||||||
logger.error("未提供表情")
|
if self.action_message:
|
||||||
return False, "未提供表情"
|
context_text = self.action_message.get("processed_plain_text", "")
|
||||||
logger.info(f"设置表情回应: {emoji_input}, 是否设置: {set_like}")
|
else:
|
||||||
|
logger.error("无法找到动作选择的原始消息")
|
||||||
emoji_id = get_emoji_id(emoji_input)
|
return False, "无法找到动作选择的原始消息"
|
||||||
if not emoji_id:
|
|
||||||
logger.error(f"找不到表情: '{emoji_input}'。请从可用列表中选择。")
|
prompt = (
|
||||||
await self.store_action_info(
|
f"根据以下这条消息,从列表中选择一个最合适的表情名称来回应这条消息。\n"
|
||||||
action_build_into_prompt=True,
|
f"消息内容: '{context_text}'\n"
|
||||||
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: 找不到表情: '{emoji_input}'",
|
f"可用表情列表: {', '.join(self.emoji_options)}\n"
|
||||||
action_done=False,
|
f"你的任务是:只输出你选择的表情的名称,不要包含任何其他文字或标点。\n"
|
||||||
|
f"例如,如果觉得应该用'赞',就只输出'赞'。"
|
||||||
)
|
)
|
||||||
return False, f"找不到表情: '{emoji_input}'。请从可用列表中选择。"
|
|
||||||
|
success, response, _, _ = await llm_api.generate_with_model(
|
||||||
|
prompt, model_config=model_to_use, request_type="plugin.set_emoji_like.select_emoji"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success or not response:
|
||||||
|
logger.error("二级LLM未能选择有效的表情。")
|
||||||
|
return False, "无法找到合适的表情。"
|
||||||
|
|
||||||
|
chosen_emoji_name = response.strip()
|
||||||
|
logger.info(f"二级LLM选择的表情是: '{chosen_emoji_name}'")
|
||||||
|
emoji_id = get_emoji_id(chosen_emoji_name)
|
||||||
|
|
||||||
|
if not emoji_id:
|
||||||
|
logger.error(f"二级LLM选择的表情 '{chosen_emoji_name}' 仍然无法匹配到有效的表情ID。")
|
||||||
|
await self.store_action_info(
|
||||||
|
action_build_into_prompt=True,
|
||||||
|
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: 找不到表情: '{chosen_emoji_name}'",
|
||||||
|
action_done=False,
|
||||||
|
)
|
||||||
|
return False, f"找不到表情: '{chosen_emoji_name}'。"
|
||||||
|
|
||||||
# 4. 使用适配器API发送命令
|
# 4. 使用适配器API发送命令
|
||||||
if not message_id:
|
if not message_id:
|
||||||
@@ -291,7 +320,7 @@ class SetEmojiLikeAction(BaseAction):
|
|||||||
logger.info("设置表情回应成功")
|
logger.info("设置表情回应成功")
|
||||||
await self.store_action_info(
|
await self.store_action_info(
|
||||||
action_build_into_prompt=True,
|
action_build_into_prompt=True,
|
||||||
action_prompt_display=f"执行了set_emoji_like动作,{emoji_input},设置表情回应: {emoji_id}, 是否设置: {set_like}",
|
action_prompt_display=f"执行了set_emoji_like动作,{chosen_emoji_name},设置表情回应: {emoji_id}, 是否设置: {set_like}",
|
||||||
action_done=True,
|
action_done=True,
|
||||||
)
|
)
|
||||||
return True, "成功设置表情回应"
|
return True, "成功设置表情回应"
|
||||||
|
|||||||
@@ -28,20 +28,20 @@ class PlanManager:
|
|||||||
if target_month is None:
|
if target_month is None:
|
||||||
target_month = datetime.now().strftime("%Y-%m")
|
target_month = datetime.now().strftime("%Y-%m")
|
||||||
|
|
||||||
if not has_active_plans(target_month):
|
if not await has_active_plans(target_month):
|
||||||
logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。")
|
logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。")
|
||||||
generation_successful = await self._generate_monthly_plans_logic(target_month)
|
generation_successful = await self._generate_monthly_plans_logic(target_month)
|
||||||
return generation_successful
|
return generation_successful
|
||||||
else:
|
else:
|
||||||
logger.info(f"{target_month} 已存在有效的月度计划。")
|
logger.info(f"{target_month} 已存在有效的月度计划。")
|
||||||
plans = get_active_plans_for_month(target_month)
|
plans = await get_active_plans_for_month(target_month)
|
||||||
max_plans = global_config.planning_system.max_plans_per_month
|
max_plans = global_config.planning_system.max_plans_per_month
|
||||||
if len(plans) > max_plans:
|
if len(plans) > max_plans:
|
||||||
logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。")
|
logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。")
|
||||||
plans_to_delete = plans[: len(plans) - max_plans]
|
plans_to_delete = plans[: len(plans) - max_plans]
|
||||||
delete_ids = [p.id for p in plans_to_delete]
|
delete_ids = [p.id for p in plans_to_delete]
|
||||||
delete_plans_by_ids(delete_ids) # type: ignore
|
await delete_plans_by_ids(delete_ids) # type: ignore
|
||||||
plans = get_active_plans_for_month(target_month)
|
plans = await get_active_plans_for_month(target_month)
|
||||||
|
|
||||||
if plans:
|
if plans:
|
||||||
plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)])
|
plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)])
|
||||||
@@ -64,11 +64,11 @@ class PlanManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
last_month = self._get_previous_month(target_month)
|
last_month = self._get_previous_month(target_month)
|
||||||
archived_plans = get_archived_plans_for_month(last_month)
|
archived_plans = await get_archived_plans_for_month(last_month)
|
||||||
plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans)
|
plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans)
|
||||||
|
|
||||||
if plans:
|
if plans:
|
||||||
add_new_plans(plans, target_month)
|
await add_new_plans(plans, target_month)
|
||||||
logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。")
|
logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@@ -95,11 +95,11 @@ class PlanManager:
|
|||||||
if target_month is None:
|
if target_month is None:
|
||||||
target_month = datetime.now().strftime("%Y-%m")
|
target_month = datetime.now().strftime("%Y-%m")
|
||||||
logger.info(f" 开始归档 {target_month} 的活跃月度计划...")
|
logger.info(f" 开始归档 {target_month} 的活跃月度计划...")
|
||||||
archived_count = archive_active_plans_for_month(target_month)
|
archived_count = await archive_active_plans_for_month(target_month)
|
||||||
logger.info(f" 成功归档了 {archived_count} 条 {target_month} 的月度计划。")
|
logger.info(f" 成功归档了 {archived_count} 条 {target_month} 的月度计划。")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}")
|
logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}")
|
||||||
|
|
||||||
def get_plans_for_schedule(self, month: str, max_count: int) -> List:
|
async def get_plans_for_schedule(self, month: str, max_count: int) -> List:
|
||||||
avoid_days = global_config.planning_system.avoid_repetition_days
|
avoid_days = global_config.planning_system.avoid_repetition_days
|
||||||
return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
|
return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
|
||||||
|
|||||||
@@ -255,25 +255,34 @@ max_context_emojis = 30 # 每次随机传递给LLM的表情包详细描述的最
|
|||||||
[memory]
|
[memory]
|
||||||
enable_memory = true # 是否启用记忆系统
|
enable_memory = true # 是否启用记忆系统
|
||||||
memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低,MoFox-Bot学习越多,但是冗余信息也会增多
|
memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低,MoFox-Bot学习越多,但是冗余信息也会增多
|
||||||
memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
enable_instant_memory = true # 是否启用即时记忆
|
||||||
memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
|
|
||||||
memory_build_sample_length = 30 # 采样长度,数值越高一段记忆内容越丰富
|
|
||||||
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
|
|
||||||
|
|
||||||
forget_memory_interval = 3000 # 记忆遗忘间隔 单位秒 间隔越低,MoFox-Bot遗忘越频繁,记忆更精简,但更难学习
|
|
||||||
memory_forget_time = 48 #多长时间后的记忆会被遗忘 单位小时
|
|
||||||
memory_forget_percentage = 0.008 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认
|
|
||||||
|
|
||||||
consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低,MoFox-Bot整合越频繁,记忆更精简
|
|
||||||
consolidation_similarity_threshold = 0.7 # 相似度阈值
|
|
||||||
consolidation_check_percentage = 0.05 # 检查节点比例
|
|
||||||
|
|
||||||
enable_instant_memory = false # 是否启用即时记忆,测试功能,可能存在未知问题
|
|
||||||
enable_llm_instant_memory = true # 是否启用基于LLM的瞬时记忆
|
enable_llm_instant_memory = true # 是否启用基于LLM的瞬时记忆
|
||||||
enable_vector_instant_memory = true # 是否启用基于向量的瞬时记忆
|
enable_vector_instant_memory = true # 是否启用基于向量的瞬时记忆
|
||||||
|
enable_enhanced_memory = true # 是否启用增强记忆系统
|
||||||
|
enhanced_memory_auto_save = true # 是否自动保存增强记忆
|
||||||
|
|
||||||
#不希望记忆的词,已经记忆的不会受到影响,需要手动清理
|
min_memory_length = 10 # 最小记忆长度
|
||||||
memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
|
max_memory_length = 500 # 最大记忆长度
|
||||||
|
memory_value_threshold = 0.7 # 记忆价值阈值,低于该值的记忆会被丢弃
|
||||||
|
vector_similarity_threshold = 0.8 # 向量相似度阈值
|
||||||
|
semantic_similarity_threshold = 0.6 # 语义重排阶段的最低匹配阈值
|
||||||
|
|
||||||
|
metadata_filter_limit = 100 # 元数据过滤阶段返回数量上限
|
||||||
|
vector_search_limit = 50 # 向量搜索阶段返回数量上限
|
||||||
|
semantic_rerank_limit = 20 # 语义重排阶段返回数量上限
|
||||||
|
final_result_limit = 10 # 综合筛选后的最终返回数量
|
||||||
|
|
||||||
|
vector_weight = 0.4 # 综合评分中向量相似度的权重
|
||||||
|
semantic_weight = 0.3 # 综合评分中语义匹配的权重
|
||||||
|
context_weight = 0.2 # 综合评分中上下文关联的权重
|
||||||
|
recency_weight = 0.1 # 综合评分中时效性的权重
|
||||||
|
|
||||||
|
fusion_similarity_threshold = 0.85 # 记忆融合时的相似度阈值
|
||||||
|
deduplication_window_hours = 24 # 记忆去重窗口(小时)
|
||||||
|
|
||||||
|
enable_memory_cache = true # 是否启用本地记忆缓存
|
||||||
|
cache_ttl_seconds = 300 # 缓存有效期(秒)
|
||||||
|
max_cache_size = 1000 # 缓存中允许的最大记忆条数
|
||||||
|
|
||||||
[voice]
|
[voice]
|
||||||
enable_asr = true # 是否启用语音识别,启用后MoFox-Bot可以识别语音消息,启用该功能需要配置语音识别模型[model.voice]
|
enable_asr = true # 是否启用语音识别,启用后MoFox-Bot可以识别语音消息,启用该功能需要配置语音识别模型[model.voice]
|
||||||
|
|||||||
@@ -203,6 +203,7 @@ max_tokens = 1000
|
|||||||
#嵌入模型
|
#嵌入模型
|
||||||
[model_task_config.embedding]
|
[model_task_config.embedding]
|
||||||
model_list = ["bge-m3"]
|
model_list = ["bge-m3"]
|
||||||
|
embedding_dimension = 1024
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user