This commit is contained in:
xiaoCZX
2025-10-01 11:00:46 +08:00
42 changed files with 2095 additions and 1035 deletions

View File

@@ -70,7 +70,6 @@ dependencies = [
"tqdm>=4.67.1",
"urllib3>=2.5.0",
"uvicorn>=0.35.0",
"watchdog>=6.0.0",
"websockets>=15.0.1",
"aiomysql>=0.2.0",
"aiosqlite>=0.21.0",

View File

@@ -50,7 +50,6 @@ reportportal-client
scikit-learn
seaborn
structlog
watchdog
httpx
requests
beautifulsoup4

View File

@@ -12,6 +12,7 @@ from sqlalchemy import select
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
logger = get_logger("bot_interest_manager")
@@ -28,7 +29,9 @@ class BotInterestManager:
# Embedding客户端配置
self.embedding_request = None
self.embedding_config = None
self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度
configured_dim = resolve_embedding_dimension()
self.embedding_dimension = int(configured_dim) if configured_dim else 0
self._detected_embedding_dimension: Optional[int] = None
@property
def is_initialized(self) -> bool:
@@ -82,8 +85,11 @@ class BotInterestManager:
logger.info("📋 找到embedding模型配置")
self.embedding_config = model_config.model_task_config.embedding
self.embedding_dimension = 1024 # BGE-M3的维度
logger.info(f"📐 使用模型维度: {self.embedding_dimension}")
if self.embedding_dimension:
logger.info(f"📐 配置的embedding维度: {self.embedding_dimension}")
else:
logger.info("📐 未在配置中检测到embedding维度将根据首次返回的向量自动识别")
# 创建LLMRequest实例用于embedding
self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding")
@@ -350,7 +356,27 @@ class BotInterestManager:
if embedding and len(embedding) > 0:
self.embedding_cache[text] = embedding
logger.debug(f"✅ Embedding获取成功维度: {len(embedding)}, 模型: {model_name}")
current_dim = len(embedding)
if self._detected_embedding_dimension is None:
self._detected_embedding_dimension = current_dim
if self.embedding_dimension and self.embedding_dimension != current_dim:
logger.warning(
"⚠️ 实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新",
current_dim,
self.embedding_dimension,
)
else:
self.embedding_dimension = current_dim
logger.info(f"📏 检测到embedding维度: {current_dim}")
elif current_dim != self.embedding_dimension:
logger.warning(
"⚠️ 收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。",
self.embedding_dimension,
current_dim,
)
logger.debug(f"✅ Embedding获取成功维度: {current_dim}, 模型: {model_name}")
return embedding
else:
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")

View File

@@ -26,6 +26,7 @@ from rich.progress import (
TextColumn,
)
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
install(extra_lines=3)
@@ -504,7 +505,10 @@ class EmbeddingStore:
# L2归一化
faiss.normalize_L2(embeddings)
# 构建索引
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
if not embedding_dim:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:

View File

@@ -11,12 +11,27 @@ from dataclasses import dataclass
from src.common.logger import get_logger
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
MEMORY_TYPE_LABELS = {
MemoryType.PERSONAL_FACT: "个人事实",
MemoryType.EVENT: "事件",
MemoryType.PREFERENCE: "偏好",
MemoryType.OPINION: "观点",
MemoryType.RELATIONSHIP: "关系",
MemoryType.EMOTION: "情感",
MemoryType.KNOWLEDGE: "知识",
MemoryType.SKILL: "技能",
MemoryType.GOAL: "目标",
MemoryType.EXPERIENCE: "经验",
MemoryType.CONTEXTUAL: "上下文",
}
@dataclass
class AdapterConfig:
"""适配器配置"""
@@ -85,12 +100,9 @@ class EnhancedMemoryAdapter:
async def process_conversation_memory(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""处理对话记忆"""
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
@@ -98,10 +110,30 @@ class EnhancedMemoryAdapter:
self.adapter_stats["total_processed"] += 1
try:
payload_context: Dict[str, Any] = dict(context or {})
conversation_text = payload_context.get("conversation_text")
if not conversation_text:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
conversation_text = str(conversation_candidate)
payload_context["conversation_text"] = conversation_text
else:
conversation_text = ""
else:
conversation_text = str(conversation_text)
if "timestamp" not in payload_context:
payload_context["timestamp"] = time.time()
logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text))
# 使用集成层处理对话
result = await self.integration_layer.process_conversation(
conversation_text, context, user_id, timestamp
)
result = await self.integration_layer.process_conversation(payload_context)
# 更新统计
processing_time = time.time() - start_time
@@ -132,7 +164,7 @@ class EnhancedMemoryAdapter:
try:
limit = limit or self.config.max_retrieval_results
memories = await self.integration_layer.retrieve_relevant_memories(
query, user_id, context, limit
query, None, context, limit
)
self.adapter_stats["memories_retrieved"] += len(memories)
@@ -157,12 +189,15 @@ class EnhancedMemoryAdapter:
if not memories:
return ""
# 格式化记忆为提示词友好的格式
memory_context_parts = []
for memory in memories:
memory_context_parts.append(f"- {memory.text_content}")
# 格式化记忆为提示词友好的Markdown结构
lines: List[str] = ["### 🧠 相关记忆 (Relevant Memories)", ""]
return "\n".join(memory_context_parts)
for memory in memories:
type_label = MEMORY_TYPE_LABELS.get(memory.memory_type, memory.memory_type.value)
display_text = memory.display or memory.text_content
lines.append(f"- **[{type_label}]** {display_text}")
return "\n".join(lines)
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
"""获取增强记忆系统摘要"""
@@ -270,13 +305,10 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
async def process_conversation_with_enhanced_memory(
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None,
llm_model: Optional[LLMRequest] = None
) -> Dict[str, Any]:
"""使用增强记忆系统处理对话"""
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
from src.llm_models.utils_model import get_global_llm_model
@@ -284,7 +316,18 @@ async def process_conversation_with_enhanced_memory(
try:
adapter = await get_enhanced_memory_adapter(llm_model)
return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp)
payload_context = dict(context or {})
if "conversation_text" not in payload_context:
conversation_candidate = (
payload_context.get("message_content")
or payload_context.get("latest_message")
or payload_context.get("raw_text")
)
if conversation_candidate is not None:
payload_context["conversation_text"] = str(conversation_candidate)
return await adapter.process_conversation_memory(payload_context)
except Exception as e:
logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}

View File

@@ -1,13 +1,15 @@
# -*- coding: utf-8 -*-
"""
增强型精准记忆系统核心模块
基于文档设计的高效记忆构建、存储与召回优化系统
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
2. 内置 LLM 查询规划器与嵌入维度自动解析机制,直接从模型配置推断向量存储参数。
"""
import asyncio
import time
import orjson
import re
import hashlib
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
@@ -22,12 +24,16 @@ from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
from src.chat.memory_system.metadata_index import MetadataIndexManager
from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger(__name__)
# 全局记忆作用域(共享记忆库)
GLOBAL_MEMORY_SCOPE = "global"
class MemorySystemStatus(Enum):
"""记忆系统状态"""
@@ -47,14 +53,20 @@ class MemorySystemConfig:
memory_value_threshold: float = 0.7
min_build_interval_seconds: float = 300.0
# 向量存储配置
vector_dimension: int = 768
# 向量存储配置(嵌入维度自动来自模型配置)
vector_dimension: int = 1024
similarity_threshold: float = 0.8
# 召回配置
coarse_recall_limit: int = 50
fine_recall_limit: int = 10
semantic_rerank_limit: int = 20
final_recall_limit: int = 5
semantic_similarity_threshold: float = 0.6
vector_weight: float = 0.4
semantic_weight: float = 0.3
context_weight: float = 0.2
recency_weight: float = 0.1
# 融合配置
fusion_similarity_threshold: float = 0.85
@@ -64,6 +76,23 @@ class MemorySystemConfig:
def from_global_config(cls):
"""从全局配置创建配置实例"""
embedding_dimension = None
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
embedding_dimension = getattr(embedding_task, "embedding_dimension", None)
except Exception:
embedding_dimension = None
if not embedding_dimension:
try:
embedding_dimension = getattr(global_config.lpmm_knowledge, "embedding_dimension", None)
except Exception:
embedding_dimension = None
if not embedding_dimension:
embedding_dimension = 1024
return cls(
# 记忆构建配置
min_memory_length=global_config.memory.min_memory_length,
@@ -72,13 +101,19 @@ class MemorySystemConfig:
min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0),
# 向量存储配置
vector_dimension=global_config.memory.vector_dimension,
vector_dimension=int(embedding_dimension),
similarity_threshold=global_config.memory.vector_similarity_threshold,
# 召回配置
coarse_recall_limit=global_config.memory.metadata_filter_limit,
fine_recall_limit=global_config.memory.final_result_limit,
fine_recall_limit=global_config.memory.vector_search_limit,
semantic_rerank_limit=global_config.memory.semantic_rerank_limit,
final_recall_limit=global_config.memory.final_result_limit,
semantic_similarity_threshold=getattr(global_config.memory, "semantic_similarity_threshold", 0.6),
vector_weight=global_config.memory.vector_weight,
semantic_weight=global_config.memory.semantic_weight,
context_weight=global_config.memory.context_weight,
recency_weight=global_config.memory.recency_weight,
# 融合配置
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
@@ -104,6 +139,7 @@ class EnhancedMemorySystem:
self.vector_storage: VectorStorageManager = None
self.metadata_index: MetadataIndexManager = None
self.retrieval_system: MultiStageRetrieval = None
self.query_planner: MemoryQueryPlanner = None
# LLM模型
self.value_assessment_model: LLMRequest = None
@@ -117,6 +153,9 @@ class EnhancedMemorySystem:
# 构建节流记录
self._last_memory_build_times: Dict[str, float] = {}
# 记忆指纹缓存,用于快速检测重复记忆
self._memory_fingerprints: Dict[str, str] = {}
logger.info("EnhancedMemorySystem 初始化开始")
async def initialize(self):
@@ -125,19 +164,29 @@ class EnhancedMemorySystem:
logger.info("正在初始化增强型记忆系统...")
# 初始化LLM模型
task_config = (
self.llm_model.model_for_task
if self.llm_model is not None
else model_config.model_task_config.utils_small
)
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
value_task_config = getattr(model_config.model_task_config, "utils_small", None)
extraction_task_config = getattr(model_config.model_task_config, "utils", None)
if value_task_config is None:
logger.warning("未找到 utils_small 模型配置,回退到 utils 或外部提供的模型配置。")
value_task_config = extraction_task_config or fallback_task
if extraction_task_config is None:
logger.warning("未找到 utils 模型配置,回退到 utils_small 或外部提供的模型配置。")
extraction_task_config = value_task_config or fallback_task
if value_task_config is None or extraction_task_config is None:
raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。")
self.value_assessment_model = LLMRequest(
model_set=task_config,
model_set=value_task_config,
request_type="memory.value_assessment"
)
self.memory_extraction_model = LLMRequest(
model_set=task_config,
model_set=extraction_task_config,
request_type="memory.extraction"
)
@@ -155,13 +204,36 @@ class EnhancedMemorySystem:
retrieval_config = RetrievalConfig(
metadata_filter_limit=self.config.coarse_recall_limit,
vector_search_limit=self.config.fine_recall_limit,
final_result_limit=self.config.final_recall_limit
semantic_rerank_limit=self.config.semantic_rerank_limit,
final_result_limit=self.config.final_recall_limit,
vector_similarity_threshold=self.config.similarity_threshold,
semantic_similarity_threshold=self.config.semantic_similarity_threshold,
vector_weight=self.config.vector_weight,
semantic_weight=self.config.semantic_weight,
context_weight=self.config.context_weight,
recency_weight=self.config.recency_weight,
)
self.retrieval_system = MultiStageRetrieval(retrieval_config)
planner_task_config = getattr(model_config.model_task_config, "planner", None)
planner_model: Optional[LLMRequest] = None
try:
planner_model = LLMRequest(
model_set=planner_task_config,
request_type="memory.query_planner"
)
except Exception as planner_exc:
logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True)
self.query_planner = MemoryQueryPlanner(
planner_model,
default_limit=self.config.final_recall_limit
)
# 加载持久化数据
await self.vector_storage.load_storage()
await self.metadata_index.load_index()
self._populate_memory_fingerprints()
self.status = MemorySystemStatus.READY
logger.info("✅ 增强型记忆系统初始化完成")
@@ -174,7 +246,7 @@ class EnhancedMemorySystem:
async def retrieve_memories_for_building(
self,
query_text: str,
user_id: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
) -> List[MemoryChunk]:
@@ -182,7 +254,6 @@ class EnhancedMemorySystem:
Args:
query_text: 查询文本
user_id: 用户ID
context: 上下文信息
limit: 返回结果数量限制
@@ -201,7 +272,6 @@ class EnhancedMemorySystem:
# 执行检索
memories = await self.vector_storage.search_similar_memories(
query_text=query_text,
user_id=user_id,
limit=limit
)
@@ -218,23 +288,18 @@ class EnhancedMemorySystem:
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
) -> List[MemoryChunk]:
"""从对话中构建记忆
Args:
conversation_text: 对话文本
context: 上下文信息(包括用户信息、群组信息等)
user_id: 用户ID
context: 上下文信息
timestamp: 时间戳,默认为当前时间
Returns:
构建的记忆块列表
"""
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
raise RuntimeError("记忆系统未就绪")
original_status = self.status
self.status = MemorySystemStatus.BUILDING
start_time = time.time()
@@ -243,9 +308,9 @@ class EnhancedMemorySystem:
build_marker_time: Optional[float] = None
try:
normalized_context = self._normalize_context(context, user_id, timestamp)
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
build_scope_key = self._get_build_scope_key(normalized_context, user_id)
build_scope_key = self._get_build_scope_key(normalized_context, GLOBAL_MEMORY_SCOPE)
min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0))
current_time = time.time()
@@ -266,7 +331,7 @@ class EnhancedMemorySystem:
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
logger.debug("开始构建记忆,文本长度: %d", len(conversation_text))
# 1. 信息价值评估
value_score = await self._assess_information_value(conversation_text, normalized_context)
@@ -280,7 +345,7 @@ class EnhancedMemorySystem:
memory_chunks = await self.memory_builder.build_memories(
conversation_text,
normalized_context,
user_id,
GLOBAL_MEMORY_SCOPE,
timestamp or time.time()
)
@@ -293,19 +358,24 @@ class EnhancedMemorySystem:
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks)
# 4. 存储记忆
await self._store_memories(fused_chunks)
stored_count = await self._store_memories(fused_chunks)
# 4.1 控制台预览
self._log_memory_preview(fused_chunks)
# 5. 更新统计
self.total_memories += len(fused_chunks)
self.total_memories += stored_count
self.last_build_time = time.time()
if build_scope_key:
self._last_memory_build_times[build_scope_key] = self.last_build_time
build_time = time.time() - start_time
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}")
logger.info(
"✅ 生成 %d 条记忆,成功入库 %d 条,耗时 %.2f",
len(fused_chunks),
stored_count,
build_time,
)
self.status = original_status
return fused_chunks
@@ -347,21 +417,34 @@ class EnhancedMemorySystem:
async def process_conversation_memory(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Dict[str, Any]
) -> Dict[str, Any]:
"""对外暴露的对话记忆处理接口,兼容旧调用方式"""
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, timestamp)
context = dict(context or {})
conversation_candidate = (
context.get("conversation_text")
or context.get("message_content")
or context.get("latest_message")
or context.get("raw_text")
or ""
)
conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
timestamp = context.get("timestamp")
if timestamp is None:
timestamp = time.time()
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
normalized_context.setdefault("conversation_text", conversation_text)
memories = await self.build_memory_from_conversation(
conversation_text=conversation_text,
context=normalized_context,
user_id=user_id,
timestamp=timestamp
)
@@ -395,52 +478,77 @@ class EnhancedMemorySystem:
**kwargs
) -> List[MemoryChunk]:
"""检索相关记忆,兼容 query/query_text 参数形式"""
if self.status != MemorySystemStatus.READY:
raise RuntimeError("记忆系统未就绪")
query_text = query_text or kwargs.get("query")
if not query_text:
raw_query = query_text or kwargs.get("query")
if not raw_query:
raise ValueError("query_text 或 query 参数不能为空")
context = context or {}
user_id = user_id or kwargs.get("user_id")
resolved_user_id = GLOBAL_MEMORY_SCOPE
if self.retrieval_system is None or self.metadata_index is None:
raise RuntimeError("检索组件未初始化")
all_memories_cache = self.vector_storage.memory_cache
if not all_memories_cache:
logger.debug("记忆缓存为空,返回空结果")
self.last_retrieval_time = time.time()
self.status = MemorySystemStatus.READY
return []
self.status = MemorySystemStatus.RETRIEVING
start_time = time.time()
try:
normalized_context = self._normalize_context(context, user_id, None)
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
candidate_memories = list(self.vector_storage.memory_cache.values())
if user_id:
candidate_memories = [m for m in candidate_memories if m.user_id == user_id]
effective_limit = limit or self.config.final_recall_limit
query_plan = None
planner_ran = False
resolved_query_text = raw_query
if self.query_planner:
try:
planner_ran = True
query_plan = await self.query_planner.plan_query(raw_query, normalized_context)
normalized_context["query_plan"] = query_plan
effective_limit = min(effective_limit, query_plan.limit or effective_limit)
if getattr(query_plan, "semantic_query", None):
resolved_query_text = query_plan.semantic_query
logger.debug(
"查询规划: semantic='%s', types=%s, subjects=%s, limit=%d",
query_plan.semantic_query,
[mt.value for mt in query_plan.memory_types],
query_plan.subject_includes,
query_plan.limit,
)
except Exception as plan_exc:
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
if not candidate_memories:
self.status = MemorySystemStatus.READY
self.last_retrieval_time = time.time()
logger.debug(f"未找到用户 {user_id} 的候选记忆")
return []
effective_limit = effective_limit or self.config.final_recall_limit
effective_limit = max(1, min(effective_limit, self.config.final_recall_limit))
normalized_context["resolved_query_text"] = resolved_query_text
scored_memories = []
for memory in candidate_memories:
score = self._compute_memory_score(query_text, memory, normalized_context)
if score > 0:
scored_memories.append((memory, score))
if not scored_memories:
# 如果所有分数为0返回最近的记忆作为降级策略
if normalized_context.get("__memory_building__"):
logger.debug("当前处于记忆构建流程,跳过查询规划并进行降级检索")
self.status = MemorySystemStatus.BUILDING
final_memories = []
candidate_memories = list(all_memories_cache.values())
candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True)
scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]]
final_memories = candidate_memories[:effective_limit]
else:
scored_memories.sort(key=lambda item: item[1], reverse=True)
retrieval_result = await self.retrieval_system.retrieve_memories(
query=resolved_query_text,
user_id=resolved_user_id,
context=normalized_context,
metadata_index=self.metadata_index,
vector_storage=self.vector_storage,
all_memories_cache=all_memories_cache,
limit=effective_limit,
)
top_memories = [memory for memory, _ in scored_memories[:limit]]
final_memories = retrieval_result.final_memories
# 更新访问信息和缓存
for memory, score in scored_memories[:limit]:
for memory in final_memories:
memory.update_access()
memory.update_relevance(score)
cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id)
if cache_entry is not None:
cache_entry["last_accessed"] = memory.metadata.last_accessed
@@ -448,14 +556,34 @@ class EnhancedMemorySystem:
cache_entry["relevance_score"] = memory.metadata.relevance_score
retrieval_time = time.time() - start_time
logger.info(
f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}"
plan_summary = ""
if planner_ran and query_plan:
plan_types = ",".join(mt.value for mt in query_plan.memory_types) or "-"
plan_subjects = ",".join(query_plan.subject_includes) or "-"
plan_summary = (
f" | planner.semantic='{query_plan.semantic_query}'"
f" | planner.limit={query_plan.limit}"
f" | planner.types={plan_types}"
f" | planner.subjects={plan_subjects}"
)
log_message = (
"✅ 记忆检索完成"
f" | user={resolved_user_id}"
f" | count={len(final_memories)}"
f" | duration={retrieval_time:.3f}s"
f" | applied_limit={effective_limit}"
f" | raw_query='{raw_query}'"
f" | semantic_query='{resolved_query_text}'"
f"{plan_summary}"
)
logger.info(log_message)
self.last_retrieval_time = time.time()
self.status = MemorySystemStatus.READY
return top_memories
return final_memories
except Exception as e:
self.status = MemorySystemStatus.ERROR
@@ -499,8 +627,8 @@ class EnhancedMemorySystem:
except Exception:
context = dict(raw_context or {})
# 基础字段
context["user_id"] = context.get("user_id") or user_id or "unknown"
# 基础字段(统一使用全局作用域)
context["user_id"] = GLOBAL_MEMORY_SCOPE
context["timestamp"] = context.get("timestamp") or timestamp or time.time()
context["message_type"] = context.get("message_type") or "normal"
context["platform"] = context.get("platform") or context.get("source_platform") or "unknown"
@@ -523,8 +651,8 @@ class EnhancedMemorySystem:
if stream_id:
context["stream_id"] = stream_id
# chat_id 兜底
context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}"
# 全局记忆无需聊天隔离
context["chat_id"] = context.get("chat_id") or "global_chat"
# 历史窗口配置
window_candidate = (
@@ -616,18 +744,7 @@ class EnhancedMemorySystem:
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
"""确定用于节流控制的记忆构建作用域"""
stream_id = context.get("stream_id")
if stream_id:
return f"stream::{stream_id}"
chat_id = context.get("chat_id")
if chat_id:
return f"chat::{chat_id}"
if user_id:
return f"user::{user_id}"
return None
return "global_scope"
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
"""确定历史消息获取数量限制在30-50之间"""
@@ -789,24 +906,134 @@ class EnhancedMemorySystem:
logger.error(f"信息价值评估失败: {e}", exc_info=True)
return 0.5 # 默认中等价值
async def _store_memories(self, memory_chunks: List[MemoryChunk]):
"""存储记忆块到各个存储系统"""
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
"""存储记忆块到各个存储系统,返回成功入库数量"""
if not memory_chunks:
return
return 0
unique_memories: List[MemoryChunk] = []
skipped_duplicates = 0
for memory in memory_chunks:
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
existing_id = self._memory_fingerprints.get(key)
if existing_id:
existing = self.vector_storage.memory_cache.get(existing_id)
if existing:
self._merge_existing_memory(existing, memory)
await self.metadata_index.update_memory_entry(existing)
skipped_duplicates += 1
logger.debug(
"检测到重复记忆,已合并到现有记录 | memory_id=%s",
existing.memory_id,
)
continue
else:
# 指纹存在但缓存缺失,视为新记忆并覆盖旧映射
logger.debug("检测到过期指纹映射,重写现有条目")
unique_memories.append(memory)
if not unique_memories:
if skipped_duplicates:
logger.info("本次记忆全部与现有内容重复,跳过入库")
return 0
# 并行存储到向量数据库和元数据索引
storage_tasks = []
storage_tasks = [
self.vector_storage.store_memories(unique_memories),
self.metadata_index.index_memories(unique_memories),
]
# 向量存储
storage_tasks.append(self.vector_storage.store_memories(memory_chunks))
# 元数据索引
storage_tasks.append(self.metadata_index.index_memories(memory_chunks))
# 等待所有存储任务完成
await asyncio.gather(*storage_tasks, return_exceptions=True)
logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统")
self._register_memory_fingerprints(unique_memories)
logger.debug(
"成功存储 %d 条记忆(跳过重复 %d 条)",
len(unique_memories),
skipped_duplicates,
)
return len(unique_memories)
def _merge_existing_memory(self, existing: MemoryChunk, incoming: MemoryChunk) -> None:
"""将新记忆的信息合并到已存在的记忆中"""
updated = False
for keyword in incoming.keywords:
if keyword not in existing.keywords:
existing.add_keyword(keyword)
updated = True
for tag in incoming.tags:
if tag not in existing.tags:
existing.add_tag(tag)
updated = True
for category in incoming.categories:
if category not in existing.categories:
existing.add_category(category)
updated = True
if incoming.metadata.source_context:
existing.metadata.source_context = incoming.metadata.source_context
if incoming.metadata.importance.value > existing.metadata.importance.value:
existing.metadata.importance = incoming.metadata.importance
updated = True
if incoming.metadata.confidence.value > existing.metadata.confidence.value:
existing.metadata.confidence = incoming.metadata.confidence
updated = True
if incoming.metadata.relevance_score > existing.metadata.relevance_score:
existing.metadata.relevance_score = incoming.metadata.relevance_score
updated = True
if updated:
existing.metadata.last_modified = time.time()
def _populate_memory_fingerprints(self) -> None:
"""基于当前缓存构建记忆指纹映射"""
self._memory_fingerprints.clear()
for memory in self.vector_storage.memory_cache.values():
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
self._memory_fingerprints[key] = memory.memory_id
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
for memory in memories:
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
self._memory_fingerprints[key] = memory.memory_id
def _build_memory_fingerprint(self, memory: MemoryChunk) -> str:
subjects = memory.subjects or []
subject_part = "|".join(sorted(s.strip() for s in subjects if s))
predicate_part = (memory.content.predicate or "").strip()
obj = memory.content.object
if isinstance(obj, (dict, list)):
obj_part = orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8")
else:
obj_part = str(obj).strip()
base = "|".join([
str(memory.user_id or "unknown"),
memory.memory_type.value,
subject_part,
predicate_part,
obj_part,
])
return hashlib.sha256(base.encode("utf-8")).hexdigest()
@staticmethod
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
return f"{str(user_id)}:{fingerprint}"
def get_system_stats(self) -> Dict[str, Any]:
"""获取系统统计信息"""

View File

@@ -241,12 +241,12 @@ class EnhancedMemoryManager:
return []
try:
result = await self.enhanced_system.process_conversation_memory(
conversation_text=conversation_text,
context=context,
user_id=user_id,
timestamp=timestamp
)
payload_context = dict(context or {})
payload_context.setdefault("conversation_text", conversation_text)
if timestamp is not None:
payload_context.setdefault("timestamp", timestamp)
result = await self.enhanced_system.process_conversation_memory(payload_context)
# 从结果中提取记忆块
memory_chunks = []
@@ -274,7 +274,7 @@ class EnhancedMemoryManager:
try:
relevant_memories = await self.enhanced_system.retrieve_relevant_memories(
query=query_text,
user_id=user_id,
user_id=None,
context=context or {},
limit=limit
)
@@ -303,6 +303,9 @@ class EnhancedMemoryManager:
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
"""将记忆块转换为更易读的文本描述"""
structure = memory.content.to_dict()
if memory.display:
return self._clean_text(memory.display), structure
subject = structure.get("subject")
predicate = structure.get("predicate") or ""
obj = structure.get("object")

View File

@@ -114,12 +114,9 @@ class MemoryIntegrationLayer:
async def process_conversation(
self,
conversation_text: str,
context: Dict[str, Any],
user_id: str,
timestamp: Optional[float] = None
context: Dict[str, Any]
) -> Dict[str, Any]:
"""处理对话记忆"""
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
@@ -128,13 +125,12 @@ class MemoryIntegrationLayer:
self.integration_stats["enhanced_queries"] += 1
try:
payload_context = dict(context or {})
conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or ""
logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text))
# 直接使用增强记忆系统处理
result = await self.enhanced_memory.process_conversation_memory(
conversation_text=conversation_text,
context=context,
user_id=user_id,
timestamp=timestamp
)
result = await self.enhanced_memory.process_conversation_memory(payload_context)
# 更新统计
processing_time = time.time() - start_time
@@ -156,7 +152,7 @@ class MemoryIntegrationLayer:
async def retrieve_relevant_memories(
self,
query: str,
user_id: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None
) -> List[MemoryChunk]:
@@ -168,7 +164,7 @@ class MemoryIntegrationLayer:
limit = limit or self.config.max_retrieval_results
memories = await self.enhanced_memory.retrieve_relevant_memories(
query=query,
user_id=user_id,
user_id=None,
context=context or {},
limit=limit
)

View File

@@ -2,21 +2,48 @@
"""
记忆构建模块
从对话流中提取高质量、结构化记忆单元
输出格式要求:
{{
"memories": [
{{
"type": "记忆类型",
"display": "用于直接展示和检索的自然语言描述",
"subject": ["主体1", "主体2"],
"predicate": "谓语(动作/状态)",
"object": "宾语(对象/属性或结构体)",
"keywords": ["关键词1", "关键词2"],
"importance": "重要性等级(1-4)",
"confidence": "置信度(1-4)",
"reasoning": "提取理由"
}}
]
}}
注意:
1. `subject` 可包含多个主体,请用数组表示;若主体不明确,请根据上下文给出最合理的称呼
2. `display` 必须是一句完整流畅的中文描述,可直接用于用户展示和向量搜索
3. 只提取确实值得记忆的信息,不要提取琐碎内容
4. 确保信息准确、具体、有价值
5. 重要性: 1=低, 2=一般, 3=高, 4=关键;置信度: 1=低, 2=中等, 3=高, 4=已验证
"""
import re
import time
import orjson
from typing import Dict, List, Optional, Any
from datetime import datetime
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union, Type
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
create_memory_chunk
MemoryChunk,
MemoryType,
ConfidenceLevel,
ImportanceLevel,
create_memory_chunk,
)
logger = get_logger(__name__)
@@ -24,6 +51,7 @@ logger = get_logger(__name__)
class ExtractionStrategy(Enum):
"""提取策略"""
LLM_BASED = "llm_based" # 基于LLM的智能提取
RULE_BASED = "rule_based" # 基于规则的提取
HYBRID = "hybrid" # 混合策略
@@ -171,18 +199,18 @@ class MemoryBuilder:
"""使用规则提取记忆"""
memories = []
subject_display = self._resolve_user_display(context, user_id)
subjects = self._resolve_conversation_participants(context, user_id)
# 规则1: 检测个人信息
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subject_display)
personal_info = self._extract_personal_info(text, user_id, timestamp, context, subjects)
memories.extend(personal_info)
# 规则2: 检测偏好信息
preferences = self._extract_preferences(text, user_id, timestamp, context, subject_display)
preferences = self._extract_preferences(text, user_id, timestamp, context, subjects)
memories.extend(preferences)
# 规则3: 检测事件信息
events = self._extract_events(text, user_id, timestamp, context, subject_display)
events = self._extract_events(text, user_id, timestamp, context, subjects)
memories.extend(events)
return memories
@@ -258,10 +286,7 @@ class MemoryBuilder:
你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。
当前时间: {current_date}
聊天ID: {chat_id}
消息类型: {message_type}
目标用户ID: {target_user_id_display}
目标用户称呼: {target_user_name}
## 🤖 机器人身份(仅供参考,禁止写入记忆)
- 机器人名称: {bot_name_display}
@@ -272,7 +297,6 @@ class MemoryBuilder:
请务必遵守以下命名规范:
- 当说话者是机器人时,请使用“{bot_name_display}”或其他明确称呼作为主语;
- 如果看到系统自动生成的长ID类似 {target_user_id}),请改用“{target_user_name}”、机器人的称呼或“该用户”描述不要把ID写入记忆
- 记录关键事实时,请准确标记主体是机器人还是用户,避免混淆。
对话内容:
@@ -450,7 +474,7 @@ class MemoryBuilder:
bot_identifiers = self._collect_bot_identifiers(context)
system_identifiers = self._collect_system_identifiers(context)
default_subject = self._resolve_user_display(context, user_id)
default_subjects = self._resolve_conversation_participants(context, user_id)
bot_display = None
if context:
@@ -481,19 +505,33 @@ class MemoryBuilder:
for mem_data in memory_list:
try:
subject_value = mem_data.get("subject")
normalized_subject = self._normalize_subject(
normalized_subject = self._normalize_subjects(
subject_value,
bot_identifiers,
system_identifiers,
default_subject,
default_subjects,
bot_display
)
if normalized_subject is None:
if not normalized_subject:
logger.debug("跳过疑似机器人自身信息的记忆: %s", mem_data)
continue
# 创建记忆块
importance_level = self._parse_enum_value(
ImportanceLevel,
mem_data.get("importance"),
ImportanceLevel.NORMAL,
"importance"
)
confidence_level = self._parse_enum_value(
ConfidenceLevel,
mem_data.get("confidence"),
ConfidenceLevel.MEDIUM,
"confidence"
)
memory = create_memory_chunk(
user_id=user_id,
subject=normalized_subject,
@@ -502,8 +540,9 @@ class MemoryBuilder:
memory_type=MemoryType(mem_data.get("type", "contextual")),
chat_id=context.get("chat_id"),
source_context=mem_data.get("reasoning", ""),
importance=ImportanceLevel(mem_data.get("importance", 2)),
confidence=ConfidenceLevel(mem_data.get("confidence", 2))
importance=importance_level,
confidence=confidence_level,
display=mem_data.get("display")
)
# 添加关键词
@@ -511,13 +550,6 @@ class MemoryBuilder:
for keyword in keywords:
memory.add_keyword(keyword)
subject_text = memory.content.subject.strip() if isinstance(memory.content.subject, str) else str(memory.content.subject)
if not subject_text:
memory.content.subject = default_subject
elif subject_text.lower() in system_identifiers or self._looks_like_system_identifier(subject_text):
logger.debug("将系统标识主语替换为默认用户名称: %s", subject_text)
memory.content.subject = default_subject
memories.append(memory)
except Exception as e:
@@ -526,6 +558,64 @@ class MemoryBuilder:
return memories
def _parse_enum_value(
self,
enum_cls: Type[Enum],
raw_value: Any,
default: Enum,
field_name: str
) -> Enum:
"""解析枚举值,兼容数字/字符串表示"""
if isinstance(raw_value, enum_cls):
return raw_value
if raw_value is None:
return default
# 直接尝试整数转换
if isinstance(raw_value, (int, float)):
int_value = int(raw_value)
try:
return enum_cls(int_value)
except ValueError:
logger.debug("%s=%s 无法解析为 %s", field_name, raw_value, enum_cls.__name__)
return default
if isinstance(raw_value, str):
value_str = raw_value.strip()
if not value_str:
return default
if value_str.isdigit():
try:
return enum_cls(int(value_str))
except ValueError:
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
else:
normalized = value_str.replace("-", "_").replace(" ", "_").upper()
for member in enum_cls:
if member.name == normalized:
return member
for member in enum_cls:
if str(member.value).lower() == value_str.lower():
return member
try:
return enum_cls(value_str)
except ValueError:
logger.debug("%s='%s' 无法解析为 %s", field_name, value_str, enum_cls.__name__)
try:
return enum_cls(raw_value)
except Exception:
logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
field_name,
raw_value,
type(raw_value).__name__,
enum_cls.__name__,
default.name)
return default
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
identifiers: set[str] = {"bot", "机器人", "ai助手"}
if not context:
@@ -580,6 +670,58 @@ class MemoryBuilder:
return identifiers
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
participants: List[str] = []
if context:
candidate_keys = [
"participants",
"participant_names",
"speaker_names",
"members",
"member_names",
"mention_users",
"audiences"
]
for key in candidate_keys:
value = context.get(key)
if isinstance(value, (list, tuple, set)):
for item in value:
if isinstance(item, str):
cleaned = self._clean_subject_text(item)
if cleaned:
participants.append(cleaned)
elif isinstance(value, str):
for part in self._split_subject_string(value):
if part:
participants.append(part)
fallback = self._resolve_user_display(context, user_id)
if fallback:
participants.append(fallback)
if context:
bot_name = context.get("bot_name") or context.get("bot_identity")
if isinstance(bot_name, str):
cleaned = self._clean_subject_text(bot_name)
if cleaned:
participants.append(cleaned)
if not participants:
participants = ["对话参与者"]
deduplicated: List[str] = []
seen = set()
for name in participants:
key = name.lower()
if key in seen:
continue
seen.add(key)
deduplicated.append(name)
return deduplicated
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
candidate_keys = [
"user_display_name",
@@ -626,51 +768,160 @@ class MemoryBuilder:
return False
def _normalize_subject(
def _split_subject_string(self, value: str) -> List[str]:
if not value:
return []
replaced = re.sub(r"\band\b", "", value, flags=re.IGNORECASE)
replaced = replaced.replace("", "").replace("", "").replace("", "")
replaced = replaced.replace("&", "").replace("/", "").replace("+", "")
tokens = [self._clean_subject_text(token) for token in re.split(r"[、,;]+", replaced)]
return [token for token in tokens if token]
def _normalize_subjects(
self,
subject: Any,
bot_identifiers: set[str],
system_identifiers: set[str],
default_subject: str,
default_subjects: List[str],
bot_display: Optional[str] = None
) -> Optional[str]:
if subject is None:
return default_subject
) -> List[str]:
defaults = default_subjects or ["对话参与者"]
subject_str = subject if isinstance(subject, str) else str(subject)
cleaned = self._clean_subject_text(subject_str)
if not cleaned:
return default_subject
raw_candidates: List[str] = []
if isinstance(subject, list):
for item in subject:
if isinstance(item, str):
raw_candidates.extend(self._split_subject_string(item))
elif item is not None:
raw_candidates.extend(self._split_subject_string(str(item)))
elif isinstance(subject, str):
raw_candidates.extend(self._split_subject_string(subject))
elif subject is not None:
raw_candidates.extend(self._split_subject_string(str(subject)))
lowered = cleaned.lower()
normalized: List[str] = []
bot_primary = self._clean_subject_text(bot_display or "")
if lowered in bot_identifiers:
return bot_primary or cleaned
for candidate in raw_candidates:
if not candidate:
continue
if lowered in {"用户", "user", "the user", "对方", "对手"}:
return default_subject
lowered = candidate.lower()
if lowered in bot_identifiers:
normalized.append(bot_primary or candidate)
continue
prefix_match = re.match(r"^(用户|User|user|USER|成员|member|Member|target|Target|TARGET)[\s:\-\u2014_]*?(.*)$", cleaned)
if prefix_match:
remainder = self._clean_subject_text(prefix_match.group(2))
if not remainder:
return default_subject
remainder_lower = remainder.lower()
if remainder_lower in bot_identifiers:
return bot_primary or remainder
if (
remainder_lower in system_identifiers
or self._looks_like_system_identifier(remainder)
):
return default_subject
cleaned = remainder
lowered = cleaned.lower()
if lowered in {"用户", "user", "the user", "对方", "对手"}:
normalized.extend(defaults)
continue
if lowered in system_identifiers or self._looks_like_system_identifier(cleaned):
return default_subject
if lowered in system_identifiers or self._looks_like_system_identifier(candidate):
continue
return cleaned
normalized.append(candidate)
if not normalized:
normalized = list(defaults)
deduplicated: List[str] = []
seen = set()
for name in normalized:
key = name.lower()
if key in seen:
continue
seen.add(key)
deduplicated.append(name)
return deduplicated
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
if isinstance(obj, dict):
for key in keys:
value = obj.get(key)
if value is None:
continue
if isinstance(value, list):
compact = "".join(str(item) for item in value[:3])
if compact:
return compact
else:
value_str = str(value).strip()
if value_str:
return value_str
elif isinstance(obj, list):
compact = "".join(str(item) for item in obj[:3])
return compact or None
elif isinstance(obj, str):
return obj.strip() or None
return None
def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str:
subject_phrase = "".join(subjects) if subjects else "对话参与者"
predicate = (predicate or "").strip()
if predicate == "is_named":
name = self._extract_value_from_object(obj, ["name", "nickname"]) or ""
name = self._clean_subject_text(name)
if name:
quoted = name if (name.startswith("") and name.endswith("")) else f"{name}"
return f"{subject_phrase}的昵称是{quoted}"
elif predicate == "is_age":
age = self._extract_value_from_object(obj, ["age"]) or ""
age = self._clean_subject_text(age)
if age:
return f"{subject_phrase}今年{age}"
elif predicate == "is_profession":
profession = self._extract_value_from_object(obj, ["profession", "job"]) or ""
profession = self._clean_subject_text(profession)
if profession:
return f"{subject_phrase}的职业是{profession}"
elif predicate == "lives_in":
location = self._extract_value_from_object(obj, ["location", "city", "place"]) or ""
location = self._clean_subject_text(location)
if location:
return f"{subject_phrase}居住在{location}"
elif predicate == "has_phone":
phone = self._extract_value_from_object(obj, ["phone", "number"]) or ""
phone = self._clean_subject_text(phone)
if phone:
return f"{subject_phrase}的电话号码是{phone}"
elif predicate == "has_email":
email = self._extract_value_from_object(obj, ["email"]) or ""
email = self._clean_subject_text(email)
if email:
return f"{subject_phrase}的邮箱是{email}"
elif predicate in {"likes", "likes_food", "favorite_is"}:
liked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
liked = self._clean_subject_text(liked)
if liked:
verb = "喜欢" if predicate != "likes_food" else "爱吃"
if predicate == "favorite_is":
verb = "最喜欢"
return f"{subject_phrase}{verb}{liked}"
elif predicate in {"dislikes", "hates"}:
disliked = self._extract_value_from_object(obj, ["item", "value", "name"]) or ""
disliked = self._clean_subject_text(disliked)
if disliked:
verb = "不喜欢" if predicate == "dislikes" else "讨厌"
return f"{subject_phrase}{verb}{disliked}"
elif predicate == "mentioned_event":
description = self._extract_value_from_object(obj, ["event_text", "description"]) or ""
description = self._clean_subject_text(description)
if description:
return f"{subject_phrase}提到了:{description}"
obj_text = self._extract_value_from_object(obj, ["value", "detail", "content"]) or ""
obj_text = self._clean_subject_text(obj_text)
if predicate and obj_text:
return f"{subject_phrase}{predicate}{obj_text}".strip()
if obj_text:
return f"{subject_phrase}{obj_text}".strip()
if predicate:
return f"{subject_phrase}{predicate}".strip()
return subject_phrase
def _extract_personal_info(
self,
@@ -678,7 +929,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取个人信息"""
memories = []
@@ -702,13 +953,14 @@ class MemoryBuilder:
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate=predicate,
obj=obj,
memory_type=MemoryType.PERSONAL_FACT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.HIGH,
confidence=ConfidenceLevel.HIGH
confidence=ConfidenceLevel.HIGH,
display=self._compose_display_text(subjects, predicate, obj)
)
memories.append(memory)
@@ -721,7 +973,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取偏好信息"""
memories = []
@@ -740,13 +992,14 @@ class MemoryBuilder:
if match:
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate=predicate,
obj=match.group(1),
memory_type=MemoryType.PREFERENCE,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM
confidence=ConfidenceLevel.MEDIUM,
display=self._compose_display_text(subjects, predicate, match.group(1))
)
memories.append(memory)
@@ -759,7 +1012,7 @@ class MemoryBuilder:
user_id: str,
timestamp: float,
context: Dict[str, Any],
subject_display: str
subjects: List[str]
) -> List[MemoryChunk]:
"""提取事件信息"""
memories = []
@@ -770,13 +1023,14 @@ class MemoryBuilder:
if any(keyword in text for keyword in event_keywords):
memory = create_memory_chunk(
user_id=user_id,
subject=subject_display,
subject=subjects,
predicate="mentioned_event",
obj={"event_text": text, "timestamp": timestamp},
memory_type=MemoryType.EVENT,
chat_id=context.get("chat_id"),
importance=ImportanceLevel.NORMAL,
confidence=ConfidenceLevel.MEDIUM
confidence=ConfidenceLevel.MEDIUM,
display=self._compose_display_text(subjects, "mentioned_event", text)
)
memories.append(memory)

View File

@@ -7,7 +7,7 @@
import time
import uuid
import orjson
from typing import Dict, List, Optional, Any, Union
from typing import Dict, List, Optional, Any, Union, Iterable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
@@ -52,17 +52,20 @@ class ImportanceLevel(Enum):
@dataclass
class ContentStructure:
"""主谓宾三元组结构"""
subject: str # 主语(通常为用户)
predicate: str # 谓语(动作、状态、关系)
object: Union[str, Dict] # 宾语(对象、属性、值)
"""主谓宾结构,包含自然语言描述"""
subject: Union[str, List[str]]
predicate: str
object: Union[str, Dict]
display: str = ""
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"subject": self.subject,
"predicate": self.predicate,
"object": self.object
"object": self.object,
"display": self.display
}
@classmethod
@@ -71,16 +74,25 @@ class ContentStructure:
return cls(
subject=data.get("subject", ""),
predicate=data.get("predicate", ""),
object=data.get("object", "")
object=data.get("object", ""),
display=data.get("display", "")
)
def to_subject_list(self) -> List[str]:
"""将主语转换为列表形式"""
if isinstance(self.subject, list):
return [s for s in self.subject if isinstance(s, str) and s.strip()]
if isinstance(self.subject, str) and self.subject.strip():
return [self.subject.strip()]
return []
def __str__(self) -> str:
"""字符串表示"""
if isinstance(self.object, dict):
object_str = str(self.object)
else:
object_str = str(self.object)
return f"{self.subject} {self.predicate} {object_str}"
if self.display:
return self.display
subjects = "".join(self.to_subject_list()) or str(self.subject)
object_str = self.object if isinstance(self.object, str) else str(self.object)
return f"{subjects} {self.predicate} {object_str}".strip()
@dataclass
@@ -236,9 +248,19 @@ class MemoryChunk:
@property
def text_content(self) -> str:
"""获取文本内容"""
"""获取文本内容优先使用display"""
return str(self.content)
@property
def display(self) -> str:
"""获取展示文本"""
return self.content.display or str(self.content)
@property
def subjects(self) -> List[str]:
"""获取主语列表"""
return self.content.to_subject_list()
def update_access(self):
"""更新访问信息"""
self.metadata.update_access()
@@ -415,16 +437,42 @@ class MemoryChunk:
confidence_icon = "" * self.metadata.confidence.value
importance_icon = "" * self.metadata.importance.value
return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}"
return f"{emoji} [{self.memory_type.value}] {self.display} {confidence_icon} {importance_icon}"
def __repr__(self) -> str:
"""调试表示"""
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
"""根据主谓宾生成自然语言描述"""
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
subject_part = "".join(subjects_clean) if subjects_clean else "对话参与者"
if isinstance(obj, dict):
object_candidates = []
for key, value in obj.items():
if isinstance(value, (str, int, float)):
object_candidates.append(f"{key}:{value}")
elif isinstance(value, list):
compact = "".join(str(item) for item in value[:3])
object_candidates.append(f"{key}:{compact}")
object_part = "".join(object_candidates) if object_candidates else str(obj)
else:
object_part = str(obj).strip()
predicate_clean = predicate.strip()
if not predicate_clean:
return f"{subject_part} {object_part}".strip()
if object_part:
return f"{subject_part}{predicate_clean}{object_part}".strip()
return f"{subject_part}{predicate_clean}".strip()
def create_memory_chunk(
user_id: str,
subject: str,
subject: Union[str, List[str]],
predicate: str,
obj: Union[str, Dict],
memory_type: MemoryType,
@@ -432,6 +480,7 @@ def create_memory_chunk(
source_context: Optional[str] = None,
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
display: Optional[str] = None,
**kwargs
) -> MemoryChunk:
"""便捷的内存块创建函数"""
@@ -447,10 +496,22 @@ def create_memory_chunk(
source_context=source_context
)
subjects: List[str]
if isinstance(subject, list):
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
subject_payload: Union[str, List[str]] = subjects
else:
cleaned = subject.strip() if isinstance(subject, str) else ""
subjects = [cleaned] if cleaned else []
subject_payload = cleaned
display_text = display or _build_display_text(subjects, predicate, obj)
content = ContentStructure(
subject=subject,
subject=subject_payload,
predicate=predicate,
object=obj
object=obj,
display=display_text
)
chunk = MemoryChunk(

View File

@@ -266,8 +266,12 @@ class MemoryFusionEngine:
consistency_score = 0.0
# 主语一致性
if mem1.content.subject == mem2.content.subject:
consistency_score += 0.4
subjects1 = set(mem1.subjects)
subjects2 = set(mem2.subjects)
if subjects1 or subjects2:
overlap = len(subjects1 & subjects2)
union_count = max(len(subjects1 | subjects2), 1)
consistency_score += (overlap / union_count) * 0.4
# 谓语相似性
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)

View File

@@ -282,9 +282,11 @@ class MemoryIntegrationHooks:
}
# 使用增强记忆系统处理对话
result = await process_conversation_with_enhanced_memory(
conversation_text, context, user_id
)
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)
@@ -336,9 +338,11 @@ class MemoryIntegrationHooks:
}
# 使用增强记忆系统处理对话
result = await process_conversation_with_enhanced_memory(
conversation_text, context, user_id
)
memory_context = dict(context)
memory_context["conversation_text"] = conversation_text
memory_context["user_id"] = user_id
result = await process_conversation_with_enhanced_memory(memory_context)
processing_time = time.time() - start_time
self._update_hook_stats(processing_time)

View File

@@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-
"""记忆检索查询规划器"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import orjson
from src.chat.memory_system.memory_chunk import MemoryType
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@dataclass
class MemoryQueryPlan:
"""查询规划结果"""
semantic_query: str
memory_types: List[MemoryType] = field(default_factory=list)
subject_includes: List[str] = field(default_factory=list)
object_includes: List[str] = field(default_factory=list)
required_keywords: List[str] = field(default_factory=list)
optional_keywords: List[str] = field(default_factory=list)
owner_filters: List[str] = field(default_factory=list)
recency_preference: str = "any"
limit: int = 10
emphasis: Optional[str] = None
raw_plan: Dict[str, Any] = field(default_factory=dict)
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
if not self.semantic_query:
self.semantic_query = fallback_query
if self.limit <= 0:
self.limit = default_limit
self.recency_preference = (self.recency_preference or "any").lower()
if self.recency_preference not in {"any", "recent", "historical"}:
self.recency_preference = "any"
self.emphasis = (self.emphasis or "balanced").lower()
class MemoryQueryPlanner:
"""基于小模型的记忆检索查询规划器"""
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
self.model = planner_model
self.default_limit = default_limit
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
if not self.model:
logger.debug("未提供查询规划模型,使用默认规划")
return self._default_plan(query_text)
prompt = self._build_prompt(query_text, context)
try:
response, _ = await self.model.generate_response_async(prompt, temperature=0.2)
payload = self._extract_json_payload(response)
if not payload:
logger.debug("查询规划模型未返回结构化结果,使用默认规划")
return self._default_plan(query_text)
try:
data = orjson.loads(payload)
except orjson.JSONDecodeError as exc:
preview = payload[:200]
logger.warning("解析查询规划JSON失败: %s,片段: %s", exc, preview)
return self._default_plan(query_text)
plan = self._parse_plan_dict(data, query_text)
plan.ensure_defaults(query_text, self.default_limit)
return plan
except Exception as exc:
logger.error("查询规划模型调用失败: %s", exc, exc_info=True)
return self._default_plan(query_text)
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
return MemoryQueryPlan(
semantic_query=query_text,
limit=self.default_limit
)
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
def _collect_list(key: str) -> List[str]:
value = data.get(key)
if isinstance(value, str):
return [value]
if isinstance(value, list):
return [self._safe_str(item) for item in value if self._safe_str(item)]
return []
memory_type_values = _collect_list("memory_types")
memory_types: List[MemoryType] = []
for item in memory_type_values:
if not item:
continue
try:
memory_types.append(MemoryType(item))
except ValueError:
# 尝试匹配value值
normalized = item.lower()
for mt in MemoryType:
if mt.value == normalized:
memory_types.append(mt)
break
plan = MemoryQueryPlan(
semantic_query=semantic_query,
memory_types=memory_types,
subject_includes=_collect_list("subject_includes"),
object_includes=_collect_list("object_includes"),
required_keywords=_collect_list("required_keywords"),
optional_keywords=_collect_list("optional_keywords"),
owner_filters=_collect_list("owner_filters"),
recency_preference=self._safe_str(data.get("recency")) or "any",
limit=self._safe_int(data.get("limit"), self.default_limit),
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
raw_plan=data
)
return plan
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
participants = context.get("participants") or context.get("speaker_names") or []
if isinstance(participants, str):
participants = [participants]
participants = [p for p in participants if isinstance(p, str) and p.strip()]
participant_preview = "".join(participants[:5]) or "未知"
persona = context.get("bot_personality") or context.get("bot_identity") or "未知"
return f"""
你是一名记忆检索分析师,将根据对话查询生成结构化的检索计划。
请结合提供的上下文输出一个JSON对象字段含义如下
- semantic_query: 提供给向量检索的自然语言查询,要求清晰具体;
- memory_types: 建议检索的记忆类型数组,取值范围参见 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual)
- subject_includes: 需要出现在记忆主语中的人或角色列表;
- object_includes: 记忆中需要提到的重要对象或主题关键词列表;
- required_keywords: 检索时必须包含的关键词;
- optional_keywords: 可以提升相关性的附加关键词;
- owner_filters: 如果需要限制检索所属主体请列出用户ID或其它标识
- recency: 建议的时间偏好,可选 recent/any/historical
- emphasis: 检索策略倾向,可选 precision/recall/balanced
- limit: 推荐的最大返回数量(1-15之间)
- notes: 额外说明,可选。
当前查询: "{query_text}"
已知的对话参与者: {participant_preview}
机器人设定: {persona}
请输出符合要求的JSON禁止添加额外说明或Markdown代码块。
"""
def _extract_json_payload(self, response: str) -> Optional[str]:
if not response:
return None
stripped = response.strip()
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
if code_block_match:
candidate = code_block_match.group(1).strip()
if candidate:
return candidate
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1]
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
@staticmethod
def _safe_str(value: Any) -> str:
if isinstance(value, str):
return value.strip()
if value is None:
return ""
return str(value).strip()
@staticmethod
def _safe_int(value: Any, default: int) -> int:
try:
number = int(value)
if number <= 0:
return default
return number
except (TypeError, ValueError):
return default

View File

@@ -25,6 +25,7 @@ class IndexType(Enum):
"""索引类型"""
MEMORY_TYPE = "memory_type" # 记忆类型索引
USER_ID = "user_id" # 用户ID索引
SUBJECT = "subject" # 主体索引
KEYWORD = "keyword" # 关键词索引
TAG = "tag" # 标签索引
CATEGORY = "category" # 分类索引
@@ -41,6 +42,7 @@ class IndexQuery:
"""索引查询条件"""
user_ids: Optional[List[str]] = None
memory_types: Optional[List[MemoryType]] = None
subjects: Optional[List[str]] = None
keywords: Optional[List[str]] = None
tags: Optional[List[str]] = None
categories: Optional[List[str]] = None
@@ -76,6 +78,7 @@ class MetadataIndexManager:
self.indices = {
IndexType.MEMORY_TYPE: defaultdict(set),
IndexType.USER_ID: defaultdict(set),
IndexType.SUBJECT: defaultdict(set),
IndexType.KEYWORD: defaultdict(set),
IndexType.TAG: defaultdict(set),
IndexType.CATEGORY: defaultdict(set),
@@ -110,6 +113,41 @@ class MetadataIndexManager:
self.auto_save_interval = 500 # 每500次操作自动保存
self._operation_count = 0
@staticmethod
def _serialize_index_key(index_type: IndexType, key: Any) -> str:
"""将索引键序列化为字符串以便存储"""
if isinstance(key, Enum):
value = key.value
else:
value = key
return str(value)
@staticmethod
def _deserialize_index_key(index_type: IndexType, key: str) -> Any:
"""根据索引类型反序列化索引键"""
try:
if index_type == IndexType.MEMORY_TYPE:
return MemoryType(key)
if index_type == IndexType.CONFIDENCE:
return ConfidenceLevel(int(key))
if index_type == IndexType.IMPORTANCE:
return ImportanceLevel(int(key))
# 其他索引键默认使用原始字符串可能已经是lower后的字符串
return key
except Exception:
logger.warning("无法反序列化索引键 %s 在索引 %s 中,使用原始字符串", key, index_type.value)
return key
@staticmethod
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
serialized = {}
for field_name, value in metadata.items():
if isinstance(value, Enum):
serialized[field_name] = value.value
else:
serialized[field_name] = value
return serialized
async def index_memories(self, memories: List[MemoryChunk]):
"""为记忆建立索引"""
if not memories:
@@ -142,6 +180,68 @@ class MetadataIndexManager:
except Exception as e:
logger.error(f"❌ 元数据索引失败: {e}", exc_info=True)
async def update_memory_entry(self, memory: MemoryChunk):
"""更新已存在记忆的索引信息"""
if not memory:
return
with self._lock:
entry = self.memory_metadata_cache.get(memory.memory_id)
if entry is None:
# 若不存在则作为新记忆索引
self._index_single_memory(memory)
return
old_confidence = entry.get("confidence")
old_importance = entry.get("importance")
old_semantic_hash = entry.get("semantic_hash")
entry.update(
{
"user_id": memory.user_id,
"memory_type": memory.memory_type,
"created_at": memory.metadata.created_at,
"last_accessed": memory.metadata.last_accessed,
"access_count": memory.metadata.access_count,
"confidence": memory.metadata.confidence,
"importance": memory.metadata.importance,
"relationship_score": memory.metadata.relationship_score,
"relevance_score": memory.metadata.relevance_score,
"semantic_hash": memory.semantic_hash,
"subjects": memory.subjects,
}
)
# 更新置信度/重要性索引
if isinstance(old_confidence, ConfidenceLevel):
self.indices[IndexType.CONFIDENCE][old_confidence].discard(memory.memory_id)
if isinstance(old_importance, ImportanceLevel):
self.indices[IndexType.IMPORTANCE][old_importance].discard(memory.memory_id)
if isinstance(old_semantic_hash, str):
self.indices[IndexType.SEMANTIC_HASH][old_semantic_hash].discard(memory.memory_id)
self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory.memory_id)
self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory.memory_id)
if memory.semantic_hash:
self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory.memory_id)
# 同步关键词/标签/分类索引
for keyword in memory.keywords:
if keyword:
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory.memory_id)
for tag in memory.tags:
if tag:
self.indices[IndexType.TAG][tag.lower()].add(memory.memory_id)
for category in memory.categories:
if category:
self.indices[IndexType.CATEGORY][category.lower()].add(memory.memory_id)
for subject in memory.subjects:
if subject:
self.indices[IndexType.SUBJECT][subject.strip().lower()].add(memory.memory_id)
def _index_single_memory(self, memory: MemoryChunk):
"""为单个记忆建立索引"""
memory_id = memory.memory_id
@@ -157,7 +257,8 @@ class MetadataIndexManager:
"importance": memory.metadata.importance,
"relationship_score": memory.metadata.relationship_score,
"relevance_score": memory.metadata.relevance_score,
"semantic_hash": memory.semantic_hash
"semantic_hash": memory.semantic_hash,
"subjects": memory.subjects
}
# 记忆类型索引
@@ -166,6 +267,12 @@ class MetadataIndexManager:
# 用户ID索引
self.indices[IndexType.USER_ID][memory.user_id].add(memory_id)
# 主体索引
for subject in memory.subjects:
normalized = subject.strip().lower()
if normalized:
self.indices[IndexType.SUBJECT][normalized].add(memory_id)
# 关键词索引
for keyword in memory.keywords:
self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id)
@@ -282,13 +389,6 @@ class MetadataIndexManager:
# 应用最严格的过滤条件
applied_filters = []
if query.user_ids:
user_ids_set = set()
for user_id in query.user_ids:
user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set()))
candidate_ids.update(user_ids_set)
applied_filters.append("user_ids")
if query.memory_types:
memory_types_set = set()
for memory_type in query.memory_types:
@@ -302,7 +402,7 @@ class MetadataIndexManager:
if query.keywords:
keywords_set = set()
for keyword in query.keywords:
keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set()))
keywords_set.update(self._collect_index_matches(IndexType.KEYWORD, keyword))
if applied_filters:
candidate_ids &= keywords_set
else:
@@ -329,12 +429,55 @@ class MetadataIndexManager:
candidate_ids.update(categories_set)
applied_filters.append("categories")
if query.subjects:
subjects_set = set()
for subject in query.subjects:
subjects_set.update(self._collect_index_matches(IndexType.SUBJECT, subject))
if applied_filters:
candidate_ids &= subjects_set
else:
candidate_ids.update(subjects_set)
applied_filters.append("subjects")
# 如果没有应用任何过滤条件,返回所有记忆
if not applied_filters:
return all_memory_ids
return candidate_ids
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
"""根据给定token收集索引匹配支持部分匹配"""
mapping = self.indices.get(index_type)
if mapping is None:
return set()
key = ""
if isinstance(token, Enum):
key = str(token.value).strip().lower()
elif isinstance(token, str):
key = token.strip().lower()
elif token is not None:
key = str(token).strip().lower()
if not key:
return set()
matches: Set[str] = set(mapping.get(key, set()))
if matches:
return set(matches)
for existing_key, ids in mapping.items():
if not existing_key or not isinstance(existing_key, str):
continue
normalized = existing_key.strip().lower()
if not normalized:
continue
if key in normalized or normalized in key:
matches.update(ids)
return matches
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
"""应用过滤条件"""
filtered_ids = list(candidate_ids)
@@ -440,10 +583,10 @@ class MetadataIndexManager:
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
"""获取应用的过滤器列表"""
filters = []
if query.user_ids:
filters.append("user_ids")
if query.memory_types:
filters.append("memory_types")
if query.subjects:
filters.append("subjects")
if query.keywords:
filters.append("keywords")
if query.tags:
@@ -502,6 +645,18 @@ class MetadataIndexManager:
# 从各类索引中移除
self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id)
self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id)
subjects = metadata.get("subjects") or []
for subject in subjects:
if not isinstance(subject, str):
continue
normalized = subject.strip().lower()
if not normalized:
continue
subject_bucket = self.indices[IndexType.SUBJECT].get(normalized)
if subject_bucket is not None:
subject_bucket.discard(memory_id)
if not subject_bucket:
self.indices[IndexType.SUBJECT].pop(normalized, None)
# 从时间索引中移除
self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id]
@@ -625,11 +780,13 @@ class MetadataIndexManager:
logger.info("正在保存元数据索引...")
# 保存各类索引
indices_data = {}
indices_data: Dict[str, Dict[str, List[str]]] = {}
for index_type, index_data in self.indices.items():
indices_data[index_type.value] = {
key: list(values) for key, values in index_data.items()
}
serialized_index = {}
for key, values in index_data.items():
serialized_key = self._serialize_index_key(index_type, key)
serialized_index[serialized_key] = list(values)
indices_data[index_type.value] = serialized_index
indices_file = self.index_path / "indices.json"
with open(indices_file, 'w', encoding='utf-8') as f:
@@ -652,8 +809,12 @@ class MetadataIndexManager:
# 保存元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
metadata_serialized = {
memory_id: self._serialize_metadata_entry(metadata)
for memory_id, metadata in self.memory_metadata_cache.items()
}
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8'))
# 保存统计信息
stats_file = self.index_path / "index_stats.json"
@@ -679,9 +840,11 @@ class MetadataIndexManager:
for index_type_value, index_data in indices_data.items():
index_type = IndexType(index_type_value)
self.indices[index_type] = {
key: set(values) for key, values in index_data.items()
}
restored_index = defaultdict(set)
for key_str, values in index_data.items():
restored_key = self._deserialize_index_key(index_type, key_str)
restored_index[restored_key] = set(values)
self.indices[index_type] = restored_index
# 加载时间索引
time_index_file = self.index_path / "time_index.json"
@@ -709,10 +872,38 @@ class MetadataIndexManager:
# 转换置信度和重要性为枚举类型
for memory_id, metadata in cache_data.items():
if isinstance(metadata["confidence"], str):
metadata["confidence"] = ConfidenceLevel(metadata["confidence"])
if isinstance(metadata["importance"], str):
metadata["importance"] = ImportanceLevel(metadata["importance"])
memory_type_value = metadata.get("memory_type")
if isinstance(memory_type_value, str):
try:
metadata["memory_type"] = MemoryType(memory_type_value)
except ValueError:
logger.warning("无法解析memory_type %s", memory_type_value)
confidence_value = metadata.get("confidence")
if isinstance(confidence_value, (str, int)):
try:
metadata["confidence"] = ConfidenceLevel(int(confidence_value))
except ValueError:
logger.warning("无法解析confidence %s", confidence_value)
importance_value = metadata.get("importance")
if isinstance(importance_value, (str, int)):
try:
metadata["importance"] = ImportanceLevel(int(importance_value))
except ValueError:
logger.warning("无法解析importance %s", importance_value)
subjects_value = metadata.get("subjects")
if isinstance(subjects_value, str):
metadata["subjects"] = [subjects_value]
elif isinstance(subjects_value, list):
cleaned_subjects = []
for item in subjects_value:
if isinstance(item, str) and item.strip():
cleaned_subjects.append(item.strip())
metadata["subjects"] = cleaned_subjects
else:
metadata["subjects"] = []
self.memory_metadata_cache = cache_data

View File

@@ -203,11 +203,17 @@ class MultiStageRetrieval:
try:
from .metadata_index import IndexQuery
# 构建索引查询
query_plan = context.get("query_plan")
memory_types = self._extract_memory_types_from_context(context)
keywords = self._extract_keywords_from_query(query, query_plan)
subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
index_query = IndexQuery(
user_ids=[user_id],
memory_types=self._extract_memory_types_from_context(context),
keywords=self._extract_keywords_from_query(query),
user_ids=None,
memory_types=memory_types,
subjects=subjects,
keywords=keywords,
limit=self.config.metadata_filter_limit,
sort_by="last_accessed",
sort_order="desc"
@@ -215,13 +221,66 @@ class MultiStageRetrieval:
# 执行查询
result = await metadata_index.query_memories(index_query)
filtered_count = result.total_count - len(result.memory_ids)
result_ids = list(result.memory_ids)
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆")
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
if not result_ids:
sorted_ids = sorted(
(memory.memory_id for memory in all_memories_cache.values()),
key=lambda mid: all_memories_cache[mid].metadata.last_accessed if mid in all_memories_cache else 0,
reverse=True,
)
if memory_types:
type_filtered = [
mid for mid in sorted_ids
if all_memories_cache[mid].memory_type in memory_types
]
sorted_ids = type_filtered or sorted_ids
if subjects:
subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()]
if subject_candidates:
subject_filtered = [
mid for mid in sorted_ids
if any(
subj.strip().lower() in subject_candidates
for subj in all_memories_cache[mid].subjects
)
]
sorted_ids = subject_filtered or sorted_ids
if keywords:
keyword_pool = {kw.lower() for kw in keywords if isinstance(kw, str) and kw.strip()}
if keyword_pool:
keyword_filtered = []
for mid in sorted_ids:
memory_text = (
(all_memories_cache[mid].display or "")
+ "\n"
+ (all_memories_cache[mid].text_content or "")
).lower()
if any(kw in memory_text for kw in keyword_pool):
keyword_filtered.append(mid)
sorted_ids = keyword_filtered or sorted_ids
result_ids = sorted_ids[: self.config.metadata_filter_limit]
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
logger.debug(
"元数据过滤未命中索引,使用近似回退: types=%s, subjects=%s, keywords=%s",
bool(memory_types),
bool(subjects),
bool(keywords),
)
logger.debug(
"元数据过滤:候选=%d, 返回=%d",
len(all_memories_cache),
len(result_ids),
)
return StageResult(
stage=RetrievalStage.METADATA_FILTERING,
memory_ids=result.memory_ids,
memory_ids=result_ids,
processing_time=time.time() - start_time,
filtered_count=filtered_count,
score_threshold=0.0
@@ -251,7 +310,7 @@ class MultiStageRetrieval:
try:
# 生成查询向量
query_embedding = await self._generate_query_embedding(query, context)
query_embedding = await self._generate_query_embedding(query, context, vector_storage)
if not query_embedding:
return StageResult(
@@ -263,22 +322,24 @@ class MultiStageRetrieval:
)
# 执行向量搜索
search_result = await vector_storage.search_similar(
query_embedding,
search_result = await vector_storage.search_similar_memories(
query_vector=query_embedding,
limit=self.config.vector_search_limit
)
candidate_pool = candidate_ids or set(all_memories_cache.keys())
# 过滤候选记忆
filtered_memories = []
for memory_id, similarity in search_result:
if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold:
if memory_id in candidate_pool and similarity >= self.config.vector_similarity_threshold:
filtered_memories.append((memory_id, similarity))
# 按相似度排序
filtered_memories.sort(key=lambda x: x[1], reverse=True)
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
filtered_count = len(candidate_ids) - len(result_ids)
filtered_count = max(0, len(candidate_pool) - len(result_ids))
logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆")
@@ -407,12 +468,20 @@ class MultiStageRetrieval:
score_threshold=0.0
)
async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]:
async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]:
"""生成查询向量"""
try:
# 这里应该调用embedding模型
# 由于我们可能没有直接的embedding模型返回None或使用简单的方法
# 在实际实现中这里应该调用与记忆存储相同的embedding模型
query_plan = context.get("query_plan")
query_text = query
if query_plan and getattr(query_plan, "semantic_query", None):
query_text = query_plan.semantic_query
if not query_text:
return None
if hasattr(vector_storage, "generate_query_embedding"):
return await vector_storage.generate_query_embedding(query_text)
return None
except Exception as e:
logger.warning(f"生成查询向量失败: {e}")
@@ -421,9 +490,15 @@ class MultiStageRetrieval:
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
"""计算语义相似度"""
try:
query_plan = context.get("query_plan")
query_text = query
if query_plan and getattr(query_plan, "semantic_query", None):
query_text = query_plan.semantic_query
# 简单的文本相似度计算
query_words = set(query.lower().split())
memory_words = set(memory.text_content.lower().split())
query_words = set(query_text.lower().split())
memory_text = (memory.display or memory.text_content or "").lower()
memory_words = set(memory_text.split())
if not query_words or not memory_words:
return 0.0
@@ -443,10 +518,15 @@ class MultiStageRetrieval:
try:
score = 0.0
query_plan = context.get("query_plan")
# 检查记忆类型是否匹配上下文
if context.get("expected_memory_types"):
if memory.memory_type in context["expected_memory_types"]:
score += 0.3
elif query_plan and getattr(query_plan, "memory_types", None):
if memory.memory_type in query_plan.memory_types:
score += 0.3
# 检查关键词匹配
if context.get("keywords"):
@@ -456,6 +536,35 @@ class MultiStageRetrieval:
if overlap:
score += len(overlap) / max(len(context_keywords), 1) * 0.4
if query_plan:
# 主体匹配
subject_score = self._calculate_subject_overlap(memory, getattr(query_plan, "subject_includes", []))
score += subject_score * 0.3
# 对象/描述匹配
object_keywords = getattr(query_plan, "object_includes", []) or []
if object_keywords:
display_text = (memory.display or memory.text_content or "").lower()
hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
if hits:
score += min(0.3, hits * 0.1)
optional_keywords = getattr(query_plan, "optional_keywords", []) or []
if optional_keywords:
display_text = (memory.display or memory.text_content or "").lower()
hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
if hits:
score += min(0.2, hits * 0.05)
# 时间偏好
recency_pref = getattr(query_plan, "recency_preference", "")
if recency_pref:
memory_age = time.time() - memory.metadata.created_at
if recency_pref == "recent" and memory_age < 7 * 24 * 3600:
score += 0.2
elif recency_pref == "historical" and memory_age > 30 * 24 * 3600:
score += 0.1
# 检查时效性
if context.get("recent_only", False):
memory_age = time.time() - memory.metadata.created_at
@@ -471,6 +580,8 @@ class MultiStageRetrieval:
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
"""计算最终评分"""
try:
query_plan = context.get("query_plan")
# 语义相似度
semantic_score = await self._calculate_semantic_similarity(query, memory, context)
@@ -482,13 +593,29 @@ class MultiStageRetrieval:
# 时效性评分
recency_score = self._calculate_recency_score(memory.metadata.created_at)
if query_plan:
recency_pref = getattr(query_plan, "recency_preference", "")
if recency_pref == "recent":
recency_score = max(recency_score, 0.8)
elif recency_pref == "historical":
recency_score = min(recency_score, 0.5)
# 权重组合
vector_weight = self.config.vector_weight
semantic_weight = self.config.semantic_weight
context_weight = self.config.context_weight
recency_weight = self.config.recency_weight
if query_plan and getattr(query_plan, "emphasis", None) == "precision":
semantic_weight += 0.05
elif query_plan and getattr(query_plan, "emphasis", None) == "recall":
context_weight += 0.05
final_score = (
semantic_score * self.config.semantic_weight +
vector_score * self.config.vector_weight +
context_score * self.config.context_weight +
recency_score * self.config.recency_weight
semantic_score * semantic_weight +
vector_score * vector_weight +
context_score * context_weight +
recency_score * recency_weight
)
# 加入记忆重要性权重
@@ -501,6 +628,31 @@ class MultiStageRetrieval:
logger.warning(f"计算最终评分失败: {e}")
return 0.0
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
if not required_subjects:
return 0.0
memory_subjects = {subject.lower() for subject in memory.subjects if isinstance(subject, str)}
if not memory_subjects:
return 0.0
hit = 0
total = 0
for subject in required_subjects:
if not isinstance(subject, str):
continue
total += 1
normalized = subject.strip().lower()
if not normalized:
continue
if any(normalized in mem_subject for mem_subject in memory_subjects):
hit += 1
if total == 0:
return 0.0
return hit / total
def _calculate_recency_score(self, timestamp: float) -> float:
"""计算时效性评分"""
try:
@@ -524,6 +676,10 @@ class MultiStageRetrieval:
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
"""从上下文中提取记忆类型"""
try:
query_plan = context.get("query_plan")
if query_plan and getattr(query_plan, "memory_types", None):
return query_plan.memory_types
if "expected_memory_types" in context:
return context["expected_memory_types"]
@@ -544,15 +700,30 @@ class MultiStageRetrieval:
except Exception:
return []
def _extract_keywords_from_query(self, query: str) -> List[str]:
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
"""从查询中提取关键词"""
try:
extracted: List[str] = []
if query_plan and getattr(query_plan, "required_keywords", None):
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
# 简单的关键词提取
words = query.lower().split()
# 过滤停用词
stopwords = {"", "", "", "", "", "", "", "", "", "", "", "", "", ""}
keywords = [word for word in words if len(word) > 1 and word not in stopwords]
return keywords[:10] # 最多返回10个关键词
extracted.extend(word for word in words if len(word) > 1 and word not in stopwords)
# 去重并保留顺序
seen = set()
deduplicated = []
for word in extracted:
if word in seen or not word:
continue
seen.add(word)
deduplicated.append(word)
return deduplicated[:10]
except Exception:
return []

View File

@@ -20,6 +20,7 @@ from pathlib import Path
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.common.config_helpers import resolve_embedding_dimension
from src.chat.memory_system.memory_chunk import MemoryChunk
logger = get_logger(__name__)
@@ -36,12 +37,12 @@ except ImportError:
@dataclass
class VectorStorageConfig:
"""向量存储配置"""
dimension: int = 768
dimension: int = 1024
similarity_threshold: float = 0.8
index_type: str = "flat" # flat, ivf, hnsw
max_index_size: int = 100000
storage_path: str = "data/memory_vectors"
auto_save_interval: int = 100 # 每N次操作自动保存
auto_save_interval: int = 10 # 每N次操作自动保存
enable_compression: bool = True
@@ -50,6 +51,15 @@ class VectorStorageManager:
def __init__(self, config: Optional[VectorStorageConfig] = None):
self.config = config or VectorStorageConfig()
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
if resolved_dimension and resolved_dimension != self.config.dimension:
logger.info(
"向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)",
resolved_dimension,
self.config.dimension,
)
self.config.dimension = resolved_dimension
self.storage_path = Path(self.config.storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
@@ -117,6 +127,32 @@ class VectorStorageManager:
)
logger.info("✅ 嵌入模型初始化完成")
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
"""生成查询向量,用于记忆召回"""
if not query_text:
return None
try:
await self.initialize_embedding_model()
embedding, _ = await self.embedding_model.get_embedding(query_text)
if not embedding:
return None
if len(embedding) != self.config.dimension:
logger.warning(
"查询向量维度不匹配: 期望 %d, 实际 %d",
self.config.dimension,
len(embedding)
)
return None
return self._normalize_vector(embedding)
except Exception as exc:
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
return None
async def store_memories(self, memories: List[MemoryChunk]):
"""存储记忆向量"""
if not memories:
@@ -213,7 +249,7 @@ class VectorStorageManager:
results[memory_id] = embedding
else:
logger.warning(
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)",
"嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。",
self.config.dimension,
len(embedding) if embedding else 0,
memory_id,
@@ -299,14 +335,32 @@ class VectorStorageManager:
async def search_similar_memories(
self,
query_vector: List[float],
query_vector: Optional[List[float]] = None,
*,
query_text: Optional[str] = None,
limit: int = 10,
user_id: Optional[str] = None
scope_id: Optional[str] = None
) -> List[Tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
try:
if query_vector is None:
if not query_text:
return []
query_vector = await self.generate_query_embedding(query_text)
if not query_vector:
return []
scope_filter: Optional[str] = None
if isinstance(scope_id, str):
normalized_scope = scope_id.strip().lower()
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
scope_filter = scope_id
elif scope_id:
scope_filter = str(scope_id)
# 规范化查询向量
query_vector = self._normalize_vector(query_vector)
@@ -341,10 +395,9 @@ class VectorStorageManager:
memory_id = self.index_to_memory_id.get(index)
if memory_id:
# 应用用户过滤
if user_id:
if scope_filter:
memory = self.memory_cache.get(memory_id)
if memory and memory.user_id != user_id:
if memory and str(memory.user_id) != scope_filter:
continue
similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内
@@ -481,8 +534,14 @@ class VectorStorageManager:
# 保存映射关系
mapping_file = self.storage_path / "id_mapping.json"
mapping_data = {
"memory_id_to_index": self.memory_id_to_index,
"index_to_memory_id": self.index_to_memory_id
"memory_id_to_index": {
str(memory_id): int(index)
for memory_id, index in self.memory_id_to_index.items()
},
"index_to_memory_id": {
str(index): memory_id
for index, memory_id in self.index_to_memory_id.items()
}
}
with open(mapping_file, 'w', encoding='utf-8') as f:
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
@@ -529,8 +588,17 @@ class VectorStorageManager:
if mapping_file.exists():
with open(mapping_file, 'r', encoding='utf-8') as f:
mapping_data = orjson.loads(f.read())
self.memory_id_to_index = mapping_data.get("memory_id_to_index", {})
self.index_to_memory_id = mapping_data.get("index_to_memory_id", {})
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
str(memory_id): int(index)
for memory_id, index in raw_memory_to_index.items()
}
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
self.index_to_memory_id = {
int(index): memory_id
for index, memory_id in raw_index_to_memory.items()
}
# 加载FAISS索引如果可用
if FAISS_AVAILABLE:

View File

@@ -373,12 +373,12 @@ class Prompt:
# 性能优化 - 为不同任务设置不同的超时时间
task_timeouts = {
"memory_block": 5.0, # 记忆系统可能较慢,单独设置超时
"tool_info": 3.0, # 工具信息中等速度
"relation_info": 2.0, # 关系信息通常较快
"knowledge_info": 3.0, # 知识库查询中等速度
"cross_context": 2.0, # 上下文处理通常较快
"expression_habits": 1.5, # 表达习惯处理很快
"memory_block": 15.0, # 记忆系统
"tool_info": 15.0, # 工具信息
"relation_info": 10.0, # 关系信息
"knowledge_info": 10.0, # 知识库查询
"cross_context": 10.0, # 上下文处理
"expression_habits": 10.0, # 表达习惯
}
# 分别处理每个任务,避免慢任务影响快任务
@@ -558,12 +558,8 @@ class Prompt:
)
]
# 等待所有记忆查询完成最多3秒
try:
running_memories, instant_memory = await asyncio.wait_for(
asyncio.gather(*memory_tasks, return_exceptions=True),
timeout=3.0
)
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
# 处理可能的异常结果
if isinstance(running_memories, Exception):

View File

@@ -207,6 +207,9 @@ class VideoAnalyzer:
"""检查视频是否已经分析过"""
try:
async with get_db_session() as session:
if not session:
logger.warning("无法获取数据库会话,跳过视频存在性检查。")
return None
# 明确刷新会话以确保看到其他事务的最新提交
await session.expire_all()
stmt = select(Videos).where(Videos.video_hash == video_hash)
@@ -227,6 +230,9 @@ class VideoAnalyzer:
try:
async with get_db_session() as session:
if not session:
logger.warning("无法获取数据库会话,跳过视频结果存储。")
return None
# 只根据video_hash查找
stmt = select(Videos).where(Videos.video_hash == video_hash)
result = await session.execute(stmt)
@@ -540,11 +546,14 @@ class VideoAnalyzer:
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
# 获取模型信息和客户端
model_info, api_provider, client = self.video_llm._select_model()
selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response")
if not selection_result:
raise RuntimeError("无法为视频分析选择可用模型。")
model_info, api_provider, client = selection_result
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
# 直接执行多图片请求
api_response = await self.video_llm._execute_request(
api_response = await self.video_llm._executor.execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,

View File

@@ -461,11 +461,14 @@ class LegacyVideoAnalyzer:
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
# 获取模型信息和客户端
model_info, api_provider, client = self.video_llm._select_model()
selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response")
if not selection_result:
raise RuntimeError("无法为视频分析选择可用模型 (legacy)。")
model_info, api_provider, client = selection_result
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
# 直接执行多图片请求
api_response = await self.video_llm._execute_request(
api_response = await self.video_llm._executor.execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Union
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.config_helpers import resolve_embedding_dimension
from src.common.database.sqlalchemy_models import CacheEntries
from src.common.database.sqlalchemy_database_api import db_query, db_save
from src.common.vector_db import vector_db_service
@@ -40,7 +41,11 @@ class CacheManager:
# L1 缓存 (内存)
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
if not embedding_dim:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.embedding_dimension = embedding_dim
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {}
@@ -72,7 +77,7 @@ class CacheManager:
embedding_array = embedding_array.flatten()
# 检查维度是否符合预期
expected_dim = global_config.lpmm_knowledge.embedding_dimension
expected_dim = getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension
if embedding_array.shape[0] != expected_dim:
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
return None

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from typing import Optional
from src.config.config import global_config, model_config
def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]:
"""获取当前配置的嵌入向量维度。
优先顺序:
1. 模型配置中 `model_task_config.embedding.embedding_dimension`
2. 机器人配置中 `lpmm_knowledge.embedding_dimension`
3. 调用方提供的 fallback
"""
candidates: list[Optional[int]] = []
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
candidates.append(getattr(embedding_task, "embedding_dimension", None))
except Exception:
candidates.append(None)
try:
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
except Exception:
candidates.append(None)
candidates.append(fallback)
resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
if resolved and sync_global:
try:
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
except Exception:
pass
return resolved

View File

@@ -759,30 +759,38 @@ async def initialize_database():
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""异步数据库会话上下文管理器"""
async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]:
"""
异步数据库会话上下文管理器。
在初始化失败时会yield None调用方需要检查会话是否为None。
"""
session: Optional[AsyncSession] = None
SessionLocal = None
try:
engine, SessionLocal = await initialize_database()
_, SessionLocal = await initialize_database()
if not SessionLocal:
raise RuntimeError("Database session not initialized")
session = SessionLocal()
logger.error("数据库会话工厂 (_SessionLocal) 未初始化。")
yield None
return
except Exception as e:
logger.error(f"数据库初始化失败,无法创建会话: {e}")
yield None
return
try:
session = SessionLocal()
# 对于 SQLite在会话开始时设置 PRAGMA
from src.config.config import global_config
if global_config.database.database_type == "sqlite":
try:
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
except Exception as e:
logger.warning(f"[SQLite] 设置会话 PRAGMA 失败: {e}")
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
yield session
except Exception as e:
logger.error(f"数据库会话错误: {e}")
logger.error(f"数据库会话期间发生错误: {e}")
if session:
await session.rollback()
raise
raise # 将会话期间的错误重新抛出给调用者
finally:
if session:
await session.close()

View File

@@ -1,4 +1,4 @@
from typing import List, Dict, Any, Literal, Union
from typing import List, Dict, Any, Literal, Union, Optional
from pydantic import Field
from threading import Lock
@@ -105,6 +105,11 @@ class TaskConfig(ValidatedConfigBase):
max_tokens: int = Field(default=800, description="任务最大输出token数")
temperature: float = Field(default=0.7, description="模型温度")
concurrency_count: int = Field(default=1, description="并发请求数量")
embedding_dimension: Optional[int] = Field(
default=None,
description="嵌入模型输出向量维度,仅在嵌入任务中使用",
ge=1,
)
@classmethod
def validate_model_list(cls, v):

View File

@@ -443,21 +443,6 @@ class MemoryConfig(ValidatedConfigBase):
enable_memory: bool = Field(default=True, description="启用记忆")
memory_build_interval: int = Field(default=600, description="记忆构建间隔")
memory_build_distribution: list[float] = Field(
default_factory=lambda: [6.0, 3.0, 0.6, 32.0, 12.0, 0.4], description="记忆构建分布"
)
memory_build_sample_num: int = Field(default=8, description="记忆构建样本数量")
memory_build_sample_length: int = Field(default=40, description="记忆构建样本长度")
memory_compress_rate: float = Field(default=0.1, description="记忆压缩率")
forget_memory_interval: int = Field(default=1000, description="遗忘记忆间隔")
memory_forget_time: int = Field(default=24, description="记忆遗忘时间")
memory_forget_percentage: float = Field(default=0.01, description="记忆遗忘百分比")
consolidate_memory_interval: int = Field(default=1000, description="记忆巩固间隔")
consolidation_similarity_threshold: float = Field(default=0.7, description="巩固相似性阈值")
consolidate_memory_percentage: float = Field(default=0.01, description="巩固记忆百分比")
memory_ban_words: list[str] = Field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"], description="记忆禁用词"
)
enable_instant_memory: bool = Field(default=True, description="启用即时记忆")
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
@@ -472,8 +457,8 @@ class MemoryConfig(ValidatedConfigBase):
memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值")
# 向量存储配置
vector_dimension: int = Field(default=768, description="向量维度")
vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值")
semantic_similarity_threshold: float = Field(default=0.6, description="语义相似度阈值")
# 多阶段检索配置
metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量")

View File

@@ -27,9 +27,8 @@ from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.component_types import EventType
# from src.api.main import start_api_server
# 导入新的插件管理器和热重载管理器
# 导入新的插件管理器
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
# 导入消息API和traceback模块
from src.common.message import get_global_api
@@ -116,12 +115,6 @@ class MainSystem:
except Exception as e:
logger.error(f"停止消息重组器时出错: {e}")
try:
# 停止插件热重载系统
hot_reload_manager.stop()
logger.info("🛑 插件热重载系统已停止")
except Exception as e:
logger.error(f"停止热重载系统时出错: {e}")
try:
# 停止增强记忆系统
@@ -229,8 +222,6 @@ MoFox_Bot(第三方修改版)
# 处理所有缓存的事件订阅(插件加载完成后)
event_manager.process_all_pending_subscriptions()
# 启动插件热重载系统
hot_reload_manager.start()
# 初始化表情管理器
get_emoji_manager().initialize()

View File

@@ -2,7 +2,7 @@ import time
import asyncio
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from typing import Tuple, Optional, List, Dict, Any
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
@@ -27,8 +27,21 @@ class BaseAction(ABC):
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
- llm_judge_prompt: LLM判断提示词
二步Action相关属性
- is_two_step_action: 是否为二步Action
- step_one_description: 第一步的描述
- sub_actions: 子Action列表
"""
# 二步Action相关类属性
is_two_step_action: bool = False
"""是否为二步Action。如果为TrueAction将分两步执行第一步选择操作第二步执行具体操作"""
step_one_description: str = ""
"""第一步的描述用于向LLM展示Action的基本功能"""
sub_actions: List[Tuple[str, str, Dict[str, str]]] = []
"""子Action列表格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用"""
def __init__(
self,
action_data: dict,
@@ -93,6 +106,13 @@ class BaseAction(ABC):
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
# 二步Action相关实例属性
self.is_two_step_action: bool = getattr(self.__class__, "is_two_step_action", False)
self.step_one_description: str = getattr(self.__class__, "step_one_description", "")
self.sub_actions: List[Tuple[str, str, Dict[str, str]]] = getattr(self.__class__, "sub_actions", []).copy()
self._selected_sub_action: Optional[str] = None
"""当前选择的子Action名称用于二步Action的状态管理"""
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# =============================================================================
@@ -412,23 +432,32 @@ class BaseAction(ABC):
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
return False, f"未找到Action组件信息: {action_name}"
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
# 确保获取的是Action组件
if component_info.component_type != ComponentType.ACTION:
logger.error(f"{log_prefix} 尝试调用的组件 '{action_name}' 不是一个Action而是一个 '{component_info.component_type.value}'")
return False, f"组件 '{action_name}' 不是一个有效的Action"
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
# 3. 实例化被调用的Action
action_instance = action_class(
action_data=called_action_data,
reasoning=f"Called by {self.action_name}",
cycle_timers=self.cycle_timers,
thinking_id=self.thinking_id,
chat_stream=self.chat_stream,
log_prefix=log_prefix,
plugin_config=plugin_config,
action_message=self.action_message,
)
action_params = {
"action_data": called_action_data,
"reasoning": f"Called by {self.action_name}",
"cycle_timers": self.cycle_timers,
"thinking_id": self.thinking_id,
"chat_stream": self.chat_stream,
"log_prefix": log_prefix,
"plugin_config": plugin_config,
"action_message": self.action_message,
}
action_instance = action_class(**action_params)
# 4. 执行Action
logger.debug(f"{log_prefix} 开始执行...")
result = await action_instance.execute()
execute_result = await action_instance.execute()
# 确保返回类型符合 (bool, str) 格式
is_success = execute_result[0] if isinstance(execute_result, tuple) and len(execute_result) > 0 else False
message = execute_result[1] if isinstance(execute_result, tuple) and len(execute_result) > 1 else ""
result = (is_success, str(message))
logger.info(f"{log_prefix} 执行完成,结果: {result}")
return result
@@ -477,15 +506,73 @@ class BaseAction(ABC):
action_require=getattr(cls, "action_require", []).copy(),
associated_types=getattr(cls, "associated_types", []).copy(),
chat_type_allow=getattr(cls, "chat_type_allow", ChatType.ALL),
# 二步Action相关属性
is_two_step_action=getattr(cls, "is_two_step_action", False),
step_one_description=getattr(cls, "step_one_description", ""),
sub_actions=getattr(cls, "sub_actions", []).copy(),
)
async def handle_step_one(self) -> Tuple[bool, str]:
"""处理二步Action的第一步
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
if not self.is_two_step_action:
return False, "此Action不是二步Action"
# 检查action_data中是否包含选择的子Action
selected_action = self.action_data.get("selected_action")
if not selected_action:
# 第一步展示可用的子Action
available_actions = [sub_action[0] for sub_action in self.sub_actions]
description = self.step_one_description or f"{self.action_name}支持以下操作"
actions_list = "\n".join([f"- {action}: {desc}" for action, desc, _ in self.sub_actions])
response = f"{description}\n\n可用操作:\n{actions_list}\n\n请选择要执行的操作。"
return True, response
else:
# 验证选择的子Action是否有效
valid_actions = [sub_action[0] for sub_action in self.sub_actions]
if selected_action not in valid_actions:
return False, f"无效的操作选择: {selected_action}。可用操作: {valid_actions}"
# 保存选择的子Action
self._selected_sub_action = selected_action
# 调用第二步执行
return await self.execute_step_two(selected_action)
async def execute_step_two(self, sub_action_name: str) -> Tuple[bool, str]:
"""执行二步Action的第二步
Args:
sub_action_name: 子Action名称
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
if not self.is_two_step_action:
return False, "此Action不是二步Action"
# 子类需要重写此方法来实现具体的第二步逻辑
return False, f"二步Action必须实现execute_step_two方法来处理操作: {sub_action_name}"
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
对于二步Action会自动处理第一步逻辑
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
# 如果是二步Action自动处理第一步
if self.is_two_step_action:
return await self.handle_step_one()
# 普通Action由子类实现
pass
async def handle_action(self) -> Tuple[bool, str]:

View File

@@ -38,6 +38,14 @@ class BaseTool(ABC):
semantic_cache_query_key: Optional[str] = None
"""用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索"""
# 二步工具调用相关属性
is_two_step_tool: bool = False
"""是否为二步工具。如果为True工具将分两步调用第一步展示工具信息第二步执行具体操作"""
step_one_description: str = ""
"""第一步的描述用于向LLM展示工具的基本功能"""
sub_tools: List[Tuple[str, str, List[Tuple[str, ToolParamType, str, bool, List[str] | None]]]] = []
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
def __init__(self, plugin_config: Optional[dict] = None):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
@@ -48,10 +56,64 @@ class BaseTool(ABC):
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
if not cls.name or not cls.description:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name description 属性")
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
# 如果是二步工具,第一步只返回基本信息
if cls.is_two_step_tool:
return {
"name": cls.name,
"description": cls.step_one_description or cls.description,
"parameters": [("action", ToolParamType.STRING, "选择要执行的操作", True, [sub_tool[0] for sub_tool in cls.sub_tools])]
}
else:
# 普通工具需要parameters
if not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 parameters 属性")
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
@classmethod
def get_step_two_tool_definition(cls, sub_tool_name: str) -> dict[str, Any]:
"""获取二步工具的第二步定义
Args:
sub_tool_name: 子工具名称
Returns:
dict: 第二步工具定义字典
"""
if not cls.is_two_step_tool:
raise ValueError(f"工具 {cls.name} 不是二步工具")
# 查找对应的子工具
for sub_name, sub_desc, sub_params in cls.sub_tools:
if sub_name == sub_tool_name:
return {
"name": f"{cls.name}_{sub_tool_name}",
"description": sub_desc,
"parameters": sub_params
}
raise ValueError(f"未找到子工具: {sub_tool_name}")
@classmethod
def get_all_sub_tool_definitions(cls) -> List[dict[str, Any]]:
"""获取所有子工具的定义
Returns:
List[dict]: 所有子工具定义列表
"""
if not cls.is_two_step_tool:
return []
definitions = []
for sub_name, sub_desc, sub_params in cls.sub_tools:
definitions.append({
"name": f"{cls.name}_{sub_name}",
"description": sub_desc,
"parameters": sub_params
})
return definitions
@classmethod
def get_tool_info(cls) -> ToolInfo:
@@ -79,8 +141,68 @@ class BaseTool(ABC):
Returns:
dict: 工具执行结果
"""
# 如果是二步工具,处理第一步调用
if self.is_two_step_tool and "action" in function_args:
return await self._handle_step_one(function_args)
raise NotImplementedError("子类必须实现execute方法")
async def _handle_step_one(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""处理二步工具的第一步调用
Args:
function_args: 包含action参数的函数参数
Returns:
dict: 第一步执行结果,包含第二步的工具定义
"""
action = function_args.get("action")
if not action:
return {"error": "缺少action参数"}
# 查找对应的子工具
sub_tool_found = None
for sub_name, sub_desc, sub_params in self.sub_tools:
if sub_name == action:
sub_tool_found = (sub_name, sub_desc, sub_params)
break
if not sub_tool_found:
available_actions = [sub_tool[0] for sub_tool in self.sub_tools]
return {"error": f"未知的操作: {action}。可用操作: {available_actions}"}
sub_name, sub_desc, sub_params = sub_tool_found
# 返回第二步工具定义
step_two_definition = {
"name": f"{self.name}_{sub_name}",
"description": sub_desc,
"parameters": sub_params
}
return {
"type": "two_step_tool_step_one",
"content": f"已选择操作: {action}。请使用以下工具进行具体调用:",
"next_tool_definition": step_two_definition,
"selected_action": action
}
async def execute_step_two(self, sub_tool_name: str, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行二步工具的第二步
Args:
sub_tool_name: 子工具名称
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
if not self.is_two_step_tool:
raise ValueError(f"工具 {self.name} 不是二步工具")
# 子类需要重写此方法来实现具体的第二步逻辑
raise NotImplementedError("二步工具必须实现execute_step_two方法")
async def direct_execute(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
"""直接执行工具函数(供插件调用)
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数

View File

@@ -142,6 +142,10 @@ class ActionInfo(ComponentInfo):
mode_enable: ChatMode = ChatMode.ALL
parallel_action: bool = False
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
# 二步Action相关属性
is_two_step_action: bool = False # 是否为二步Action
step_one_description: str = "" # 第一步的描述
sub_actions: List[Tuple[str, str, Dict[str, str]]] = field(default_factory=list) # 子Action列表
def __post_init__(self):
super().__post_init__()
@@ -153,6 +157,8 @@ class ActionInfo(ComponentInfo):
self.action_require = []
if self.associated_types is None:
self.associated_types = []
if self.sub_actions is None:
self.sub_actions = []
self.component_type = ComponentType.ACTION

View File

@@ -8,12 +8,10 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
__all__ = [
"plugin_manager",
"component_registry",
"event_manager",
"global_announcement_manager",
"hot_reload_manager",
]

View File

@@ -1,512 +0,0 @@
"""
插件热重载模块
使用 Watchdog 监听插件目录变化,自动重载插件
"""
import os
import sys
import time
import importlib
from pathlib import Path
from threading import Thread
from typing import Dict, Set, List, Optional, Tuple
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from src.common.logger import get_logger
from .plugin_manager import plugin_manager
logger = get_logger("plugin_hot_reload")
class PluginFileHandler(FileSystemEventHandler):
"""插件文件变化处理器"""
def __init__(self, hot_reload_manager):
super().__init__()
self.hot_reload_manager = hot_reload_manager
self.pending_reloads: Set[str] = set() # 待重载的插件名称
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
self.debounce_delay = 2.0 # 增加防抖延迟到2秒确保文件写入完成
self.file_change_cache: Dict[str, float] = {} # 文件变化缓存
def on_modified(self, event):
"""文件修改事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "modified")
def on_created(self, event):
"""文件创建事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "created")
def on_deleted(self, event):
"""文件删除事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "deleted")
def _handle_file_change(self, file_path: str, change_type: str):
"""处理文件变化"""
try:
# 获取插件名称
plugin_info = self._get_plugin_info_from_path(file_path)
if not plugin_info:
return
plugin_name, source_type = plugin_info
current_time = time.time()
# 文件变化缓存,避免重复处理同一文件的快速连续变化
file_cache_key = f"{file_path}_{change_type}"
last_file_time = self.file_change_cache.get(file_cache_key, 0)
if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略
return
self.file_change_cache[file_cache_key] = current_time
# 插件级别的防抖处理
last_plugin_time = self.last_reload_time.get(plugin_name, 0)
if current_time - last_plugin_time < self.debounce_delay:
# 如果在防抖期内,更新待重载标记但不立即处理
self.pending_reloads.add(plugin_name)
return
file_name = Path(file_path).name
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type}) [{source_type}] -> {plugin_name}")
# 如果是删除事件,处理关键文件删除
if change_type == "deleted":
# 解析实际的插件名称
actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name)
if file_name == "plugin.py":
if actual_plugin_name in plugin_manager.loaded_plugins:
logger.info(
f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
self.hot_reload_manager._unload_plugin(actual_plugin_name)
else:
logger.info(
f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
return
elif file_name in ("manifest.toml", "_manifest.json"):
if actual_plugin_name in plugin_manager.loaded_plugins:
logger.info(
f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
self.hot_reload_manager._unload_plugin(actual_plugin_name)
else:
logger.info(
f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
return
# 对于修改和创建事件,都进行重载
# 添加到待重载列表
self.pending_reloads.add(plugin_name)
self.last_reload_time[plugin_name] = current_time
# 延迟重载,确保文件写入完成
reload_thread = Thread(
target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True
)
reload_thread.start()
except Exception as e:
logger.error(f"❌ 处理文件变化时发生错误: {e}", exc_info=True)
def _delayed_reload(self, plugin_name: str, source_type: str, trigger_time: float):
"""延迟重载插件"""
try:
# 等待文件写入完成
time.sleep(self.debounce_delay)
# 检查是否还需要重载(可能在等待期间有更新的变化)
if plugin_name not in self.pending_reloads:
return
# 检查是否有更新的重载请求
if self.last_reload_time.get(plugin_name, 0) > trigger_time:
return
self.pending_reloads.discard(plugin_name)
logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]")
# 执行深度重载
success = self.hot_reload_manager._deep_reload_plugin(plugin_name)
if success:
logger.info(f"✅ 插件重载成功: {plugin_name} [{source_type}]")
else:
logger.error(f"❌ 插件重载失败: {plugin_name} [{source_type}]")
except Exception as e:
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]:
"""从文件路径获取插件信息
Returns:
tuple[插件名称, 源类型] 或 None
"""
try:
path = Path(file_path)
# 检查是否在任何一个监听的插件目录中
for watch_dir in self.hot_reload_manager.watch_directories:
plugin_root = Path(watch_dir)
if path.is_relative_to(plugin_root):
# 确定源类型
if "src" in str(plugin_root):
source_type = "built-in"
else:
source_type = "external"
# 获取插件目录名(插件名)
relative_path = path.relative_to(plugin_root)
if len(relative_path.parts) == 0:
continue
plugin_name = relative_path.parts[0]
# 确认这是一个有效的插件目录
plugin_dir = plugin_root / plugin_name
if plugin_dir.is_dir():
# 检查是否有插件主文件或配置文件
has_plugin_py = (plugin_dir / "plugin.py").exists()
has_manifest = (plugin_dir / "manifest.toml").exists() or (
plugin_dir / "_manifest.json"
).exists()
if has_plugin_py or has_manifest:
return plugin_name, source_type
return None
except Exception:
return None
class PluginHotReloadManager:
"""插件热重载管理器"""
def __init__(self, watch_directories: Optional[List[str]] = None):
if watch_directories is None:
# 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录
self.watch_directories = [
os.path.join(os.getcwd(), "plugins"), # 外部插件目录
os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录
]
else:
self.watch_directories = watch_directories
self.observers = []
self.file_handlers = []
self.is_running = False
# 确保监听目录存在
for watch_dir in self.watch_directories:
if not os.path.exists(watch_dir):
os.makedirs(watch_dir, exist_ok=True)
logger.info(f"📁 创建插件监听目录: {watch_dir}")
def start(self):
"""启动热重载监听"""
if self.is_running:
logger.warning("插件热重载已经在运行中")
return
try:
# 为每个监听目录创建独立的观察者
for watch_dir in self.watch_directories:
observer = Observer()
file_handler = PluginFileHandler(self)
observer.schedule(file_handler, watch_dir, recursive=True)
observer.start()
self.observers.append(observer)
self.file_handlers.append(file_handler)
self.is_running = True
# 打印监听的目录信息
dir_info = []
for watch_dir in self.watch_directories:
if "src" in watch_dir:
dir_info.append(f"{watch_dir} (内置插件)")
else:
dir_info.append(f"{watch_dir} (外部插件)")
logger.info("🚀 插件热重载已启动,监听目录:")
for info in dir_info:
logger.info(f" 📂 {info}")
except Exception as e:
logger.error(f"❌ 启动插件热重载失败: {e}")
self.stop() # 清理已创建的观察者
self.is_running = False
def stop(self):
"""停止热重载监听"""
if not self.is_running and not self.observers:
return
# 停止所有观察者
for observer in self.observers:
try:
observer.stop()
observer.join()
except Exception as e:
logger.error(f"❌ 停止观察者时发生错误: {e}")
self.observers.clear()
self.file_handlers.clear()
self.is_running = False
logger.info("🛑 插件热重载已停止")
def _reload_plugin(self, plugin_name: str):
"""重载指定插件(简单重载)"""
try:
# 解析实际的插件名称
actual_plugin_name = self._resolve_plugin_name(plugin_name)
logger.info(f"🔄 开始简单重载插件: {plugin_name} -> {actual_plugin_name}")
if plugin_manager.reload_plugin(actual_plugin_name):
logger.info(f"✅ 插件简单重载成功: {actual_plugin_name}")
return True
else:
logger.error(f"❌ 插件简单重载失败: {actual_plugin_name}")
return False
except Exception as e:
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
@staticmethod
def _resolve_plugin_name(folder_name: str) -> str:
"""
将文件夹名称解析为实际的插件名称
通过检查插件管理器中的路径映射来找到对应的插件名
"""
# 首先检查是否直接匹配
if folder_name in plugin_manager.plugin_classes:
logger.debug(f"🔍 直接匹配插件名: {folder_name}")
return folder_name
# 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称
matched_plugins = []
for plugin_name, plugin_path in plugin_manager.plugin_paths.items():
# 检查路径是否包含该文件夹名
if folder_name in plugin_path:
matched_plugins.append((plugin_name, plugin_path))
# 在匹配的插件中,优先选择在插件类中存在的
for plugin_name, plugin_path in matched_plugins:
if plugin_name in plugin_manager.plugin_classes:
logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})")
return plugin_name
# 如果还是没找到在插件类中存在的,返回第一个匹配项
if matched_plugins:
plugin_name, plugin_path = matched_plugins[0]
logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在")
return plugin_name
# 如果还是没找到,返回原文件夹名
logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称")
return folder_name
def _deep_reload_plugin(self, plugin_name: str):
"""深度重载指定插件(清理模块缓存)"""
try:
# 解析实际的插件名称
actual_plugin_name = self._resolve_plugin_name(plugin_name)
logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}")
# 强制清理相关模块缓存
self._force_clear_plugin_modules(plugin_name)
# 使用插件管理器的强制重载功能
success = plugin_manager.force_reload_plugin(actual_plugin_name)
if success:
logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}")
return True
else:
logger.error(f"❌ 插件深度重载失败,尝试简单重载: {actual_plugin_name}")
# 如果深度重载失败,尝试简单重载
return self._reload_plugin(actual_plugin_name)
except Exception as e:
logger.error(f"❌ 深度重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
# 出错时尝试简单重载
return self._reload_plugin(plugin_name)
@staticmethod
def _force_clear_plugin_modules(plugin_name: str):
"""强制清理插件相关的模块缓存"""
# 找到所有相关的模块名
modules_to_remove = []
plugin_module_prefix = f"src.plugins.built_in.{plugin_name}"
for module_name in list(sys.modules.keys()):
if plugin_module_prefix in module_name:
modules_to_remove.append(module_name)
# 删除模块缓存
for module_name in modules_to_remove:
if module_name in sys.modules:
logger.debug(f"🗑️ 清理模块缓存: {module_name}")
del sys.modules[module_name]
@staticmethod
def _force_reimport_plugin(plugin_name: str):
"""强制重新导入插件(委托给插件管理器)"""
try:
# 使用插件管理器的重载功能
success = plugin_manager.reload_plugin(plugin_name)
return success
except Exception as e:
logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
@staticmethod
def _unload_plugin(plugin_name: str):
"""卸载指定插件"""
try:
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
if plugin_manager.unload_plugin(plugin_name):
logger.info(f"✅ 插件卸载成功: {plugin_name}")
return True
else:
logger.error(f"❌ 插件卸载失败: {plugin_name}")
return False
except Exception as e:
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
def reload_all_plugins(self):
"""重载所有插件"""
try:
logger.info("🔄 开始深度重载所有插件...")
# 获取当前已加载的插件列表
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
success_count = 0
fail_count = 0
for plugin_name in loaded_plugins:
logger.info(f"🔄 重载插件: {plugin_name}")
if self._deep_reload_plugin(plugin_name):
success_count += 1
else:
fail_count += 1
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count}")
# 清理全局缓存
importlib.invalidate_caches()
except Exception as e:
logger.error(f"❌ 重载所有插件时发生错误: {e}", exc_info=True)
def force_reload_plugin(self, plugin_name: str):
"""手动强制重载指定插件(委托给插件管理器)"""
try:
logger.info(f"🔄 手动强制重载插件: {plugin_name}")
# 清理待重载列表中的该插件(避免重复重载)
for handler in self.file_handlers:
handler.pending_reloads.discard(plugin_name)
# 使用插件管理器的强制重载功能
success = plugin_manager.force_reload_plugin(plugin_name)
if success:
logger.info(f"✅ 手动强制重载成功: {plugin_name}")
else:
logger.error(f"❌ 手动强制重载失败: {plugin_name}")
return success
except Exception as e:
logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
def add_watch_directory(self, directory: str):
"""添加新的监听目录"""
if directory in self.watch_directories:
logger.info(f"目录 {directory} 已在监听列表中")
return
# 确保目录存在
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"📁 创建插件监听目录: {directory}")
self.watch_directories.append(directory)
# 如果热重载正在运行,为新目录创建观察者
if self.is_running:
try:
observer = Observer()
file_handler = PluginFileHandler(self)
observer.schedule(file_handler, directory, recursive=True)
observer.start()
self.observers.append(observer)
self.file_handlers.append(file_handler)
logger.info(f"📂 已添加新的监听目录: {directory}")
except Exception as e:
logger.error(f"❌ 添加监听目录 {directory} 失败: {e}")
self.watch_directories.remove(directory)
def get_status(self) -> dict:
"""获取热重载状态"""
pending_reloads = set()
if self.file_handlers:
for handler in self.file_handlers:
pending_reloads.update(handler.pending_reloads)
return {
"is_running": self.is_running,
"watch_directories": self.watch_directories,
"active_observers": len(self.observers),
"loaded_plugins": len(plugin_manager.loaded_plugins),
"failed_plugins": len(plugin_manager.failed_plugins),
"pending_reloads": list(pending_reloads),
"debounce_delay": self.file_handlers[0].debounce_delay if self.file_handlers else 0,
}
@staticmethod
def clear_all_caches():
"""清理所有Python模块缓存"""
try:
logger.info("🧹 开始清理所有Python模块缓存...")
# 重新扫描所有插件目录,这会重新加载模块
plugin_manager.rescan_plugin_directory()
logger.info("✅ 模块缓存清理完成")
except Exception as e:
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
# 全局热重载管理器实例
hot_reload_manager = PluginHotReloadManager()

View File

@@ -55,6 +55,10 @@ class ToolExecutor:
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 二步工具调用状态管理
self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {}
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
logger.info(f"{self.log_prefix}工具执行器初始化完成")
async def execute_from_chat_message(
@@ -112,7 +116,18 @@ class ToolExecutor:
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [definition for name, definition in all_tools if name not in user_disabled_tools]
# 获取基础工具定义(包括二步工具的第一步)
tool_definitions = [definition for name, definition in all_tools if name not in user_disabled_tools]
# 检查是否有待处理的二步工具第二步调用
pending_step_two = getattr(self, '_pending_step_two_tools', {})
if pending_step_two:
# 添加第二步工具定义
for tool_name, step_two_def in pending_step_two.items():
tool_definitions.append(step_two_def)
return tool_definitions
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用
@@ -251,6 +266,32 @@ class ToolExecutor:
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
)
function_args["llm_called"] = True # 标记为LLM调用
# 检查是否是二步工具的第二步调用
if "_" in function_name and function_name.count("_") >= 1:
# 可能是二步工具的第二步调用,格式为 "tool_name_sub_tool_name"
parts = function_name.split("_", 1)
if len(parts) == 2:
base_tool_name, sub_tool_name = parts
base_tool_instance = get_tool_instance(base_tool_name)
if base_tool_instance and base_tool_instance.is_two_step_tool:
logger.info(f"{self.log_prefix}执行二步工具第二步: {base_tool_name}.{sub_tool_name}")
result = await base_tool_instance.execute_step_two(sub_tool_name, function_args)
# 清理待处理的第二步工具
self._pending_step_two_tools.pop(base_tool_name, None)
if result:
logger.debug(f"{self.log_prefix}二步工具第二步 {function_name} 执行成功")
return {
"tool_call_id": tool_call.call_id,
"role": "tool",
"name": function_name,
"type": "function",
"content": result.get("content", ""),
}
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
if not tool_instance:
@@ -260,6 +301,16 @@ class ToolExecutor:
# 执行工具并记录日志
logger.debug(f"{self.log_prefix}执行工具 {function_name},参数: {function_args}")
result = await tool_instance.execute(function_args)
# 检查是否是二步工具的第一步结果
if result and result.get("type") == "two_step_tool_step_one":
logger.info(f"{self.log_prefix}二步工具第一步完成: {function_name}")
# 保存第二步工具定义
next_tool_def = result.get("next_tool_definition")
if next_tool_def:
self._pending_step_two_tools[function_name] = next_tool_def
logger.debug(f"{self.log_prefix}已保存第二步工具定义: {next_tool_def['name']}")
if result:
logger.debug(f"{self.log_prefix}工具 {function_name} 执行成功,结果: {result}")
return {

View File

@@ -91,7 +91,6 @@ INSTALL_NAME_TO_IMPORT_NAME = {
"pyusb": "usb", # USB访问
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
"psutil": "psutil", # 系统信息和进程管理
"watchdog": "watchdog", # 文件系统事件监控
"python-gnupg": "gnupg", # GnuPG的Python接口
# ============== 加密与安全 (Cryptography & Security) ==============
"pycrypto": "Crypto", # 加密库 (较旧)

View File

@@ -88,25 +88,27 @@ class MaiZoneRefactoredPlugin(BasePlugin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def on_plugin_loaded(self):
"""插件加载完成后的回调,初始化服务并启动后台任务"""
# --- 注册权限节点 ---
await permission_api.register_permission_node(
"plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False
)
await permission_api.register_permission_node(
"plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True
)
# 创建所有服务实例
# --- 创建并注册所有服务实例 ---
content_service = ContentService(self.get_config)
image_service = ImageService(self.get_config)
cookie_service = CookieService(self.get_config)
reply_tracker_service = ReplyTrackerService()
# 使用已创建的 reply_tracker_service 实例
qzone_service = QZoneService(
self.get_config,
content_service,
image_service,
cookie_service,
reply_tracker_service, # 传入已创建的实例
reply_tracker_service,
)
scheduler_service = SchedulerService(self.get_config, qzone_service)
monitor_service = MonitorService(self.get_config, qzone_service)
@@ -115,18 +117,12 @@ class MaiZoneRefactoredPlugin(BasePlugin):
register_service("reply_tracker", reply_tracker_service)
register_service("get_config", self.get_config)
# 保存服务引用以便后续启动
self.scheduler_service = scheduler_service
self.monitor_service = monitor_service
logger.info("MaiZone重构版插件服务已注册。")
logger.info("MaiZone重构版插件已加载服务已注册。")
async def on_plugin_loaded(self):
"""插件加载完成后的回调,启动异步服务"""
if hasattr(self, "scheduler_service") and hasattr(self, "monitor_service"):
asyncio.create_task(self.scheduler_service.start())
asyncio.create_task(self.monitor_service.start())
logger.info("MaiZone后台任务已启动。")
# --- 启动后台任务 ---
asyncio.create_task(scheduler_service.start())
asyncio.create_task(monitor_service.start())
logger.info("MaiZone后台监控和定时任务已启动。")
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [

View File

@@ -113,31 +113,32 @@ class CookieService:
async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
"""
获取Cookie按以下顺序尝试
1. Adapter API
2. HTTP备用端点
3. 本地文件缓存
1. HTTP备用端点 (更稳定)
2. 本地文件缓存
3. Adapter API (作为最后手段)
"""
# 1. 尝试从Adapter获取
cookies = await self._get_cookies_from_adapter(stream_id)
if cookies:
logger.info("成功从Adapter获取Cookie。")
self._save_cookies_to_file(qq_account, cookies)
return cookies
# 2. 尝试从HTTP备用端点获取
logger.warning("从Adapter获取Cookie失败尝试使用HTTP备用地址。")
# 1. 尝试从HTTP备用端点获取
logger.info(f"开始尝试从HTTP备用地址获取 {qq_account} 的Cookie...")
cookies = await self._get_cookies_from_http()
if cookies:
logger.info("成功从HTTP备用地址获取Cookie。")
logger.info(f"成功从HTTP备用地址{qq_account} 获取Cookie。")
self._save_cookies_to_file(qq_account, cookies)
return cookies
# 3. 尝试从本地文件加载
logger.warning("从HTTP备用地址获取Cookie失败尝试加载本地缓存。")
# 2. 尝试从本地文件加载
logger.warning(f"从HTTP备用地址获取 {qq_account}Cookie失败尝试加载本地缓存。")
cookies = self._load_cookies_from_file(qq_account)
if cookies:
logger.info("成功从本地文件加载缓存的Cookie。")
logger.info(f"成功从本地文件{qq_account} 加载缓存的Cookie。")
return cookies
logger.error("所有Cookie获取方法均失败。")
# 3. 尝试从Adapter获取 (作为最后的备用方案)
logger.warning(f"从本地缓存加载 {qq_account} 的Cookie失败最后尝试使用Adapter API。")
cookies = await self._get_cookies_from_adapter(stream_id)
if cookies:
logger.info(f"成功从Adapter API为 {qq_account} 获取Cookie。")
self._save_cookies_to_file(qq_account, cookies)
return cookies
logger.error(f"{qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作或存在有效的本地Cookie文件。")
return None

View File

@@ -409,8 +409,9 @@ class QZoneService:
cookie_dir.mkdir(exist_ok=True)
cookie_file_path = cookie_dir / f"cookies-{qq_account}.json"
# 优先尝试通过Napcat HTTP服务获取最新的Cookie
try:
# 使用HTTP服务器方式获取Cookie
logger.info("尝试通过Napcat HTTP服务获取Cookie...")
host = self.get_config("cookie.http_fallback_host", "172.20.130.55")
port = self.get_config("cookie.http_fallback_port", "9999")
napcat_token = self.get_config("cookie.napcat_token", "")
@@ -421,23 +422,43 @@ class QZoneService:
parsed_cookies = {
k.strip(): v.strip() for k, v in (p.split("=", 1) for p in cookie_str.split("; ") if "=" in p)
}
with open(cookie_file_path, "wb") as f:
f.write(orjson.dumps(parsed_cookies))
logger.info(f"Cookie已更新并保存至: {cookie_file_path}")
# 成功获取后,异步写入本地文件作为备份
try:
with open(cookie_file_path, "wb") as f:
f.write(orjson.dumps(parsed_cookies))
logger.info(f"通过Napcat服务成功更新Cookie并已保存至: {cookie_file_path}")
except Exception as e:
logger.warning(f"保存Cookie到文件时出错: {e}")
return parsed_cookies
else:
logger.warning("通过Napcat服务未能获取有效Cookie。")
# 如果HTTP获取失败尝试读取本地文件
if cookie_file_path.exists():
with open(cookie_file_path, "rb") as f:
return orjson.loads(f.read())
return None
except Exception as e:
logger.error(f"更新或加载Cookie时发生异常: {e}")
return None
logger.warning(f"通过Napcat HTTP服务获取Cookie时发生异常: {e}。将尝试从本地文件加载。")
async def _fetch_cookies_http(self, host: str, port: str, napcat_token: str) -> Optional[Dict]:
# 如果通过服务获取失败,则尝试从本地文件加载
logger.info("尝试从本地Cookie文件加载...")
if cookie_file_path.exists():
try:
with open(cookie_file_path, "rb") as f:
cookies = orjson.loads(f.read())
logger.info(f"成功从本地文件加载Cookie: {cookie_file_path}")
return cookies
except Exception as e:
logger.error(f"从本地文件 {cookie_file_path} 读取或解析Cookie失败: {e}")
else:
logger.warning(f"本地Cookie文件不存在: {cookie_file_path}")
logger.error("所有获取Cookie的方式均失败。")
return None
async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> Optional[Dict]:
"""通过HTTP服务器获取Cookie"""
url = f"http://{host}:{port}/get_cookies"
# 从配置中读取主机和端口,如果未提供则使用传入的参数
final_host = self.get_config("cookie.http_fallback_host", host)
final_port = self.get_config("cookie.http_fallback_port", port)
url = f"http://{final_host}:{final_port}/get_cookies"
max_retries = 5
retry_delay = 1
@@ -481,14 +502,19 @@ class QZoneService:
async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]:
cookies = await self.cookie_service.get_cookies(qq_account, stream_id)
if not cookies:
logger.error("获取API客户端失败未能获取到Cookie。请检查Napcat连接是否正常或是否存在有效的本地Cookie文件。")
return None
p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper())
if not p_skey:
logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
return None
gtk = self._generate_gtk(p_skey)
uin = cookies.get("uin", "").lstrip("o")
if not uin:
logger.error(f"获取API客户端失败Cookie中缺少关键的 'uin'。Cookie内容: {cookies}")
return None
async def _request(method, url, params=None, data=None, headers=None):
final_headers = {"referer": f"https://user.qzone.qq.com/{uin}", "origin": "https://user.qzone.qq.com"}

View File

@@ -185,9 +185,13 @@ class SendHandler:
logger.info(f"执行适配器命令: {action}")
# 直接向Napcat发送命令并获取响应
response_task = asyncio.create_task(self.send_message_to_napcat(action, params))
response = await response_task
# 根据action决定处理方式
if action == "get_cookies":
# 对于get_cookies我们需要一个更长的超时时间
response = await self.send_message_to_napcat(action, params, timeout=40.0)
else:
# 对于其他命令,使用默认超时
response = await self.send_message_to_napcat(action, params)
# 发送响应回MaiBot
await self.send_adapter_command_response(raw_message_base, response, request_id)
@@ -196,6 +200,8 @@ class SendHandler:
logger.info(f"适配器命令 {action} 执行成功")
else:
logger.warning(f"适配器命令 {action} 执行失败napcat返回{str(response)}")
# 无论成功失败,都记录下完整的响应内容以供调试
logger.debug(f"适配器命令 {action} 的完整响应: {response}")
except Exception as e:
logger.error(f"处理适配器命令时发生错误: {e}")
@@ -583,7 +589,7 @@ class SendHandler:
},
)
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict:
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
@@ -595,9 +601,9 @@ class SendHandler:
try:
await connection.send(payload)
response = await get_response(request_uuid)
response = await get_response(request_uuid, timeout=timeout) # 使用传入的超时时间
except TimeoutError:
logger.error("发送消息超时,未收到响应")
logger.error(f"发送消息超时{timeout}秒),未收到响应: action={action}, params={params}")
return {"status": "error", "message": "timeout"}
except Exception as e:
logger.error(f"发送消息失败: {e}")

View File

@@ -15,7 +15,6 @@ from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.utils.permission_decorators import require_permission
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
class ManagementCommand(PlusCommand):
@@ -78,10 +77,6 @@ class ManagementCommand(PlusCommand):
await self._force_reload_plugin(args[1])
elif action in ["add_dir", "添加目录"] and len(args) > 1:
await self._add_dir(args[1])
elif action in ["hotreload_status", "热重载状态"]:
await self._show_hotreload_status()
elif action in ["clear_cache", "清理缓存"]:
await self._clear_all_caches()
else:
await self.send_text("❌ 插件管理命令不合法\n使用 /pm plugin help 查看帮助")
return False, "命令不合法", True
@@ -179,14 +174,9 @@ class ManagementCommand(PlusCommand):
• `/pm plugin force_reload <插件名>` - 强制重载指定插件(深度清理)
• `/pm plugin add_dir <目录路径>` - 添加插件目录
<EFBFBD> 热重载管理:
• `/pm plugin hotreload_status` - 查看热重载状态
• `/pm plugin clear_cache` - 清理所有模块缓存
<EFBFBD>📝 示例:
• `/pm plugin load echo_example`
• `/pm plugin force_reload permission_manager_plugin`
• `/pm plugin clear_cache`"""
• `/pm plugin force_reload permission_manager_plugin`"""
elif target == "component":
help_msg = """🧩 组件管理命令帮助
@@ -262,7 +252,7 @@ class ManagementCommand(PlusCommand):
await self.send_text(f"🔄 开始强制重载插件: `{plugin_name}`...")
try:
success = hot_reload_manager.force_reload_plugin(plugin_name)
success = plugin_manage_api.force_reload_plugin(plugin_name)
if success:
await self.send_text(f"✅ 插件强制重载成功: `{plugin_name}`")
else:
@@ -270,43 +260,6 @@ class ManagementCommand(PlusCommand):
except Exception as e:
await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}")
async def _show_hotreload_status(self):
"""显示热重载状态"""
try:
status = hot_reload_manager.get_status()
status_text = f"""🔄 **热重载系统状态**
🟢 **运行状态:** {"运行中" if status["is_running"] else "已停止"}
📂 **监听目录:** {len(status["watch_directories"])}
👁️ **活跃观察者:** {status["active_observers"]}
📦 **已加载插件:** {status["loaded_plugins"]}
❌ **失败插件:** {status["failed_plugins"]}
⏱️ **防抖延迟:** {status.get("debounce_delay", 0)}
📋 **监听的目录:**"""
for i, watch_dir in enumerate(status["watch_directories"], 1):
dir_type = "(内置插件)" if "src" in watch_dir else "(外部插件)"
status_text += f"\n{i}. `{watch_dir}` {dir_type}"
if status.get("pending_reloads"):
status_text += f"\n\n⏳ **待重载插件:** {', '.join([f'`{p}`' for p in status['pending_reloads']])}"
await self.send_text(status_text)
except Exception as e:
await self.send_text(f"❌ 获取热重载状态时发生错误: {str(e)}")
async def _clear_all_caches(self):
"""清理所有模块缓存"""
await self.send_text("🧹 开始清理所有Python模块缓存...")
try:
hot_reload_manager.clear_all_caches()
await self.send_text("✅ 模块缓存清理完成!建议重载相关插件以确保生效。")
except Exception as e:
await self.send_text(f"❌ 清理缓存时发生错误: {str(e)}")
async def _add_dir(self, dir_path: str):
"""添加插件目录"""

View File

@@ -60,10 +60,12 @@ class ReminderTask(AsyncTask):
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
extra_info = f"现在是提醒时间,请你以一种符合你人设的、俏皮的方式提醒 {self.target_user_name}\n提醒内容: {self.event_details}\n设置提醒的人: {self.creator_name}"
last_message = self.chat_stream.context_manager.context.get_last_message()
reply_message_dict = last_message.flatten() if last_message else None
success, reply_set, _ = await generator_api.generate_reply(
chat_stream=self.chat_stream,
extra_info=extra_info,
reply_message=self.chat_stream.context_manager.context.get_last_message().to_dict(),
reply_message=reply_message_dict,
request_type="plugin.reminder.remind_message",
)
@@ -150,9 +152,11 @@ class PokeAction(BaseAction):
action_require = ["当需要戳某个用户时使用", "当你想提醒特定用户时使用"]
llm_judge_prompt = """
判定是否需要使用戳一戳动作的条件:
1. 用户明确要求使用戳一戳
2. 你想以一种有趣的方式提醒或与某人互动
3. 上下文明确需要你戳一个或多个人
1. **关键**: 这是一个高消耗的动作,请仅在绝对必要时使用,例如用户明确要求或作为提醒的关键部分。请极其谨慎地使用
2. **用户请求**: 用户明确要求使用戳一戳
3. **互动提醒**: 你想以一种有趣的方式提醒或与某人互动,但请确保这是对话的自然延伸,而不是无故打扰
4. **上下文需求**: 上下文明确需要你戳一个或多个人。
5. **频率限制**: 如果最近已经戳过,或者用户情绪不高,请绝对不要使用。
请回答""""
"""
@@ -217,7 +221,6 @@ class SetEmojiLikeAction(BaseAction):
emoji_options.append(match.group(1))
action_parameters = {
"emoji": f"要回应的表情,必须从以下表情中选择: {', '.join(emoji_options)}",
"set": "是否设置回应 (True/False)",
}
action_require = [
@@ -238,6 +241,7 @@ class SetEmojiLikeAction(BaseAction):
async def execute(self) -> Tuple[bool, str]:
"""执行设置表情回应的动作"""
message_id = None
set_like = self.action_data.get("set", True)
if self.has_action_message:
logger.debug(str(self.action_message))
if isinstance(self.action_message, dict):
@@ -251,24 +255,49 @@ class SetEmojiLikeAction(BaseAction):
action_done=False,
)
return False, "未提供消息ID"
available_models = llm_api.get_available_models()
if "utils_small" not in available_models:
logger.error("未找到 'utils_small' 模型配置,无法选择表情")
return False, "表情选择功能配置错误"
emoji_input = self.action_data.get("emoji")
set_like = self.action_data.get("set", True)
model_to_use = available_models["utils_small"]
if not emoji_input:
logger.error("未提供表情")
return False, "未提供表情"
logger.info(f"设置表情回应: {emoji_input}, 是否设置: {set_like}")
# 获取最近的对话历史作为上下文
context_text = ""
if self.action_message:
context_text = self.action_message.get("processed_plain_text", "")
else:
logger.error("无法找到动作选择的原始消息")
return False, "无法找到动作选择的原始消息"
emoji_id = get_emoji_id(emoji_input)
if not emoji_id:
logger.error(f"找不到表情: '{emoji_input}'。请从可用列表中选择。")
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"执行了set_emoji_like动作{self.action_name},失败: 找不到表情: '{emoji_input}'",
action_done=False,
prompt = (
f"根据以下这条消息,从列表中选择一个最合适的表情名称来回应这条消息。\n"
f"消息内容: '{context_text}'\n"
f"可用表情列表: {', '.join(self.emoji_options)}\n"
f"你的任务是:只输出你选择的表情的名称,不要包含任何其他文字或标点。\n"
f"例如,如果觉得应该用'',就只输出''"
)
return False, f"找不到表情: '{emoji_input}'。请从可用列表中选择。"
success, response, _, _ = await llm_api.generate_with_model(
prompt, model_config=model_to_use, request_type="plugin.set_emoji_like.select_emoji"
)
if not success or not response:
logger.error("二级LLM未能选择有效的表情。")
return False, "无法找到合适的表情。"
chosen_emoji_name = response.strip()
logger.info(f"二级LLM选择的表情是: '{chosen_emoji_name}'")
emoji_id = get_emoji_id(chosen_emoji_name)
if not emoji_id:
logger.error(f"二级LLM选择的表情 '{chosen_emoji_name}' 仍然无法匹配到有效的表情ID。")
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"执行了set_emoji_like动作{self.action_name},失败: 找不到表情: '{chosen_emoji_name}'",
action_done=False,
)
return False, f"找不到表情: '{chosen_emoji_name}'"
# 4. 使用适配器API发送命令
if not message_id:
@@ -291,7 +320,7 @@ class SetEmojiLikeAction(BaseAction):
logger.info("设置表情回应成功")
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"执行了set_emoji_like动作,{emoji_input},设置表情回应: {emoji_id}, 是否设置: {set_like}",
action_prompt_display=f"执行了set_emoji_like动作,{chosen_emoji_name},设置表情回应: {emoji_id}, 是否设置: {set_like}",
action_done=True,
)
return True, "成功设置表情回应"

View File

@@ -28,20 +28,20 @@ class PlanManager:
if target_month is None:
target_month = datetime.now().strftime("%Y-%m")
if not has_active_plans(target_month):
if not await has_active_plans(target_month):
logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。")
generation_successful = await self._generate_monthly_plans_logic(target_month)
return generation_successful
else:
logger.info(f"{target_month} 已存在有效的月度计划。")
plans = get_active_plans_for_month(target_month)
plans = await get_active_plans_for_month(target_month)
max_plans = global_config.planning_system.max_plans_per_month
if len(plans) > max_plans:
logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。")
plans_to_delete = plans[: len(plans) - max_plans]
delete_ids = [p.id for p in plans_to_delete]
delete_plans_by_ids(delete_ids) # type: ignore
plans = get_active_plans_for_month(target_month)
await delete_plans_by_ids(delete_ids) # type: ignore
plans = await get_active_plans_for_month(target_month)
if plans:
plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)])
@@ -64,11 +64,11 @@ class PlanManager:
return False
last_month = self._get_previous_month(target_month)
archived_plans = get_archived_plans_for_month(last_month)
archived_plans = await get_archived_plans_for_month(last_month)
plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans)
if plans:
add_new_plans(plans, target_month)
await add_new_plans(plans, target_month)
logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。")
return True
else:
@@ -95,11 +95,11 @@ class PlanManager:
if target_month is None:
target_month = datetime.now().strftime("%Y-%m")
logger.info(f" 开始归档 {target_month} 的活跃月度计划...")
archived_count = archive_active_plans_for_month(target_month)
archived_count = await archive_active_plans_for_month(target_month)
logger.info(f" 成功归档了 {archived_count}{target_month} 的月度计划。")
except Exception as e:
logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}")
def get_plans_for_schedule(self, month: str, max_count: int) -> List:
async def get_plans_for_schedule(self, month: str, max_count: int) -> List:
avoid_days = global_config.planning_system.avoid_repetition_days
return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)

View File

@@ -255,25 +255,34 @@ max_context_emojis = 30 # 每次随机传递给LLM的表情包详细描述的最
[memory]
enable_memory = true # 是否启用记忆系统
memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低MoFox-Bot学习越多但是冗余信息也会增多
memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
memory_build_sample_length = 30 # 采样长度,数值越高一段记忆内容越丰富
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
forget_memory_interval = 3000 # 记忆遗忘间隔 单位秒 间隔越低MoFox-Bot遗忘越频繁记忆更精简但更难学习
memory_forget_time = 48 #多长时间后的记忆会被遗忘 单位小时
memory_forget_percentage = 0.008 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认
consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低MoFox-Bot整合越频繁记忆更精简
consolidation_similarity_threshold = 0.7 # 相似度阈值
consolidation_check_percentage = 0.05 # 检查节点比例
enable_instant_memory = false # 是否启用即时记忆,测试功能,可能存在未知问题
enable_instant_memory = true # 是否启用即时记忆
enable_llm_instant_memory = true # 是否启用基于LLM的瞬时记忆
enable_vector_instant_memory = true # 是否启用基于向量的瞬时记忆
enable_enhanced_memory = true # 是否启用增强记忆系统
enhanced_memory_auto_save = true # 是否自动保存增强记忆
#不希望记忆的词,已经记忆的不会受到影响,需要手动清理
memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
min_memory_length = 10 # 最小记忆长度
max_memory_length = 500 # 最大记忆长度
memory_value_threshold = 0.7 # 记忆价值阈值,低于该值的记忆会被丢弃
vector_similarity_threshold = 0.8 # 向量相似度阈值
semantic_similarity_threshold = 0.6 # 语义重排阶段的最低匹配阈值
metadata_filter_limit = 100 # 元数据过滤阶段返回数量上限
vector_search_limit = 50 # 向量搜索阶段返回数量上限
semantic_rerank_limit = 20 # 语义重排阶段返回数量上限
final_result_limit = 10 # 综合筛选后的最终返回数量
vector_weight = 0.4 # 综合评分中向量相似度的权重
semantic_weight = 0.3 # 综合评分中语义匹配的权重
context_weight = 0.2 # 综合评分中上下文关联的权重
recency_weight = 0.1 # 综合评分中时效性的权重
fusion_similarity_threshold = 0.85 # 记忆融合时的相似度阈值
deduplication_window_hours = 24 # 记忆去重窗口(小时)
enable_memory_cache = true # 是否启用本地记忆缓存
cache_ttl_seconds = 300 # 缓存有效期(秒)
max_cache_size = 1000 # 缓存中允许的最大记忆条数
[voice]
enable_asr = true # 是否启用语音识别启用后MoFox-Bot可以识别语音消息启用该功能需要配置语音识别模型[model.voice]

View File

@@ -203,6 +203,7 @@ max_tokens = 1000
#嵌入模型
[model_task_config.embedding]
model_list = ["bge-m3"]
embedding_dimension = 1024