Merge branch 'dev' into dev

This commit is contained in:
拾风
2025-11-07 13:14:27 +08:00
committed by GitHub
98 changed files with 16116 additions and 8718 deletions

View File

@@ -132,6 +132,56 @@ class ExpressionLearner:
self.chat_name = stream_name or self.chat_id
self._chat_name_initialized = True
async def cleanup_expired_expressions(self, expiration_days: int | None = None) -> int:
"""
清理过期的表达方式
Args:
expiration_days: 过期天数,超过此天数未激活的表达方式将被删除(不指定则从配置读取)
Returns:
int: 删除的表达方式数量
"""
# 从配置读取过期天数
if expiration_days is None:
expiration_days = global_config.expression.expiration_days
current_time = time.time()
expiration_threshold = current_time - (expiration_days * 24 * 3600)
try:
deleted_count = 0
async with get_db_session() as session:
# 查询过期的表达方式只清理当前chat_id的
query = await session.execute(
select(Expression).where(
(Expression.chat_id == self.chat_id)
& (Expression.last_active_time < expiration_threshold)
)
)
expired_expressions = list(query.scalars())
if expired_expressions:
for expr in expired_expressions:
await session.delete(expr)
deleted_count += 1
await session.commit()
logger.info(f"清理了 {deleted_count} 个过期表达方式(超过 {expiration_days} 天未使用)")
# 清除缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
await cache.delete(generate_cache_key("chat_expressions", self.chat_id))
else:
logger.debug(f"没有发现过期的表达方式(阈值:{expiration_days} 天)")
return deleted_count
except Exception as e:
logger.error(f"清理过期表达方式失败: {e}")
return 0
def can_learn_for_chat(self) -> bool:
"""
检查指定聊天流是否允许学习表达
@@ -214,6 +264,9 @@ class ExpressionLearner:
try:
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
# 🔥 改进3在学习前清理过期的表达方式
await self.cleanup_expired_expressions()
# 学习语言风格
learnt_style = await self.learn_and_store(type="style", num=25)
@@ -397,9 +450,29 @@ class ExpressionLearner:
for chat_id, expr_list in chat_dict.items():
async with get_db_session() as session:
for new_expr in expr_list:
# 查是否存在相似表达方式
# 注意: get_all_by 不支持复杂条件,这里仍需使用 session
query = await session.execute(
# 🔥 改进1查是否存在相同情景或相同表达的数据
# 情况1相同 chat_id + type + situation相同情景不同表达
query_same_situation = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
)
)
same_situation_expr = query_same_situation.scalar()
# 情况2相同 chat_id + type + style相同表达不同情景
query_same_style = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.style == new_expr["style"])
)
)
same_style_expr = query_same_style.scalar()
# 情况3完全相同相同情景+相同表达)
query_exact_match = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
@@ -407,16 +480,29 @@ class ExpressionLearner:
& (Expression.style == new_expr["style"])
)
)
existing_expr = query.scalar()
if existing_expr:
expr_obj = existing_expr
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
exact_match_expr = query_exact_match.scalar()
# 优先处理完全匹配的情况
if exact_match_expr:
# 完全相同增加count更新时间
expr_obj = exact_match_expr
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
logger.debug(f"完全匹配更新count {expr_obj.count}")
elif same_situation_expr:
# 相同情景,不同表达:覆盖旧的表达
logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{new_expr['style']}'")
same_situation_expr.style = new_expr["style"]
same_situation_expr.count = same_situation_expr.count + 1
same_situation_expr.last_active_time = current_time
elif same_style_expr:
# 相同表达,不同情景:覆盖旧的情景
logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{new_expr['situation']}'")
same_style_expr.situation = new_expr["situation"]
same_style_expr.count = same_style_expr.count + 1
same_style_expr.last_active_time = current_time
else:
# 完全新的表达方式:创建新记录
new_expression = Expression(
situation=new_expr["situation"],
style=new_expr["style"],
@@ -427,6 +513,7 @@ class ExpressionLearner:
create_date=current_time, # 手动设置创建日期
)
session.add(new_expression)
logger.debug(f"新增表达方式:{new_expr['situation']} -> {new_expr['style']}")
# 限制最大数量 - 使用 get_all_by_sorted 获取排序结果
exprs_result = await session.execute(

View File

@@ -61,6 +61,34 @@ class ExpressorModel:
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def remove_candidate(self, cid: str) -> bool:
"""
删除候选文本
Args:
cid: 候选ID
Returns:
是否删除成功
"""
removed = False
if cid in self._candidates:
del self._candidates[cid]
removed = True
if cid in self._situations:
del self._situations[cid]
# 从nb模型中删除
if cid in self.nb.cls_counts:
del self.nb.cls_counts[cid]
if cid in self.nb.token_counts:
del self.nb.token_counts[cid]
return removed
def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分

View File

@@ -36,6 +36,8 @@ class StyleLearner:
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.cleanup_threshold = 0.9 # 达到90%容量时触发清理
self.cleanup_ratio = 0.2 # 每次清理20%的风格
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
self.id_to_style: dict[str, str] = {} # style_id -> style文本
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
@@ -45,6 +47,7 @@ class StyleLearner:
self.learning_stats = {
"total_samples": 0,
"style_counts": {},
"style_last_used": {}, # 记录每个风格最后使用时间
"last_update": time.time(),
}
@@ -66,10 +69,19 @@ class StyleLearner:
if style in self.style_to_id:
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"已达到最大风格数量限制 ({self.max_styles})")
return False
# 检查是否需要清理
current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger:
if current_count >= self.max_styles:
# 已经达到最大限制,必须清理
logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理")
self._cleanup_styles()
elif current_count >= cleanup_trigger:
# 接近限制,提前清理
logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理")
self._cleanup_styles()
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
@@ -94,6 +106,80 @@ class StyleLearner:
logger.error(f"添加风格失败: {e}")
return False
def _cleanup_styles(self):
"""
清理低价值的风格,为新风格腾出空间
清理策略:
1. 综合考虑使用次数和最后使用时间
2. 删除得分最低的风格
3. 默认清理 cleanup_ratio (20%) 的风格
"""
try:
current_time = time.time()
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
# 计算每个风格的价值分数
style_scores = []
for style_id in self.style_to_id.values():
# 使用次数
usage_count = self.learning_stats["style_counts"].get(style_id, 0)
# 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float('inf')
# 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响
import math
usage_score = math.log1p(usage_count) # log(1 + count)
# 时间分数:转换为天数,使用指数衰减
days_unused = time_since_used / 86400 # 转换为天
time_score = math.exp(-days_unused / 30) # 30天衰减因子
# 综合分数80%使用频率 + 20%时间新鲜度
total_score = 0.8 * usage_score + 0.2 * time_score
style_scores.append((style_id, total_score, usage_count, days_unused))
# 按分数排序,分数低的先删除
style_scores.sort(key=lambda x: x[1])
# 删除分数最低的风格
deleted_styles = []
for style_id, score, usage, days in style_scores[:cleanup_count]:
style_text = self.id_to_style.get(style_id)
if style_text:
# 从映射中删除
del self.style_to_id[style_text]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从统计中删除
if style_id in self.learning_stats["style_counts"]:
del self.learning_stats["style_counts"][style_id]
if style_id in self.learning_stats["style_last_used"]:
del self.learning_stats["style_last_used"][style_id]
# 从expressor模型中删除
self.expressor.remove_candidate(style_id)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
logger.info(
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
f"剩余 {len(self.style_to_id)} 个风格"
)
# 记录前5个被删除的风格用于调试
if deleted_styles:
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
except Exception as e:
logger.error(f"清理风格失败: {e}", exc_info=True)
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
@@ -118,9 +204,11 @@ class StyleLearner:
self.expressor.update_positive(up_content, style_id)
# 更新统计
current_time = time.time()
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["last_update"] = time.time()
self.learning_stats["style_last_used"][style_id] = current_time # 更新最后使用时间
self.learning_stats["last_update"] = current_time
logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}")
return True
@@ -171,6 +259,10 @@ class StyleLearner:
else:
logger.warning(f"跳过无法转换的style_id: {sid}")
# 更新最后使用时间(仅针对最佳风格)
if best_style_id:
self.learning_stats["style_last_used"][best_style_id] = time.time()
logger.debug(
f"预测成功: up_content={up_content[:30]}..., "
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
@@ -208,6 +300,30 @@ class StyleLearner:
"""
return list(self.style_to_id.keys())
def cleanup_old_styles(self, ratio: float | None = None) -> int:
"""
手动清理旧风格
Args:
ratio: 清理比例如果为None则使用默认的cleanup_ratio
Returns:
清理的风格数量
"""
old_count = len(self.style_to_id)
if ratio is not None:
old_cleanup_ratio = self.cleanup_ratio
self.cleanup_ratio = ratio
self._cleanup_styles()
self.cleanup_ratio = old_cleanup_ratio
else:
self._cleanup_styles()
new_count = len(self.style_to_id)
cleaned = old_count - new_count
logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格")
return cleaned
def apply_decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -241,6 +357,11 @@ class StyleLearner:
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
# 确保 learning_stats 包含所有必要字段
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
meta_data = {
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
@@ -295,6 +416,10 @@ class StyleLearner:
self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"]
# 确保旧数据兼容:如果没有 style_last_used 字段,添加它
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
logger.info(f"StyleLearner加载成功: {save_dir}")
return True
@@ -398,6 +523,26 @@ class StyleLearnerManager:
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]:
"""
对所有学习器清理旧风格
Args:
ratio: 清理比例
Returns:
{chat_id: 清理数量}
"""
cleanup_results = {}
for chat_id, learner in self.learners.items():
cleaned = learner.cleanup_old_styles(ratio)
if cleaned > 0:
cleanup_results[chat_id] = cleaned
total_cleaned = sum(cleanup_results.values())
logger.info(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格")
return cleanup_results
def apply_decay_all(self, factor: float | None = None):
"""
对所有学习器应用知识衰减

View File

@@ -1,73 +0,0 @@
"""
简化记忆系统模块
移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制
"""
# 核心数据结构
# 激活器
from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator
from .memory_chunk import (
ConfidenceLevel,
ContentStructure,
ImportanceLevel,
MemoryChunk,
MemoryMetadata,
MemoryType,
create_memory_chunk,
)
# 兼容性别名
from .memory_chunk import MemoryChunk as Memory
# 遗忘引擎
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
from .memory_formatter import format_memories_bracket_style
# 记忆管理器
from .memory_manager import MemoryManager, MemoryResult, memory_manager
# 记忆核心系统
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
__all__ = [
"ConfidenceLevel",
"ContentStructure",
"ForgettingConfig",
"ImportanceLevel",
"Memory", # 兼容性别名
# 激活器
"MemoryActivator",
# 核心数据结构
"MemoryChunk",
# 遗忘引擎
"MemoryForgettingEngine",
# 记忆管理器
"MemoryManager",
"MemoryMetadata",
"MemoryResult",
# 记忆系统
"MemorySystem",
"MemorySystemConfig",
"MemoryType",
# Vector DB存储
"VectorMemoryStorage",
"VectorStorageConfig",
"create_memory_chunk",
"enhanced_memory_activator", # 兼容性别名
# 格式化工具
"format_memories_bracket_style",
"get_memory_forgetting_engine",
"get_memory_system",
"get_vector_memory_storage",
"initialize_memory_system",
"memory_activator",
"memory_manager",
]
# 版本信息
__version__ = "3.0.0"
__author__ = "MoFox Team"
__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制"

View File

@@ -1,240 +0,0 @@
"""
记忆激活器
记忆系统的激活器组件
"""
import difflib
from datetime import datetime
import orjson
from json_repair import repair_json
from src.chat.memory_system.memory_manager import MemoryResult
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> list:
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Memory Activator Prompt ---
memory_activator_prompt = """
你是一个记忆分析器,你需要根据以下信息来进行记忆检索
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆检索的触发词
聊天记录:
{obs_info_text}
用户想要回复的消息:
{target_message}
历史关键词(请避免重复提取这些关键词):
{cached_keywords}
请输出一个json格式包含以下字段
{{
"keywords": ["关键词1", "关键词2", "关键词3",......]
}}
不要输出其他多余内容只输出json格式就好
"""
Prompt(memory_activator_prompt, "memory_activator_prompt")
class MemoryActivator:
"""记忆激活器"""
def __init__(self):
self.key_words_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
self.running_memory = []
self.cached_keywords = set() # 用于缓存历史关键词
self.last_memory_query_time = 0 # 上次查询记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
# 将缓存的关键词转换为字符串用于prompt
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
prompt = await global_prompt_manager.format_prompt(
"memory_activator_prompt",
obs_info_text=chat_history_prompt,
target_message=target_message,
cached_keywords=cached_keywords_str,
)
# 生成关键词
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
prompt, temperature=0.5
)
keywords = list(get_keywords_from_json(response))
# 更新关键词缓存
if keywords:
# 限制缓存大小最多保留10个关键词
if len(self.cached_keywords) > 10:
# 转换为列表,移除最早的关键词
cached_list = list(self.cached_keywords)
self.cached_keywords = set(cached_list[-8:])
# 添加新的关键词到缓存
self.cached_keywords.update(keywords)
logger.debug(f"记忆关键词: {self.cached_keywords}")
# 使用记忆系统获取相关记忆
memory_results = await self._query_unified_memory(keywords, target_message)
# 处理和记忆结果
if memory_results:
for result in memory_results:
# 检查是否已存在相似内容的记忆
exists = any(
m["content"] == result.content
or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
for m in self.running_memory
)
if not exists:
memory_entry = {
"topic": result.memory_type,
"content": result.content,
"timestamp": datetime.fromtimestamp(result.timestamp).isoformat(),
"duration": 1,
"confidence": result.confidence,
"importance": result.importance,
"source": result.source,
"relevance_score": result.relevance_score, # 添加相关度评分
}
self.running_memory.append(memory_entry)
logger.debug(f"添加新记忆: {result.memory_type} - {result.content}")
# 激活时所有已有记忆的duration+1达到3则移除
for m in self.running_memory[:]:
m["duration"] = m.get("duration", 1) + 1
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
# 限制同时加载的记忆条数最多保留最后5条
if len(self.running_memory) > 5:
self.running_memory = self.running_memory[-5:]
return self.running_memory
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
"""查询统一记忆系统"""
try:
# 使用记忆系统
from src.chat.memory_system.memory_system import get_memory_system
memory_system = get_memory_system()
if not memory_system or memory_system.status.value != "ready":
logger.warning("记忆系统未就绪")
return []
# 构建查询上下文
context = {"keywords": keywords, "query_intent": "conversation_response"}
# 查询记忆
memories = await memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="global", # 使用全局作用域
context=context,
limit=5,
)
# 转换为 MemoryResult 格式
memory_results = []
for memory in memories:
result = MemoryResult(
content=memory.display,
memory_type=memory.memory_type.value,
confidence=memory.metadata.confidence.value,
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="unified_memory",
relevance_score=memory.metadata.relevance_score,
)
memory_results.append(result)
logger.debug(f"统一记忆查询返回 {len(memory_results)} 条结果")
return memory_results
except Exception as e:
logger.error(f"查询统一记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
"""
获取即时记忆 - 兼容原有接口(使用统一存储)
"""
try:
# 使用统一存储系统获取相关记忆
from src.chat.memory_system.memory_system import get_memory_system
memory_system = get_memory_system()
if not memory_system or memory_system.status.value != "ready":
return None
context = {"query_intent": "instant_response", "chat_id": chat_id}
memories = await memory_system.retrieve_relevant_memories(
query_text=target_message, user_id="global", context=context, limit=1
)
if memories:
return memories[0].display
return None
except Exception as e:
logger.error(f"获取即时记忆失败: {e}")
return None
def clear_cache(self):
"""清除缓存"""
self.cached_keywords.clear()
self.running_memory.clear()
logger.debug("记忆激活器缓存已清除")
# 创建全局实例
memory_activator = MemoryActivator()
# 兼容性别名
enhanced_memory_activator = memory_activator
init_prompt()

View File

@@ -1,721 +0,0 @@
"""
海马体双峰分布采样器
基于旧版海马体的采样策略,适配新版记忆系统
实现低消耗、高效率的记忆采样模式
"""
import asyncio
import random
import time
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
import numpy as np
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
)
from src.chat.utils.utils import translate_timestamp_to_human_readable
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
# 全局背景任务集合
_background_tasks = set()
@dataclass
class HippocampusSampleConfig:
"""海马体采样配置"""
# 双峰分布参数
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
recent_weight: float = 0.7 # 近期分布权重
distant_mean_hours: float = 48.0 # 远期分布均值(小时)
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
distant_weight: float = 0.3 # 远期分布权重
# 采样参数
total_samples: int = 50 # 总采样数
sample_interval: int = 1800 # 采样间隔(秒)
max_sample_length: int = 30 # 每次采样的最大消息数量
batch_size: int = 5 # 批处理大小
@classmethod
def from_global_config(cls) -> "HippocampusSampleConfig":
"""从全局配置创建海马体采样配置"""
config = global_config.memory.hippocampus_distribution_config
return cls(
recent_mean_hours=config[0],
recent_std_hours=config[1],
recent_weight=config[2],
distant_mean_hours=config[3],
distant_std_hours=config[4],
distant_weight=config[5],
total_samples=global_config.memory.hippocampus_sample_size,
sample_interval=global_config.memory.hippocampus_sample_interval,
max_sample_length=global_config.memory.hippocampus_batch_size,
batch_size=global_config.memory.hippocampus_batch_size,
)
class HippocampusSampler:
"""海马体双峰分布采样器"""
def __init__(self, memory_system=None):
self.memory_system = memory_system
self.config = HippocampusSampleConfig.from_global_config()
self.last_sample_time = 0
self.is_running = False
# 记忆构建模型
self.memory_builder_model: LLMRequest | None = None
# 统计信息
self.sample_count = 0
self.success_count = 0
self.last_sample_results: list[dict[str, Any]] = []
async def initialize(self):
"""初始化采样器"""
try:
# 初始化LLM模型
from src.config.config import model_config
task_config = getattr(model_config.model_task_config, "utils", None)
if task_config:
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
task = asyncio.create_task(self.start_background_sampling())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.info("✅ 海马体采样器初始化成功")
else:
raise RuntimeError("未找到记忆构建模型配置")
except Exception as e:
logger.error(f"❌ 海马体采样器初始化失败: {e}")
raise
def generate_time_samples(self) -> list[datetime]:
"""生成双峰分布的时间采样点"""
# 计算每个分布的样本数
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
distant_samples = max(1, self.config.total_samples - recent_samples)
# 生成两个正态分布的小时偏移
recent_offsets = np.random.normal(
loc=self.config.recent_mean_hours, scale=self.config.recent_std_hours, size=recent_samples
)
distant_offsets = np.random.normal(
loc=self.config.distant_mean_hours, scale=self.config.distant_std_hours, size=distant_samples
)
# 合并两个分布的偏移
all_offsets = np.concatenate([recent_offsets, distant_offsets])
# 转换为时间戳(使用绝对值确保时间点在过去)
base_time = datetime.now()
timestamps = [base_time - timedelta(hours=abs(offset)) for offset in all_offsets]
# 按时间排序(从最早到最近)
return sorted(timestamps)
async def collect_message_samples(self, target_timestamp: float) -> list[dict[str, Any]] | None:
"""收集指定时间戳附近的消息样本"""
try:
# 随机时间窗口5-30分钟
time_window_seconds = random.randint(300, 1800)
# 尝试3次获取消息
for attempt in range(3):
timestamp_start = target_timestamp
timestamp_end = target_timestamp + time_window_seconds
# 获取单条消息作为锚点
anchor_messages = await get_raw_msg_by_timestamp(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=1,
limit_mode="earliest",
)
if not anchor_messages:
target_timestamp -= 120 # 向前调整2分钟
continue
anchor_message = anchor_messages[0]
chat_id = anchor_message.get("chat_id")
if not chat_id:
continue
# 获取同聊天的多条消息
messages = await get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=self.config.max_sample_length,
limit_mode="earliest",
chat_id=chat_id,
)
if messages and len(messages) >= 2: # 至少需要2条消息
# 过滤掉已经记忆过的消息
filtered_messages = [
msg
for msg in messages
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
]
if filtered_messages:
logger.debug(f"成功收集 {len(filtered_messages)} 条消息样本")
return filtered_messages
target_timestamp -= 120 # 向前调整再试
logger.debug(f"时间戳 {target_timestamp} 附近未找到有效消息样本")
return None
except Exception as e:
logger.error(f"收集消息样本失败: {e}")
return None
async def build_memory_from_samples(self, messages: list[dict[str, Any]], target_timestamp: float) -> str | None:
"""从消息样本构建记忆"""
if not messages or not self.memory_system or not self.memory_builder_model:
return None
try:
# 构建可读消息文本
readable_text = await build_readable_messages(
messages,
merge_messages=True,
timestamp_mode="normal_no_YMD",
replace_bot_name=False,
)
if not readable_text:
logger.warning("无法从消息样本生成可读文本")
return None
# 直接使用对话文本,不添加系统标识符
input_text = readable_text
logger.debug(f"开始构建记忆,文本长度: {len(input_text)}")
# 构建上下文
context = {
"user_id": "hippocampus_sampler",
"timestamp": time.time(),
"source": "hippocampus_sampling",
"message_count": len(messages),
"sample_mode": "bimodal_distribution",
"is_hippocampus_sample": True, # 标识为海马体样本
"bypass_value_threshold": True, # 绕过价值阈值检查
"hippocampus_sample_time": target_timestamp, # 记录样本时间
}
# 使用记忆系统构建记忆(绕过构建间隔检查)
memories = await self.memory_system.build_memory_from_conversation(
conversation_text=input_text,
context=context,
timestamp=time.time(),
bypass_interval=True, # 海马体采样器绕过构建间隔限制
)
if memories:
memory_count = len(memories)
self.success_count += 1
# 记录采样结果
result = {
"timestamp": time.time(),
"memory_count": memory_count,
"message_count": len(messages),
"text_preview": readable_text[:100] + "..." if len(readable_text) > 100 else readable_text,
"memory_types": [m.memory_type.value for m in memories],
}
self.last_sample_results.append(result)
# 限制结果历史长度
if len(self.last_sample_results) > 10:
self.last_sample_results.pop(0)
logger.info(f"✅ 海马体采样成功构建 {memory_count} 条记忆")
return f"构建{memory_count}条记忆"
else:
logger.debug("海马体采样未生成有效记忆")
return None
except Exception as e:
logger.error(f"海马体采样构建记忆失败: {e}")
return None
async def perform_sampling_cycle(self) -> dict[str, Any]:
"""执行一次完整的采样周期(优化版:批量融合构建)"""
if not self.should_sample():
return {"status": "skipped", "reason": "interval_not_met"}
start_time = time.time()
self.sample_count += 1
try:
# 生成时间采样点
time_samples = self.generate_time_samples()
logger.debug(f"生成 {len(time_samples)} 个时间采样点")
# 记录时间采样点(调试用)
readable_timestamps = [
translate_timestamp_to_human_readable(int(ts.timestamp()), mode="normal")
for ts in time_samples[:5] # 只显示前5个
]
logger.debug(f"时间采样点示例: {readable_timestamps}")
# 第一步:批量收集所有消息样本
logger.debug("开始批量收集消息样本...")
collected_messages = await self._collect_all_message_samples(time_samples)
if not collected_messages:
logger.info("未收集到有效消息样本,跳过本次采样")
self.last_sample_time = time.time()
return {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": 0,
"duration": time.time() - start_time,
"samples_generated": len(time_samples),
"message": "未收集到有效消息样本",
}
logger.info(f"收集到 {len(collected_messages)} 组消息样本")
# 第二步:融合和去重消息
logger.debug("开始融合和去重消息...")
fused_messages = await self._fuse_and_deduplicate_messages(collected_messages)
if not fused_messages:
logger.info("消息融合后为空,跳过记忆构建")
self.last_sample_time = time.time()
return {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": 0,
"duration": time.time() - start_time,
"samples_generated": len(time_samples),
"message": "消息融合后为空",
}
logger.info(f"融合后得到 {len(fused_messages)} 组有效消息")
# 第三步:一次性构建记忆
logger.debug("开始批量构建记忆...")
build_result = await self._build_batch_memory(fused_messages, time_samples)
# 更新最后采样时间
self.last_sample_time = time.time()
duration = time.time() - start_time
result = {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": build_result.get("memory_count", 0),
"duration": duration,
"samples_generated": len(time_samples),
"messages_collected": len(collected_messages),
"messages_fused": len(fused_messages),
"optimization_mode": "batch_fusion",
}
logger.info(
f"✅ 海马体采样周期完成(批量融合模式) | "
f"采样点: {len(time_samples)} | "
f"收集消息: {len(collected_messages)} | "
f"融合消息: {len(fused_messages)} | "
f"构建记忆: {build_result.get('memory_count', 0)} | "
f"耗时: {duration:.2f}s"
)
return result
except Exception as e:
logger.error(f"❌ 海马体采样周期失败: {e}")
return {
"status": "error",
"error": str(e),
"sample_count": self.sample_count,
"duration": time.time() - start_time,
}
async def _collect_all_message_samples(self, time_samples: list[datetime]) -> list[list[dict[str, Any]]]:
"""批量收集所有时间点的消息样本"""
collected_messages = []
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
for i in range(0, len(time_samples), max_concurrent):
batch = time_samples[i : i + max_concurrent]
tasks = []
# 创建并发收集任务
for timestamp in batch:
target_ts = timestamp.timestamp()
task = self.collect_message_samples(target_ts)
tasks.append(task)
# 执行并发收集
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理收集结果
for result in results:
if isinstance(result, list) and result:
collected_messages.append(result)
elif isinstance(result, Exception):
logger.debug(f"消息收集异常: {result}")
# 批次间短暂延迟
if i + max_concurrent < len(time_samples):
await asyncio.sleep(0.5)
return collected_messages
async def _fuse_and_deduplicate_messages(
self, collected_messages: list[list[dict[str, Any]]]
) -> list[list[dict[str, Any]]]:
"""融合和去重消息样本"""
if not collected_messages:
return []
try:
# 展平所有消息
all_messages = []
for message_group in collected_messages:
all_messages.extend(message_group)
logger.debug(f"展开后总消息数: {len(all_messages)}")
# 去重逻辑:基于消息内容和时间戳
unique_messages = []
seen_hashes = set()
for message in all_messages:
# 创建消息哈希用于去重
content = message.get("processed_plain_text", "") or message.get("display_message", "")
timestamp = message.get("time", 0)
chat_id = message.get("chat_id", "")
# 简单哈希内容前50字符 + 时间戳(精确到分钟) + 聊天ID
hash_key = f"{content[:50]}_{int(timestamp // 60)}_{chat_id}"
if hash_key not in seen_hashes and len(content.strip()) > 10:
seen_hashes.add(hash_key)
unique_messages.append(message)
logger.debug(f"去重后消息数: {len(unique_messages)}")
# 按时间排序
unique_messages.sort(key=lambda x: x.get("time", 0))
# 按聊天ID分组重新组织
chat_groups = {}
for message in unique_messages:
chat_id = message.get("chat_id", "unknown")
if chat_id not in chat_groups:
chat_groups[chat_id] = []
chat_groups[chat_id].append(message)
# 合并相邻时间范围内的消息
fused_groups = []
for chat_id, messages in chat_groups.items():
fused_groups.extend(self._merge_adjacent_messages(messages))
logger.debug(f"融合后消息组数: {len(fused_groups)}")
return fused_groups
except Exception as e:
logger.error(f"消息融合失败: {e}")
# 返回原始消息组作为备选
return collected_messages[:5] # 限制返回数量
def _merge_adjacent_messages(
self, messages: list[dict[str, Any]], time_gap: int = 1800
) -> list[list[dict[str, Any]]]:
"""合并时间间隔内的消息"""
if not messages:
return []
merged_groups = []
current_group = [messages[0]]
for i in range(1, len(messages)):
current_time = messages[i].get("time", 0)
prev_time = current_group[-1].get("time", 0)
# 如果时间间隔小于阈值,合并到当前组
if current_time - prev_time <= time_gap:
current_group.append(messages[i])
else:
# 否则开始新组
merged_groups.append(current_group)
current_group = [messages[i]]
# 添加最后一组
merged_groups.append(current_group)
# 过滤掉只有一条消息的组(除非内容较长)
result_groups = [
group for group in merged_groups
if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group)
]
return result_groups
async def _build_batch_memory(
self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]
) -> dict[str, Any]:
"""批量构建记忆"""
if not fused_messages:
return {"memory_count": 0, "memories": []}
try:
total_memories = []
total_memory_count = 0
# 构建融合后的文本
batch_input_text = await self._build_fused_conversation_text(fused_messages)
if not batch_input_text:
logger.warning("无法构建融合文本,尝试单独构建")
# 备选方案:分别构建
return await self._fallback_individual_build(fused_messages)
# 创建批量上下文
batch_context = {
"user_id": "hippocampus_batch_sampler",
"timestamp": time.time(),
"source": "hippocampus_batch_sampling",
"message_groups_count": len(fused_messages),
"total_messages": sum(len(group) for group in fused_messages),
"sample_count": len(time_samples),
"is_hippocampus_sample": True,
"bypass_value_threshold": True,
"optimization_mode": "batch_fusion",
}
logger.debug(f"批量构建记忆,文本长度: {len(batch_input_text)}")
# 一次性构建记忆
memories = await self.memory_system.build_memory_from_conversation(
conversation_text=batch_input_text, context=batch_context, timestamp=time.time(), bypass_interval=True
)
if memories:
memory_count = len(memories)
self.success_count += 1
total_memory_count += memory_count
total_memories.extend(memories)
logger.info(f"✅ 批量海马体采样成功构建 {memory_count} 条记忆")
else:
logger.debug("批量海马体采样未生成有效记忆")
# 记录采样结果
result = {
"timestamp": time.time(),
"memory_count": total_memory_count,
"message_groups_count": len(fused_messages),
"total_messages": sum(len(group) for group in fused_messages),
"text_preview": batch_input_text[:200] + "..." if len(batch_input_text) > 200 else batch_input_text,
"memory_types": [m.memory_type.value for m in total_memories],
}
self.last_sample_results.append(result)
# 限制结果历史长度
if len(self.last_sample_results) > 10:
self.last_sample_results.pop(0)
return {"memory_count": total_memory_count, "memories": total_memories, "result": result}
except Exception as e:
logger.error(f"批量构建记忆失败: {e}")
return {"memory_count": 0, "error": str(e)}
async def _build_fused_conversation_text(self, fused_messages: list[list[dict[str, Any]]]) -> str:
"""构建融合后的对话文本"""
try:
conversation_parts = []
for group_idx, message_group in enumerate(fused_messages):
if not message_group:
continue
# 为每个消息组添加分隔符
group_header = f"\n=== 对话片段 {group_idx + 1} ==="
conversation_parts.append(group_header)
# 构建可读消息
group_text = await build_readable_messages(
message_group,
merge_messages=True,
timestamp_mode="normal_no_YMD",
replace_bot_name=False,
)
if group_text and len(group_text.strip()) > 10:
conversation_parts.append(group_text.strip())
return "\n".join(conversation_parts)
except Exception as e:
logger.error(f"构建融合文本失败: {e}")
return ""
async def _fallback_individual_build(self, fused_messages: list[list[dict[str, Any]]]) -> dict[str, Any]:
"""备选方案:单独构建每个消息组"""
total_memories = []
total_count = 0
for group in fused_messages[:5]: # 限制最多5组
try:
memories = await self.build_memory_from_samples(group, time.time())
if memories:
total_memories.extend(memories)
total_count += len(memories)
except Exception as e:
logger.debug(f"单独构建失败: {e}")
return {"memory_count": total_count, "memories": total_memories, "fallback_mode": True}
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
"""处理单个时间戳采样(保留作为备选方法)"""
try:
# 收集消息样本
messages = await self.collect_message_samples(target_timestamp)
if not messages:
return None
# 构建记忆
result = await self.build_memory_from_samples(messages, target_timestamp)
return result
except Exception as e:
logger.debug(f"处理时间戳采样失败 {target_timestamp}: {e}")
return None
def should_sample(self) -> bool:
"""检查是否应该进行采样"""
current_time = time.time()
# 检查时间间隔
if current_time - self.last_sample_time < self.config.sample_interval:
return False
# 检查是否已初始化
if not self.memory_builder_model:
logger.warning("海马体采样器未初始化")
return False
return True
async def start_background_sampling(self):
"""启动后台采样"""
if self.is_running:
logger.warning("海马体后台采样已在运行")
return
self.is_running = True
logger.info("🚀 启动海马体后台采样任务")
try:
while self.is_running:
try:
# 执行采样周期
result = await self.perform_sampling_cycle()
# 如果是跳过状态,短暂睡眠
if result.get("status") == "skipped":
await asyncio.sleep(60) # 1分钟后重试
else:
# 正常等待下一个采样间隔
await asyncio.sleep(self.config.sample_interval)
except Exception as e:
logger.error(f"海马体后台采样异常: {e}")
await asyncio.sleep(300) # 异常时等待5分钟
except asyncio.CancelledError:
logger.info("海马体后台采样任务被取消")
finally:
self.is_running = False
def stop_background_sampling(self):
"""停止后台采样"""
self.is_running = False
logger.info("🛑 停止海马体后台采样任务")
def get_sampling_stats(self) -> dict[str, Any]:
"""获取采样统计信息"""
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
# 计算最近的平均数据
recent_avg_messages = 0
recent_avg_memory_count = 0
if self.last_sample_results:
recent_results = self.last_sample_results[-5:] # 最近5次
recent_avg_messages = sum(r.get("total_messages", 0) for r in recent_results) / len(recent_results)
recent_avg_memory_count = sum(r.get("memory_count", 0) for r in recent_results) / len(recent_results)
return {
"is_running": self.is_running,
"sample_count": self.sample_count,
"success_count": self.success_count,
"success_rate": f"{success_rate:.1f}%",
"last_sample_time": self.last_sample_time,
"optimization_mode": "batch_fusion", # 显示优化模式
"performance_metrics": {
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
"fusion_efficiency": f"{(recent_avg_messages / max(recent_avg_memory_count, 1)):.1f}x"
if recent_avg_messages > 0
else "N/A",
},
"config": {
"sample_interval": self.config.sample_interval,
"total_samples": self.config.total_samples,
"recent_weight": f"{self.config.recent_weight:.1%}",
"distant_weight": f"{self.config.distant_weight:.1%}",
"max_concurrent": 5, # 批量模式并发数
"fusion_time_gap": "30分钟", # 消息融合时间间隔
},
"recent_results": self.last_sample_results[-5:], # 最近5次结果
}
# 全局海马体采样器实例
_hippocampus_sampler: HippocampusSampler | None = None
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
"""获取全局海马体采样器实例"""
global _hippocampus_sampler
if _hippocampus_sampler is None:
_hippocampus_sampler = HippocampusSampler(memory_system)
return _hippocampus_sampler
async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
"""初始化全局海马体采样器"""
sampler = get_hippocampus_sampler(memory_system)
await sampler.initialize()
return sampler

View File

@@ -1,238 +0,0 @@
"""
记忆激活器
记忆系统的激活器组件
"""
import difflib
from datetime import datetime
import orjson
from json_repair import repair_json
from src.chat.memory_system.memory_manager import MemoryResult
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> list:
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Memory Activator Prompt ---
memory_activator_prompt = """
你是一个记忆分析器,你需要根据以下信息来进行记忆检索
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆检索的触发词
聊天记录:
{obs_info_text}
用户想要回复的消息:
{target_message}
历史关键词(请避免重复提取这些关键词):
{cached_keywords}
请输出一个json格式包含以下字段
{{
"keywords": ["关键词1", "关键词2", "关键词3",......]
}}
不要输出其他多余内容只输出json格式就好
"""
Prompt(memory_activator_prompt, "memory_activator_prompt")
class MemoryActivator:
"""记忆激活器"""
def __init__(self):
self.key_words_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
self.running_memory = []
self.cached_keywords = set() # 用于缓存历史关键词
self.last_memory_query_time = 0 # 上次查询记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
# 将缓存的关键词转换为字符串用于prompt
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
prompt = await global_prompt_manager.format_prompt(
"memory_activator_prompt",
obs_info_text=chat_history_prompt,
target_message=target_message,
cached_keywords=cached_keywords_str,
)
# 生成关键词
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
prompt, temperature=0.5
)
keywords = list(get_keywords_from_json(response))
# 更新关键词缓存
if keywords:
# 限制缓存大小最多保留10个关键词
if len(self.cached_keywords) > 10:
# 转换为列表,移除最早的关键词
cached_list = list(self.cached_keywords)
self.cached_keywords = set(cached_list[-8:])
# 添加新的关键词到缓存
self.cached_keywords.update(keywords)
logger.debug(f"记忆关键词: {self.cached_keywords}")
# 使用记忆系统获取相关记忆
memory_results = await self._query_unified_memory(keywords, target_message)
# 处理和记忆结果
if memory_results:
for result in memory_results:
# 检查是否已存在相似内容的记忆
exists = any(
m["content"] == result.content
or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
for m in self.running_memory
)
if not exists:
memory_entry = {
"topic": result.memory_type,
"content": result.content,
"timestamp": datetime.fromtimestamp(result.timestamp).isoformat(),
"duration": 1,
"confidence": result.confidence,
"importance": result.importance,
"source": result.source,
"relevance_score": result.relevance_score, # 添加相关度评分
}
self.running_memory.append(memory_entry)
logger.debug(f"添加新记忆: {result.memory_type} - {result.content}")
# 激活时所有已有记忆的duration+1达到3则移除
for m in self.running_memory[:]:
m["duration"] = m.get("duration", 1) + 1
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
# 限制同时加载的记忆条数最多保留最后5条
if len(self.running_memory) > 5:
self.running_memory = self.running_memory[-5:]
return self.running_memory
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
"""查询统一记忆系统"""
try:
# 使用记忆系统
from src.chat.memory_system.memory_system import get_memory_system
memory_system = get_memory_system()
if not memory_system or memory_system.status.value != "ready":
logger.warning("记忆系统未就绪")
return []
# 构建查询上下文
context = {"keywords": keywords, "query_intent": "conversation_response"}
# 查询记忆
memories = await memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="global", # 使用全局作用域
context=context,
limit=5,
)
# 转换为 MemoryResult 格式
memory_results = []
for memory in memories:
result = MemoryResult(
content=memory.display,
memory_type=memory.memory_type.value,
confidence=memory.metadata.confidence.value,
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="unified_memory",
relevance_score=memory.metadata.relevance_score,
)
memory_results.append(result)
logger.debug(f"统一记忆查询返回 {len(memory_results)} 条结果")
return memory_results
except Exception as e:
logger.error(f"查询统一记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
"""
获取即时记忆 - 兼容原有接口(使用统一存储)
"""
try:
# 使用统一存储系统获取相关记忆
from src.chat.memory_system.memory_system import get_memory_system
memory_system = get_memory_system()
if not memory_system or memory_system.status.value != "ready":
return None
context = {"query_intent": "instant_response", "chat_id": chat_id}
memories = await memory_system.retrieve_relevant_memories(
query_text=target_message, user_id="global", context=context, limit=1
)
if memories:
return memories[0].display
return None
except Exception as e:
logger.error(f"获取即时记忆失败: {e}")
return None
def clear_cache(self):
"""清除缓存"""
self.cached_keywords.clear()
self.running_memory.clear()
logger.debug("记忆激活器缓存已清除")
# 创建全局实例
memory_activator = MemoryActivator()
init_prompt()

File diff suppressed because it is too large Load Diff

View File

@@ -1,647 +0,0 @@
"""
结构化记忆单元设计
实现高质量、结构化的记忆单元,符合文档设计规范
"""
import hashlib
import time
import uuid
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import numpy as np
import orjson
from src.common.logger import get_logger
logger = get_logger(__name__)
class MemoryType(Enum):
"""记忆类型分类"""
PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等)
EVENT = "event" # 事件(重要经历、约会等)
PREFERENCE = "preference" # 偏好(喜好、习惯等)
OPINION = "opinion" # 观点(对事物的看法)
RELATIONSHIP = "relationship" # 关系(与他人的关系)
EMOTION = "emotion" # 情感状态
KNOWLEDGE = "knowledge" # 知识信息
SKILL = "skill" # 技能能力
GOAL = "goal" # 目标计划
EXPERIENCE = "experience" # 经验教训
CONTEXTUAL = "contextual" # 上下文信息
class ConfidenceLevel(Enum):
"""置信度等级"""
LOW = 1 # 低置信度,可能不准确
MEDIUM = 2 # 中等置信度,有一定依据
HIGH = 3 # 高置信度,有明确来源
VERIFIED = 4 # 已验证,非常可靠
class ImportanceLevel(Enum):
"""重要性等级"""
LOW = 1 # 低重要性,普通信息
NORMAL = 2 # 一般重要性,日常信息
HIGH = 3 # 高重要性,重要信息
CRITICAL = 4 # 关键重要性,核心信息
@dataclass
class ContentStructure:
"""主谓宾结构,包含自然语言描述"""
subject: str | list[str]
predicate: str
object: str | dict
display: str = ""
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ContentStructure":
"""从字典创建实例"""
return cls(
subject=data.get("subject", ""),
predicate=data.get("predicate", ""),
object=data.get("object", ""),
display=data.get("display", ""),
)
def to_subject_list(self) -> list[str]:
"""将主语转换为列表形式"""
if isinstance(self.subject, list):
return [s for s in self.subject if isinstance(s, str) and s.strip()]
if isinstance(self.subject, str) and self.subject.strip():
return [self.subject.strip()]
return []
def __str__(self) -> str:
"""字符串表示"""
if self.display:
return self.display
subjects = "".join(self.to_subject_list()) or str(self.subject)
object_str = self.object if isinstance(self.object, str) else str(self.object)
return f"{subjects} {self.predicate} {object_str}".strip()
@dataclass
class MemoryMetadata:
"""记忆元数据 - 简化版本"""
# 基础信息
memory_id: str # 唯一标识符
user_id: str # 用户ID
chat_id: str | None = None # 聊天ID群聊或私聊
# 时间信息
created_at: float = 0.0 # 创建时间戳
last_accessed: float = 0.0 # 最后访问时间
last_modified: float = 0.0 # 最后修改时间
# 激活频率管理
last_activation_time: float = 0.0 # 最后激活时间
activation_frequency: int = 0 # 激活频率(单位时间内的激活次数)
total_activations: int = 0 # 总激活次数
# 统计信息
access_count: int = 0 # 访问次数
relevance_score: float = 0.0 # 相关度评分
# 信心和重要性(核心字段)
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM
importance: ImportanceLevel = ImportanceLevel.NORMAL
# 遗忘机制相关
forgetting_threshold: float = 0.0 # 遗忘阈值(动态计算)
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
# 来源信息
source_context: str | None = None # 来源上下文片段
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
source: str | None = None
def __post_init__(self):
"""后初始化处理"""
if not self.memory_id:
self.memory_id = str(uuid.uuid4())
current_time = time.time()
if self.created_at == 0:
self.created_at = current_time
if self.last_accessed == 0:
self.last_accessed = current_time
if self.last_modified == 0:
self.last_modified = current_time
if self.last_activation_time == 0:
self.last_activation_time = current_time
if self.last_forgetting_check == 0:
self.last_forgetting_check = current_time
# 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步
if not getattr(self, "source", None) and getattr(self, "source_context", None):
try:
self.source = str(self.source_context)
except Exception:
self.source = None
# 如果有 source 字段但 source_context 为空,也同步回去
if not getattr(self, "source_context", None) and getattr(self, "source", None):
try:
self.source_context = str(self.source)
except Exception:
self.source_context = None
def update_access(self):
"""更新访问信息"""
current_time = time.time()
self.last_accessed = current_time
self.access_count += 1
self.total_activations += 1
# 更新激活频率
self._update_activation_frequency(current_time)
def _update_activation_frequency(self, current_time: float):
"""更新激活频率24小时内的激活次数"""
# 如果超过24小时重置激活频率
if current_time - self.last_activation_time > 86400: # 24小时 = 86400秒
self.activation_frequency = 1
else:
self.activation_frequency += 1
self.last_activation_time = current_time
def update_relevance(self, new_score: float):
"""更新相关度评分"""
self.relevance_score = max(0.0, min(1.0, new_score))
self.last_modified = time.time()
def calculate_forgetting_threshold(self) -> float:
"""计算遗忘阈值(天数)"""
# 基础天数
base_days = 30.0
# 重要性权重 (1-4 -> 0-3)
importance_weight = (self.importance.value - 1) * 15 # 0, 15, 30, 45
# 置信度权重 (1-4 -> 0-3)
confidence_weight = (self.confidence.value - 1) * 10 # 0, 10, 20, 30
# 激活频率权重每5次激活增加1天
frequency_weight = min(self.activation_frequency, 20) * 0.5 # 最多10天
# 计算最终阈值
threshold = base_days + importance_weight + confidence_weight + frequency_weight
# 设置最小和最大阈值
return max(7.0, min(threshold, 365.0)) # 7天到1年之间
def should_forget(self, current_time: float | None = None) -> bool:
"""判断是否应该遗忘"""
if current_time is None:
current_time = time.time()
# 计算遗忘阈值
self.forgetting_threshold = self.calculate_forgetting_threshold()
# 计算距离最后激活的时间
days_since_activation = (current_time - self.last_activation_time) / 86400
return days_since_activation > self.forgetting_threshold
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
"""判断是否处于休眠状态(长期未激活)"""
if current_time is None:
current_time = time.time()
days_since_last_access = (current_time - self.last_accessed) / 86400
return days_since_last_access > inactive_days
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"memory_id": self.memory_id,
"user_id": self.user_id,
"chat_id": self.chat_id,
"created_at": self.created_at,
"last_accessed": self.last_accessed,
"last_modified": self.last_modified,
"last_activation_time": self.last_activation_time,
"activation_frequency": self.activation_frequency,
"total_activations": self.total_activations,
"access_count": self.access_count,
"relevance_score": self.relevance_score,
"confidence": self.confidence.value,
"importance": self.importance.value,
"forgetting_threshold": self.forgetting_threshold,
"last_forgetting_check": self.last_forgetting_check,
"source_context": self.source_context,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata":
"""从字典创建实例"""
return cls(
memory_id=data.get("memory_id", ""),
user_id=data.get("user_id", ""),
chat_id=data.get("chat_id"),
created_at=data.get("created_at", 0),
last_accessed=data.get("last_accessed", 0),
last_modified=data.get("last_modified", 0),
last_activation_time=data.get("last_activation_time", 0),
activation_frequency=data.get("activation_frequency", 0),
total_activations=data.get("total_activations", 0),
access_count=data.get("access_count", 0),
relevance_score=data.get("relevance_score", 0.0),
confidence=ConfidenceLevel(data.get("confidence", ConfidenceLevel.MEDIUM.value)),
importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)),
forgetting_threshold=data.get("forgetting_threshold", 0.0),
last_forgetting_check=data.get("last_forgetting_check", 0),
source_context=data.get("source_context"),
)
@dataclass
class MemoryChunk:
"""结构化记忆单元 - 核心数据结构"""
# 元数据
metadata: MemoryMetadata
# 内容结构
content: ContentStructure # 主谓宾结构
memory_type: MemoryType # 记忆类型
# 扩展信息
keywords: list[str] = field(default_factory=list) # 关键词列表
tags: list[str] = field(default_factory=list) # 标签列表
categories: list[str] = field(default_factory=list) # 分类列表
# 语义信息
embedding: list[float] | None = None # 语义向量
semantic_hash: str | None = None # 语义哈希值
# 关联信息
related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表
temporal_context: dict[str, Any] | None = None # 时间上下文
def __post_init__(self):
"""后初始化处理"""
if self.embedding and len(self.embedding) > 0:
self._generate_semantic_hash()
def _generate_semantic_hash(self):
"""生成语义哈希值"""
if not self.embedding:
return
try:
# 使用向量和内容生成稳定的哈希
content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}"
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
hash_input = f"{content_str}|{embedding_str}"
hash_object = hashlib.sha256(hash_input.encode("utf-8"))
self.semantic_hash = hash_object.hexdigest()[:16]
except Exception as e:
logger.warning(f"生成语义哈希失败: {e}")
self.semantic_hash = str(uuid.uuid4())[:16]
@property
def memory_id(self) -> str:
"""获取记忆ID"""
return self.metadata.memory_id
@property
def user_id(self) -> str:
"""获取用户ID"""
return self.metadata.user_id
@property
def text_content(self) -> str:
"""获取文本内容优先使用display"""
return str(self.content)
@property
def display(self) -> str:
"""获取展示文本"""
return self.content.display or str(self.content)
@property
def subjects(self) -> list[str]:
"""获取主语列表"""
return self.content.to_subject_list()
def update_access(self):
"""更新访问信息"""
self.metadata.update_access()
def update_relevance(self, new_score: float):
"""更新相关度评分"""
self.metadata.update_relevance(new_score)
def should_forget(self, current_time: float | None = None) -> bool:
"""判断是否应该遗忘"""
return self.metadata.should_forget(current_time)
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
"""判断是否处于休眠状态(长期未激活)"""
return self.metadata.is_dormant(current_time, inactive_days)
def calculate_forgetting_threshold(self) -> float:
"""计算遗忘阈值(天数)"""
return self.metadata.calculate_forgetting_threshold()
def add_keyword(self, keyword: str):
"""添加关键词"""
if keyword and keyword not in self.keywords:
self.keywords.append(keyword.strip())
def add_tag(self, tag: str):
"""添加标签"""
if tag and tag not in self.tags:
self.tags.append(tag.strip())
def add_category(self, category: str):
"""添加分类"""
if category and category not in self.categories:
self.categories.append(category.strip())
def add_related_memory(self, memory_id: str):
"""添加关联记忆"""
if memory_id and memory_id not in self.related_memories:
self.related_memories.append(memory_id)
def set_embedding(self, embedding: list[float]):
"""设置语义向量"""
self.embedding = embedding
self._generate_semantic_hash()
def calculate_similarity(self, other: "MemoryChunk") -> float:
"""计算与另一个记忆块的相似度"""
if not self.embedding or not other.embedding:
return 0.0
try:
# 计算余弦相似度
v1 = np.array(self.embedding)
v2 = np.array(other.embedding)
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0.0
similarity = dot_product / (norm1 * norm2)
return max(0.0, min(1.0, similarity))
except Exception as e:
logger.warning(f"计算记忆相似度失败: {e}")
return 0.0
def to_dict(self) -> dict[str, Any]:
"""转换为完整的字典格式"""
return {
"metadata": self.metadata.to_dict(),
"content": self.content.to_dict(),
"memory_type": self.memory_type.value,
"keywords": self.keywords,
"tags": self.tags,
"categories": self.categories,
"embedding": self.embedding,
"semantic_hash": self.semantic_hash,
"related_memories": self.related_memories,
"temporal_context": self.temporal_context,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk":
"""从字典创建实例"""
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
content = ContentStructure.from_dict(data.get("content", {}))
chunk = cls(
metadata=metadata,
content=content,
memory_type=MemoryType(data.get("memory_type", MemoryType.CONTEXTUAL.value)),
keywords=data.get("keywords", []),
tags=data.get("tags", []),
categories=data.get("categories", []),
embedding=data.get("embedding"),
semantic_hash=data.get("semantic_hash"),
related_memories=data.get("related_memories", []),
temporal_context=data.get("temporal_context"),
)
return chunk
def to_json(self) -> str:
"""转换为JSON字符串"""
return orjson.dumps(self.to_dict()).decode("utf-8")
@classmethod
def from_json(cls, json_str: str) -> "MemoryChunk":
"""从JSON字符串创建实例"""
try:
data = orjson.loads(json_str)
return cls.from_dict(data)
except Exception as e:
logger.error(f"从JSON创建记忆块失败: {e}")
raise
def is_similar_to(self, other: "MemoryChunk", threshold: float = 0.8) -> bool:
"""判断是否与另一个记忆块相似"""
if self.semantic_hash and other.semantic_hash:
return self.semantic_hash == other.semantic_hash
return self.calculate_similarity(other) >= threshold
def merge_with(self, other: "MemoryChunk") -> bool:
"""与另一个记忆块合并(如果相似)"""
if not self.is_similar_to(other):
return False
try:
# 合并关键词
for keyword in other.keywords:
self.add_keyword(keyword)
# 合并标签
for tag in other.tags:
self.add_tag(tag)
# 合并分类
for category in other.categories:
self.add_category(category)
# 合并关联记忆
for memory_id in other.related_memories:
self.add_related_memory(memory_id)
# 更新元数据
self.metadata.last_modified = time.time()
self.metadata.access_count += other.metadata.access_count
self.metadata.relevance_score = max(self.metadata.relevance_score, other.metadata.relevance_score)
# 更新置信度
if other.metadata.confidence.value > self.metadata.confidence.value:
self.metadata.confidence = other.metadata.confidence
# 更新重要性
if other.metadata.importance.value > self.metadata.importance.value:
self.metadata.importance = other.metadata.importance
logger.debug(f"记忆块 {self.memory_id} 合并了记忆块 {other.memory_id}")
return True
except Exception as e:
logger.error(f"合并记忆块失败: {e}")
return False
def __str__(self) -> str:
"""字符串表示"""
type_emoji = {
MemoryType.PERSONAL_FACT: "👤",
MemoryType.EVENT: "📅",
MemoryType.PREFERENCE: "❤️",
MemoryType.OPINION: "💭",
MemoryType.RELATIONSHIP: "👥",
MemoryType.EMOTION: "😊",
MemoryType.KNOWLEDGE: "📚",
MemoryType.SKILL: "🛠️",
MemoryType.GOAL: "🎯",
MemoryType.EXPERIENCE: "💡",
MemoryType.CONTEXTUAL: "📝",
}
emoji = type_emoji.get(self.memory_type, "📝")
confidence_icon = "" * self.metadata.confidence.value
importance_icon = "" * self.metadata.importance.value
return f"{emoji} [{self.memory_type.value}] {self.display} {confidence_icon} {importance_icon}"
def __repr__(self) -> str:
"""调试表示"""
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str:
"""根据主谓宾生成自然语言描述"""
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
subject_part = "".join(subjects_clean) if subjects_clean else "对话参与者"
if isinstance(obj, dict):
object_candidates = []
for key, value in obj.items():
if isinstance(value, str | int | float):
object_candidates.append(f"{key}:{value}")
elif isinstance(value, list):
compact = "".join(str(item) for item in value[:3])
object_candidates.append(f"{key}:{compact}")
object_part = "".join(object_candidates) if object_candidates else str(obj)
else:
object_part = str(obj).strip()
predicate_clean = predicate.strip()
if not predicate_clean:
return f"{subject_part} {object_part}".strip()
if object_part:
return f"{subject_part}{predicate_clean}{object_part}".strip()
return f"{subject_part}{predicate_clean}".strip()
def create_memory_chunk(
user_id: str,
subject: str | list[str],
predicate: str,
obj: str | dict,
memory_type: MemoryType,
chat_id: str | None = None,
source_context: str | None = None,
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
display: str | None = None,
**kwargs,
) -> MemoryChunk:
"""便捷的内存块创建函数"""
metadata = MemoryMetadata(
memory_id="",
user_id=user_id,
chat_id=chat_id,
created_at=time.time(),
last_accessed=0,
last_modified=0,
confidence=confidence,
importance=importance,
source_context=source_context,
)
subjects: list[str]
if isinstance(subject, list):
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
subject_payload: str | list[str] = subjects
else:
cleaned = subject.strip() if isinstance(subject, str) else ""
subjects = [cleaned] if cleaned else []
subject_payload = cleaned
display_text = display or _build_display_text(subjects, predicate, obj)
content = ContentStructure(subject=subject_payload, predicate=predicate, object=obj, display=display_text)
chunk = MemoryChunk(metadata=metadata, content=content, memory_type=memory_type, **kwargs)
return chunk
@dataclass
class MessageCollection:
"""消息集合数据结构"""
collection_id: str = field(default_factory=lambda: str(uuid.uuid4()))
chat_id: str | None = None # 聊天ID群聊或私聊
messages: list[str] = field(default_factory=list)
combined_text: str = ""
created_at: float = field(default_factory=time.time)
embedding: list[float] | None = None
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"collection_id": self.collection_id,
"chat_id": self.chat_id,
"messages": self.messages,
"combined_text": self.combined_text,
"created_at": self.created_at,
"embedding": self.embedding,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MessageCollection":
"""从字典创建实例"""
return cls(
collection_id=data.get("collection_id", str(uuid.uuid4())),
chat_id=data.get("chat_id"),
messages=data.get("messages", []),
combined_text=data.get("combined_text", ""),
created_at=data.get("created_at", time.time()),
embedding=data.get("embedding"),
)

View File

@@ -1,355 +0,0 @@
"""
智能记忆遗忘引擎
基于重要程度、置信度和激活频率的智能遗忘机制
"""
import asyncio
import time
from dataclasses import dataclass
from datetime import datetime
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ForgettingStats:
"""遗忘统计信息"""
total_checked: int = 0
marked_for_forgetting: int = 0
actually_forgotten: int = 0
dormant_memories: int = 0
last_check_time: float = 0.0
check_duration: float = 0.0
@dataclass
class ForgettingConfig:
"""遗忘引擎配置"""
# 检查频率配置
check_interval_hours: int = 24 # 定期检查间隔(小时)
batch_size: int = 100 # 批处理大小
# 遗忘阈值配置
base_forgetting_days: float = 30.0 # 基础遗忘天数
min_forgetting_days: float = 7.0 # 最小遗忘天数
max_forgetting_days: float = 365.0 # 最大遗忘天数
# 重要程度权重
critical_importance_bonus: float = 45.0 # 关键重要性额外天数
high_importance_bonus: float = 30.0 # 高重要性额外天数
normal_importance_bonus: float = 15.0 # 一般重要性额外天数
low_importance_bonus: float = 0.0 # 低重要性额外天数
# 置信度权重
verified_confidence_bonus: float = 30.0 # 已验证置信度额外天数
high_confidence_bonus: float = 20.0 # 高置信度额外天数
medium_confidence_bonus: float = 10.0 # 中等置信度额外天数
low_confidence_bonus: float = 0.0 # 低置信度额外天数
# 激活频率权重
activation_frequency_weight: float = 0.5 # 每次激活增加的天数权重
max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数
# 休眠配置
dormant_threshold_days: int = 90 # 休眠状态判定天数
force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数
class MemoryForgettingEngine:
"""智能记忆遗忘引擎"""
def __init__(self, config: ForgettingConfig | None = None):
self.config = config or ForgettingConfig()
self.stats = ForgettingStats()
self._last_forgetting_check = 0.0
self._forgetting_lock = asyncio.Lock()
logger.info("MemoryForgettingEngine 初始化完成")
def calculate_forgetting_threshold(self, memory: MemoryChunk) -> float:
"""
计算记忆的遗忘阈值(天数)
Args:
memory: 记忆块
Returns:
遗忘阈值(天数)
"""
# 基础天数
threshold = self.config.base_forgetting_days
# 重要性权重
importance = memory.metadata.importance
if importance == ImportanceLevel.CRITICAL:
threshold += self.config.critical_importance_bonus
elif importance == ImportanceLevel.HIGH:
threshold += self.config.high_importance_bonus
elif importance == ImportanceLevel.NORMAL:
threshold += self.config.normal_importance_bonus
# LOW 级别不增加额外天数
# 置信度权重
confidence = memory.metadata.confidence
if confidence == ConfidenceLevel.VERIFIED:
threshold += self.config.verified_confidence_bonus
elif confidence == ConfidenceLevel.HIGH:
threshold += self.config.high_confidence_bonus
elif confidence == ConfidenceLevel.MEDIUM:
threshold += self.config.medium_confidence_bonus
# LOW 级别不增加额外天数
# 激活频率权重
frequency_bonus = min(
memory.metadata.activation_frequency * self.config.activation_frequency_weight,
self.config.max_frequency_bonus,
)
threshold += frequency_bonus
# 确保在合理范围内
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断记忆是否应该被遗忘
Args:
memory: 记忆块
current_time: 当前时间戳
Returns:
是否应该遗忘
"""
if current_time is None:
current_time = time.time()
# 关键重要性的记忆永不遗忘
if memory.metadata.importance == ImportanceLevel.CRITICAL:
return False
# 计算遗忘阈值
forgetting_threshold = self.calculate_forgetting_threshold(memory)
# 计算距离最后激活的时间
days_since_activation = (current_time - memory.metadata.last_activation_time) / 86400
# 判断是否超过阈值
should_forget = days_since_activation > forgetting_threshold
if should_forget:
logger.debug(
f"记忆 {memory.memory_id[:8]} 触发遗忘条件: "
f"重要性={memory.metadata.importance.name}, "
f"置信度={memory.metadata.confidence.name}, "
f"激活频率={memory.metadata.activation_frequency}, "
f"阈值={forgetting_threshold:.1f}天, "
f"未激活天数={days_since_activation:.1f}"
)
return should_forget
def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断记忆是否处于休眠状态
Args:
memory: 记忆块
current_time: 当前时间戳
Returns:
是否处于休眠状态
"""
return memory.is_dormant(current_time, self.config.dormant_threshold_days)
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断是否应该强制遗忘休眠记忆
Args:
memory: 记忆块
current_time: 当前时间戳
Returns:
是否应该强制遗忘
"""
if current_time is None:
current_time = time.time()
# 只有非关键重要性的记忆才会被强制遗忘
if memory.metadata.importance == ImportanceLevel.CRITICAL:
return False
days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400
return days_since_last_access > self.config.force_forget_dormant_days
async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]:
"""
检查记忆列表,识别需要遗忘的记忆
Args:
memories: 记忆块列表
Returns:
(普通遗忘列表, 强制遗忘列表)
"""
start_time = time.time()
current_time = start_time
normal_forgetting_ids = []
force_forgetting_ids = []
self.stats.total_checked = len(memories)
self.stats.last_check_time = current_time
for memory in memories:
try:
# 检查休眠状态
if self.is_dormant_memory(memory, current_time):
self.stats.dormant_memories += 1
# 检查是否应该强制遗忘休眠记忆
if self.should_force_forget_dormant(memory, current_time):
force_forgetting_ids.append(memory.memory_id)
logger.debug(f"休眠记忆 {memory.memory_id[:8]} 被标记为强制遗忘")
continue
# 检查普通遗忘条件
if self.should_forget_memory(memory, current_time):
normal_forgetting_ids.append(memory.memory_id)
self.stats.marked_for_forgetting += 1
except Exception as e:
logger.warning(f"检查记忆 {memory.memory_id[:8]} 遗忘状态失败: {e}")
continue
self.stats.check_duration = time.time() - start_time
logger.info(
f"遗忘检查完成 | 总数={self.stats.total_checked}, "
f"标记遗忘={len(normal_forgetting_ids)}, "
f"强制遗忘={len(force_forgetting_ids)}, "
f"休眠={self.stats.dormant_memories}, "
f"耗时={self.stats.check_duration:.3f}s"
)
return normal_forgetting_ids, force_forgetting_ids
async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]:
"""
执行完整的遗忘检查流程
Args:
memories: 记忆块列表
Returns:
检查结果统计
"""
async with self._forgetting_lock:
normal_forgetting, force_forgetting = await self.check_memories_for_forgetting(memories)
# 更新统计
self.stats.actually_forgotten = len(normal_forgetting) + len(force_forgetting)
return {
"normal_forgetting": normal_forgetting,
"force_forgetting": force_forgetting,
"stats": {
"total_checked": self.stats.total_checked,
"marked_for_forgetting": self.stats.marked_for_forgetting,
"actually_forgotten": self.stats.actually_forgotten,
"dormant_memories": self.stats.dormant_memories,
"check_duration": self.stats.check_duration,
"last_check_time": self.stats.last_check_time,
},
}
def is_forgetting_check_needed(self) -> bool:
"""检查是否需要进行遗忘检查"""
current_time = time.time()
hours_since_last_check = (current_time - self._last_forgetting_check) / 3600
return hours_since_last_check >= self.config.check_interval_hours
async def schedule_periodic_check(self, memories_provider, enable_auto_cleanup: bool = True):
"""
定期执行遗忘检查(可以在后台任务中调用)
Args:
memories_provider: 提供记忆列表的函数
enable_auto_cleanup: 是否启用自动清理
"""
if not self.is_forgetting_check_needed():
return
try:
logger.info("开始执行定期遗忘检查...")
# 获取记忆列表
memories = await memories_provider()
if not memories:
logger.debug("无记忆数据需要检查")
return
# 执行遗忘检查
result = await self.perform_forgetting_check(memories)
# 如果启用自动清理,执行实际的遗忘操作
if enable_auto_cleanup and (result["normal_forgetting"] or result["force_forgetting"]):
logger.info(
f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆"
)
# 这里可以调用实际的删除逻辑
# await self.cleanup_forgotten_memories(result["normal_forgetting"] + result["force_forgetting"])
self._last_forgetting_check = time.time()
except Exception as e:
logger.error(f"定期遗忘检查失败: {e}", exc_info=True)
def get_forgetting_stats(self) -> dict[str, any]:
"""获取遗忘统计信息"""
return {
"total_checked": self.stats.total_checked,
"marked_for_forgetting": self.stats.marked_for_forgetting,
"actually_forgotten": self.stats.actually_forgotten,
"dormant_memories": self.stats.dormant_memories,
"last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat()
if self.stats.last_check_time
else None,
"last_check_duration": self.stats.check_duration,
"config": {
"check_interval_hours": self.config.check_interval_hours,
"base_forgetting_days": self.config.base_forgetting_days,
"min_forgetting_days": self.config.min_forgetting_days,
"max_forgetting_days": self.config.max_forgetting_days,
},
}
def reset_stats(self):
"""重置统计信息"""
self.stats = ForgettingStats()
logger.debug("遗忘统计信息已重置")
def update_config(self, **kwargs):
"""更新配置"""
for key, value in kwargs.items():
if hasattr(self.config, key):
setattr(self.config, key, value)
logger.debug(f"遗忘配置更新: {key} = {value}")
else:
logger.warning(f"未知的配置项: {key}")
# 创建全局遗忘引擎实例
memory_forgetting_engine = MemoryForgettingEngine()
def get_memory_forgetting_engine() -> MemoryForgettingEngine:
"""获取全局遗忘引擎实例"""
return memory_forgetting_engine

View File

@@ -1,120 +0,0 @@
"""记忆格式化工具
提供统一的记忆块格式化函数,供构建 Prompt 时使用。
当前使用的函数: format_memories_bracket_style
输入: list[dict] 其中每个元素包含:
- display: str 记忆可读内容
- memory_type: str 记忆类型 (personal_fact/opinion/preference/event 等)
- metadata: dict 可选,包括
- confidence: 置信度 (str|float)
- importance: 重要度 (str|float)
- timestamp: 时间戳 (float|str)
- source: 来源 (str)
- relevance_score: 相关度 (float)
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
"""
from __future__ import annotations
import time
from collections.abc import Iterable
from typing import Any
def _format_timestamp(ts: Any) -> str:
try:
if ts in (None, ""):
return ""
if isinstance(ts, int | float) and ts > 0:
return time.strftime("%Y-%m-%d %H:%M", time.localtime(float(ts)))
return str(ts)
except Exception:
return ""
def _coerce_str(v: Any) -> str:
if v is None:
return ""
return str(v)
def format_memories_bracket_style(
memories: Iterable[dict[str, Any]] | None,
query_context: str | None = None,
max_items: int = 15,
) -> str:
"""以方括号 + 标注字段的方式格式化记忆列表。
例子输出:
## 相关记忆回顾
- [类型:personal_fact|重要:高|置信:0.83|相关:0.72] 他喜欢黑咖啡 (来源: chat, 2025-10-05 09:30)
Args:
memories: 记忆字典迭代器
query_context: 当前查询/用户的消息,用于在首行提示(可选)
max_items: 最多输出的记忆条数
Returns:
str: 格式化文本;若无内容返回空串
"""
if not memories:
return ""
lines: list[str] = ["## 相关记忆回顾"]
if query_context:
lines.append(f"(与当前消息相关:{query_context[:60]}{'...' if len(query_context) > 60 else ''}")
lines.append("")
count = 0
for mem in memories:
if count >= max_items:
break
if not isinstance(mem, dict):
continue
display = _coerce_str(mem.get("display", "")).strip()
if not display:
continue
mtype = _coerce_str(mem.get("memory_type", "fact")) or "fact"
meta = mem.get("metadata", {}) if isinstance(mem.get("metadata"), dict) else {}
confidence = _coerce_str(meta.get("confidence", ""))
importance = _coerce_str(meta.get("importance", ""))
source = _coerce_str(meta.get("source", ""))
rel = meta.get("relevance_score")
try:
rel_str = f"{float(rel):.2f}" if rel is not None else ""
except Exception:
rel_str = ""
ts = _format_timestamp(meta.get("timestamp"))
# 构建标签段
tags: list[str] = [f"类型:{mtype}"]
if importance:
tags.append(f"重要:{importance}")
if confidence:
tags.append(f"置信:{confidence}")
if rel_str:
tags.append(f"相关:{rel_str}")
tag_block = "|".join(tags)
suffix_parts = []
if source:
suffix_parts.append(source)
if ts:
suffix_parts.append(ts)
suffix = (" (" + ", ".join(suffix_parts) + ")") if suffix_parts else ""
lines.append(f"- [{tag_block}] {display}{suffix}")
count += 1
if count == 0:
return ""
if count >= max_items:
lines.append(f"\n(已截断,仅显示前 {max_items} 条相关记忆)")
return "\n".join(lines)
__all__ = ["format_memories_bracket_style"]

View File

@@ -1,505 +0,0 @@
"""
记忆融合与去重机制
避免记忆碎片化,确保长期记忆库的高质量
"""
import time
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class FusionResult:
"""融合结果"""
original_count: int
fused_count: int
removed_duplicates: int
merged_memories: list[MemoryChunk]
fusion_time: float
details: list[str]
@dataclass
class DuplicateGroup:
"""重复记忆组"""
group_id: str
memories: list[MemoryChunk]
similarity_matrix: list[list[float]]
representative_memory: MemoryChunk | None = None
class MemoryFusionEngine:
"""记忆融合引擎"""
def __init__(self, similarity_threshold: float = 0.85):
self.similarity_threshold = similarity_threshold
self.fusion_stats = {
"total_fusions": 0,
"memories_fused": 0,
"duplicates_removed": 0,
"average_similarity": 0.0,
}
# 融合策略配置
self.fusion_strategies = {
"semantic_similarity": True, # 语义相似性融合
"temporal_proximity": True, # 时间接近性融合
"logical_consistency": True, # 逻辑一致性融合
"confidence_boosting": True, # 置信度提升
"importance_preservation": True, # 重要性保持
}
async def fuse_memories(
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None
) -> list[MemoryChunk]:
"""融合记忆列表"""
start_time = time.time()
try:
if not new_memories:
return []
logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}")
# 1. 检测重复记忆组
duplicate_groups = await self._detect_duplicate_groups(new_memories, existing_memories or [])
if not duplicate_groups:
fusion_time = time.time() - start_time
self._update_fusion_stats(len(new_memories), 0, fusion_time)
logger.info("✅ 记忆融合完成: %d 条记忆,移除 0 条重复", len(new_memories))
return new_memories
# 2. 对每个重复组进行融合
fused_memories = []
removed_count = 0
for group in duplicate_groups:
if len(group.memories) == 1:
# 单个记忆,直接添加
fused_memories.append(group.memories[0])
else:
# 多个记忆,进行融合
fused_memory = await self._fuse_memory_group(group)
if fused_memory:
fused_memories.append(fused_memory)
removed_count += len(group.memories) - 1
# 3. 更新统计
fusion_time = time.time() - start_time
self._update_fusion_stats(len(new_memories), removed_count, fusion_time)
logger.info(f"✅ 记忆融合完成: {len(fused_memories)} 条记忆,移除 {removed_count} 条重复")
return fused_memories
except Exception as e:
logger.error(f"❌ 记忆融合失败: {e}", exc_info=True)
return new_memories # 失败时返回原始记忆
async def _detect_duplicate_groups(
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk]
) -> list[DuplicateGroup]:
"""检测重复记忆组"""
all_memories = new_memories + existing_memories
new_memory_ids = {memory.memory_id for memory in new_memories}
groups = []
processed_ids = set()
for i, memory1 in enumerate(all_memories):
if memory1.memory_id in processed_ids:
continue
# 创建新的重复组
group = DuplicateGroup(group_id=f"group_{len(groups)}", memories=[memory1], similarity_matrix=[[1.0]])
processed_ids.add(memory1.memory_id)
# 寻找相似记忆
for j, memory2 in enumerate(all_memories[i + 1 :], i + 1):
if memory2.memory_id in processed_ids:
continue
similarity = self._calculate_comprehensive_similarity(memory1, memory2)
if similarity >= self.similarity_threshold:
group.memories.append(memory2)
processed_ids.add(memory2.memory_id)
# 更新相似度矩阵
self._update_similarity_matrix(group, memory2, similarity)
if len(group.memories) > 1:
# 选择代表性记忆
group.representative_memory = self._select_representative_memory(group)
groups.append(group)
else:
# 仅包含单条记忆,只有当其来自新记忆列表时保留
if memory1.memory_id in new_memory_ids:
groups.append(group)
logger.debug(f"检测到 {len(groups)} 个重复记忆组")
return groups
def _calculate_comprehensive_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float:
"""计算综合相似度"""
similarity_scores = []
# 1. 语义向量相似度
if self.fusion_strategies["semantic_similarity"]:
semantic_sim = mem1.calculate_similarity(mem2)
similarity_scores.append(("semantic", semantic_sim))
# 2. 文本相似度
text_sim = self._calculate_text_similarity(mem1.text_content, mem2.text_content)
similarity_scores.append(("text", text_sim))
# 3. 关键词重叠度
keyword_sim = self._calculate_keyword_similarity(mem1.keywords, mem2.keywords)
similarity_scores.append(("keyword", keyword_sim))
# 4. 类型一致性
type_consistency = 1.0 if mem1.memory_type == mem2.memory_type else 0.0
similarity_scores.append(("type", type_consistency))
# 5. 时间接近性
if self.fusion_strategies["temporal_proximity"]:
temporal_sim = self._calculate_temporal_similarity(mem1.metadata.created_at, mem2.metadata.created_at)
similarity_scores.append(("temporal", temporal_sim))
# 6. 逻辑一致性
if self.fusion_strategies["logical_consistency"]:
logical_sim = self._calculate_logical_similarity(mem1, mem2)
similarity_scores.append(("logical", logical_sim))
# 计算加权平均相似度
weights = {"semantic": 0.35, "text": 0.25, "keyword": 0.15, "type": 0.10, "temporal": 0.10, "logical": 0.05}
weighted_sum = 0.0
total_weight = 0.0
for score_type, score in similarity_scores:
weight = weights.get(score_type, 0.1)
weighted_sum += weight * score
total_weight += weight
final_similarity = weighted_sum / total_weight if total_weight > 0 else 0.0
logger.debug(f"综合相似度计算: {final_similarity:.3f} - {[(t, f'{s:.3f}') for t, s in similarity_scores]}")
return final_similarity
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度"""
# 简单的词汇重叠度计算
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1 & words2
union = words1 | words2
jaccard_similarity = len(intersection) / len(union)
return jaccard_similarity
def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float:
"""计算关键词相似度"""
if not keywords1 or not keywords2:
return 0.0
set1 = set(k.lower() for k in keywords1) # noqa: C401
set2 = set(k.lower() for k in keywords2) # noqa: C401
intersection = set1 & set2
union = set1 | set2
return len(intersection) / len(union) if union else 0.0
def _calculate_temporal_similarity(self, time1: float, time2: float) -> float:
"""计算时间相似度"""
time_diff = abs(time1 - time2)
hours_diff = time_diff / 3600
# 24小时内相似度较高
if hours_diff <= 24:
return 1.0 - (hours_diff / 24)
elif hours_diff <= 168: # 一周内
return 0.7 - ((hours_diff - 24) / 168) * 0.5
else:
return 0.2
def _calculate_logical_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float:
"""计算逻辑一致性"""
# 检查主谓宾结构的逻辑一致性
consistency_score = 0.0
# 主语一致性
subjects1 = set(mem1.subjects)
subjects2 = set(mem2.subjects)
if subjects1 or subjects2:
overlap = len(subjects1 & subjects2)
union_count = max(len(subjects1 | subjects2), 1)
consistency_score += (overlap / union_count) * 0.4
# 谓语相似性
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)
consistency_score += predicate_sim * 0.3
# 宾语相似性
if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str):
object_sim = self._calculate_text_similarity(str(mem1.content.object), str(mem2.content.object))
consistency_score += object_sim * 0.3
return consistency_score
def _update_similarity_matrix(self, group: DuplicateGroup, new_memory: MemoryChunk, similarity: float):
"""更新组的相似度矩阵"""
# 为新记忆添加行和列
for i in range(len(group.similarity_matrix)):
group.similarity_matrix[i].append(similarity)
# 添加新行
new_row = [similarity] + [1.0] * len(group.similarity_matrix)
group.similarity_matrix.append(new_row)
def _select_representative_memory(self, group: DuplicateGroup) -> MemoryChunk:
"""选择代表性记忆"""
if not group.memories:
return None
# 评分标准
best_memory = None
best_score = -1.0
for memory in group.memories:
score = 0.0
# 置信度权重
score += memory.metadata.confidence.value * 0.3
# 重要性权重
score += memory.metadata.importance.value * 0.3
# 访问次数权重
score += min(memory.metadata.access_count * 0.1, 0.2)
# 相关度权重
score += memory.metadata.relevance_score * 0.2
if score > best_score:
best_score = score
best_memory = memory
return best_memory
async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None:
"""融合记忆组"""
if not group.memories:
return None
if len(group.memories) == 1:
return group.memories[0]
try:
# 选择基础记忆(通常是代表性记忆)
base_memory = group.representative_memory or group.memories[0]
# 融合其他记忆的属性
fused_memory = await self._merge_memory_attributes(base_memory, group.memories)
# 更新元数据
self._update_fused_metadata(fused_memory, group)
logger.debug(f"成功融合记忆组,包含 {len(group.memories)} 条原始记忆")
return fused_memory
except Exception as e:
logger.error(f"融合记忆组失败: {e}")
# 返回置信度最高的记忆
return max(group.memories, key=lambda m: m.metadata.confidence.value)
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk:
"""合并记忆属性"""
# 创建基础记忆的深拷贝
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
# 合并关键词
all_keywords = set()
for memory in memories:
all_keywords.update(memory.keywords)
fused_memory.keywords = sorted(all_keywords)
# 合并标签
all_tags = set()
for memory in memories:
all_tags.update(memory.tags)
fused_memory.tags = sorted(all_tags)
# 合并分类
all_categories = set()
for memory in memories:
all_categories.update(memory.categories)
fused_memory.categories = sorted(all_categories)
# 合并关联记忆
all_related = set()
for memory in memories:
all_related.update(memory.related_memories)
# 移除对自身和组内记忆的引用
all_related = {rid for rid in all_related if rid not in [m.memory_id for m in memories]}
fused_memory.related_memories = sorted(all_related)
# 合并时间上下文
if self.fusion_strategies["temporal_proximity"]:
fused_memory.temporal_context = self._merge_temporal_context(memories)
return fused_memory
def _update_fused_metadata(self, fused_memory: MemoryChunk, group: DuplicateGroup):
"""更新融合记忆的元数据"""
# 更新修改时间
fused_memory.metadata.last_modified = time.time()
# 计算平均访问次数
total_access = sum(m.metadata.access_count for m in group.memories)
fused_memory.metadata.access_count = total_access
# 提升置信度(如果有多个来源支持)
if self.fusion_strategies["confidence_boosting"] and len(group.memories) > 1:
max_confidence = max(m.metadata.confidence.value for m in group.memories)
if max_confidence < ConfidenceLevel.VERIFIED.value:
fused_memory.metadata.confidence = ConfidenceLevel(
min(max_confidence + 1, ConfidenceLevel.VERIFIED.value)
)
# 保持最高重要性
if self.fusion_strategies["importance_preservation"]:
max_importance = max(m.metadata.importance.value for m in group.memories)
fused_memory.metadata.importance = ImportanceLevel(max_importance)
# 计算平均相关度
avg_relevance = sum(m.metadata.relevance_score for m in group.memories) / len(group.memories)
fused_memory.metadata.relevance_score = min(avg_relevance * 1.1, 1.0) # 稍微提升相关度
# 设置来源信息
source_ids = [m.memory_id[:8] for m in group.memories]
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]:
"""合并时间上下文"""
contexts = [m.temporal_context for m in memories if m.temporal_context]
if not contexts:
return {}
# 计算时间范围
timestamps = [m.metadata.created_at for m in memories]
earliest_time = min(timestamps)
latest_time = max(timestamps)
merged_context = {
"earliest_timestamp": earliest_time,
"latest_timestamp": latest_time,
"time_span_hours": (latest_time - earliest_time) / 3600,
"source_memories": len(memories),
}
# 合并其他上下文信息
for context in contexts:
for key, value in context.items():
if key not in ["timestamp", "earliest_timestamp", "latest_timestamp"]:
if key not in merged_context:
merged_context[key] = value
elif merged_context[key] != value:
merged_context[key] = f"multiple: {value}"
return merged_context
async def incremental_fusion(
self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk]
) -> tuple[MemoryChunk, list[MemoryChunk]]:
"""增量融合(单个新记忆与现有记忆融合)"""
# 寻找相似记忆
similar_memories = []
for existing in existing_memories:
similarity = self._calculate_comprehensive_similarity(new_memory, existing)
if similarity >= self.similarity_threshold:
similar_memories.append((existing, similarity))
if not similar_memories:
# 没有相似记忆,直接返回
return new_memory, existing_memories
# 按相似度排序
similar_memories.sort(key=lambda x: x[1], reverse=True)
# 与最相似的记忆融合
best_match, similarity = similar_memories[0]
# 创建融合组
group = DuplicateGroup(
group_id=f"incremental_{int(time.time())}",
memories=[new_memory, best_match],
similarity_matrix=[[1.0, similarity], [similarity, 1.0]],
)
# 执行融合
fused_memory = await self._fuse_memory_group(group)
# 从现有记忆中移除被融合的记忆
updated_existing = [m for m in existing_memories if m.memory_id != best_match.memory_id]
updated_existing.append(fused_memory)
logger.debug(f"增量融合完成,相似度: {similarity:.3f}")
return fused_memory, updated_existing
def _update_fusion_stats(self, original_count: int, removed_count: int, fusion_time: float):
"""更新融合统计"""
self.fusion_stats["total_fusions"] += 1
self.fusion_stats["memories_fused"] += original_count
self.fusion_stats["duplicates_removed"] += removed_count
# 更新平均相似度(估算)
if removed_count > 0:
avg_similarity = 0.9 # 假设平均相似度较高
total_similarity = self.fusion_stats["average_similarity"] * (self.fusion_stats["total_fusions"] - 1)
total_similarity += avg_similarity
self.fusion_stats["average_similarity"] = total_similarity / self.fusion_stats["total_fusions"]
async def maintenance(self):
"""维护操作"""
try:
logger.info("开始记忆融合引擎维护...")
# 可以在这里添加定期维护任务,如:
# - 重新评估低置信度记忆
# - 清理孤立记忆引用
# - 优化融合策略参数
logger.info("✅ 记忆融合引擎维护完成")
except Exception as e:
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
def get_fusion_stats(self) -> dict[str, Any]:
"""获取融合统计信息"""
return self.fusion_stats.copy()
def reset_stats(self):
"""重置统计信息"""
self.fusion_stats = {
"total_fusions": 0,
"memories_fused": 0,
"duplicates_removed": 0,
"average_similarity": 0.0,
}

View File

@@ -1,512 +0,0 @@
"""
记忆系统管理器
替代原有的 Hippocampus 和 instant_memory 系统
"""
import re
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.message_collection_processor import MessageCollectionProcessor
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class MemoryResult:
"""记忆查询结果"""
content: str
memory_type: str
confidence: float
importance: float
timestamp: float
source: str = "memory"
relevance_score: float = 0.0
structure: dict[str, Any] | None = None
class MemoryManager:
"""记忆系统管理器 - 替代原有的 HippocampusManager"""
def __init__(self):
self.memory_system: MemorySystem | None = None
self.message_collection_storage: MessageCollectionStorage | None = None
self.message_collection_processor: MessageCollectionProcessor | None = None
self.is_initialized = False
self.user_cache = {} # 用户记忆缓存
def _clean_text(self, text: Any) -> str:
if text is None:
return ""
cleaned = re.sub(r"[\s\u3000]+", " ", str(text)).strip()
cleaned = re.sub(r"[、,,;]+$", "", cleaned)
return cleaned
async def initialize(self):
"""初始化记忆系统"""
if self.is_initialized:
return
try:
from src.config.config import global_config
# 检查是否启用记忆系统
if not global_config.memory.enable_memory:
logger.info("记忆系统已禁用,跳过初始化")
self.is_initialized = True
return
logger.info("正在初始化记忆系统...")
# 初始化记忆系统
from src.chat.memory_system.memory_system import get_memory_system
self.memory_system = get_memory_system()
# 初始化消息集合系统
self.message_collection_storage = MessageCollectionStorage()
self.message_collection_processor = MessageCollectionProcessor(self.message_collection_storage)
self.is_initialized = True
logger.info(" 记忆系统初始化完成")
except Exception as e:
logger.error(f"记忆系统初始化失败: {e}")
# 如果系统初始化失败,创建一个空的管理器避免系统崩溃
self.memory_system = None
self.message_collection_storage = None
self.message_collection_processor = None
self.is_initialized = True # 标记为已初始化但系统不可用
def get_hippocampus(self):
"""兼容原有接口 - 返回空"""
logger.debug("get_hippocampus 调用 - 记忆系统不使用此方法")
return {}
async def build_memory(self):
"""兼容原有接口 - 构建记忆"""
if not self.is_initialized or not self.memory_system:
return
try:
# 记忆系统使用实时构建,不需要定时构建
logger.debug("build_memory 调用 - 记忆系统使用实时构建")
except Exception as e:
logger.error(f"build_memory 失败: {e}")
async def forget_memory(self, percentage: float = 0.005):
"""兼容原有接口 - 遗忘机制"""
if not self.is_initialized or not self.memory_system:
return
try:
# 增强记忆系统有内置的遗忘机制
logger.debug(f"forget_memory 调用 - 参数: {percentage}")
# 可以在这里调用增强系统的维护功能
await self.memory_system.maintenance()
except Exception as e:
logger.error(f"forget_memory 失败: {e}")
async def get_memory_from_text(
self,
text: str,
chat_id: str,
user_id: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
time_weight: float = 1.0,
keyword_weight: float = 1.0,
) -> list[tuple[str, str]]:
"""从文本获取相关记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
try:
# 使用增强记忆系统检索
context = {
"chat_id": chat_id,
"expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE],
}
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query=text, user_id=user_id, context=context, limit=max_memory_num
)
# 转换为原有格式 (topic, content)
results = []
for memory in relevant_memories:
topic = memory.memory_type.value
content = memory.text_content
results.append((topic, content))
logger.debug(f"从文本检索到 {len(results)} 条相关记忆")
# 如果检索到有效记忆,打印详细信息
if results:
logger.info(f"📚 从文本 '{text[:50]}...' 检索到 {len(results)} 条有效记忆:")
for i, (topic, content) in enumerate(results, 1):
# 处理长内容如果超过150字符则截断
display_content = content
if len(content) > 150:
display_content = content[:150] + "..."
logger.info(f" 记忆#{i} [{topic}]: {display_content}")
return results
except Exception as e:
logger.error(f"get_memory_from_text 失败: {e}")
return []
async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> list[tuple[str, str]]:
"""从关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
try:
# 将关键词转换为查询文本
query_text = " ".join(valid_keywords)
# 使用增强记忆系统检索
context = {
"keywords": valid_keywords,
"expected_memory_types": [
MemoryType.PERSONAL_FACT,
MemoryType.EVENT,
MemoryType.PREFERENCE,
MemoryType.OPINION,
],
}
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="default_user", # 可以根据实际需要传递
context=context,
limit=max_memory_num,
)
# 转换为原有格式 (topic, content)
results = []
for memory in relevant_memories:
topic = memory.memory_type.value
content = memory.text_content
results.append((topic, content))
logger.debug(f"从关键词 {valid_keywords} 检索到 {len(results)} 条相关记忆")
# 如果检索到有效记忆,打印详细信息
if results:
keywords_str = ", ".join(valid_keywords[:5]) # 最多显示5个关键词
if len(valid_keywords) > 5:
keywords_str += f" ... (共{len(valid_keywords)}个关键词)"
logger.info(f"🔍 从关键词 [{keywords_str}] 检索到 {len(results)} 条有效记忆:")
for i, (topic, content) in enumerate(results, 1):
# 处理长内容如果超过150字符则截断
display_content = content
if len(content) > 150:
display_content = content[:150] + "..."
logger.info(f" 记忆#{i} [{topic}]: {display_content}")
return results
except Exception as e:
logger.error(f"get_memory_from_topic 失败: {e}")
return []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从单个关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
try:
# 同步方法,返回空列表
logger.debug(f"get_memory_from_keyword 调用 - 关键词: {keyword}")
return []
except Exception as e:
logger.error(f"get_memory_from_keyword 失败: {e}")
return []
async def process_conversation(
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None
) -> list[MemoryChunk]:
"""处理对话并构建记忆 - 新增功能"""
if not self.is_initialized or not self.memory_system:
return []
try:
# 将消息添加到消息集合处理器
chat_id = context.get("chat_id")
if self.message_collection_processor and chat_id:
await self.message_collection_processor.add_message(conversation_text, chat_id)
payload_context = dict(context or {})
payload_context.setdefault("conversation_text", conversation_text)
if timestamp is not None:
payload_context.setdefault("timestamp", timestamp)
result = await self.memory_system.process_conversation_memory(payload_context)
# 从结果中提取记忆块
memory_chunks = []
if result.get("success"):
memory_chunks = result.get("created_memories", [])
logger.info(f"从对话构建了 {len(memory_chunks)} 条记忆")
return memory_chunks
except Exception as e:
logger.error(f"process_conversation 失败: {e}")
return []
async def get_enhanced_memory_context(
self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5
) -> list[MemoryResult]:
"""获取增强记忆上下文 - 新增功能"""
if not self.is_initialized or not self.memory_system:
return []
try:
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query=query_text, user_id=None, context=context or {}, limit=limit
)
results = []
for memory in relevant_memories:
formatted_content, structure = self._format_memory_chunk(memory)
result = MemoryResult(
content=formatted_content,
memory_type=memory.memory_type.value,
confidence=memory.metadata.confidence.value,
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="enhanced_memory",
relevance_score=memory.metadata.relevance_score,
structure=structure,
)
results.append(result)
return results
except Exception as e:
logger.error(f"get_enhanced_memory_context 失败: {e}")
return []
def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]:
"""将记忆块转换为更易读的文本描述"""
structure = memory.content.to_dict()
if memory.display:
return self._clean_text(memory.display), structure
subject = structure.get("subject")
predicate = structure.get("predicate") or ""
obj = structure.get("object")
subject_display = self._format_subject(subject, memory)
formatted = self._apply_predicate_format(subject_display, predicate, obj)
if not formatted:
predicate_display = self._format_predicate(predicate)
object_display = self._format_object(obj)
formatted = f"{subject_display}{predicate_display}{object_display}".strip()
formatted = self._clean_text(formatted)
return formatted, structure
def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str:
if not subject:
return "该用户"
if subject == memory.metadata.user_id:
return "该用户"
if memory.metadata.chat_id and subject == memory.metadata.chat_id:
return "该聊天"
return self._clean_text(subject)
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None:
predicate = (predicate or "").strip()
obj_value = obj
if predicate == "is_named":
name = self._extract_from_object(obj_value, ["name", "nickname"]) or self._format_object(obj_value)
name = self._clean_text(name)
if not name:
return None
name_display = name if (name.startswith("") and name.endswith("")) else f"{name}"
return f"{subject}的昵称是{name_display}"
if predicate == "is_age":
age = self._extract_from_object(obj_value, ["age"]) or self._format_object(obj_value)
age = self._clean_text(age)
if not age:
return None
return f"{subject}今年{age}"
if predicate == "is_profession":
profession = self._extract_from_object(obj_value, ["profession", "job"]) or self._format_object(obj_value)
profession = self._clean_text(profession)
if not profession:
return None
return f"{subject}的职业是{profession}"
if predicate == "lives_in":
location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(
obj_value
)
location = self._clean_text(location)
if not location:
return None
return f"{subject}居住在{location}"
if predicate == "has_phone":
phone = self._extract_from_object(obj_value, ["phone", "number"]) or self._format_object(obj_value)
phone = self._clean_text(phone)
if not phone:
return None
return f"{subject}的电话号码是{phone}"
if predicate == "has_email":
email = self._extract_from_object(obj_value, ["email"]) or self._format_object(obj_value)
email = self._clean_text(email)
if not email:
return None
return f"{subject}的邮箱是{email}"
if predicate == "likes":
liked = self._format_object(obj_value)
if not liked:
return None
return f"{subject}喜欢{liked}"
if predicate == "likes_food":
food = self._format_object(obj_value)
if not food:
return None
return f"{subject}爱吃{food}"
if predicate == "dislikes":
disliked = self._format_object(obj_value)
if not disliked:
return None
return f"{subject}不喜欢{disliked}"
if predicate == "hates":
hated = self._format_object(obj_value)
if not hated:
return None
return f"{subject}讨厌{hated}"
if predicate == "favorite_is":
favorite = self._format_object(obj_value)
if not favorite:
return None
return f"{subject}最喜欢{favorite}"
if predicate == "mentioned_event":
event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(
obj_value
)
event_text = self._clean_text(self._truncate(event_text))
if not event_text:
return None
return f"{subject}提到了计划或事件:{event_text}"
if predicate in {"正在", "", "正在进行"}:
action = self._format_object(obj_value)
if not action:
return None
return f"{subject}{predicate}{action}"
if predicate in {"感到", "觉得", "表示", "提到", "说道", ""}:
feeling = self._format_object(obj_value)
if not feeling:
return None
return f"{subject}{predicate}{feeling}"
if predicate in {"", "", ""}:
counterpart = self._format_object(obj_value)
if counterpart:
return f"{subject}{predicate}{counterpart}"
return f"{subject}{predicate}"
return None
def _format_predicate(self, predicate: str) -> str:
if not predicate:
return ""
predicate_map = {
"is_named": "的昵称是",
"is_profession": "的职业是",
"lives_in": "居住在",
"has_phone": "的电话是",
"has_email": "的邮箱是",
"likes": "喜欢",
"dislikes": "不喜欢",
"likes_food": "爱吃",
"hates": "讨厌",
"favorite_is": "最喜欢",
"mentioned_event": "提到的事件",
}
if predicate in predicate_map:
connector = predicate_map[predicate]
if connector.startswith(""):
return connector
return f" {connector} "
cleaned = predicate.replace("_", " ").strip()
if re.search(r"[\u4e00-\u9fff]", cleaned):
return cleaned
return f" {cleaned} "
def _format_object(self, obj: Any) -> str:
if obj is None:
return ""
if isinstance(obj, dict):
parts = []
for key, value in obj.items():
formatted_value = self._format_object(value)
if not formatted_value:
continue
pretty_key = {
"name": "名字",
"profession": "职业",
"location": "位置",
"event_text": "内容",
"timestamp": "时间",
}.get(key, key)
parts.append(f"{pretty_key}: {formatted_value}")
return self._clean_text("".join(parts))
if isinstance(obj, list):
formatted_items = [self._format_object(item) for item in obj]
filtered = [item for item in formatted_items if item]
return self._clean_text("".join(filtered)) if filtered else ""
if isinstance(obj, int | float):
return str(obj)
text = self._truncate(str(obj).strip())
return self._clean_text(text)
def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None:
if isinstance(obj, dict):
for key in keys:
if obj.get(key):
value = obj[key]
if isinstance(value, dict | list):
return self._clean_text(self._format_object(value))
return self._clean_text(value)
if isinstance(obj, list) and obj:
return self._clean_text(self._format_object(obj[0]))
if isinstance(obj, str | int | float):
return self._clean_text(obj)
return None
def _truncate(self, text: str, max_length: int = 80) -> str:
if len(text) <= max_length:
return text
return text[: max_length - 1] + ""
async def shutdown(self):
"""关闭增强记忆系统"""
if not self.is_initialized:
return
try:
if self.memory_system:
await self.memory_system.shutdown()
logger.info(" 记忆系统已关闭")
except Exception as e:
logger.error(f"关闭记忆系统失败: {e}")
# 全局记忆管理器实例
memory_manager = MemoryManager()

View File

@@ -1,122 +0,0 @@
"""
记忆元数据索引。
"""
from dataclasses import asdict, dataclass
from typing import Any
from src.common.logger import get_logger
logger = get_logger(__name__)
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
@dataclass
class MemoryMetadataIndexEntry:
memory_id: str
user_id: str
memory_type: str
subjects: list[str]
objects: list[str]
keywords: list[str]
tags: list[str]
importance: int
confidence: int
created_at: float
access_count: int
chat_id: str | None = None
content_preview: str | None = None
class MemoryMetadataIndex:
"""Rust 加速版本唯一实现。"""
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
self._rust = _RustIndex(index_file)
# 仅为向量层和调试提供最小缓存长度判断、get_entry 返回)
self.index: dict[str, MemoryMetadataIndexEntry] = {}
logger.info("✅ MemoryMetadataIndex (Rust) 初始化完成,仅支持加速实现")
# 向后代码仍调用的接口batch_add_or_update / add_or_update
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
if not entries:
return
payload = []
for e in entries:
if not e.memory_id:
continue
self.index[e.memory_id] = e
payload.append(asdict(e))
if payload:
try:
self._rust.batch_add(payload)
except Exception as ex:
logger.error(f"Rust 元数据批量添加失败: {ex}")
def add_or_update(self, entry: MemoryMetadataIndexEntry):
self.batch_add_or_update([entry])
def search(
self,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
keywords: list[str] | None = None,
tags: list[str] | None = None,
importance_min: int | None = None,
importance_max: int | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
flexible_mode: bool = True,
) -> list[str]:
params: dict[str, Any] = {
"user_id": user_id,
"memory_types": memory_types,
"subjects": subjects,
"keywords": keywords,
"tags": tags,
"importance_min": importance_min,
"importance_max": importance_max,
"created_after": created_after,
"created_before": created_before,
"limit": limit,
}
params = {k: v for k, v in params.items() if v is not None}
try:
if flexible_mode:
return list(self._rust.search_flexible(params))
return list(self._rust.search_strict(params))
except Exception as ex:
logger.error(f"Rust 搜索失败返回空: {ex}")
return []
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
return self.index.get(memory_id)
def get_stats(self) -> dict[str, Any]:
try:
raw = self._rust.stats()
return {
"total_memories": raw.get("total", 0),
"types": raw.get("types_dist", {}),
"subjects_count": raw.get("subjects_indexed", 0),
"keywords_count": raw.get("keywords_indexed", 0),
"tags_count": raw.get("tags_indexed", 0),
}
except Exception as ex:
logger.warning(f"读取 Rust stats 失败: {ex}")
return {"total_memories": 0}
def save(self): # 仅调用 rust save
try:
self._rust.save()
except Exception as ex:
logger.warning(f"Rust save 失败: {ex}")
__all__ = [
"MemoryMetadataIndex",
"MemoryMetadataIndexEntry",
]

View File

@@ -1,219 +0,0 @@
"""记忆检索查询规划器"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any
import orjson
from src.chat.memory_system.memory_chunk import MemoryType
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.utils.json_parser import extract_and_parse_json
logger = get_logger(__name__)
@dataclass
class MemoryQueryPlan:
"""查询规划结果"""
semantic_query: str
memory_types: list[MemoryType] = field(default_factory=list)
subject_includes: list[str] = field(default_factory=list)
object_includes: list[str] = field(default_factory=list)
required_keywords: list[str] = field(default_factory=list)
optional_keywords: list[str] = field(default_factory=list)
owner_filters: list[str] = field(default_factory=list)
recency_preference: str = "any"
limit: int = 10
emphasis: str | None = None
raw_plan: dict[str, Any] = field(default_factory=dict)
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
if not self.semantic_query:
self.semantic_query = fallback_query
if self.limit <= 0:
self.limit = default_limit
self.recency_preference = (self.recency_preference or "any").lower()
if self.recency_preference not in {"any", "recent", "historical"}:
self.recency_preference = "any"
self.emphasis = (self.emphasis or "balanced").lower()
class MemoryQueryPlanner:
"""基于小模型的记忆检索查询规划器"""
def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10):
self.model = planner_model
self.default_limit = default_limit
async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan:
if not self.model:
logger.debug("未提供查询规划模型,使用默认规划")
return self._default_plan(query_text)
prompt = self._build_prompt(query_text, context)
try:
response, _ = await self.model.generate_response_async(prompt, temperature=0.2)
# 使用统一的 JSON 解析工具
data = extract_and_parse_json(response, strict=False)
if not data or not isinstance(data, dict):
logger.debug("查询规划模型未返回有效的结构化结果,使用默认规划")
return self._default_plan(query_text)
plan = self._parse_plan_dict(data, query_text)
plan.ensure_defaults(query_text, self.default_limit)
return plan
except Exception as exc:
logger.error("查询规划模型调用失败: %s", exc, exc_info=True)
return self._default_plan(query_text)
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
def _collect_list(key: str) -> list[str]:
value = data.get(key)
if isinstance(value, str):
return [value]
if isinstance(value, list):
return [self._safe_str(item) for item in value if self._safe_str(item)]
return []
memory_type_values = _collect_list("memory_types")
memory_types: list[MemoryType] = []
for item in memory_type_values:
if not item:
continue
try:
memory_types.append(MemoryType(item))
except ValueError:
# 尝试匹配value值
normalized = item.lower()
for mt in MemoryType:
if mt.value == normalized:
memory_types.append(mt)
break
plan = MemoryQueryPlan(
semantic_query=semantic_query,
memory_types=memory_types,
subject_includes=_collect_list("subject_includes"),
object_includes=_collect_list("object_includes"),
required_keywords=_collect_list("required_keywords"),
optional_keywords=_collect_list("optional_keywords"),
owner_filters=_collect_list("owner_filters"),
recency_preference=self._safe_str(data.get("recency")) or "any",
limit=self._safe_int(data.get("limit"), self.default_limit),
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
raw_plan=data,
)
return plan
def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str:
participants = context.get("participants") or context.get("speaker_names") or []
if isinstance(participants, str):
participants = [participants]
participants = [p for p in participants if isinstance(p, str) and p.strip()]
participant_preview = "".join(participants[:5]) or "未知"
persona = context.get("bot_personality") or context.get("bot_identity") or "未知"
# 构建未读消息上下文信息
context_section = ""
if context.get("has_unread_context") and context.get("unread_messages_context"):
unread_context = context["unread_messages_context"]
unread_messages = unread_context.get("messages", [])
unread_keywords = unread_context.get("keywords", [])
unread_participants = unread_context.get("participants", [])
context_summary = unread_context.get("context_summary", "")
if unread_messages:
# 构建未读消息摘要
message_previews = []
for msg in unread_messages[:5]: # 最多显示5条
sender = msg.get("sender", "未知")
content = msg.get("content", "")[:100] # 限制每条消息长度
message_previews.append(f"{sender}: {content}")
context_section = f"""
## 📋 未读消息上下文 (共{unread_context.get("total_count", 0)}条未读消息)
### 最近消息预览:
{chr(10).join(message_previews)}
### 上下文关键词:
{", ".join(unread_keywords[:15]) if unread_keywords else ""}
### 对话参与者:
{", ".join(unread_participants) if unread_participants else ""}
### 上下文摘要:
{context_summary[:300] if context_summary else ""}
"""
else:
context_section = """
## 📋 未读消息上下文:
无未读消息或上下文信息不可用
"""
return f"""
你是一名记忆检索规划助手,请基于输入生成一个简洁的 JSON 检索计划。
你的任务是分析当前查询并结合未读消息的上下文,生成更精准的记忆检索策略。
仅需提供以下字段:
- semantic_query: 用于向量召回的自然语言描述,要求具体且贴合当前查询和上下文;
- memory_types: 建议检索的记忆类型列表,取值范围来自 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual)
- subject_includes: 建议出现在记忆主语中的人物或角色;
- object_includes: 建议关注的对象、主题或关键信息;
- required_keywords: 建议必须包含的关键词(从上下文中提取);
- recency: 推荐的时间偏好,可选 recent/any/historical
- limit: 推荐的最大返回数量 (1-15)
- emphasis: 检索重点,可选 balanced/contextual/recent/comprehensive。
请不要生成谓语字段,也不要额外补充其它参数。
## 当前查询:
"{query_text}"
## 已知对话参与者:
{participant_preview}
## 机器人设定:
{persona}{context_section}
## 🎯 指导原则:
1. **上下文关联**: 优先分析与当前查询相关的未读消息内容和关键词
2. **语义理解**: 结合上下文理解查询的真实意图,而非字面意思
3. **参与者感知**: 考虑未读消息中的参与者,检索与他们相关的记忆
4. **主题延续**: 关注未读消息中讨论的主题,检索相关的历史记忆
5. **时间相关性**: 如果未读消息讨论最近的事件,偏向检索相关时期的记忆
请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。
"""
@staticmethod
def _safe_str(value: Any) -> str:
if isinstance(value, str):
return value.strip()
if value is None:
return ""
return str(value).strip()
@staticmethod
def _safe_int(value: Any, default: int) -> int:
try:
number = int(value)
if number <= 0:
return default
return number
except (TypeError, ValueError):
return default

File diff suppressed because it is too large Load Diff

View File

@@ -1,75 +0,0 @@
"""
消息集合处理器
负责收集消息、创建集合并将其存入向量存储。
"""
import asyncio
from collections import deque
from typing import Any
from src.chat.memory_system.memory_chunk import MessageCollection
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
from src.common.logger import get_logger
logger = get_logger(__name__)
class MessageCollectionProcessor:
"""处理消息集合的创建和存储"""
def __init__(self, storage: MessageCollectionStorage, buffer_size: int = 5):
self.storage = storage
self.buffer_size = buffer_size
self.message_buffers: dict[str, deque[str]] = {}
self._lock = asyncio.Lock()
async def add_message(self, message_text: str, chat_id: str):
"""添加一条新消息到指定聊天的缓冲区,并在满时触发处理"""
async with self._lock:
if not isinstance(message_text, str) or not message_text.strip():
return
if chat_id not in self.message_buffers:
self.message_buffers[chat_id] = deque(maxlen=self.buffer_size)
buffer = self.message_buffers[chat_id]
buffer.append(message_text)
logger.debug(f"消息已添加到聊天 '{chat_id}' 的缓冲区,当前数量: {len(buffer)}/{self.buffer_size}")
if len(buffer) == self.buffer_size:
await self._process_buffer(chat_id)
async def _process_buffer(self, chat_id: str):
"""处理指定聊天缓冲区中的消息,创建并存储一个集合"""
buffer = self.message_buffers.get(chat_id)
if not buffer or len(buffer) < self.buffer_size:
return
messages_to_process = list(buffer)
buffer.clear()
logger.info(f"聊天 '{chat_id}' 的消息缓冲区已满,开始创建消息集合...")
try:
combined_text = "\n".join(messages_to_process)
collection = MessageCollection(
chat_id=chat_id,
messages=messages_to_process,
combined_text=combined_text,
)
await self.storage.add_collection(collection)
logger.info(f"成功为聊天 '{chat_id}' 创建并存储了新的消息集合: {collection.collection_id}")
except Exception as e:
logger.error(f"处理聊天 '{chat_id}' 的消息缓冲区失败: {e}", exc_info=True)
def get_stats(self) -> dict[str, Any]:
"""获取处理器统计信息"""
total_buffered_messages = sum(len(buf) for buf in self.message_buffers.values())
return {
"active_buffers": len(self.message_buffers),
"total_buffered_messages": total_buffered_messages,
"buffer_capacity_per_chat": self.buffer_size,
}

View File

@@ -1,193 +0,0 @@
"""
消息集合向量存储系统
专用于存储和检索消息集合,以提供即时上下文。
"""
import time
from typing import Any
from src.chat.memory_system.memory_chunk import MessageCollection
from src.chat.utils.utils import get_embedding
from src.common.logger import get_logger
from src.common.vector_db import vector_db_service
from src.config.config import global_config
logger = get_logger(__name__)
class MessageCollectionStorage:
"""消息集合向量存储"""
def __init__(self):
self.config = global_config.memory
self.vector_db_service = vector_db_service
self.collection_name = "message_collections"
self._initialize_storage()
def _initialize_storage(self):
"""初始化存储"""
try:
self.vector_db_service.get_or_create_collection(
name=self.collection_name,
metadata={"description": "短期消息集合记忆", "hnsw:space": "cosine"},
)
logger.info(f"消息集合存储初始化完成,集合: '{self.collection_name}'")
except Exception as e:
logger.error(f"消息集合存储初始化失败: {e}", exc_info=True)
raise
async def add_collection(self, collection: MessageCollection):
"""添加一个新的消息集合,并处理容量和时间限制"""
try:
# 清理过期和超额的集合
await self._cleanup_collections()
# 向量化并存储
embedding = await get_embedding(collection.combined_text)
if not embedding:
logger.warning(f"无法为消息集合 {collection.collection_id} 生成向量,跳过存储。")
return
collection.embedding = embedding
self.vector_db_service.add(
collection_name=self.collection_name,
embeddings=[embedding],
ids=[collection.collection_id],
documents=[collection.combined_text],
metadatas=[collection.to_dict()],
)
logger.debug(f"成功存储消息集合: {collection.collection_id}")
except Exception as e:
logger.error(f"存储消息集合失败: {e}", exc_info=True)
async def _cleanup_collections(self):
"""清理超额和过期的消息集合"""
try:
# 基于时间清理
if self.config.instant_memory_retention_hours > 0:
expiration_time = time.time() - self.config.instant_memory_retention_hours * 3600
expired_docs = self.vector_db_service.get(
collection_name=self.collection_name,
where={"created_at": {"$lt": expiration_time}},
include=[], # 只获取ID
)
if expired_docs and expired_docs.get("ids"):
self.vector_db_service.delete(collection_name=self.collection_name, ids=expired_docs["ids"])
logger.info(f"删除了 {len(expired_docs['ids'])} 个过期的瞬时记忆")
# 基于数量清理
current_count = self.vector_db_service.count(self.collection_name)
if current_count > self.config.instant_memory_max_collections:
num_to_delete = current_count - self.config.instant_memory_max_collections
# 获取所有文档的元数据以进行排序
all_docs = self.vector_db_service.get(
collection_name=self.collection_name,
include=["metadatas"]
)
if all_docs and all_docs.get("ids"):
# 在内存中排序找到最旧的文档
sorted_docs = sorted(
zip(all_docs["ids"], all_docs["metadatas"]),
key=lambda item: item[1].get("created_at", 0),
)
ids_to_delete = [doc[0] for doc in sorted_docs[:num_to_delete]]
if ids_to_delete:
self.vector_db_service.delete(collection_name=self.collection_name, ids=ids_to_delete)
logger.info(f"消息集合已满,删除最旧的 {len(ids_to_delete)} 个集合")
except Exception as e:
logger.error(f"清理消息集合失败: {e}", exc_info=True)
async def get_relevant_collection(self, query_text: str, n_results: int = 1) -> list[MessageCollection]:
"""根据查询文本检索最相关的消息集合"""
if not query_text.strip():
return []
try:
query_embedding = await get_embedding(query_text)
if not query_embedding:
return []
results = self.vector_db_service.query(
collection_name=self.collection_name,
query_embeddings=[query_embedding],
n_results=n_results,
)
collections = []
if results and results.get("ids") and results["ids"][0]:
collections.extend(MessageCollection.from_dict(metadata) for metadata in results["metadatas"][0])
return collections
except Exception as e:
logger.error(f"检索相关消息集合失败: {e}", exc_info=True)
return []
async def get_message_collection_context(self, query_text: str, chat_id: str) -> str:
"""获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。"""
try:
collections = await self.get_relevant_collection(query_text, n_results=5)
if not collections:
return ""
# 根据传入的 chat_id 对集合进行排序
collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True)
context_parts = []
for collection in collections:
if not collection.combined_text:
continue
header = "## 📝 相关对话上下文\n"
if collection.chat_id == chat_id:
# 匹配的ID使用更明显的标识
context_parts.append(
f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```"
)
else:
# 不匹配的ID
context_parts.append(
f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```"
)
if not context_parts:
return ""
# 格式化消息集合为 prompt 上下文
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
return f"\n{final_context}\n"
except Exception as e:
logger.error(f"get_message_collection_context 失败: {e}")
return ""
def clear_all(self):
"""清空所有消息集合"""
try:
# In ChromaDB, the easiest way to clear a collection is to delete and recreate it.
self.vector_db_service.delete_collection(name=self.collection_name)
self._initialize_storage()
logger.info(f"已清空所有消息集合: '{self.collection_name}'")
except Exception as e:
logger.error(f"清空消息集合失败: {e}", exc_info=True)
def get_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
try:
count = self.vector_db_service.count(self.collection_name)
return {
"collection_name": self.collection_name,
"total_collections": count,
"storage_limit": self.config.instant_memory_max_collections,
}
except Exception as e:
logger.error(f"获取消息集合存储统计失败: {e}")
return {}

File diff suppressed because it is too large Load Diff

View File

@@ -69,7 +69,11 @@ class SingleStreamContextManager:
try:
from .message_manager import message_manager as mm
message_manager = mm
use_cache_system = message_manager.is_running
# 检查配置是否启用消息缓存系统
cache_enabled = global_config.chat.enable_message_cache
use_cache_system = message_manager.is_running and cache_enabled
if not cache_enabled:
logger.debug(f"消息缓存系统已在配置中禁用")
except Exception as e:
logger.debug(f"MessageManager不可用使用直接添加: {e}")
use_cache_system = False

View File

@@ -323,8 +323,8 @@ class GlobalNoticeManager:
return message.additional_config.get("is_notice", False)
elif isinstance(message.additional_config, str):
# 兼容JSON字符串格式
import json
config = json.loads(message.additional_config)
import orjson
config = orjson.loads(message.additional_config)
return config.get("is_notice", False)
# 检查消息类型或其他标识
@@ -349,8 +349,8 @@ class GlobalNoticeManager:
if isinstance(message.additional_config, dict):
return message.additional_config.get("notice_type")
elif isinstance(message.additional_config, str):
import json
config = json.loads(message.additional_config)
import orjson
config = orjson.loads(message.additional_config)
return config.get("notice_type")
return None
except Exception:

View File

@@ -12,6 +12,7 @@ from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.core import get_db_session
from src.common.database.core.models import Images, Messages
from src.common.logger import get_logger
from src.config.config import global_config
from .chat_stream import ChatStream
from .message import MessageSending
@@ -181,12 +182,14 @@ class MessageStorageBatcher:
is_command = message.is_command or False
is_public_notice = message.is_public_notice or False
notice_type = message.notice_type
actions = message.actions
# 序列化actions列表为JSON字符串
actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None
should_reply = message.should_reply
should_act = message.should_act
additional_config = message.additional_config
key_words = ""
key_words_lite = ""
# 确保关键词字段是字符串格式(如果不是,则序列化)
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
memorized_times = 0
user_platform = message.user_info.platform if message.user_info else ""
@@ -253,7 +256,8 @@ class MessageStorageBatcher:
is_command = message.is_command
is_public_notice = getattr(message, "is_public_notice", False)
notice_type = getattr(message, "notice_type", None)
actions = getattr(message, "actions", None)
# 序列化actions列表为JSON字符串
actions = orjson.dumps(getattr(message, "actions", None)).decode("utf-8") if getattr(message, "actions", None) else None
should_reply = getattr(message, "should_reply", None)
should_act = getattr(message, "should_act", None)
additional_config = getattr(message, "additional_config", None)
@@ -275,6 +279,9 @@ class MessageStorageBatcher:
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
if user_id == global_config.bot.qq_account:
user_id = "SELF"
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
@@ -576,6 +583,11 @@ class MessageStorage:
is_picid = False
is_notify = False
is_command = False
is_public_notice = False
notice_type = None
actions = None
should_reply = False
should_act = False
key_words = ""
key_words_lite = ""
else:
@@ -589,6 +601,12 @@ class MessageStorage:
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
is_public_notice = getattr(message, "is_public_notice", False)
notice_type = getattr(message, "notice_type", None)
# 序列化actions列表为JSON字符串
actions = orjson.dumps(getattr(message, "actions", None)).decode("utf-8") if getattr(message, "actions", None) else None
should_reply = getattr(message, "should_reply", False)
should_act = getattr(message, "should_act", False)
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
@@ -612,6 +630,9 @@ class MessageStorage:
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
if user_id == global_config.bot.qq_account:
user_id = "SELF"
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
@@ -659,6 +680,11 @@ class MessageStorage:
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
is_public_notice=is_public_notice,
notice_type=notice_type,
actions=actions,
should_reply=should_reply,
should_act=should_act,
key_words=key_words,
key_words_lite=key_words_lite,
)

View File

@@ -255,8 +255,6 @@ class DefaultReplyer:
self._chat_info_initialized = False
self.heart_fc_sender = HeartFCSender()
# 使用新的增强记忆系统
# from src.chat.memory_system.enhanced_memory_activator import EnhancedMemoryActivator
self._chat_info_initialized = False
async def _initialize_chat_info(self):
@@ -393,19 +391,9 @@ class DefaultReplyer:
f"插件{result.get_summary().get('stopped_handlers', '')}于请求后取消了内容生成"
)
# 回复生成成功后,异步存储聊天记忆(不阻塞返回)
try:
# 将记忆存储作为子任务创建,可以被取消
memory_task = asyncio.create_task(
self._store_chat_memory_async(reply_to, reply_message),
name=f"store_memory_{self.chat_stream.stream_id}"
)
# 不等待完成,让它在后台运行
# 如果父任务被取消,这个子任务也会被垃圾回收
logger.debug(f"创建记忆存储子任务: {memory_task.get_name()}")
except Exception as memory_e:
# 记忆存储失败不应该影响回复生成的成功返回
logger.warning(f"记忆存储失败,但不影响回复生成: {memory_e}")
# 旧的自动记忆存储已移除,现在使用记忆图系统通过工具创建记忆
# 记忆由LLM在对话过程中通过CreateMemoryTool主动创建而非自动存储
pass
return True, llm_response, prompt
@@ -550,178 +538,116 @@ class DefaultReplyer:
Returns:
str: 记忆信息字符串
"""
if not global_config.memory.enable_memory:
return ""
# 使用新的记忆图系统检索记忆(带智能查询优化)
all_memories = []
try:
from src.memory_graph.manager_singleton import get_memory_manager, is_initialized
if is_initialized():
manager = get_memory_manager()
if manager:
# 构建查询上下文
stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None)
sender_name = ""
if user_info_obj:
sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "")
# 获取参与者信息
participants = []
try:
# 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'):
history_manager = stream.chat_history_manager
# 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history(
user_id=getattr(stream, "user_id", ""),
count=10,
memory_types=["chat_message", "system_message"]
)
# 提取唯一的参与者名称
for record in recent_records[:5]: # 最近5条记录
content = record.get("content", {})
participant = content.get("participant_name")
if participant and participant not in participants:
participants.append(participant)
instant_memory = None
# 如果消息包含发送者信息,也添加到参与者列表
if content.get("sender_name") and content.get("sender_name") not in participants:
participants.append(content.get("sender_name"))
except Exception as e:
logger.debug(f"获取参与者信息失败: {e}")
# 使用新的增强记忆系统检索记忆
running_memories = []
instant_memory = None
# 如果发送者不在参与者列表中,添加进去
if sender_name and sender_name not in participants:
participants.insert(0, sender_name)
if global_config.memory.enable_memory:
try:
# 使用新的统一记忆系统
from src.chat.memory_system import get_memory_system
# 格式化聊天历史为更友好的格式
formatted_history = ""
if chat_history:
# 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n')
recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines)
stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None)
group_info_obj = getattr(stream, "group_info", None)
memory_user_id = str(stream.stream_id)
memory_user_display = None
memory_aliases = []
user_info_dict = {}
if user_info_obj is not None:
raw_user_id = getattr(user_info_obj, "user_id", None)
if raw_user_id:
memory_user_id = str(raw_user_id)
if hasattr(user_info_obj, "to_dict"):
try:
user_info_dict = user_info_obj.to_dict() # type: ignore[attr-defined]
except Exception:
user_info_dict = {}
candidate_keys = [
"user_cardname",
"user_nickname",
"nickname",
"remark",
"display_name",
"user_name",
]
for key in candidate_keys:
value = user_info_dict.get(key)
if isinstance(value, str) and value.strip():
stripped = value.strip()
if memory_user_display is None:
memory_user_display = stripped
elif stripped not in memory_aliases:
memory_aliases.append(stripped)
attr_keys = [
"user_cardname",
"user_nickname",
"nickname",
"remark",
"display_name",
"name",
]
for attr in attr_keys:
value = getattr(user_info_obj, attr, None)
if isinstance(value, str) and value.strip():
stripped = value.strip()
if memory_user_display is None:
memory_user_display = stripped
elif stripped not in memory_aliases:
memory_aliases.append(stripped)
alias_values = (
user_info_dict.get("aliases")
or user_info_dict.get("alias_names")
or user_info_dict.get("alias")
query_context = {
"chat_history": formatted_history,
"sender": sender_name,
"participants": participants,
}
# 使用记忆管理器的智能检索(多查询策略)
memories = await manager.search_memories(
query=target,
top_k=10,
min_importance=0.3,
include_forgotten=False,
use_multi_query=True,
context=query_context,
)
if isinstance(alias_values, list | tuple | set):
for alias in alias_values:
if isinstance(alias, str) and alias.strip():
stripped = alias.strip()
if stripped not in memory_aliases and stripped != memory_user_display:
memory_aliases.append(stripped)
memory_context = {
"user_id": memory_user_id,
"user_display_name": memory_user_display or "",
"user_name": memory_user_display or "",
"nickname": memory_user_display or "",
"sender_name": memory_user_display or "",
"platform": getattr(stream, "platform", None),
"chat_id": stream.stream_id,
"stream_id": stream.stream_id,
}
if memory_aliases:
memory_context["user_aliases"] = memory_aliases
if group_info_obj is not None:
group_name = getattr(group_info_obj, "group_name", None) or getattr(
group_info_obj, "group_nickname", None
)
if group_name:
memory_context["group_name"] = str(group_name)
group_id = getattr(group_info_obj, "group_id", None)
if group_id:
memory_context["group_id"] = str(group_id)
memory_context = {key: value for key, value in memory_context.items() if value}
# 获取记忆系统实例
memory_system = get_memory_system()
# 使用统一记忆系统检索相关记忆
enhanced_memories = await memory_system.retrieve_relevant_memories(
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
)
# 注意:记忆存储已迁移到回复生成完成后进行,不在查询阶段执行
# 转换格式以兼容现有代码
running_memories = []
if enhanced_memories:
logger.debug(f"[记忆转换] 收到 {len(enhanced_memories)} 条原始记忆")
for idx, memory_chunk in enumerate(enhanced_memories, 1):
# 获取结构化内容的字符串表示
structure_display = str(memory_chunk.content) if hasattr(memory_chunk, "content") else "unknown"
# 获取记忆内容优先使用display
content = memory_chunk.display or memory_chunk.text_content or ""
# 调试:记录每条记忆的内容获取情况
logger.debug(
f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}"
if memories:
logger.info(f"[记忆图] 检索到 {len(memories)} 条相关记忆")
# 使用新的格式化工具构建完整的记忆描述
from src.memory_graph.utils.memory_formatter import (
format_memory_for_prompt,
get_memory_type_label,
)
running_memories.append(
{
"content": content,
"memory_type": memory_chunk.memory_type.value,
"confidence": memory_chunk.metadata.confidence.value,
"importance": memory_chunk.metadata.importance.value,
"relevance": getattr(memory_chunk.metadata, "relevance_score", 0.5),
"source": memory_chunk.metadata.source,
"structure": structure_display,
}
)
# 构建瞬时记忆字符串
if running_memories:
top_memory = running_memories[:1]
if top_memory:
instant_memory = top_memory[0].get("content", "")
logger.info(
f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆"
)
except Exception as e:
logger.warning(f"增强记忆系统检索失败: {e}")
running_memories = []
instant_memory = ""
for memory in memories:
# 使用格式化工具生成完整的主谓宾描述
content = format_memory_for_prompt(memory, include_metadata=False)
# 获取记忆类型
mem_type = memory.memory_type.value if memory.memory_type else "未知"
if content:
all_memories.append({
"content": content,
"memory_type": mem_type,
"importance": memory.importance,
"relevance": 0.7,
"source": "memory_graph",
})
logger.debug(f"[记忆构建] 格式化记忆: [{mem_type}] {content[:50]}...")
else:
logger.debug("[记忆图] 未找到相关记忆")
except Exception as e:
logger.debug(f"[记忆图] 检索失败: {e}")
all_memories = []
# 构建记忆字符串,使用方括号格式
memory_str = ""
has_any_memory = False
# 添加长期记忆(来自增强记忆系统)
if running_memories:
# 添加长期记忆(来自记忆系统)
if all_memories:
# 使用方括号格式
memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
# 按相关度排序,并记录相关度信息用于调试
sorted_memories = sorted(running_memories, key=lambda x: x.get("relevance", 0.0), reverse=True)
sorted_memories = sorted(all_memories, key=lambda x: x.get("relevance", 0.0), reverse=True)
# 调试相关度信息
relevance_info = [(m.get("memory_type", "unknown"), m.get("relevance", 0.0)) for m in sorted_memories]
@@ -738,8 +664,13 @@ class DefaultReplyer:
logger.debug(f"[记忆构建] 空记忆详情: {running_memory}")
continue
# 使用全局记忆类型映射表
chinese_type = get_memory_type_chinese_label(memory_type)
# 使用记忆图的类型映射(优先)或全局映射
try:
from src.memory_graph.utils.memory_formatter import get_memory_type_label
chinese_type = get_memory_type_label(memory_type)
except ImportError:
# 回退到全局映射
chinese_type = get_memory_type_chinese_label(memory_type)
# 提取纯净内容(如果包含旧格式的元数据)
clean_content = content
@@ -753,13 +684,7 @@ class DefaultReplyer:
has_any_memory = True
logger.debug(f"[记忆构建] 成功构建记忆字符串,包含 {len(memory_parts) - 2} 条记忆")
# 添加瞬时记忆
if instant_memory:
if not any(rm["content"] == instant_memory for rm in running_memories):
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
memory_str += f"- 最相关记忆:{instant_memory}\n"
has_any_memory = True
# 瞬时记忆由另一套系统处理,这里不再添加
# 只有当完全没有任何记忆时才返回空字符串
return memory_str if has_any_memory else ""
@@ -780,32 +705,46 @@ class DefaultReplyer:
return ""
try:
# 使用工具执行器获取信息
# 首先获取当前的历史记录(在执行新工具调用之前)
tool_history_str = self.tool_executor.history_manager.format_for_prompt(max_records=3, include_results=True)
# 然后执行工具调用
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
sender=sender, target_message=target, chat_history=chat_history, return_details=False
)
info_parts = []
# 显示之前的工具调用历史(不包括当前这次调用)
if tool_history_str:
info_parts.append(tool_history_str)
# 显示当前工具调用的结果(简要信息)
if tool_results:
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
current_results_parts = ["## 🔧 刚获取的工具信息"]
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
# 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}")
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
info_parts.append("\n".join(current_results_parts))
logger.info(f"获取到 {len(tool_results)} 个工具结果")
return tool_info_str
else:
logger.debug("未获取到任何工具结果")
# 如果没有任何信息,返回空字符串
if not info_parts:
logger.debug("未获取到任何工具结果或历史记录")
return ""
return "\n\n".join(info_parts)
except Exception as e:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt
@@ -1145,29 +1084,6 @@ class DefaultReplyer:
return read_history_prompt, unread_history_prompt
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
"""为消息获取兴趣度评分(使用预计算的兴趣值)"""
interest_scores = {}
try:
# 直接使用消息中的预计算兴趣值
for msg_dict in messages:
message_id = msg_dict.get("message_id", "")
interest_value = msg_dict.get("interest_value")
if interest_value is not None:
interest_scores[message_id] = float(interest_value)
logger.debug(f"使用预计算兴趣度 - 消息 {message_id}: {interest_value:.3f}")
else:
interest_scores[message_id] = 0.5 # 默认值
logger.debug(f"消息 {message_id} 无预计算兴趣值,使用默认值 0.5")
except Exception as e:
logger.warning(f"处理预计算兴趣值失败: {e}")
return interest_scores
async def build_prompt_reply_context(
self,
reply_to: str,
@@ -1976,14 +1892,22 @@ class DefaultReplyer:
return f"你与{sender}是普通朋友关系。"
# 已废弃:旧的自动记忆存储逻辑
# 新的记忆图系统通过LLM工具(CreateMemoryTool)主动创建记忆,而非自动存储
async def _store_chat_memory_async(self, reply_to: str, reply_message: DatabaseMessages | dict[str, Any] | None = None):
"""
异步存储聊天记忆从build_memory_block迁移而来
[已废弃] 异步存储聊天记忆从build_memory_block迁移而来
此函数已被记忆图系统的工具调用方式替代。
记忆现在由LLM在对话过程中通过CreateMemoryTool主动创建。
Args:
reply_to: 回复对象
reply_message: 回复的原始消息
"""
return # 已禁用,保留函数签名以防其他地方有引用
# 以下代码已废弃,不再执行
try:
if not global_config.memory.enable_memory:
return
@@ -2121,23 +2045,9 @@ class DefaultReplyer:
show_actions=True,
)
# 异步存储聊天历史(完全非阻塞)
memory_system = get_memory_system()
task = asyncio.create_task(
memory_system.process_conversation_memory(
context={
"conversation_text": chat_history,
"user_id": memory_user_id,
"scope_id": stream.stream_id,
**memory_context,
}
)
)
# 将任务添加到全局集合以防止被垃圾回收
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.debug(f"已启动记忆存储任务,用户: {memory_user_display or memory_user_id}")
# 旧记忆系统的自动存储已禁用
# 新记忆系统通过 LLM 工具调用create_memory来创建记忆
logger.debug(f"记忆创建通过 LLM 工具调用进行,用户: {memory_user_display or memory_user_id}")
except asyncio.CancelledError:
logger.debug("记忆存储任务被取消")

View File

@@ -44,8 +44,8 @@ def replace_user_references_sync(
if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
return f"{global_config.bot.nickname}(你)"
# 同步函数中无法使用异步的 get_value直接返回 user_id
# 建议调用方使用 replace_user_references_async 以获取完整的用户名
@@ -60,8 +60,8 @@ def replace_user_references_sync(
aaa = match[1]
bbb = match[2]
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
reply_person_name = f"{global_config.bot.nickname}(你)"
else:
reply_person_name = name_resolver(platform, bbb) or aaa
@@ -81,8 +81,8 @@ def replace_user_references_sync(
aaa = m.group(1)
bbb = m.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
at_person_name = f"{global_config.bot.nickname}(你)"
else:
at_person_name = name_resolver(platform, bbb) or aaa
@@ -120,8 +120,8 @@ async def replace_user_references_async(
person_info_manager = get_person_info_manager()
async def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id)
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
@@ -135,8 +135,8 @@ async def replace_user_references_async(
aaa = match.group(1)
bbb = match.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
reply_person_name = f"{global_config.bot.nickname}(你)"
else:
reply_person_name = await name_resolver(platform, bbb) or aaa
@@ -156,8 +156,8 @@ async def replace_user_references_async(
aaa = m.group(1)
bbb = m.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
at_person_name = f"{global_config.bot.nickname}(你)"
else:
at_person_name = await name_resolver(platform, bbb) or aaa
@@ -638,13 +638,14 @@ async def _build_readable_messages_internal(
if not all([platform, user_id, timestamp is not None]):
continue
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager()
# 根据 replace_bot_name 参数决定是否替换机器人名称
person_name: str
if replace_bot_name and user_id == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
person_name = f"{global_config.bot.nickname}(你)"
else:
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
@@ -656,8 +657,8 @@ async def _build_readable_messages_internal(
else:
person_name = "某人"
# 在用户名后面添加 QQ 号, 但机器人本体不用
if user_id != global_config.bot.qq_account:
# 在用户名后面添加 QQ 号, 但机器人本体不用包括SELF标记
if user_id != global_config.bot.qq_account and user_id != "SELF":
person_name = f"{person_name}({user_id})"
# 使用独立函数处理用户引用格式

View File

@@ -398,6 +398,9 @@ class Prompt:
"""
start_time = time.time()
# 初始化预构建参数字典
pre_built_params = {}
try:
# --- 步骤 1: 准备构建任务 ---
tasks = []
@@ -406,7 +409,6 @@ class Prompt:
# --- 步骤 1.1: 优先使用预构建的参数 ---
# 如果参数对象中已经包含了某些block说明它们是外部预构建的
# 我们将它们存起来,并跳过对应的实时构建任务。
pre_built_params = {}
if self.parameters.expression_habits_block:
pre_built_params["expression_habits_block"] = self.parameters.expression_habits_block
if self.parameters.relation_info_block:
@@ -428,11 +430,9 @@ class Prompt:
tasks.append(self._build_expression_habits())
task_names.append("expression_habits")
# 记忆块构建非常耗时,强烈建议预构建。如果没有预构建,这里会运行一个快速的后备版本。
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
logger.debug("memory_block未预构建执行快速构建作为后备方案")
tasks.append(self._build_memory_block_fast())
task_names.append("memory_block")
# 记忆块构建已移至 default_generator.py 的 build_memory_block 方法
# 使用新的记忆图系统,不再在 prompt.py 中构建记忆
# 如果需要记忆,必须通过 pre_built_params 传入
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info())
@@ -637,146 +637,6 @@ class Prompt:
logger.error(f"构建表达习惯失败: {e}")
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> dict[str, Any]:
"""构建与当前对话相关的记忆上下文块(完整版)."""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
try:
from src.chat.memory_system.enhanced_memory_activator import enhanced_memory_activator
# 准备用于记忆激活的聊天历史
chat_history = ""
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-20:]
chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 并行查询长期记忆和即时记忆以提高性能
import asyncio
memory_tasks = [
enhanced_memory_activator.activate_memory_with_chat_history(
target_message=self.parameters.target, chat_history_prompt=chat_history
),
enhanced_memory_activator.get_instant_memory(
target_message=self.parameters.target, chat_id=self.parameters.chat_id
),
]
try:
# 使用 `return_exceptions=True` 来防止一个任务的失败导致所有任务失败
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
# 单独处理每个任务的结果,如果是异常则记录并使用默认值
if isinstance(running_memories, BaseException):
logger.warning(f"长期记忆查询失败: {running_memories}")
running_memories = []
if isinstance(instant_memory, BaseException):
logger.warning(f"即时记忆查询失败: {instant_memory}")
instant_memory = None
except asyncio.TimeoutError:
logger.warning("记忆查询超时,使用部分结果")
running_memories = []
instant_memory = None
# 将检索到的记忆格式化为提示词
if running_memories:
try:
from src.chat.memory_system.memory_formatter import format_memories_bracket_style
# 将原始记忆数据转换为格式化器所需的标准格式
formatted_memories = []
for memory in running_memories:
content = memory.get("content", "")
display_text = content
# 清理内容,移除元数据括号
if "(类型:" in content and "" in content:
display_text = content.split("(类型:")[0].strip()
# 映射记忆主题到标准类型
topic = memory.get("topic", "personal_fact")
memory_type_mapping = {
"relationship": "personal_fact",
"opinion": "opinion",
"personal_fact": "personal_fact",
"preference": "preference",
"event": "event",
}
mapped_type = memory_type_mapping.get(topic, "personal_fact")
formatted_memories.append(
{
"display": display_text,
"memory_type": mapped_type,
"metadata": {
"confidence": memory.get("confidence", "未知"),
"importance": memory.get("importance", "一般"),
"timestamp": memory.get("timestamp", ""),
"source": memory.get("source", "unknown"),
"relevance_score": memory.get("relevance_score", 0.0),
},
}
)
# 使用指定的风格进行格式化
memory_block = format_memories_bracket_style(
formatted_memories, query_context=self.parameters.target
)
except Exception as e:
# 如果格式化失败,提供一个简化的、健壮的备用格式
logger.warning(f"记忆格式化失败,使用简化格式: {e}")
memory_parts = ["## 相关记忆回顾", ""]
for memory in running_memories:
content = memory.get("content", "")
if "(类型:" in content and "" in content:
clean_content = content.split("(类型:")[0].strip()
memory_parts.append(f"- {clean_content}")
else:
memory_parts.append(f"- {content}")
memory_block = "\n".join(memory_parts)
else:
memory_block = ""
# 将即时记忆附加到记忆块的末尾
if instant_memory:
if memory_block:
memory_block += f"\n- 最相关记忆:{instant_memory}"
else:
memory_block = f"- 最相关记忆:{instant_memory}"
return {"memory_block": memory_block}
except Exception as e:
logger.error(f"构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_memory_block_fast(self) -> dict[str, Any]:
"""快速构建记忆块(简化版),作为未预构建时的后备方案."""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
try:
from src.chat.memory_system.enhanced_memory_activator import enhanced_memory_activator
# 这个快速版本只查询最高优先级的“即时记忆”,速度更快
instant_memory = await enhanced_memory_activator.get_instant_memory(
target_message=self.parameters.target, chat_id=self.parameters.chat_id
)
if instant_memory:
memory_block = f"- 相关记忆:{instant_memory}"
else:
memory_block = ""
return {"memory_block": memory_block}
except Exception as e:
logger.warning(f"快速构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_relation_info(self) -> dict[str, Any]:
"""构建与对话目标相关的关系信息."""
try:

View File

@@ -57,8 +57,16 @@ class CacheManager:
# 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
# 工具调用统计
self.tool_stats = {
"total_tool_calls": 0,
"cache_hits_by_tool": {}, # 按工具名称统计缓存命中
"execution_times_by_tool": {}, # 按工具名称统计执行时间
"most_used_tools": {}, # 最常用的工具
}
self._initialized = True
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB) + 工具统计")
@staticmethod
def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
@@ -363,58 +371,205 @@ class CacheManager:
def get_health_stats(self) -> dict[str, Any]:
"""获取缓存健康统计信息"""
from src.common.memory_utils import format_size
# 简化的健康统计,不包含内存监控(因为相关属性未定义)
return {
"l1_count": len(self.l1_kv_cache),
"l1_memory": self.l1_current_memory,
"l1_memory_formatted": format_size(self.l1_current_memory),
"l1_max_memory": self.l1_max_memory,
"l1_memory_usage_percent": round((self.l1_current_memory / self.l1_max_memory) * 100, 2),
"l1_max_size": self.l1_max_size,
"l1_size_usage_percent": round((len(self.l1_kv_cache) / self.l1_max_size) * 100, 2),
"average_item_size": self.l1_current_memory // len(self.l1_kv_cache) if self.l1_kv_cache else 0,
"average_item_size_formatted": format_size(self.l1_current_memory // len(self.l1_kv_cache)) if self.l1_kv_cache else "0 B",
"largest_item_size": max(self.l1_size_map.values()) if self.l1_size_map else 0,
"largest_item_size_formatted": format_size(max(self.l1_size_map.values())) if self.l1_size_map else "0 B",
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0,
"tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
"cache_hits": sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
"cache_misses": sum(data.get("misses", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
}
}
def check_health(self) -> tuple[bool, list[str]]:
"""检查缓存健康状态
Returns:
(is_healthy, warnings) - 是否健康,警告列表
"""
warnings = []
# 检查内存使用
memory_usage = (self.l1_current_memory / self.l1_max_memory) * 100
if memory_usage > 90:
warnings.append(f"⚠️ L1缓存内存使用率过高: {memory_usage:.1f}%")
elif memory_usage > 75:
warnings.append(f"⚡ L1缓存内存使用率较高: {memory_usage:.1f}%")
# 检查条目数
size_usage = (len(self.l1_kv_cache) / self.l1_max_size) * 100
if size_usage > 90:
warnings.append(f"⚠️ L1缓存条目数过多: {size_usage:.1f}%")
# 检查平均条目大小
if self.l1_kv_cache:
avg_size = self.l1_current_memory // len(self.l1_kv_cache)
if avg_size > 100 * 1024: # >100KB
from src.common.memory_utils import format_size
warnings.append(f"⚡ 平均缓存条目过大: {format_size(avg_size)}")
# 检查最大单条目
if self.l1_size_map:
max_size = max(self.l1_size_map.values())
if max_size > 500 * 1024: # >500KB
from src.common.memory_utils import format_size
warnings.append(f"⚠️ 发现超大缓存条目: {format_size(max_size)}")
# 检查L1缓存大小
l1_size = len(self.l1_kv_cache)
if l1_size > 1000: # 如果超过1000个条目
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0
if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")
# 检查工具统计健康
total_calls = self.tool_stats.get("total_tool_calls", 0)
if total_calls > 0:
total_hits = sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values())
cache_hit_rate = (total_hits / total_calls) * 100
if cache_hit_rate < 50: # 缓存命中率低于50%
warnings.append(f"⚡ 整体缓存命中率较低: {cache_hit_rate:.1f}%")
return len(warnings) == 0, warnings
async def get_tool_result_with_stats(self,
tool_name: str,
function_args: dict[str, Any],
tool_file_path: str | Path,
semantic_query: str | None = None) -> tuple[Any | None, bool]:
"""获取工具结果并更新统计信息
Args:
tool_name: 工具名称
function_args: 函数参数
tool_file_path: 工具文件路径
semantic_query: 语义查询字符串
Returns:
Tuple[结果, 是否命中缓存]
"""
# 更新总调用次数
self.tool_stats["total_tool_calls"] += 1
# 更新工具使用统计
if tool_name not in self.tool_stats["most_used_tools"]:
self.tool_stats["most_used_tools"][tool_name] = 0
self.tool_stats["most_used_tools"][tool_name] += 1
# 尝试获取缓存
result = await self.get(tool_name, function_args, tool_file_path, semantic_query)
# 更新缓存命中统计
if tool_name not in self.tool_stats["cache_hits_by_tool"]:
self.tool_stats["cache_hits_by_tool"][tool_name] = {"hits": 0, "misses": 0}
if result is not None:
self.tool_stats["cache_hits_by_tool"][tool_name]["hits"] += 1
logger.info(f"工具缓存命中: {tool_name}")
return result, True
else:
self.tool_stats["cache_hits_by_tool"][tool_name]["misses"] += 1
return None, False
async def set_tool_result_with_stats(self,
tool_name: str,
function_args: dict[str, Any],
tool_file_path: str | Path,
data: Any,
execution_time: float | None = None,
ttl: int | None = None,
semantic_query: str | None = None):
"""存储工具结果并更新统计信息
Args:
tool_name: 工具名称
function_args: 函数参数
tool_file_path: 工具文件路径
data: 结果数据
execution_time: 执行时间
ttl: 缓存TTL
semantic_query: 语义查询字符串
"""
# 更新执行时间统计
if execution_time is not None:
if tool_name not in self.tool_stats["execution_times_by_tool"]:
self.tool_stats["execution_times_by_tool"][tool_name] = []
self.tool_stats["execution_times_by_tool"][tool_name].append(execution_time)
# 只保留最近100次的执行时间记录
if len(self.tool_stats["execution_times_by_tool"][tool_name]) > 100:
self.tool_stats["execution_times_by_tool"][tool_name] = \
self.tool_stats["execution_times_by_tool"][tool_name][-100:]
# 存储到缓存
await self.set(tool_name, function_args, tool_file_path, data, ttl, semantic_query)
def get_tool_performance_stats(self) -> dict[str, Any]:
"""获取工具性能统计信息
Returns:
统计信息字典
"""
stats = self.tool_stats.copy()
# 计算平均执行时间
avg_times = {}
for tool_name, times in stats["execution_times_by_tool"].items():
if times:
avg_times[tool_name] = {
"average": sum(times) / len(times),
"min": min(times),
"max": max(times),
"count": len(times),
}
# 计算缓存命中率
cache_hit_rates = {}
for tool_name, hit_data in stats["cache_hits_by_tool"].items():
total = hit_data["hits"] + hit_data["misses"]
if total > 0:
cache_hit_rates[tool_name] = {
"hit_rate": (hit_data["hits"] / total) * 100,
"hits": hit_data["hits"],
"misses": hit_data["misses"],
"total": total,
}
# 按使用频率排序工具
most_used = sorted(stats["most_used_tools"].items(), key=lambda x: x[1], reverse=True)
return {
"total_tool_calls": stats["total_tool_calls"],
"average_execution_times": avg_times,
"cache_hit_rates": cache_hit_rates,
"most_used_tools": most_used[:10], # 前10个最常用工具
"cache_health": self.get_health_stats(),
}
def get_tool_recommendations(self) -> dict[str, Any]:
"""获取工具优化建议
Returns:
优化建议字典
"""
recommendations = []
# 分析缓存命中率低的工具
cache_hit_rates = {}
for tool_name, hit_data in self.tool_stats["cache_hits_by_tool"].items():
total = hit_data["hits"] + hit_data["misses"]
if total >= 5: # 至少调用5次才分析
hit_rate = (hit_data["hits"] / total) * 100
cache_hit_rates[tool_name] = hit_rate
if hit_rate < 30: # 缓存命中率低于30%
recommendations.append({
"tool": tool_name,
"type": "low_cache_hit_rate",
"message": f"工具 {tool_name} 的缓存命中率仅为 {hit_rate:.1f}%,建议检查缓存配置或参数变化频率",
"severity": "medium" if hit_rate > 10 else "high",
})
# 分析执行时间长的工具
for tool_name, times in self.tool_stats["execution_times_by_tool"].items():
if len(times) >= 3: # 至少3次执行才分析
avg_time = sum(times) / len(times)
if avg_time > 5.0: # 平均执行时间超过5秒
recommendations.append({
"tool": tool_name,
"type": "slow_execution",
"message": f"工具 {tool_name} 平均执行时间较长 ({avg_time:.2f}s),建议优化算法或增加缓存",
"severity": "medium" if avg_time < 10.0 else "high",
})
return {
"recommendations": recommendations,
"summary": {
"total_issues": len(recommendations),
"high_priority": len([r for r in recommendations if r["severity"] == "high"]),
"medium_priority": len([r for r in recommendations if r["severity"] == "medium"]),
}
}
# 全局实例
tool_cache = CacheManager()

View File

@@ -2,6 +2,7 @@ import os
import shutil
import sys
from datetime import datetime
from typing import Optional
import tomlkit
from pydantic import Field
@@ -380,7 +381,7 @@ class Config(ValidatedConfigBase):
notice: NoticeConfig = Field(..., description="Notice消息配置")
emoji: EmojiConfig = Field(..., description="表情配置")
expression: ExpressionConfig = Field(..., description="表达配置")
memory: MemoryConfig = Field(..., description="记忆配置")
memory: Optional[MemoryConfig] = Field(default=None, description="记忆配置")
mood: MoodConfig = Field(..., description="情绪配置")
reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置")
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")

View File

@@ -120,6 +120,10 @@ class ChatConfig(ValidatedConfigBase):
timestamp_display_mode: Literal["normal", "normal_no_YMD", "relative"] = Field(
default="normal_no_YMD", description="时间戳显示模式"
)
# 消息缓存系统配置
enable_message_cache: bool = Field(
default=True, description="是否启用消息缓存系统(启用后,处理中收到的消息会被缓存,处理完成后统一刷新到未读列表)"
)
# 消息打断系统配置 - 线性概率模型
interruption_enabled: bool = Field(default=True, description="是否启用消息打断系统")
allow_reply_interruption: bool = Field(
@@ -181,6 +185,10 @@ class ExpressionConfig(ValidatedConfigBase):
default="classic",
description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测"
)
expiration_days: int = Field(
default=90,
description="表达方式过期天数,超过此天数未激活的表达方式将被清理"
)
rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
@staticmethod
@@ -393,6 +401,66 @@ class MemoryConfig(ValidatedConfigBase):
memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡")
memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流")
memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列")
# === 记忆图系统配置 (Memory Graph System) ===
# 新一代记忆系统的配置项
enable: bool = Field(default=True, description="启用记忆图系统")
data_dir: str = Field(default="data/memory_graph", description="记忆数据存储目录")
# 向量存储配置
vector_collection_name: str = Field(default="memory_nodes", description="向量集合名称")
vector_db_path: str = Field(default="data/memory_graph/chroma_db", description="向量数据库路径")
# 检索配置
search_top_k: int = Field(default=10, description="默认检索返回数量")
search_min_importance: float = Field(default=0.3, description="最小重要性阈值")
search_similarity_threshold: float = Field(default=0.5, description="向量相似度阈值")
search_max_expand_depth: int = Field(default=2, description="检索时图扩展深度0-3")
search_expand_semantic_threshold: float = Field(default=0.3, description="图扩展时语义相似度阈值建议0.3-0.5,过低可能引入无关记忆,过高无法扩展)")
enable_query_optimization: bool = Field(default=True, description="启用查询优化")
# 检索权重配置 (记忆图系统)
search_vector_weight: float = Field(default=0.4, description="向量相似度权重")
search_graph_distance_weight: float = Field(default=0.2, description="图距离权重")
search_importance_weight: float = Field(default=0.2, description="重要性权重")
search_recency_weight: float = Field(default=0.2, description="时效性权重")
# 记忆整合配置
consolidation_enabled: bool = Field(default=False, description="是否启用记忆整合")
consolidation_interval_hours: float = Field(default=2.0, description="整合任务执行间隔(小时)")
consolidation_deduplication_threshold: float = Field(default=0.93, description="相似记忆去重阈值")
consolidation_time_window_hours: float = Field(default=2.0, description="整合时间窗口(小时)- 统一用于去重和关联")
consolidation_max_batch_size: int = Field(default=30, description="单次最多处理的记忆数量")
# 记忆关联配置(整合功能的子模块)
consolidation_linking_enabled: bool = Field(default=True, description="是否启用记忆关联建立")
consolidation_linking_max_candidates: int = Field(default=10, description="每个记忆最多关联的候选数")
consolidation_linking_max_memories: int = Field(default=20, description="单次最多处理的记忆总数")
consolidation_linking_min_importance: float = Field(default=0.5, description="最低重要性阈值")
consolidation_linking_pre_filter_threshold: float = Field(default=0.7, description="向量相似度预筛选阈值")
consolidation_linking_max_pairs_for_llm: int = Field(default=5, description="最多发送给LLM分析的候选对数")
consolidation_linking_min_confidence: float = Field(default=0.7, description="LLM分析最低置信度阈值")
consolidation_linking_llm_temperature: float = Field(default=0.2, description="LLM分析温度参数")
consolidation_linking_llm_max_tokens: int = Field(default=1500, description="LLM分析最大输出长度")
# 遗忘配置 (记忆图系统)
forgetting_enabled: bool = Field(default=True, description="是否启用自动遗忘")
forgetting_activation_threshold: float = Field(default=0.1, description="激活度阈值")
forgetting_min_importance: float = Field(default=0.8, description="最小保护重要性")
# 激活配置
activation_decay_rate: float = Field(default=0.9, description="激活度衰减率")
activation_propagation_strength: float = Field(default=0.5, description="激活传播强度")
activation_propagation_depth: int = Field(default=2, description="激活传播深度")
# 性能配置
max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数")
max_related_memories: int = Field(default=5, description="相关记忆最大数量")
# 节点去重合并配置
node_merger_similarity_threshold: float = Field(default=0.85, description="节点去重相似度阈值")
node_merger_context_match_required: bool = Field(default=True, description="节点合并是否要求上下文匹配")
node_merger_merge_batch_size: int = Field(default=50, description="节点合并批量处理大小")
class MoodConfig(ValidatedConfigBase):

View File

@@ -9,7 +9,7 @@ class ToolParamType(Enum):
STRING = "string" # 字符串
INTEGER = "integer" # 整型
FLOAT = "number" # 浮点型
BOOLEAN = "bool" # 布尔型
BOOLEAN = "boolean" # 布尔型
class ToolParam:

View File

@@ -13,7 +13,6 @@ from maim_message import MessageServer
from rich.traceback import install
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.memory_system.memory_manager import memory_manager
from src.chat.message_receive.bot import chat_bot
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
@@ -76,8 +75,6 @@ class MainSystem:
"""主系统类,负责协调所有组件"""
def __init__(self) -> None:
# 使用增强记忆系统
self.memory_manager = memory_manager
self.individuality: Individuality = get_individuality()
# 使用消息API替代直接的FastAPI实例
@@ -250,12 +247,6 @@ class MainSystem:
logger.error(f"准备停止消息重组器时出错: {e}")
# 停止增强记忆系统
try:
if global_config.memory.enable_memory:
cleanup_tasks.append(("增强记忆系统", self.memory_manager.shutdown()))
except Exception as e:
logger.error(f"准备停止增强记忆系统时出错: {e}")
# 停止统一调度器
try:
from src.schedule.unified_scheduler import shutdown_scheduler
@@ -468,13 +459,12 @@ MoFox_Bot(第三方修改版)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# 初始化增强记忆系统
if global_config.memory.enable_memory:
from src.chat.memory_system.memory_system import initialize_memory_system
await self._safe_init("增强记忆系统", initialize_memory_system)()
await self._safe_init("记忆管理器", self.memory_manager.initialize)()
else:
logger.info("记忆系统已禁用,跳过初始化")
# 初始化记忆系统
try:
from src.memory_graph.manager_singleton import initialize_memory_manager
await self._safe_init("记忆系统", initialize_memory_manager)()
except Exception as e:
logger.error(f"记忆图系统初始化失败: {e}")
# 初始化消息兴趣值计算组件
await self._initialize_interest_calculator()

View File

@@ -0,0 +1,29 @@
"""
记忆图系统 (Memory Graph System)
基于知识图谱 + 语义向量的混合记忆架构
"""
from src.memory_graph.manager import MemoryManager
from src.memory_graph.models import (
EdgeType,
Memory,
MemoryEdge,
MemoryNode,
MemoryStatus,
MemoryType,
NodeType,
)
__all__ = [
"EdgeType",
"Memory",
"MemoryEdge",
"MemoryManager",
"MemoryNode",
"MemoryStatus",
"MemoryType",
"NodeType",
]
__version__ = "0.1.0"

View File

@@ -0,0 +1,9 @@
"""
核心模块
"""
from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.core.node_merger import NodeMerger
__all__ = ["MemoryBuilder", "MemoryExtractor", "NodeMerger"]

View File

@@ -0,0 +1,548 @@
"""
记忆构建器:自动构造记忆子图
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
import numpy as np
from src.common.logger import get_logger
from src.memory_graph.models import (
EdgeType,
Memory,
MemoryEdge,
MemoryNode,
MemoryStatus,
NodeType,
)
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
logger = get_logger(__name__)
class MemoryBuilder:
"""
记忆构建器
负责:
1. 根据提取的元素自动构造记忆子图
2. 创建节点和边的完整结构
3. 生成语义嵌入向量
4. 检查并复用已存在的相似节点
5. 构造符合层级结构的记忆对象
"""
def __init__(
self,
vector_store: VectorStore,
graph_store: GraphStore,
embedding_generator: Any | None = None,
):
"""
初始化记忆构建器
Args:
vector_store: 向量存储
graph_store: 图存储
embedding_generator: 嵌入向量生成器(可选)
"""
self.vector_store = vector_store
self.graph_store = graph_store
self.embedding_generator = embedding_generator
async def build_memory(self, extracted_params: dict[str, Any]) -> Memory:
"""
构建完整的记忆对象
Args:
extracted_params: 提取器返回的标准化参数
Returns:
Memory 对象(状态为 STAGED
"""
try:
nodes = []
edges = []
memory_id = self._generate_memory_id()
# 1. 创建主体节点 (SUBJECT)
subject_node = await self._create_or_reuse_node(
content=extracted_params["subject"],
node_type=NodeType.SUBJECT,
memory_id=memory_id,
)
nodes.append(subject_node)
# 2. 创建主题节点 (TOPIC) - 需要嵌入向量
topic_node = await self._create_topic_node(
content=extracted_params["topic"], memory_id=memory_id
)
nodes.append(topic_node)
# 3. 连接主体 -> 记忆类型 -> 主题
memory_type_edge = MemoryEdge(
id=self._generate_edge_id(),
source_id=subject_node.id,
target_id=topic_node.id,
relation=extracted_params["memory_type"].value,
edge_type=EdgeType.MEMORY_TYPE,
importance=extracted_params["importance"],
metadata={"memory_id": memory_id},
)
edges.append(memory_type_edge)
# 4. 如果有客体,创建客体节点并连接
if extracted_params.get("object"):
object_node = await self._create_object_node(
content=extracted_params["object"], memory_id=memory_id
)
nodes.append(object_node)
# 连接主题 -> 核心关系 -> 客体
core_relation_edge = MemoryEdge(
id=self._generate_edge_id(),
source_id=topic_node.id,
target_id=object_node.id,
relation="核心关系", # 默认关系名
edge_type=EdgeType.CORE_RELATION,
importance=extracted_params["importance"],
metadata={"memory_id": memory_id},
)
edges.append(core_relation_edge)
# 5. 处理属性
if extracted_params.get("attributes"):
attr_nodes, attr_edges = await self._process_attributes(
attributes=extracted_params["attributes"],
parent_id=topic_node.id,
memory_id=memory_id,
importance=extracted_params["importance"],
)
nodes.extend(attr_nodes)
edges.extend(attr_edges)
# 6. 构建 Memory 对象
memory = Memory(
id=memory_id,
subject_id=subject_node.id,
memory_type=extracted_params["memory_type"],
nodes=nodes,
edges=edges,
importance=extracted_params["importance"],
created_at=extracted_params["timestamp"],
last_accessed=extracted_params["timestamp"],
access_count=0,
status=MemoryStatus.STAGED,
metadata={
"subject": extracted_params["subject"],
"topic": extracted_params["topic"],
},
)
logger.info(
f"构建记忆成功: {memory_id} - {len(nodes)} 节点, {len(edges)}"
)
return memory
except Exception as e:
logger.error(f"记忆构建失败: {e}", exc_info=True)
raise RuntimeError(f"记忆构建失败: {e}")
async def _create_or_reuse_node(
self, content: str, node_type: NodeType, memory_id: str
) -> MemoryNode:
"""
创建新节点或复用已存在的相似节点
对于主体(SUBJECT)和属性(ATTRIBUTE),检查是否已存在相同内容的节点
Args:
content: 节点内容
node_type: 节点类型
memory_id: 所属记忆ID
Returns:
MemoryNode 对象
"""
# 对于主体,尝试查找已存在的节点
if node_type == NodeType.SUBJECT:
existing = await self._find_existing_node(content, node_type)
if existing:
logger.debug(f"复用已存在的主体节点: {existing.id}")
return existing
# 创建新节点
node = MemoryNode(
id=self._generate_node_id(),
content=content,
node_type=node_type,
embedding=None, # 主体和属性不需要嵌入
metadata={"memory_ids": [memory_id]},
)
return node
async def _create_topic_node(self, content: str, memory_id: str) -> MemoryNode:
"""
创建主题节点(需要生成嵌入向量)
Args:
content: 节点内容
memory_id: 所属记忆ID
Returns:
MemoryNode 对象
"""
# 生成嵌入向量
embedding = await self._generate_embedding(content)
# 检查是否存在高度相似的节点
existing = await self._find_similar_topic(content, embedding)
if existing:
logger.debug(f"复用相似的主题节点: {existing.id}")
# 添加当前记忆ID到元数据
if "memory_ids" not in existing.metadata:
existing.metadata["memory_ids"] = []
existing.metadata["memory_ids"].append(memory_id)
return existing
# 创建新节点
node = MemoryNode(
id=self._generate_node_id(),
content=content,
node_type=NodeType.TOPIC,
embedding=embedding,
metadata={"memory_ids": [memory_id]},
)
return node
async def _create_object_node(self, content: str, memory_id: str) -> MemoryNode:
"""
创建客体节点(需要生成嵌入向量)
Args:
content: 节点内容
memory_id: 所属记忆ID
Returns:
MemoryNode 对象
"""
# 生成嵌入向量
embedding = await self._generate_embedding(content)
# 检查是否存在高度相似的节点
existing = await self._find_similar_object(content, embedding)
if existing:
logger.debug(f"复用相似的客体节点: {existing.id}")
if "memory_ids" not in existing.metadata:
existing.metadata["memory_ids"] = []
existing.metadata["memory_ids"].append(memory_id)
return existing
# 创建新节点
node = MemoryNode(
id=self._generate_node_id(),
content=content,
node_type=NodeType.OBJECT,
embedding=embedding,
metadata={"memory_ids": [memory_id]},
)
return node
async def _process_attributes(
self,
attributes: dict[str, Any],
parent_id: str,
memory_id: str,
importance: float,
) -> tuple[list[MemoryNode], list[MemoryEdge]]:
"""
处理属性,构建属性子图
结构TOPIC -> ATTRIBUTE -> VALUE
Args:
attributes: 属性字典
parent_id: 父节点ID通常是TOPIC
memory_id: 所属记忆ID
importance: 重要性
Returns:
(属性节点列表, 属性边列表)
"""
nodes = []
edges = []
for attr_name, attr_value in attributes.items():
# 创建属性节点
attr_node = await self._create_or_reuse_node(
content=attr_name, node_type=NodeType.ATTRIBUTE, memory_id=memory_id
)
nodes.append(attr_node)
# 连接父节点 -> 属性
attr_edge = MemoryEdge(
id=self._generate_edge_id(),
source_id=parent_id,
target_id=attr_node.id,
relation="属性",
edge_type=EdgeType.ATTRIBUTE,
importance=importance * 0.8, # 属性的重要性略低
metadata={"memory_id": memory_id},
)
edges.append(attr_edge)
# 创建值节点
value_node = await self._create_or_reuse_node(
content=str(attr_value), node_type=NodeType.VALUE, memory_id=memory_id
)
nodes.append(value_node)
# 连接属性 -> 值
value_edge = MemoryEdge(
id=self._generate_edge_id(),
source_id=attr_node.id,
target_id=value_node.id,
relation="",
edge_type=EdgeType.ATTRIBUTE,
importance=importance * 0.8,
metadata={"memory_id": memory_id},
)
edges.append(value_edge)
return nodes, edges
async def _generate_embedding(self, text: str) -> np.ndarray:
"""
生成文本的嵌入向量
Args:
text: 文本内容
Returns:
嵌入向量
"""
if self.embedding_generator:
try:
embedding = await self.embedding_generator.generate(text)
return embedding
except Exception as e:
logger.warning(f"嵌入生成失败,使用随机向量: {e}")
# 回退:生成随机向量(仅用于测试)
return np.random.rand(384).astype(np.float32)
async def _find_existing_node(
self, content: str, node_type: NodeType
) -> MemoryNode | None:
"""
查找已存在的完全匹配节点(用于主体和属性)
Args:
content: 节点内容
node_type: 节点类型
Returns:
已存在的节点,如果没有则返回 None
"""
# 在图存储中查找
for node_id in self.graph_store.graph.nodes():
node_data = self.graph_store.graph.nodes[node_id]
if node_data.get("content") == content and node_data.get("node_type") == node_type.value:
# 重建 MemoryNode 对象
return MemoryNode(
id=node_id,
content=node_data["content"],
node_type=NodeType(node_data["node_type"]),
embedding=node_data.get("embedding"),
metadata=node_data.get("metadata", {}),
)
return None
async def _find_similar_topic(
self, content: str, embedding: np.ndarray
) -> MemoryNode | None:
"""
查找相似的主题节点(基于语义相似度)
Args:
content: 内容
embedding: 嵌入向量
Returns:
相似节点,如果没有则返回 None
"""
try:
# 搜索相似节点(阈值 0.95
similar_nodes = await self.vector_store.search_similar_nodes(
query_embedding=embedding,
limit=1,
node_types=[NodeType.TOPIC],
min_similarity=0.95,
)
if similar_nodes and similar_nodes[0][1] >= 0.95:
node_id, similarity, metadata = similar_nodes[0]
logger.debug(
f"找到相似主题节点: {metadata.get('content', '')} (相似度: {similarity:.3f})"
)
# 从图存储中获取完整节点
if node_id in self.graph_store.graph.nodes:
node_data = self.graph_store.graph.nodes[node_id]
existing_node = MemoryNode(
id=node_id,
content=node_data["content"],
node_type=NodeType(node_data["node_type"]),
embedding=node_data.get("embedding"),
metadata=node_data.get("metadata", {}),
)
# 添加当前记忆ID到元数据
return existing_node
except Exception as e:
logger.warning(f"相似节点搜索失败: {e}")
return None
async def _find_similar_object(
self, content: str, embedding: np.ndarray
) -> MemoryNode | None:
"""
查找相似的客体节点(基于语义相似度)
Args:
content: 内容
embedding: 嵌入向量
Returns:
相似节点,如果没有则返回 None
"""
try:
# 搜索相似节点(阈值 0.95
similar_nodes = await self.vector_store.search_similar_nodes(
query_embedding=embedding,
limit=1,
node_types=[NodeType.OBJECT],
min_similarity=0.95,
)
if similar_nodes and similar_nodes[0][1] >= 0.95:
node_id, similarity, metadata = similar_nodes[0]
logger.debug(
f"找到相似客体节点: {metadata.get('content', '')} (相似度: {similarity:.3f})"
)
# 从图存储中获取完整节点
if node_id in self.graph_store.graph.nodes:
node_data = self.graph_store.graph.nodes[node_id]
existing_node = MemoryNode(
id=node_id,
content=node_data["content"],
node_type=NodeType(node_data["node_type"]),
embedding=node_data.get("embedding"),
metadata=node_data.get("metadata", {}),
)
return existing_node
except Exception as e:
logger.warning(f"相似节点搜索失败: {e}")
return None
def _generate_memory_id(self) -> str:
"""生成记忆ID"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
return f"mem_{timestamp}"
def _generate_node_id(self) -> str:
"""生成节点ID"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
return f"node_{timestamp}"
def _generate_edge_id(self) -> str:
"""生成边ID"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
return f"edge_{timestamp}"
async def link_memories(
self,
source_memory: Memory,
target_memory: Memory,
relation_type: str,
importance: float = 0.6,
) -> MemoryEdge:
"""
关联两个记忆(创建因果或引用边)
Args:
source_memory: 源记忆
target_memory: 目标记忆
relation_type: 关系类型(如 "导致", "引用"
importance: 重要性
Returns:
创建的边
"""
try:
# 获取两个记忆的主题节点(作为连接点)
source_topic = self._find_topic_node(source_memory)
target_topic = self._find_topic_node(target_memory)
if not source_topic or not target_topic:
raise ValueError("无法找到记忆的主题节点")
# 确定边的类型
edge_type = self._determine_edge_type(relation_type)
# 创建边
edge_id = f"edge_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
edge = MemoryEdge(
id=edge_id,
source_id=source_topic.id,
target_id=target_topic.id,
relation=relation_type,
edge_type=edge_type,
importance=importance,
metadata={
"source_memory_id": source_memory.id,
"target_memory_id": target_memory.id,
},
)
logger.info(
f"关联记忆: {source_memory.id} --{relation_type}--> {target_memory.id}"
)
return edge
except Exception as e:
logger.error(f"记忆关联失败: {e}", exc_info=True)
raise RuntimeError(f"记忆关联失败: {e}")
def _find_topic_node(self, memory: Memory) -> MemoryNode | None:
"""查找记忆中的主题节点"""
for node in memory.nodes:
if node.node_type == NodeType.TOPIC:
return node
return None
def _determine_edge_type(self, relation_type: str) -> EdgeType:
"""根据关系类型确定边的类型"""
causality_keywords = ["导致", "引起", "造成", "因为", "所以"]
reference_keywords = ["引用", "基于", "关于", "参考"]
for keyword in causality_keywords:
if keyword in relation_type:
return EdgeType.CAUSALITY
for keyword in reference_keywords:
if keyword in relation_type:
return EdgeType.REFERENCE
# 默认为引用类型
return EdgeType.REFERENCE

View File

@@ -0,0 +1,311 @@
"""
记忆提取器:从工具参数中提取和验证记忆元素
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from src.common.logger import get_logger
from src.memory_graph.models import MemoryType
from src.memory_graph.utils.time_parser import TimeParser
logger = get_logger(__name__)
class MemoryExtractor:
"""
记忆提取器
负责:
1. 从工具调用参数中提取记忆元素
2. 验证参数完整性和有效性
3. 标准化时间表达
4. 清洗和格式化数据
"""
def __init__(self, time_parser: TimeParser | None = None):
"""
初始化记忆提取器
Args:
time_parser: 时间解析器(可选)
"""
self.time_parser = time_parser or TimeParser()
def extract_from_tool_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""
从工具参数中提取记忆元素
Args:
params: 工具调用参数,例如:
{
"subject": "",
"memory_type": "事件",
"topic": "吃饭",
"object": "白米饭",
"attributes": {"时间": "今天", "地点": "家里"},
"importance": 0.3
}
Returns:
提取和标准化后的参数字典
"""
try:
# 1. 验证必需参数
self._validate_required_params(params)
# 2. 提取基础元素
extracted = {
"subject": self._clean_text(params["subject"]),
"memory_type": self._parse_memory_type(params["memory_type"]),
"topic": self._clean_text(params["topic"]),
}
# 3. 提取可选的客体
if params.get("object"):
extracted["object"] = self._clean_text(params["object"])
# 4. 提取和标准化属性
if params.get("attributes"):
extracted["attributes"] = self._process_attributes(params["attributes"])
else:
extracted["attributes"] = {}
# 5. 提取重要性
extracted["importance"] = self._parse_importance(params.get("importance", 0.5))
# 6. 添加时间戳
extracted["timestamp"] = datetime.now()
logger.debug(f"提取记忆元素: {extracted['subject']} - {extracted['topic']}")
return extracted
except Exception as e:
logger.error(f"记忆提取失败: {e}", exc_info=True)
raise ValueError(f"记忆提取失败: {e}")
def _validate_required_params(self, params: dict[str, Any]) -> None:
"""
验证必需参数
Args:
params: 参数字典
Raises:
ValueError: 如果缺少必需参数
"""
required_fields = ["subject", "memory_type", "topic"]
for field in required_fields:
if field not in params or not params[field]:
raise ValueError(f"缺少必需参数: {field}")
def _clean_text(self, text: Any) -> str:
"""
清洗文本
Args:
text: 输入文本
Returns:
清洗后的文本
"""
if not text:
return ""
text = str(text).strip()
# 移除多余的空格
text = " ".join(text.split())
# 移除特殊字符(保留基本标点)
# text = re.sub(r'[^\w\s\u4e00-\u9fff,.。!?;::、]', '', text)
return text
def _parse_memory_type(self, type_str: str) -> MemoryType:
"""
解析记忆类型
Args:
type_str: 类型字符串
Returns:
MemoryType 枚举
Raises:
ValueError: 如果类型无效
"""
type_str = type_str.strip()
# 尝试直接匹配
try:
return MemoryType(type_str)
except ValueError:
pass
# 模糊匹配
type_mapping = {
"事件": MemoryType.EVENT,
"event": MemoryType.EVENT,
"事实": MemoryType.FACT,
"fact": MemoryType.FACT,
"关系": MemoryType.RELATION,
"relation": MemoryType.RELATION,
"观点": MemoryType.OPINION,
"opinion": MemoryType.OPINION,
}
if type_str.lower() in type_mapping:
return type_mapping[type_str.lower()]
raise ValueError(f"无效的记忆类型: {type_str}")
def _parse_importance(self, importance: Any) -> float:
"""
解析重要性值
Args:
importance: 重要性值(可以是数字、字符串等)
Returns:
0-1之间的浮点数
"""
try:
value = float(importance)
# 限制在 0-1 范围内
return max(0.0, min(1.0, value))
except (ValueError, TypeError):
logger.warning(f"无效的重要性值: {importance},使用默认值 0.5")
return 0.5
def _process_attributes(self, attributes: dict[str, Any]) -> dict[str, Any]:
"""
处理属性字典
Args:
attributes: 原始属性字典
Returns:
处理后的属性字典
"""
processed = {}
for key, value in attributes.items():
key = key.strip()
# 特殊处理:时间属性
if key in ["时间", "time", "when"]:
parsed_time = self.time_parser.parse(str(value))
if parsed_time:
processed["时间"] = parsed_time.isoformat()
else:
processed["时间"] = str(value)
# 特殊处理:地点属性
elif key in ["地点", "place", "where", "位置"]:
processed["地点"] = self._clean_text(value)
# 特殊处理:原因属性
elif key in ["原因", "reason", "why", "因为"]:
processed["原因"] = self._clean_text(value)
# 特殊处理:方式属性
elif key in ["方式", "how", "manner"]:
processed["方式"] = self._clean_text(value)
# 其他属性
else:
processed[key] = self._clean_text(value)
return processed
def extract_link_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""
提取记忆关联参数(用于 link_memories 工具)
Args:
params: 工具参数,例如:
{
"source_memory_description": "我今天不开心",
"target_memory_description": "我摔东西",
"relation_type": "导致",
"importance": 0.6
}
Returns:
提取后的参数
"""
try:
required = ["source_memory_description", "target_memory_description", "relation_type"]
for field in required:
if field not in params or not params[field]:
raise ValueError(f"缺少必需参数: {field}")
extracted = {
"source_description": self._clean_text(params["source_memory_description"]),
"target_description": self._clean_text(params["target_memory_description"]),
"relation_type": self._clean_text(params["relation_type"]),
"importance": self._parse_importance(params.get("importance", 0.6)),
}
logger.debug(
f"提取关联参数: {extracted['source_description']} --{extracted['relation_type']}--> "
f"{extracted['target_description']}"
)
return extracted
except Exception as e:
logger.error(f"关联参数提取失败: {e}", exc_info=True)
raise ValueError(f"关联参数提取失败: {e}")
def validate_relation_type(self, relation_type: str) -> str:
"""
验证关系类型
Args:
relation_type: 关系类型字符串
Returns:
标准化的关系类型
"""
# 因果关系映射
causality_relations = {
"因为": "因为",
"所以": "所以",
"导致": "导致",
"引起": "导致",
"造成": "导致",
"": "因为",
"": "所以",
}
# 引用关系映射
reference_relations = {
"引用": "引用",
"基于": "基于",
"关于": "关于",
"参考": "引用",
}
# 相关关系
related_relations = {
"相关": "相关",
"有关": "相关",
"联系": "相关",
}
relation_type = relation_type.strip()
# 查找匹配
for mapping in [causality_relations, reference_relations, related_relations]:
if relation_type in mapping:
return mapping[relation_type]
# 未找到映射,返回原值
logger.warning(f"未识别的关系类型: {relation_type},使用原值")
return relation_type

View File

@@ -0,0 +1,355 @@
"""
节点去重合并器:基于语义相似度合并重复节点
"""
from __future__ import annotations
from src.common.logger import get_logger
from src.config.official_configs import MemoryConfig
from src.memory_graph.models import MemoryNode, NodeType
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
logger = get_logger(__name__)
class NodeMerger:
"""
节点合并器
负责:
1. 基于语义相似度查找重复节点
2. 验证上下文匹配
3. 执行节点合并操作
"""
def __init__(
self,
vector_store: VectorStore,
graph_store: GraphStore,
config: MemoryConfig,
):
"""
初始化节点合并器
Args:
vector_store: 向量存储
graph_store: 图存储
config: 记忆配置对象
"""
self.vector_store = vector_store
self.graph_store = graph_store
self.config = config
logger.info(
f"初始化节点合并器: threshold={self.config.node_merger_similarity_threshold}, "
f"context_match={self.config.node_merger_context_match_required}"
)
async def find_similar_nodes(
self,
node: MemoryNode,
threshold: float | None = None,
limit: int = 5,
) -> list[tuple[MemoryNode, float]]:
"""
查找与指定节点相似的节点
Args:
node: 查询节点
threshold: 相似度阈值(可选,默认使用配置值)
limit: 返回结果数量
Returns:
List of (similar_node, similarity)
"""
if not node.has_embedding():
logger.warning(f"节点 {node.id} 没有 embedding无法查找相似节点")
return []
threshold = threshold or self.config.node_merger_similarity_threshold
try:
# 在向量存储中搜索相似节点
results = await self.vector_store.search_similar_nodes(
query_embedding=node.embedding,
limit=limit + 1, # +1 因为可能包含节点自己
node_types=[node.node_type], # 只搜索相同类型的节点
min_similarity=threshold,
)
# 过滤掉节点自己,并构建结果
similar_nodes = []
for node_id, similarity, metadata in results:
if node_id == node.id:
continue # 跳过自己
# 从图存储中获取完整节点信息
memories = self.graph_store.get_memories_by_node(node_id)
if memories:
# 从第一个记忆中获取节点
target_node = memories[0].get_node_by_id(node_id)
if target_node:
similar_nodes.append((target_node, similarity))
logger.debug(f"找到 {len(similar_nodes)} 个相似节点 (阈值: {threshold})")
return similar_nodes
except Exception as e:
logger.error(f"查找相似节点失败: {e}", exc_info=True)
return []
async def should_merge(
self,
source_node: MemoryNode,
target_node: MemoryNode,
similarity: float,
) -> bool:
"""
判断两个节点是否应该合并
Args:
source_node: 源节点
target_node: 目标节点
similarity: 语义相似度
Returns:
是否应该合并
"""
# 1. 检查相似度阈值
if similarity < self.config.node_merger_similarity_threshold:
return False
# 2. 非常高的相似度(>0.95)直接合并
if similarity > 0.95:
logger.debug(f"高相似度 ({similarity:.3f}),直接合并")
return True
# 3. 如果不要求上下文匹配,则通过相似度判断
if not self.config.node_merger_context_match_required:
return True
# 4. 检查上下文匹配
context_match = await self._check_context_match(source_node, target_node)
if context_match:
logger.debug(
f"相似度 {similarity:.3f} + 上下文匹配,决定合并: "
f"'{source_node.content}''{target_node.content}'"
)
return True
logger.debug(
f"相似度 {similarity:.3f} 但上下文不匹配,不合并: "
f"'{source_node.content}''{target_node.content}'"
)
return False
async def _check_context_match(
self,
source_node: MemoryNode,
target_node: MemoryNode,
) -> bool:
"""
检查两个节点的上下文是否匹配
上下文匹配的标准:
1. 节点类型相同
2. 邻居节点有重叠
3. 邻居节点的内容相似
Args:
source_node: 源节点
target_node: 目标节点
Returns:
是否匹配
"""
# 1. 节点类型必须相同
if source_node.node_type != target_node.node_type:
return False
# 2. 获取邻居节点
source_neighbors = self.graph_store.get_neighbors(source_node.id, direction="both")
target_neighbors = self.graph_store.get_neighbors(target_node.id, direction="both")
# 如果都没有邻居,认为上下文不足,保守地不合并
if not source_neighbors or not target_neighbors:
return False
# 3. 检查邻居内容是否有重叠
source_neighbor_contents = set()
for neighbor_id, edge_data in source_neighbors:
neighbor_node = self._get_node_content(neighbor_id)
if neighbor_node:
source_neighbor_contents.add(neighbor_node.lower())
target_neighbor_contents = set()
for neighbor_id, edge_data in target_neighbors:
neighbor_node = self._get_node_content(neighbor_id)
if neighbor_node:
target_neighbor_contents.add(neighbor_node.lower())
# 计算重叠率
intersection = source_neighbor_contents & target_neighbor_contents
union = source_neighbor_contents | target_neighbor_contents
if not union:
return False
overlap_ratio = len(intersection) / len(union)
# 如果有 30% 以上的邻居重叠,认为上下文匹配
return overlap_ratio > 0.3
def _get_node_content(self, node_id: str) -> str | None:
"""获取节点的内容"""
memories = self.graph_store.get_memories_by_node(node_id)
if memories:
node = memories[0].get_node_by_id(node_id)
if node:
return node.content
return None
async def merge_nodes(
self,
source: MemoryNode,
target: MemoryNode,
) -> bool:
"""
合并两个节点
将 source 节点的所有边转移到 target 节点,然后删除 source
Args:
source: 源节点(将被删除)
target: 目标节点(保留)
Returns:
是否成功
"""
try:
logger.info(f"合并节点: '{source.content}' ({source.id}) → '{target.content}' ({target.id})")
# 1. 在图存储中合并节点
self.graph_store.merge_nodes(source.id, target.id)
# 2. 在向量存储中删除源节点
await self.vector_store.delete_node(source.id)
# 3. 更新所有相关记忆的节点引用
self._update_memory_references(source.id, target.id)
logger.info(f"节点合并成功: {source.id}{target.id}")
return True
except Exception as e:
logger.error(f"节点合并失败: {e}", exc_info=True)
return False
def _update_memory_references(self, old_node_id: str, new_node_id: str) -> None:
"""
更新记忆中的节点引用
Args:
old_node_id: 旧节点ID
new_node_id: 新节点ID
"""
# 获取所有包含旧节点的记忆
memories = self.graph_store.get_memories_by_node(old_node_id)
for memory in memories:
# 移除旧节点
memory.nodes = [n for n in memory.nodes if n.id != old_node_id]
# 更新边的引用
for edge in memory.edges:
if edge.source_id == old_node_id:
edge.source_id = new_node_id
if edge.target_id == old_node_id:
edge.target_id = new_node_id
# 更新主体ID如果是主体节点
if memory.subject_id == old_node_id:
memory.subject_id = new_node_id
async def batch_merge_similar_nodes(
self,
nodes: list[MemoryNode],
progress_callback: callable | None = None,
) -> dict:
"""
批量处理节点合并
Args:
nodes: 要处理的节点列表
progress_callback: 进度回调函数
Returns:
统计信息字典
"""
stats = {
"total": len(nodes),
"checked": 0,
"merged": 0,
"skipped": 0,
}
for i, node in enumerate(nodes):
try:
# 只处理有 embedding 的主题和客体节点
if not node.has_embedding() or node.node_type not in [
NodeType.TOPIC,
NodeType.OBJECT,
]:
stats["skipped"] += 1
continue
# 查找相似节点
similar_nodes = await self.find_similar_nodes(node, limit=5)
if similar_nodes:
# 选择最相似的节点
best_match, similarity = similar_nodes[0]
# 判断是否应该合并
if await self.should_merge(node, best_match, similarity):
success = await self.merge_nodes(node, best_match)
if success:
stats["merged"] += 1
stats["checked"] += 1
# 调用进度回调
if progress_callback:
progress_callback(i + 1, stats["total"], stats)
except Exception as e:
logger.error(f"处理节点 {node.id} 时失败: {e}", exc_info=True)
stats["skipped"] += 1
logger.info(
f"批量合并完成: 总数={stats['total']}, 检查={stats['checked']}, "
f"合并={stats['merged']}, 跳过={stats['skipped']}"
)
return stats
def get_merge_candidates(
self,
min_similarity: float = 0.85,
limit: int = 100,
) -> list[tuple[str, str, float]]:
"""
获取待合并的候选节点对
Args:
min_similarity: 最小相似度
limit: 最大返回数量
Returns:
List of (node_id_1, node_id_2, similarity)
"""
# TODO: 实现更智能的候选查找算法
# 目前返回空列表,后续可以基于向量存储进行批量查询
return []

1838
src/memory_graph/manager.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
"""
记忆系统管理单例
提供全局访问的 MemoryManager 实例
"""
from __future__ import annotations
from pathlib import Path
from src.common.logger import get_logger
from src.memory_graph.manager import MemoryManager
logger = get_logger(__name__)
# 全局 MemoryManager 实例
_memory_manager: MemoryManager | None = None
_initialized: bool = False
async def initialize_memory_manager(
data_dir: Path | str | None = None,
) -> MemoryManager | None:
"""
初始化全局 MemoryManager
直接从 global_config.memory 读取配置
Args:
data_dir: 数据目录(可选,默认从配置读取)
Returns:
MemoryManager 实例,如果禁用则返回 None
"""
global _memory_manager, _initialized
if _initialized and _memory_manager:
logger.info("MemoryManager 已经初始化,返回现有实例")
return _memory_manager
try:
from src.config.config import global_config
# 检查是否启用
if not global_config.memory or not getattr(global_config.memory, "enable", False):
logger.info("记忆图系统已在配置中禁用")
_initialized = False
_memory_manager = None
return None
# 处理数据目录
if data_dir is None:
data_dir = getattr(global_config.memory, "data_dir", "data/memory_graph")
if isinstance(data_dir, str):
data_dir = Path(data_dir)
logger.info(f"正在初始化全局 MemoryManager (data_dir={data_dir})...")
_memory_manager = MemoryManager(data_dir=data_dir)
await _memory_manager.initialize()
_initialized = True
logger.info("✅ 全局 MemoryManager 初始化成功")
return _memory_manager
except Exception as e:
logger.error(f"初始化 MemoryManager 失败: {e}", exc_info=True)
_initialized = False
_memory_manager = None
raise
def get_memory_manager() -> MemoryManager | None:
"""
获取全局 MemoryManager 实例
Returns:
MemoryManager 实例,如果未初始化则返回 None
"""
if not _initialized or _memory_manager is None:
logger.warning("MemoryManager 尚未初始化,请先调用 initialize_memory_manager()")
return None
return _memory_manager
async def shutdown_memory_manager():
"""关闭全局 MemoryManager"""
global _memory_manager, _initialized
if _memory_manager:
try:
logger.info("正在关闭全局 MemoryManager...")
await _memory_manager.shutdown()
logger.info("✅ 全局 MemoryManager 已关闭")
except Exception as e:
logger.error(f"关闭 MemoryManager 时出错: {e}", exc_info=True)
finally:
_memory_manager = None
_initialized = False
def is_initialized() -> bool:
"""检查 MemoryManager 是否已初始化"""
return _initialized and _memory_manager is not None

299
src/memory_graph/models.py Normal file
View File

@@ -0,0 +1,299 @@
"""
记忆图系统核心数据模型
定义节点、边、记忆等核心数据结构
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
import numpy as np
class NodeType(Enum):
"""节点类型枚举"""
SUBJECT = "主体" # 记忆的主语(我、小明、老师)
TOPIC = "主题" # 动作或状态(吃饭、情绪、学习)
OBJECT = "客体" # 宾语(白米饭、学校、书)
ATTRIBUTE = "属性" # 延伸属性(时间、地点、原因)
VALUE = "" # 属性的具体值2025-11-05、不开心
class MemoryType(Enum):
"""记忆类型枚举"""
EVENT = "事件" # 有时间点的动作
FACT = "事实" # 相对稳定的状态
RELATION = "关系" # 人际关系
OPINION = "观点" # 主观评价
class EdgeType(Enum):
"""边类型枚举"""
MEMORY_TYPE = "记忆类型" # 主体 → 主题
CORE_RELATION = "核心关系" # 主题 → 客体(是/做/有)
ATTRIBUTE = "属性关系" # 任意节点 → 属性
CAUSALITY = "因果关系" # 记忆 → 记忆
REFERENCE = "引用关系" # 记忆 → 记忆(转述)
RELATION = "关联关系" # 记忆 → 记忆(自动关联发现的关系)
class MemoryStatus(Enum):
"""记忆状态枚举"""
STAGED = "staged" # 临时状态,未整理
CONSOLIDATED = "consolidated" # 已整理
ARCHIVED = "archived" # 已归档(低价值,很少访问)
@dataclass
class MemoryNode:
"""记忆节点"""
id: str # 节点唯一ID
content: str # 节点内容(如:"我"、"吃饭"、"白米饭"
node_type: NodeType # 节点类型
embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要)
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self):
"""后初始化处理"""
if not self.id:
self.id = str(uuid.uuid4())
def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
"content": self.content,
"node_type": self.node_type.value,
"embedding": self.embedding.tolist() if self.embedding is not None else None,
"metadata": self.metadata,
"created_at": self.created_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MemoryNode:
"""从字典创建节点"""
embedding = None
if data.get("embedding") is not None:
embedding = np.array(data["embedding"])
return cls(
id=data["id"],
content=data["content"],
node_type=NodeType(data["node_type"]),
embedding=embedding,
metadata=data.get("metadata", {}),
created_at=datetime.fromisoformat(data["created_at"]),
)
def has_embedding(self) -> bool:
"""是否有语义向量"""
return self.embedding is not None
def __str__(self) -> str:
return f"Node({self.node_type.value}: {self.content})"
@dataclass
class MemoryEdge:
"""记忆边(节点之间的关系)"""
id: str # 边唯一ID
source_id: str # 源节点ID
target_id: str # 目标节点ID或目标记忆ID
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为"
edge_type: EdgeType # 边类型
importance: float = 0.5 # 重要性 [0-1]
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self):
"""后初始化处理"""
if not self.id:
self.id = str(uuid.uuid4())
# 确保重要性在有效范围内
self.importance = max(0.0, min(1.0, self.importance))
def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
"source_id": self.source_id,
"target_id": self.target_id,
"relation": self.relation,
"edge_type": self.edge_type.value,
"importance": self.importance,
"metadata": self.metadata,
"created_at": self.created_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MemoryEdge:
"""从字典创建边"""
return cls(
id=data["id"],
source_id=data["source_id"],
target_id=data["target_id"],
relation=data["relation"],
edge_type=EdgeType(data["edge_type"]),
importance=data.get("importance", 0.5),
metadata=data.get("metadata", {}),
created_at=datetime.fromisoformat(data["created_at"]),
)
def __str__(self) -> str:
return f"Edge({self.source_id} --{self.relation}--> {self.target_id})"
@dataclass
class Memory:
"""完整记忆(由节点和边组成的子图)"""
id: str # 记忆唯一ID
subject_id: str # 主体节点ID
memory_type: MemoryType # 记忆类型
nodes: list[MemoryNode] # 该记忆包含的所有节点
edges: list[MemoryEdge] # 该记忆包含的所有边
importance: float = 0.5 # 整体重要性 [0-1]
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
created_at: datetime = field(default_factory=datetime.now)
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
access_count: int = 0 # 访问次数
decay_factor: float = 1.0 # 衰减因子(随时间变化)
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
def __post_init__(self):
"""后初始化处理"""
if not self.id:
self.id = str(uuid.uuid4())
# 确保重要性和激活度在有效范围内
self.importance = max(0.0, min(1.0, self.importance))
self.activation = max(0.0, min(1.0, self.activation))
def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"id": self.id,
"subject_id": self.subject_id,
"memory_type": self.memory_type.value,
"nodes": [node.to_dict() for node in self.nodes],
"edges": [edge.to_dict() for edge in self.edges],
"importance": self.importance,
"activation": self.activation,
"status": self.status.value,
"created_at": self.created_at.isoformat(),
"last_accessed": self.last_accessed.isoformat(),
"access_count": self.access_count,
"decay_factor": self.decay_factor,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Memory:
"""从字典创建记忆"""
return cls(
id=data["id"],
subject_id=data["subject_id"],
memory_type=MemoryType(data["memory_type"]),
nodes=[MemoryNode.from_dict(n) for n in data["nodes"]],
edges=[MemoryEdge.from_dict(e) for e in data["edges"]],
importance=data.get("importance", 0.5),
activation=data.get("activation", 0.0),
status=MemoryStatus(data.get("status", "staged")),
created_at=datetime.fromisoformat(data["created_at"]),
last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])),
access_count=data.get("access_count", 0),
decay_factor=data.get("decay_factor", 1.0),
metadata=data.get("metadata", {}),
)
def update_access(self) -> None:
"""更新访问记录"""
self.last_accessed = datetime.now()
self.access_count += 1
def get_node_by_id(self, node_id: str) -> MemoryNode | None:
"""根据ID获取节点"""
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_subject_node(self) -> MemoryNode | None:
"""获取主体节点"""
return self.get_node_by_id(self.subject_id)
def to_text(self) -> str:
"""转换为文本描述用于显示和LLM处理"""
subject_node = self.get_subject_node()
if not subject_node:
return f"[记忆 {self.id[:8]}]"
# 简单的文本生成逻辑
parts = [f"{subject_node.content}"]
# 查找主题节点(通过记忆类型边连接)
topic_node = None
for edge in self.edges:
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == self.subject_id:
topic_node = self.get_node_by_id(edge.target_id)
break
if topic_node:
parts.append(topic_node.content)
# 查找客体节点(通过核心关系边连接)
for edge in self.edges:
if edge.edge_type == EdgeType.CORE_RELATION and edge.source_id == topic_node.id:
obj_node = self.get_node_by_id(edge.target_id)
if obj_node:
parts.append(f"{edge.relation} {obj_node.content}")
break
return " ".join(parts)
def __str__(self) -> str:
return f"Memory({self.memory_type.value}: {self.to_text()})"
@dataclass
class StagedMemory:
"""临时记忆(未整理状态)"""
memory: Memory # 原始记忆对象
status: MemoryStatus = MemoryStatus.STAGED # 状态
created_at: datetime = field(default_factory=datetime.now)
consolidated_at: datetime | None = None # 整理时间
merge_history: list[str] = field(default_factory=list) # 被合并的节点ID列表
def to_dict(self) -> dict[str, Any]:
"""转换为字典"""
return {
"memory": self.memory.to_dict(),
"status": self.status.value,
"created_at": self.created_at.isoformat(),
"consolidated_at": self.consolidated_at.isoformat() if self.consolidated_at else None,
"merge_history": self.merge_history,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> StagedMemory:
"""从字典创建临时记忆"""
return cls(
memory=Memory.from_dict(data["memory"]),
status=MemoryStatus(data.get("status", "staged")),
created_at=datetime.fromisoformat(data["created_at"]),
consolidated_at=datetime.fromisoformat(data["consolidated_at"]) if data.get("consolidated_at") else None,
merge_history=data.get("merge_history", []),
)

View File

@@ -0,0 +1,258 @@
"""
记忆系统插件工具
将 MemoryTools 适配为 BaseTool 格式,供 LLM 使用
"""
from __future__ import annotations
from typing import Any, ClassVar
from src.common.logger import get_logger
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ToolParamType
logger = get_logger(__name__)
class CreateMemoryTool(BaseTool):
"""创建记忆工具"""
name = "create_memory"
description = """记录对话中有价值的信息,构建长期记忆。
## 应该记录的内容类型:
### 高优先级记录importance 0.7-1.0
- 个人核心信息:姓名、年龄、职业、学历、联系方式
- 重要关系:家人、亲密朋友、恋人关系
- 核心目标:人生规划、职业目标、重要决定
- 关键事件:毕业、入职、搬家、重要成就
### 中等优先级importance 0.5-0.7
- 生活状态:工作内容、学习情况、日常习惯
- 兴趣偏好:喜欢/不喜欢的事物、消费偏好
- 观点态度:价值观、对事物的看法
- 技能知识:掌握的技能、专业领域
- 一般事件:日常活动、例行任务
### 低优先级importance 0.3-0.5
- 临时状态:今天心情、当前活动
- 一般评价:对产品/服务的简单评价
- 琐碎事件:买东西、看电影等常规活动
### ❌ 不应记录
- 单纯招呼语:"你好""再见""谢谢"
- 无意义语气词:"""""好的"
- 纯粹回复确认:没有信息量的回应
## 记忆拆分原则
一句话多个信息点 → 多次调用创建多条记忆
示例:"我最近在学Python想找数据分析的工作"
→ 调用1{{subject:"[从历史提取真实名字]", memory_type:"事实", topic:"学习", object:"Python", attributes:{{时间:"最近", 状态:"进行中"}}, importance:0.7}}
→ 调用2{{subject:"[从历史提取真实名字]", memory_type:"目标", topic:"求职", object:"数据分析岗位", attributes:{{状态:"计划中"}}, importance:0.8}}"""
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
("subject", ToolParamType.STRING, "记忆主体(重要!)。从对话历史中提取真实发送人名字。示例:如果看到'Prou(12345678): 我喜欢...'subject应填'Prou';如果看到'张三: 我在...'subject应填'张三'。❌禁止使用'用户'这种泛指,必须用具体名字!", True, None),
("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]),
("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None),
("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填如果有明确对象建议填写", False, None),
("attributes", ToolParamType.STRING, '详细属性JSON格式字符串。强烈建议包含时间具体到日期和小时分钟、地点、状态、原因等上下文信息。例{"时间":"2025-11-06 12:00","地点":"公司","状态":"进行中","原因":"项目需要"}', False, None),
("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考日常琐事0.3-0.4一般对话0.5-0.6重要信息0.7-0.8核心记忆0.9-1.0。不确定时用0.5", False, None),
]
available_for_llm = True
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行创建记忆"""
try:
# 获取全局 memory_manager
from src.memory_graph.manager_singleton import get_memory_manager
manager = get_memory_manager()
if not manager:
return {
"name": self.name,
"content": "记忆系统未初始化"
}
# 提取参数
subject = function_args.get("subject", "")
memory_type = function_args.get("memory_type", "")
topic = function_args.get("topic", "")
obj = function_args.get("object")
# 处理 attributes可能是字符串或字典
attributes_raw = function_args.get("attributes", {})
if isinstance(attributes_raw, str):
import orjson
try:
attributes = orjson.loads(attributes_raw)
except Exception:
attributes = {}
else:
attributes = attributes_raw
importance = function_args.get("importance", 0.5)
# 创建记忆
memory = await manager.create_memory(
subject=subject,
memory_type=memory_type,
topic=topic,
object_=obj,
attributes=attributes,
importance=importance,
)
if memory:
logger.info(f"[CreateMemoryTool] 成功创建记忆: {memory.id}")
return {
"name": self.name,
"content": f"成功创建记忆ID: {memory.id}",
"memory_id": memory.id, # 返回记忆ID供后续使用
}
else:
return {
"name": self.name,
"content": "创建记忆失败",
"memory_id": None,
}
except Exception as e:
logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True)
return {
"name": self.name,
"content": f"创建记忆时出错: {e!s}"
}
class LinkMemoriesTool(BaseTool):
"""关联记忆工具"""
name = "link_memories"
description = "在两个记忆之间建立关联关系。用于连接相关的记忆,形成知识网络。"
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
("source_query", ToolParamType.STRING, "源记忆的搜索查询(如记忆的主题关键词)", True, None),
("target_query", ToolParamType.STRING, "目标记忆的搜索查询", True, None),
("relation", ToolParamType.STRING, "关系类型", True, ["导致", "引用", "相似", "相反", "部分"]),
("strength", ToolParamType.FLOAT, "关系强度0.0-1.0默认0.7", False, None),
]
available_for_llm = False # 暂不对 LLM 开放
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行关联记忆"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
manager = get_memory_manager()
if not manager:
return {
"name": self.name,
"content": "记忆系统未初始化"
}
source_query = function_args.get("source_query", "")
target_query = function_args.get("target_query", "")
relation = function_args.get("relation", "引用")
strength = function_args.get("strength", 0.7)
# 关联记忆
success = await manager.link_memories(
source_description=source_query,
target_description=target_query,
relation_type=relation,
importance=strength,
)
if success:
logger.info(f"[LinkMemoriesTool] 成功关联记忆: {source_query} -> {target_query}")
return {
"name": self.name,
"content": f"成功建立关联: {source_query} --{relation}--> {target_query}"
}
else:
return {
"name": self.name,
"content": "关联记忆失败,可能找不到匹配的记忆"
}
except Exception as e:
logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True)
return {
"name": self.name,
"content": f"关联记忆时出错: {e!s}"
}
class SearchMemoriesTool(BaseTool):
"""搜索记忆工具"""
name = "search_memories"
description = "搜索相关的记忆。根据查询词搜索记忆库,返回最相关的记忆。"
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
("query", ToolParamType.STRING, "搜索查询词,描述想要找什么样的记忆", True, None),
("top_k", ToolParamType.INTEGER, "返回的记忆数量默认5", False, None),
("min_importance", ToolParamType.FLOAT, "最低重要性阈值0.0-1.0),只返回重要性不低于此值的记忆", False, None),
]
available_for_llm = False # 暂不对 LLM 开放,记忆检索在提示词构建时自动执行
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行搜索记忆"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
manager = get_memory_manager()
if not manager:
return {
"name": self.name,
"content": "记忆系统未初始化"
}
query = function_args.get("query", "")
top_k = function_args.get("top_k", 5)
min_importance_raw = function_args.get("min_importance")
min_importance = float(min_importance_raw) if min_importance_raw is not None else 0.0
# 搜索记忆
memories = await manager.search_memories(
query=query,
top_k=top_k,
min_importance=min_importance,
)
if memories:
# 格式化结果
result_lines = [f"找到 {len(memories)} 条相关记忆:\n"]
for i, mem in enumerate(memories, 1):
topic = mem.metadata.get("topic", "N/A")
mem_type = mem.metadata.get("memory_type", "N/A")
importance = mem.importance
result_lines.append(
f"{i}. [{mem_type}] {topic} (重要性: {importance:.2f})"
)
result_text = "\n".join(result_lines)
logger.info(f"[SearchMemoriesTool] 搜索成功: 查询='{query}', 结果数={len(memories)}")
return {
"name": self.name,
"content": result_text
}
else:
return {
"name": self.name,
"content": f"未找到与 '{query}' 相关的记忆"
}
except Exception as e:
logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True)
return {
"name": self.name,
"content": f"搜索记忆时出错: {e!s}"
}

View File

@@ -0,0 +1,8 @@
"""
存储层模块
"""
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
__all__ = ["GraphStore", "VectorStore"]

View File

@@ -0,0 +1,505 @@
"""
图存储层:基于 NetworkX 的图结构管理
"""
from __future__ import annotations
import networkx as nx
from src.common.logger import get_logger
from src.memory_graph.models import Memory, MemoryEdge
logger = get_logger(__name__)
class GraphStore:
"""
图存储封装类
负责:
1. 记忆图的构建和维护
2. 节点和边的快速查询
3. 图遍历算法BFS/DFS
4. 邻接关系查询
"""
def __init__(self):
"""初始化图存储"""
# 使用有向图(记忆关系通常是有向的)
self.graph = nx.DiGraph()
# 索引记忆ID -> 记忆对象
self.memory_index: dict[str, Memory] = {}
# 索引节点ID -> 所属记忆ID集合
self.node_to_memories: dict[str, set[str]] = {}
logger.info("初始化图存储")
def add_memory(self, memory: Memory) -> None:
"""
添加记忆到图
Args:
memory: 要添加的记忆
"""
try:
# 1. 添加所有节点到图
for node in memory.nodes:
if not self.graph.has_node(node.id):
self.graph.add_node(
node.id,
content=node.content,
node_type=node.node_type.value,
created_at=node.created_at.isoformat(),
metadata=node.metadata,
)
# 更新节点到记忆的映射
if node.id not in self.node_to_memories:
self.node_to_memories[node.id] = set()
self.node_to_memories[node.id].add(memory.id)
# 2. 添加所有边到图
for edge in memory.edges:
self.graph.add_edge(
edge.source_id,
edge.target_id,
edge_id=edge.id,
relation=edge.relation,
edge_type=edge.edge_type.value,
importance=edge.importance,
metadata=edge.metadata,
created_at=edge.created_at.isoformat(),
)
# 3. 保存记忆对象
self.memory_index[memory.id] = memory
logger.debug(f"添加记忆到图: {memory}")
except Exception as e:
logger.error(f"添加记忆失败: {e}", exc_info=True)
raise
def get_memory_by_id(self, memory_id: str) -> Memory | None:
"""
根据ID获取记忆
Args:
memory_id: 记忆ID
Returns:
记忆对象或 None
"""
return self.memory_index.get(memory_id)
def get_all_memories(self) -> list[Memory]:
"""
获取所有记忆
Returns:
所有记忆的列表
"""
return list(self.memory_index.values())
def get_memories_by_node(self, node_id: str) -> list[Memory]:
"""
获取包含指定节点的所有记忆
Args:
node_id: 节点ID
Returns:
记忆列表
"""
if node_id not in self.node_to_memories:
return []
memory_ids = self.node_to_memories[node_id]
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]:
"""
获取从指定节点出发的所有边
Args:
node_id: 源节点ID
relation_types: 关系类型过滤(可选)
Returns:
边信息列表
"""
if not self.graph.has_node(node_id):
return []
edges = []
for _, target_id, edge_data in self.graph.out_edges(node_id, data=True):
# 过滤关系类型
if relation_types and edge_data.get("relation") not in relation_types:
continue
edges.append(
{
"source_id": node_id,
"target_id": target_id,
"relation": edge_data.get("relation"),
"edge_type": edge_data.get("edge_type"),
"importance": edge_data.get("importance", 0.5),
**edge_data,
}
)
return edges
def get_neighbors(
self, node_id: str, direction: str = "out", relation_types: list[str] | None = None
) -> list[tuple[str, dict]]:
"""
获取节点的邻居节点
Args:
node_id: 节点ID
direction: 方向 ("out"=出边, "in"=入边, "both"=双向)
relation_types: 关系类型过滤
Returns:
List of (neighbor_id, edge_data)
"""
if not self.graph.has_node(node_id):
return []
neighbors = []
# 处理出边
if direction in ["out", "both"]:
for _, target_id, edge_data in self.graph.out_edges(node_id, data=True):
if not relation_types or edge_data.get("relation") in relation_types:
neighbors.append((target_id, edge_data))
# 处理入边
if direction in ["in", "both"]:
for source_id, _, edge_data in self.graph.in_edges(node_id, data=True):
if not relation_types or edge_data.get("relation") in relation_types:
neighbors.append((source_id, edge_data))
return neighbors
def find_path(self, source_id: str, target_id: str, max_length: int | None = None) -> list[str] | None:
"""
查找两个节点之间的最短路径
Args:
source_id: 源节点ID
target_id: 目标节点ID
max_length: 最大路径长度(可选)
Returns:
路径节点ID列表或 None如果不存在路径
"""
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
return None
try:
if max_length:
# 使用 cutoff 限制路径长度
path = nx.shortest_path(self.graph, source_id, target_id, weight=None)
if len(path) - 1 <= max_length: # 边数 = 节点数 - 1
return path
return None
else:
return nx.shortest_path(self.graph, source_id, target_id, weight=None)
except nx.NetworkXNoPath:
return None
except Exception as e:
logger.error(f"查找路径失败: {e}", exc_info=True)
return None
def bfs_expand(
self,
start_nodes: list[str],
depth: int = 1,
relation_types: list[str] | None = None,
) -> set[str]:
"""
从起始节点进行广度优先搜索扩展
Args:
start_nodes: 起始节点ID列表
depth: 扩展深度
relation_types: 关系类型过滤
Returns:
扩展到的所有节点ID集合
"""
visited = set()
queue = [(node_id, 0) for node_id in start_nodes if self.graph.has_node(node_id)]
while queue:
current_node, current_depth = queue.pop(0)
if current_node in visited:
continue
visited.add(current_node)
if current_depth >= depth:
continue
# 获取邻居并加入队列
neighbors = self.get_neighbors(current_node, direction="out", relation_types=relation_types)
for neighbor_id, _ in neighbors:
if neighbor_id not in visited:
queue.append((neighbor_id, current_depth + 1))
return visited
def get_subgraph(self, node_ids: list[str]) -> nx.DiGraph:
"""
获取包含指定节点的子图
Args:
node_ids: 节点ID列表
Returns:
NetworkX 子图
"""
return self.graph.subgraph(node_ids).copy()
def merge_nodes(self, source_id: str, target_id: str) -> None:
"""
合并两个节点将source的所有边转移到target然后删除source
Args:
source_id: 源节点ID将被删除
target_id: 目标节点ID保留
"""
if not self.graph.has_node(source_id) or not self.graph.has_node(target_id):
logger.warning(f"合并节点失败: 节点不存在 ({source_id}, {target_id})")
return
try:
# 1. 转移入边
for pred, _, edge_data in self.graph.in_edges(source_id, data=True):
if pred != target_id: # 避免自环
self.graph.add_edge(pred, target_id, **edge_data)
# 2. 转移出边
for _, succ, edge_data in self.graph.out_edges(source_id, data=True):
if succ != target_id: # 避免自环
self.graph.add_edge(target_id, succ, **edge_data)
# 3. 更新节点到记忆的映射
if source_id in self.node_to_memories:
memory_ids = self.node_to_memories[source_id]
if target_id not in self.node_to_memories:
self.node_to_memories[target_id] = set()
self.node_to_memories[target_id].update(memory_ids)
del self.node_to_memories[source_id]
# 4. 删除源节点
self.graph.remove_node(source_id)
logger.info(f"节点合并: {source_id}{target_id}")
except Exception as e:
logger.error(f"合并节点失败: {e}", exc_info=True)
raise
def get_node_degree(self, node_id: str) -> tuple[int, int]:
"""
获取节点的度数
Args:
node_id: 节点ID
Returns:
(in_degree, out_degree)
"""
if not self.graph.has_node(node_id):
return (0, 0)
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
def get_statistics(self) -> dict[str, int]:
"""获取图的统计信息"""
return {
"total_nodes": self.graph.number_of_nodes(),
"total_edges": self.graph.number_of_edges(),
"total_memories": len(self.memory_index),
"connected_components": nx.number_weakly_connected_components(self.graph),
}
def to_dict(self) -> dict:
"""
将图转换为字典(用于持久化)
Returns:
图的字典表示
"""
return {
"nodes": [
{"id": node_id, **self.graph.nodes[node_id]} for node_id in self.graph.nodes()
],
"edges": [
{
"source": u,
"target": v,
**data,
}
for u, v, data in self.graph.edges(data=True)
],
"memories": {memory_id: memory.to_dict() for memory_id, memory in self.memory_index.items()},
"node_to_memories": {node_id: list(mem_ids) for node_id, mem_ids in self.node_to_memories.items()},
}
@classmethod
def from_dict(cls, data: dict) -> GraphStore:
"""
从字典加载图
Args:
data: 图的字典表示
Returns:
GraphStore 实例
"""
store = cls()
# 1. 加载节点
for node_data in data.get("nodes", []):
node_id = node_data.pop("id")
store.graph.add_node(node_id, **node_data)
# 2. 加载边
for edge_data in data.get("edges", []):
source = edge_data.pop("source")
target = edge_data.pop("target")
store.graph.add_edge(source, target, **edge_data)
# 3. 加载记忆
for memory_id, memory_dict in data.get("memories", {}).items():
store.memory_index[memory_id] = Memory.from_dict(memory_dict)
# 4. 加载节点到记忆的映射
for node_id, mem_ids in data.get("node_to_memories", {}).items():
store.node_to_memories[node_id] = set(mem_ids)
# 5. 同步图中的边到 Memory.edges保证内存对象和图一致
try:
store._sync_memory_edges_from_graph()
except Exception:
logger.exception("同步图边到记忆.edges 失败")
logger.info(f"从字典加载图: {store.get_statistics()}")
return store
def _sync_memory_edges_from_graph(self) -> None:
"""
将 NetworkX 图中的边重建为 MemoryEdge 并注入到对应的 Memory.edges 列表中。
目的:当从持久化数据加载时,确保 memory_index 中的 Memory 对象的
edges 列表反映图中实际存在的边(避免只有图中存在而 memory.edges 为空的不同步情况)。
规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。
已存在的边(通过 edge.id 检查)将不会重复添加。
"""
# 构建快速查重索引memory_id -> set(edge_id)
existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()}
for u, v, data in self.graph.edges(data=True):
# 兼容旧数据edge_id 可能在 data 中,或叫 id
edge_id = data.get("edge_id") or data.get("id") or ""
edge_dict = {
"id": edge_id or "",
"source_id": u,
"target_id": v,
"relation": data.get("relation", ""),
"edge_type": data.get("edge_type", data.get("edge_type", "")),
"importance": data.get("importance", 0.5),
"metadata": data.get("metadata", {}),
"created_at": data.get("created_at", "1970-01-01T00:00:00"),
}
# 找到相关记忆(包含源或目标节点)
related_memory_ids = set()
if u in self.node_to_memories:
related_memory_ids.update(self.node_to_memories[u])
if v in self.node_to_memories:
related_memory_ids.update(self.node_to_memories[v])
for mid in related_memory_ids:
mem = self.memory_index.get(mid)
if mem is None:
continue
# 检查是否已存在
if edge_dict["id"] and edge_dict["id"] in existing_edges.get(mid, set()):
continue
try:
# 使用 MemoryEdge.from_dict 构建对象
mem_edge = MemoryEdge.from_dict(edge_dict)
except Exception:
# 兼容性:直接构造对象
mem_edge = MemoryEdge(
id=edge_dict["id"] or "",
source_id=edge_dict["source_id"],
target_id=edge_dict["target_id"],
relation=edge_dict["relation"],
edge_type=edge_dict["edge_type"],
importance=edge_dict.get("importance", 0.5),
metadata=edge_dict.get("metadata", {}),
)
mem.edges.append(mem_edge)
existing_edges.setdefault(mid, set()).add(mem_edge.id)
logger.info("已将图中的边同步到 Memory.edges保证 graph 与 memory 对象一致)")
def remove_memory(self, memory_id: str) -> bool:
"""
从图中删除指定记忆
Args:
memory_id: 要删除的记忆ID
Returns:
是否删除成功
"""
try:
# 1. 检查记忆是否存在
if memory_id not in self.memory_index:
logger.warning(f"记忆不存在,无法删除: {memory_id}")
return False
memory = self.memory_index[memory_id]
# 2. 从节点映射中移除此记忆
for node in memory.nodes:
if node.id in self.node_to_memories:
self.node_to_memories[node.id].discard(memory_id)
# 如果该节点不再属于任何记忆,从图中移除节点
if not self.node_to_memories[node.id]:
if self.graph.has_node(node.id):
self.graph.remove_node(node.id)
del self.node_to_memories[node.id]
# 3. 从记忆索引中移除
del self.memory_index[memory_id]
logger.info(f"成功删除记忆: {memory_id}")
return True
except Exception as e:
logger.error(f"删除记忆失败 {memory_id}: {e}", exc_info=True)
return False
def clear(self) -> None:
"""清空图(危险操作,仅用于测试)"""
self.graph.clear()
self.memory_index.clear()
self.node_to_memories.clear()
logger.warning("图存储已清空")

View File

@@ -0,0 +1,377 @@
"""
持久化管理:负责记忆图数据的保存和加载
"""
from __future__ import annotations
import asyncio
import json
from datetime import datetime
from pathlib import Path
import orjson
from src.common.logger import get_logger
from src.memory_graph.models import StagedMemory
from src.memory_graph.storage.graph_store import GraphStore
logger = get_logger(__name__)
class PersistenceManager:
"""
持久化管理器
负责:
1. 图数据的保存和加载
2. 定期自动保存
3. 备份管理
"""
def __init__(
self,
data_dir: Path,
graph_file_name: str = "memory_graph.json",
staged_file_name: str = "staged_memories.json",
auto_save_interval: int = 300, # 自动保存间隔(秒)
):
"""
初始化持久化管理器
Args:
data_dir: 数据存储目录
graph_file_name: 图数据文件名
staged_file_name: 临时记忆文件名
auto_save_interval: 自动保存间隔(秒)
"""
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.graph_file = self.data_dir / graph_file_name
self.staged_file = self.data_dir / staged_file_name
self.backup_dir = self.data_dir / "backups"
self.backup_dir.mkdir(parents=True, exist_ok=True)
self.auto_save_interval = auto_save_interval
self._auto_save_task: asyncio.Task | None = None
self._running = False
logger.info(f"初始化持久化管理器: data_dir={data_dir}")
async def save_graph_store(self, graph_store: GraphStore) -> None:
"""
保存图存储到文件
Args:
graph_store: 图存储对象
"""
try:
# 转换为字典
data = graph_store.to_dict()
# 添加元数据
data["metadata"] = {
"version": "0.1.0",
"saved_at": datetime.now().isoformat(),
"statistics": graph_store.get_statistics(),
}
# 使用 orjson 序列化(更快)
json_data = orjson.dumps(
data,
option=orjson.OPT_INDENT_2 | orjson.OPT_SERIALIZE_NUMPY,
)
# 原子写入(先写临时文件,再重命名)
temp_file = self.graph_file.with_suffix(".tmp")
temp_file.write_bytes(json_data)
temp_file.replace(self.graph_file)
logger.info(f"图数据已保存: {self.graph_file}, 大小: {len(json_data) / 1024:.2f} KB")
except Exception as e:
logger.error(f"保存图数据失败: {e}", exc_info=True)
raise
async def load_graph_store(self) -> GraphStore | None:
"""
从文件加载图存储
Returns:
GraphStore 对象,如果文件不存在则返回 None
"""
if not self.graph_file.exists():
logger.info("图数据文件不存在,返回空图")
return None
try:
# 读取文件
json_data = self.graph_file.read_bytes()
data = orjson.loads(json_data)
# 检查版本(未来可能需要数据迁移)
version = data.get("metadata", {}).get("version", "unknown")
logger.info(f"加载图数据: version={version}")
# 恢复图存储
graph_store = GraphStore.from_dict(data)
logger.info(f"图数据加载完成: {graph_store.get_statistics()}")
return graph_store
except Exception as e:
logger.error(f"加载图数据失败: {e}", exc_info=True)
# 尝试加载备份
return await self._load_from_backup()
async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None:
"""
保存临时记忆列表
Args:
staged_memories: 临时记忆列表
"""
try:
data = {
"metadata": {
"version": "0.1.0",
"saved_at": datetime.now().isoformat(),
"count": len(staged_memories),
},
"staged_memories": [sm.to_dict() for sm in staged_memories],
}
json_data = orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SERIALIZE_NUMPY)
temp_file = self.staged_file.with_suffix(".tmp")
temp_file.write_bytes(json_data)
temp_file.replace(self.staged_file)
logger.info(f"临时记忆已保存: {len(staged_memories)}")
except Exception as e:
logger.error(f"保存临时记忆失败: {e}", exc_info=True)
raise
async def load_staged_memories(self) -> list[StagedMemory]:
"""
加载临时记忆列表
Returns:
临时记忆列表
"""
if not self.staged_file.exists():
logger.info("临时记忆文件不存在,返回空列表")
return []
try:
json_data = self.staged_file.read_bytes()
data = orjson.loads(json_data)
staged_memories = [StagedMemory.from_dict(sm) for sm in data.get("staged_memories", [])]
logger.info(f"临时记忆加载完成: {len(staged_memories)}")
return staged_memories
except Exception as e:
logger.error(f"加载临时记忆失败: {e}", exc_info=True)
return []
async def create_backup(self) -> Path | None:
"""
创建当前数据的备份
Returns:
备份文件路径,如果失败则返回 None
"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = self.backup_dir / f"memory_graph_backup_{timestamp}.json"
if self.graph_file.exists():
# 复制图数据文件
import shutil
shutil.copy2(self.graph_file, backup_file)
# 清理旧备份只保留最近10个
await self._cleanup_old_backups(keep=10)
logger.info(f"备份创建成功: {backup_file}")
return backup_file
return None
except Exception as e:
logger.error(f"创建备份失败: {e}", exc_info=True)
return None
async def _load_from_backup(self) -> GraphStore | None:
"""从最新的备份加载数据"""
try:
# 查找最新的备份文件
backup_files = sorted(self.backup_dir.glob("memory_graph_backup_*.json"), reverse=True)
if not backup_files:
logger.warning("没有可用的备份文件")
return None
latest_backup = backup_files[0]
logger.warning(f"尝试从备份恢复: {latest_backup}")
json_data = latest_backup.read_bytes()
data = orjson.loads(json_data)
graph_store = GraphStore.from_dict(data)
logger.info(f"从备份恢复成功: {graph_store.get_statistics()}")
return graph_store
except Exception as e:
logger.error(f"从备份恢复失败: {e}", exc_info=True)
return None
async def _cleanup_old_backups(self, keep: int = 10) -> None:
"""
清理旧备份,只保留最近的几个
Args:
keep: 保留的备份数量
"""
try:
backup_files = sorted(self.backup_dir.glob("memory_graph_backup_*.json"), reverse=True)
# 删除超出数量的备份
for backup_file in backup_files[keep:]:
backup_file.unlink()
logger.debug(f"删除旧备份: {backup_file}")
except Exception as e:
logger.warning(f"清理旧备份失败: {e}")
async def start_auto_save(
self,
graph_store: GraphStore,
staged_memories_getter: callable | None = None,
) -> None:
"""
启动自动保存任务
Args:
graph_store: 图存储对象
staged_memories_getter: 获取临时记忆的回调函数
"""
if self._auto_save_task and not self._auto_save_task.done():
logger.warning("自动保存任务已在运行")
return
self._running = True
async def auto_save_loop():
logger.info(f"自动保存任务已启动,间隔: {self.auto_save_interval}")
while self._running:
try:
await asyncio.sleep(self.auto_save_interval)
if not self._running:
break
# 保存图数据
await self.save_graph_store(graph_store)
# 保存临时记忆(如果提供了获取函数)
if staged_memories_getter:
staged_memories = staged_memories_getter()
if staged_memories:
await self.save_staged_memories(staged_memories)
# 定期创建备份(每小时)
current_time = datetime.now()
if current_time.minute == 0: # 每个整点
await self.create_backup()
except Exception as e:
logger.error(f"自动保存失败: {e}", exc_info=True)
logger.info("自动保存任务已停止")
self._auto_save_task = asyncio.create_task(auto_save_loop())
def stop_auto_save(self) -> None:
"""停止自动保存任务"""
self._running = False
if self._auto_save_task:
self._auto_save_task.cancel()
logger.info("自动保存任务已取消")
async def export_to_json(self, output_file: Path, graph_store: GraphStore) -> None:
"""
导出图数据到指定的 JSON 文件(用于数据迁移或分析)
Args:
output_file: 输出文件路径
graph_store: 图存储对象
"""
try:
data = graph_store.to_dict()
data["metadata"] = {
"version": "0.1.0",
"exported_at": datetime.now().isoformat(),
"statistics": graph_store.get_statistics(),
}
# 使用标准 json 以获得更好的可读性
output_file.parent.mkdir(parents=True, exist_ok=True)
with output_file.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"图数据已导出: {output_file}")
except Exception as e:
logger.error(f"导出图数据失败: {e}", exc_info=True)
raise
async def import_from_json(self, input_file: Path) -> GraphStore | None:
"""
从 JSON 文件导入图数据
Args:
input_file: 输入文件路径
Returns:
GraphStore 对象
"""
try:
with input_file.open("r", encoding="utf-8") as f:
data = json.load(f)
graph_store = GraphStore.from_dict(data)
logger.info(f"图数据已导入: {graph_store.get_statistics()}")
return graph_store
except Exception as e:
logger.error(f"导入图数据失败: {e}", exc_info=True)
raise
def get_data_size(self) -> dict[str, int]:
"""
获取数据文件的大小信息
Returns:
文件大小字典(字节)
"""
sizes = {}
if self.graph_file.exists():
sizes["graph"] = self.graph_file.stat().st_size
if self.staged_file.exists():
sizes["staged"] = self.staged_file.stat().st_size
# 计算备份文件总大小
backup_size = sum(f.stat().st_size for f in self.backup_dir.glob("*.json"))
sizes["backups"] = backup_size
return sizes

View File

@@ -0,0 +1,452 @@
"""
向量存储层:基于 ChromaDB 的语义向量存储
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
from src.common.logger import get_logger
from src.memory_graph.models import MemoryNode, NodeType
logger = get_logger(__name__)
class VectorStore:
"""
向量存储封装类
负责:
1. 节点的语义向量存储和检索
2. 基于相似度的向量搜索
3. 节点去重时的相似节点查找
"""
def __init__(
self,
collection_name: str = "memory_nodes",
data_dir: Path | None = None,
embedding_function: Any | None = None,
):
"""
初始化向量存储
Args:
collection_name: ChromaDB 集合名称
data_dir: 数据存储目录
embedding_function: 嵌入函数如果为None则使用默认
"""
self.collection_name = collection_name
self.data_dir = data_dir or Path("data/memory_graph")
self.data_dir.mkdir(parents=True, exist_ok=True)
self.client = None
self.collection = None
self.embedding_function = embedding_function
logger.info(f"初始化向量存储: collection={collection_name}, dir={self.data_dir}")
async def initialize(self) -> None:
"""异步初始化 ChromaDB"""
try:
import chromadb
from chromadb.config import Settings
# 创建持久化客户端
self.client = chromadb.PersistentClient(
path=str(self.data_dir / "chroma"),
settings=Settings(
anonymized_telemetry=False,
allow_reset=True,
),
)
# 获取或创建集合
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Memory graph node embeddings"},
)
logger.info(f"ChromaDB 初始化完成,集合包含 {self.collection.count()} 个节点")
except Exception as e:
logger.error(f"初始化 ChromaDB 失败: {e}", exc_info=True)
raise
async def add_node(self, node: MemoryNode) -> None:
"""
添加节点到向量存储
Args:
node: 要添加的节点
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
if not node.has_embedding():
logger.warning(f"节点 {node.id} 没有 embedding跳过添加")
return
try:
# 准备元数据ChromaDB 只支持 str, int, float, bool
metadata = {
"content": node.content,
"node_type": node.node_type.value,
"created_at": node.created_at.isoformat(),
}
# 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items():
if isinstance(value, (list, dict)):
import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value
else:
metadata[key] = str(value)
self.collection.add(
ids=[node.id],
embeddings=[node.embedding.tolist()],
metadatas=[metadata],
documents=[node.content], # 文本内容用于检索
)
logger.debug(f"添加节点到向量存储: {node}")
except Exception as e:
logger.error(f"添加节点失败: {e}", exc_info=True)
raise
async def add_nodes_batch(self, nodes: list[MemoryNode]) -> None:
"""
批量添加节点
Args:
nodes: 节点列表
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
# 过滤出有 embedding 的节点
valid_nodes = [n for n in nodes if n.has_embedding()]
if not valid_nodes:
logger.warning("批量添加:没有有效的节点(缺少 embedding")
return
try:
# 准备元数据
import orjson
metadatas = []
for n in valid_nodes:
metadata = {
"content": n.content,
"node_type": n.node_type.value,
"created_at": n.created_at.isoformat(),
}
for key, value in n.metadata.items():
if isinstance(value, (list, dict)):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value # type: ignore
else:
metadata[key] = str(value)
metadatas.append(metadata)
self.collection.add(
ids=[n.id for n in valid_nodes],
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
metadatas=metadatas,
documents=[n.content for n in valid_nodes],
)
logger.info(f"批量添加 {len(valid_nodes)} 个节点到向量存储")
except Exception as e:
logger.error(f"批量添加节点失败: {e}", exc_info=True)
raise
async def search_similar_nodes(
self,
query_embedding: np.ndarray,
limit: int = 10,
node_types: list[NodeType] | None = None,
min_similarity: float = 0.0,
) -> list[tuple[str, float, dict[str, Any]]]:
"""
搜索相似节点
Args:
query_embedding: 查询向量
limit: 返回结果数量
node_types: 限制节点类型(可选)
min_similarity: 最小相似度阈值
Returns:
List of (node_id, similarity, metadata)
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
try:
# 构建 where 条件
where_filter = None
if node_types:
where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}}
# 执行查询
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=limit,
where=where_filter,
)
# 解析结果
import orjson
similar_nodes = []
# 修复:检查 ids 列表长度而不是直接判断真值(避免 numpy 数组歧义)
ids = results.get("ids")
if ids is not None and len(ids) > 0 and len(ids[0]) > 0:
distances = results.get("distances")
metadatas = results.get("metadatas")
for i, node_id in enumerate(ids[0]):
# ChromaDB 返回的是距离,需要转换为相似度
# 余弦距离: distance = 1 - similarity
distance = distances[0][i] if distances is not None and len(distances) > 0 else 0.0 # type: ignore
similarity = 1.0 - distance
if similarity >= min_similarity:
metadata = metadatas[0][i] if metadatas is not None and len(metadatas) > 0 else {} # type: ignore
# 解析 JSON 字符串回列表/字典
for key, value in list(metadata.items()):
if isinstance(value, str) and (value.startswith("[") or value.startswith("{")):
try:
metadata[key] = orjson.loads(value)
except Exception:
pass # 保持原值
similar_nodes.append((node_id, similarity, metadata))
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
return similar_nodes
except Exception as e:
logger.error(f"相似节点搜索失败: {e}", exc_info=True)
raise
async def search_with_multiple_queries(
self,
query_embeddings: list[np.ndarray],
query_weights: list[float] | None = None,
limit: int = 10,
node_types: list[NodeType] | None = None,
min_similarity: float = 0.0,
fusion_strategy: str = "weighted_max",
) -> list[tuple[str, float, dict[str, Any]]]:
"""
多查询融合搜索
使用多个查询向量进行搜索,然后融合结果。
这能解决单一查询向量无法同时关注多个关键概念的问题。
Args:
query_embeddings: 查询向量列表
query_weights: 每个查询的权重(可选,默认均等)
limit: 最终返回结果数量
node_types: 限制节点类型(可选)
min_similarity: 最小相似度阈值
fusion_strategy: 融合策略
- "weighted_max": 加权最大值(推荐)
- "weighted_sum": 加权求和
- "rrf": Reciprocal Rank Fusion
Returns:
融合后的节点列表 [(node_id, fused_score, metadata), ...]
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
if not query_embeddings:
return []
# 默认权重均等
if query_weights is None:
query_weights = [1.0 / len(query_embeddings)] * len(query_embeddings)
# 归一化权重
total_weight = sum(query_weights)
if total_weight > 0:
query_weights = [w / total_weight for w in query_weights]
try:
# 1. 对每个查询执行搜索
all_results: dict[str, dict[str, Any]] = {} # node_id -> {scores, metadata}
for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)):
# 搜索更多结果以提高融合质量
search_limit = limit * 3
results = await self.search_similar_nodes(
query_embedding=query_emb,
limit=search_limit,
node_types=node_types,
min_similarity=min_similarity,
)
# 记录每个结果
for rank, (node_id, similarity, metadata) in enumerate(results):
if node_id not in all_results:
all_results[node_id] = {
"scores": [],
"ranks": [],
"metadata": metadata,
}
all_results[node_id]["scores"].append((similarity, weight))
all_results[node_id]["ranks"].append((rank, weight))
# 2. 融合分数
fused_results = []
for node_id, data in all_results.items():
scores = data["scores"]
ranks = data["ranks"]
metadata = data["metadata"]
if fusion_strategy == "weighted_max":
# 加权最大值 + 出现次数奖励
max_weighted_score = max(score * weight for score, weight in scores)
appearance_bonus = len(scores) * 0.05 # 出现多次有奖励
fused_score = max_weighted_score + appearance_bonus
elif fusion_strategy == "weighted_sum":
# 加权求和(可能导致出现多次的结果分数过高)
fused_score = sum(score * weight for score, weight in scores)
elif fusion_strategy == "rrf":
# Reciprocal Rank Fusion
# RRF score = sum(weight / (rank + k))
k = 60 # RRF 常数
fused_score = sum(weight / (rank + k) for rank, weight in ranks)
else:
# 默认使用加权平均
fused_score = sum(score * weight for score, weight in scores) / len(scores)
fused_results.append((node_id, fused_score, metadata))
# 3. 排序并返回 Top-K
fused_results.sort(key=lambda x: x[1], reverse=True)
final_results = fused_results[:limit]
logger.info(
f"多查询融合搜索完成: {len(query_embeddings)} 个查询, "
f"融合后 {len(fused_results)} 个结果, 返回 {len(final_results)}"
)
return final_results
except Exception as e:
logger.error(f"多查询融合搜索失败: {e}", exc_info=True)
raise
async def get_node_by_id(self, node_id: str) -> dict[str, Any] | None:
"""
根据ID获取节点元数据
Args:
node_id: 节点ID
Returns:
节点元数据或 None
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
try:
result = self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
# 修复:直接检查 ids 列表是否非空(避免 numpy 数组的布尔值歧义)
if result is not None:
ids = result.get("ids")
if ids is not None and len(ids) > 0:
metadatas = result.get("metadatas")
embeddings = result.get("embeddings")
return {
"id": ids[0],
"metadata": metadatas[0] if metadatas is not None and len(metadatas) > 0 else {},
"embedding": np.array(embeddings[0]) if embeddings is not None and len(embeddings) > 0 and embeddings[0] is not None else None,
}
return None
except Exception as e:
logger.error(f"获取节点失败: {e}", exc_info=True)
return None
async def delete_node(self, node_id: str) -> None:
"""
删除节点
Args:
node_id: 节点ID
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
try:
self.collection.delete(ids=[node_id])
logger.debug(f"删除节点: {node_id}")
except Exception as e:
logger.error(f"删除节点失败: {e}", exc_info=True)
raise
async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None:
"""
更新节点的 embedding
Args:
node_id: 节点ID
embedding: 新的向量
"""
if not self.collection:
raise RuntimeError("向量存储未初始化")
try:
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
logger.debug(f"更新节点 embedding: {node_id}")
except Exception as e:
logger.error(f"更新节点 embedding 失败: {e}", exc_info=True)
raise
def get_total_count(self) -> int:
"""获取向量存储中的节点总数"""
if not self.collection:
return 0
return self.collection.count()
async def clear(self) -> None:
"""清空向量存储(危险操作,仅用于测试)"""
if not self.collection:
return
try:
# 删除并重新创建集合
self.client.delete_collection(self.collection_name)
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Memory graph node embeddings"},
)
logger.warning(f"向量存储已清空: {self.collection_name}")
except Exception as e:
logger.error(f"清空向量存储失败: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,7 @@
"""
记忆系统工具模块
"""
from src.memory_graph.tools.memory_tools import MemoryTools
__all__ = ["MemoryTools"]

View File

@@ -0,0 +1,868 @@
"""
LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑
"""
from __future__ import annotations
from typing import Any
from src.common.logger import get_logger
from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.models import Memory
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.persistence import PersistenceManager
from src.memory_graph.storage.vector_store import VectorStore
from src.memory_graph.utils.embeddings import EmbeddingGenerator
from src.memory_graph.utils.graph_expansion import expand_memories_with_semantic_filter
logger = get_logger(__name__)
class MemoryTools:
"""
记忆系统工具集
提供给 LLM 使用的工具接口:
1. create_memory: 创建新记忆
2. link_memories: 关联两个记忆
3. search_memories: 搜索记忆
"""
def __init__(
self,
vector_store: VectorStore,
graph_store: GraphStore,
persistence_manager: PersistenceManager,
embedding_generator: EmbeddingGenerator | None = None,
max_expand_depth: int = 1,
expand_semantic_threshold: float = 0.3,
):
"""
初始化工具集
Args:
vector_store: 向量存储
graph_store: 图存储
persistence_manager: 持久化管理器
embedding_generator: 嵌入生成器(可选)
max_expand_depth: 图扩展深度的默认值(从配置读取)
expand_semantic_threshold: 图扩展时语义相似度阈值(从配置读取)
"""
self.vector_store = vector_store
self.graph_store = graph_store
self.persistence_manager = persistence_manager
self._initialized = False
self.max_expand_depth = max_expand_depth # 保存配置的默认值
self.expand_semantic_threshold = expand_semantic_threshold # 保存配置的语义阈值
logger.info(f"MemoryTools 初始化: max_expand_depth={max_expand_depth}, expand_semantic_threshold={expand_semantic_threshold}")
# 初始化组件
self.extractor = MemoryExtractor()
self.builder = MemoryBuilder(
vector_store=vector_store,
graph_store=graph_store,
embedding_generator=embedding_generator,
)
async def _ensure_initialized(self):
"""确保向量存储已初始化"""
if not self._initialized:
await self.vector_store.initialize()
self._initialized = True
@staticmethod
def get_create_memory_schema() -> dict[str, Any]:
"""
获取 create_memory 工具的 JSON schema
Returns:
工具 schema 定义
"""
return {
"name": "create_memory",
"description": """创建一个新的记忆节点,记录对话中有价值的信息。
🎯 **核心原则**:主动记录、积极构建、丰富细节
✅ **优先创建记忆的场景**(鼓励记录):
1. **个人信息**:姓名、昵称、年龄、职业、身份、所在地、联系方式等
2. **兴趣爱好**:喜欢/不喜欢的事物、娱乐偏好、运动爱好、饮食口味等
3. **生活状态**:工作学习状态、生活习惯、作息时间、日常安排等
4. **经历事件**:正在做的事、完成的任务、参与的活动、遇到的问题等
5. **观点态度**:对事物的看法、价值观、情绪表达、评价意见等
6. **计划目标**:未来打算、学习计划、工作目标、待办事项等
7. **人际关系**:提到的朋友、家人、同事、认识的人等
8. **技能知识**:掌握的技能、学习的知识、专业领域、使用的工具等
9. **物品资源**:拥有的物品、使用的设备、喜欢的品牌等
10. **时间地点**:重要时间节点、常去的地点、活动场所等
⚠️ **暂不创建的情况**(仅限以下):
- 纯粹的招呼语(单纯的"你好""再见"
- 完全无意义的语气词(单纯的""""
- 明确的系统指令(如"切换模式""重启"
<EFBFBD> **记忆拆分建议**
- 一句话包含多个信息点 → 拆成多条记忆(更利于后续检索)
- 例如:"我最近在学Python和机器学习想找工作"
→ 拆成3条
1. "用户正在学习Python"(事件)
2. "用户正在学习机器学习"(事件)
3. "用户想找工作"(事件/目标)
📌 **记忆质量建议**
- 记录时尽量补充时间("今天""最近""昨天"等)
- 包含具体细节(越具体越好)
- 主体明确(优先使用"用户"或具体人名,避免""
记忆结构:主体 + 类型 + 主题 + 客体(可选)+ 属性(越详细越好)""",
"parameters": {
"type": "object",
"properties": {
"subject": {
"type": "string",
"description": "记忆的主体(谁的信息):\n- 对话中的用户统一使用'用户'\n- 提到的具体人物使用其名字(如'小明''张三'\n- 避免使用''''等代词",
},
"memory_type": {
"type": "string",
"enum": ["事件", "事实", "关系", "观点"],
"description": "选择最合适的记忆类型:\n\n【事件】时间相关的动作或发生的事(用'正在''完成了''参加'等动词)\n正在学习Python、完成了项目、参加会议、去旅行\n\n【事实】相对稳定的客观信息(用''''''等描述状态)\n 例:职业是工程师、住在北京、有一只猫、会说英语\n\n【观点】主观看法、喜好、态度(用'喜欢''认为''觉得'等)\n喜欢Python、认为AI很重要、觉得累、讨厌加班\n\n【关系】人与人之间的关系\n 例:认识了朋友、是同事、家人关系",
},
"topic": {
"type": "string",
"description": "记忆的核心内容(做什么/是什么/关于什么):\n- 尽量具体明确('学习Python编程' 优于 '学习'\n- 包含关键动词或核心概念\n- 可以包含时间状态('正在学习''已完成''计划做'",
},
"object": {
"type": "string",
"description": "可选:记忆涉及的对象或目标:\n- 事件的对象(学习的是什么、购买的是什么)\n- 观点的对象(喜欢的是什么、讨厌的是什么)\n- 可以留空如果topic已经足够完整",
},
"attributes": {
"type": "object",
"description": "记忆的详细属性(建议尽量填写,越详细越好):",
"properties": {
"时间": {
"type": "string",
"description": "时间信息(强烈建议填写):\n- 具体日期:'2025-11-05''2025年11月'\n- 相对时间:'今天''昨天''上周''最近''3天前'\n- 时间段:'今天下午''上个月''这学期'",
},
"地点": {
"type": "string",
"description": "地点信息(如涉及):\n- 具体地址、城市名、国家\n- 场所类型:'在家''公司''学校''咖啡店'"
},
"原因": {
"type": "string",
"description": "为什么这样做/这样想(如明确提到)"
},
"方式": {
"type": "string",
"description": "怎么做的/通过什么方式(如明确提到)"
},
"结果": {
"type": "string",
"description": "结果如何/产生什么影响(如明确提到)"
},
"状态": {
"type": "string",
"description": "当前进展:'进行中''已完成''计划中''暂停'"
},
"程度": {
"type": "string",
"description": "程度描述(如'非常''比较''有点''不太'"
},
},
"additionalProperties": True,
},
"importance": {
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
"description": "重要性评分默认0.5日常对话建议0.5-0.7\n\n0.3-0.4: 次要细节(偶然提及的琐事)\n0.5-0.6: 日常信息(一般性的分享、普通爱好)← 推荐默认值\n0.7-0.8: 重要信息(明确的偏好、重要计划、核心爱好)\n0.9-1.0: 关键信息(身份信息、重大决定、强烈情感)\n\n💡 建议日常对话中大部分记忆使用0.5-0.6,除非用户特别强调",
},
},
"required": ["subject", "memory_type", "topic"],
},
}
@staticmethod
def get_link_memories_schema() -> dict[str, Any]:
"""
获取 link_memories 工具的 JSON schema
Returns:
工具 schema 定义
"""
return {
"name": "link_memories",
"description": """手动关联两个已存在的记忆。
⚠️ 使用建议:
- 系统会自动发现记忆间的关联关系,通常不需要手动调用此工具
- 仅在以下情况使用:
1. 用户明确指出两个记忆之间的关系
2. 发现明显的因果关系但系统未自动关联
3. 需要建立特殊的引用关系
关系类型说明:
- 导致A事件/行为导致B事件/结果(因果关系)
- 引用A记忆引用/基于B记忆知识关联
- 相似A和B描述相似的内容主题相似
- 相反A和B表达相反的观点对比关系
- 关联A和B存在一般性关联其他关系""",
"parameters": {
"type": "object",
"properties": {
"source_memory_description": {
"type": "string",
"description": "源记忆的关键描述(用于搜索定位,需要足够具体)",
},
"target_memory_description": {
"type": "string",
"description": "目标记忆的关键描述(用于搜索定位,需要足够具体)",
},
"relation_type": {
"type": "string",
"enum": ["导致", "引用", "相似", "相反", "关联"],
"description": "关系类型从上述5种类型中选择最合适的",
},
"importance": {
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
"description": "关系的重要性0.0-1.0\n- 0.5-0.6: 一般关联\n- 0.7-0.8: 重要关联\n- 0.9-1.0: 关键关联\n默认0.6",
},
},
"required": [
"source_memory_description",
"target_memory_description",
"relation_type",
],
},
}
@staticmethod
def get_search_memories_schema() -> dict[str, Any]:
"""
获取 search_memories 工具的 JSON schema
Returns:
工具 schema 定义
"""
return {
"name": "search_memories",
"description": """搜索相关的记忆,用于回忆和查找历史信息。
使用场景:
- 用户询问之前的对话内容
- 需要回忆用户的个人信息、偏好、经历
- 查找相关的历史事件或观点
- 基于上下文补充信息
搜索特性:
- 语义搜索:基于内容相似度匹配
- 图遍历:自动扩展相关联的记忆
- 时间过滤:按时间范围筛选
- 类型过滤:按记忆类型筛选""",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询(用自然语言描述要查找的内容,如'用户的职业''最近的项目''Python相关的记忆'",
},
"memory_types": {
"type": "array",
"items": {
"type": "string",
"enum": ["事件", "事实", "关系", "观点"],
},
"description": "记忆类型过滤(可选,留空表示搜索所有类型)",
},
"time_range": {
"type": "object",
"properties": {
"start": {
"type": "string",
"description": "开始时间(如'3天前''上周''2025-11-01'",
},
"end": {
"type": "string",
"description": "结束时间(如'今天''现在''2025-11-05'",
},
},
"description": "时间范围(可选,用于查找特定时间段的记忆)",
},
"top_k": {
"type": "integer",
"minimum": 1,
"maximum": 50,
"description": "返回结果数量1-50默认10。根据需求调整\n- 快速查找3-5条\n- 一般搜索10条\n- 全面了解20-30条",
},
"expand_depth": {
"type": "integer",
"minimum": 0,
"maximum": 3,
"description": "图扩展深度0-3默认1\n- 0: 仅返回直接匹配的记忆\n- 1: 包含一度相关的记忆(推荐)\n- 2-3: 包含更多间接相关的记忆(用于深度探索)",
},
},
"required": ["query"],
},
}
async def create_memory(self, **params) -> dict[str, Any]:
"""
执行 create_memory 工具
Args:
**params: 工具参数
Returns:
执行结果
"""
try:
logger.info(f"创建记忆: {params.get('subject')} - {params.get('topic')}")
# 0. 确保初始化
await self._ensure_initialized()
# 1. 提取参数
extracted = self.extractor.extract_from_tool_params(params)
# 2. 构建记忆
memory = await self.builder.build_memory(extracted)
# 3. 添加到存储(暂存状态)
await self._add_memory_to_stores(memory)
# 4. 保存到磁盘
await self.persistence_manager.save_graph_store(self.graph_store)
logger.info(f"记忆创建成功: {memory.id}")
return {
"success": True,
"memory_id": memory.id,
"message": f"记忆已创建: {extracted['subject']} - {extracted['topic']}",
"nodes_count": len(memory.nodes),
"edges_count": len(memory.edges),
}
except Exception as e:
logger.error(f"记忆创建失败: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"message": "记忆创建失败",
}
async def link_memories(self, **params) -> dict[str, Any]:
"""
执行 link_memories 工具
Args:
**params: 工具参数
Returns:
执行结果
"""
try:
logger.info(
f"关联记忆: {params.get('source_memory_description')} -> "
f"{params.get('target_memory_description')}"
)
# 1. 提取参数
extracted = self.extractor.extract_link_params(params)
# 2. 查找源记忆和目标记忆
source_memory = await self._find_memory_by_description(
extracted["source_description"]
)
target_memory = await self._find_memory_by_description(
extracted["target_description"]
)
if not source_memory:
return {
"success": False,
"error": "找不到源记忆",
"message": f"未找到匹配的源记忆: {extracted['source_description']}",
}
if not target_memory:
return {
"success": False,
"error": "找不到目标记忆",
"message": f"未找到匹配的目标记忆: {extracted['target_description']}",
}
# 3. 创建关联边
edge = await self.builder.link_memories(
source_memory=source_memory,
target_memory=target_memory,
relation_type=extracted["relation_type"],
importance=extracted["importance"],
)
# 4. 添加边到图存储
self.graph_store.graph.add_edge(
edge.source_id,
edge.target_id,
relation=edge.relation,
edge_type=edge.edge_type.value,
importance=edge.importance,
**edge.metadata
)
# 5. 保存
await self.persistence_manager.save_graph_store(self.graph_store)
logger.info(f"记忆关联成功: {source_memory.id} -> {target_memory.id}")
return {
"success": True,
"message": f"记忆已关联: {extracted['relation_type']}",
"source_memory_id": source_memory.id,
"target_memory_id": target_memory.id,
"relation_type": extracted["relation_type"],
}
except Exception as e:
logger.error(f"记忆关联失败: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"message": "记忆关联失败",
}
async def search_memories(self, **params) -> dict[str, Any]:
"""
执行 search_memories 工具
使用多策略检索优化:
1. 查询分解(识别主要实体和概念)
2. 多查询并行检索
3. 结果融合和重排
Args:
**params: 工具参数
- query: 查询字符串
- top_k: 返回结果数默认10
- expand_depth: 扩展深度(暂未使用)
- use_multi_query: 是否使用多查询策略默认True
- context: 查询上下文(可选)
Returns:
搜索结果
"""
try:
query = params.get("query", "")
top_k = params.get("top_k", 10)
# 使用配置中的默认值而不是硬编码的 1
expand_depth = params.get("expand_depth", self.max_expand_depth)
use_multi_query = params.get("use_multi_query", True)
context = params.get("context", None)
logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth}, multi_query={use_multi_query})")
# 0. 确保初始化
await self._ensure_initialized()
# 1. 根据策略选择检索方式
if use_multi_query:
# 多查询策略
similar_nodes = await self._multi_query_search(query, top_k, context)
else:
# 传统单查询策略
similar_nodes = await self._single_query_search(query, top_k)
# 2. 提取初始记忆ID来自向量搜索
initial_memory_ids = set()
memory_scores = {} # 记录每个记忆的初始分数
for node_id, similarity, metadata in similar_nodes:
if "memory_ids" in metadata:
ids = metadata["memory_ids"]
# 确保是列表
if isinstance(ids, str):
import orjson
try:
ids = orjson.loads(ids)
except Exception:
ids = [ids]
if isinstance(ids, list):
for mem_id in ids:
initial_memory_ids.add(mem_id)
# 记录最高分数
if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
memory_scores[mem_id] = similarity
# 3. 图扩展如果启用且有expand_depth
expanded_memory_scores = {}
if expand_depth > 0 and initial_memory_ids:
logger.info(f"开始图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}")
# 获取查询的embedding用于语义过滤
if self.builder.embedding_generator:
try:
query_embedding = await self.builder.embedding_generator.generate(query)
# 使用共享的图扩展工具函数
expanded_results = await expand_memories_with_semantic_filter(
graph_store=self.graph_store,
vector_store=self.vector_store,
initial_memory_ids=list(initial_memory_ids),
query_embedding=query_embedding,
max_depth=expand_depth,
semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值
max_expanded=top_k * 2
)
# 合并扩展结果
expanded_memory_scores.update(dict(expanded_results))
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
except Exception as e:
logger.warning(f"图扩展失败: {e}")
# 4. 合并初始记忆和扩展记忆
all_memory_ids = set(initial_memory_ids) | set(expanded_memory_scores.keys())
# 计算最终分数:初始记忆保持原分数,扩展记忆使用扩展分数
final_scores = {}
for mem_id in all_memory_ids:
if mem_id in memory_scores:
# 初始记忆:使用向量相似度分数
final_scores[mem_id] = memory_scores[mem_id]
elif mem_id in expanded_memory_scores:
# 扩展记忆:使用图扩展分数(稍微降权)
final_scores[mem_id] = expanded_memory_scores[mem_id] * 0.8
# 按分数排序
sorted_memory_ids = sorted(
final_scores.keys(),
key=lambda x: final_scores[x],
reverse=True
)[:top_k * 2] # 取2倍数量用于后续过滤
# 5. 获取完整记忆并进行最终排序
memories_with_scores = []
for memory_id in sorted_memory_ids:
memory = self.graph_store.get_memory_by_id(memory_id)
if memory:
# 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%)
similarity_score = final_scores[memory_id]
importance_score = memory.importance
# 计算时效性分数(最近的记忆得分更高)
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
# 确保 memory.created_at 有时区信息
if memory.created_at.tzinfo is None:
memory_time = memory.created_at.replace(tzinfo=timezone.utc)
else:
memory_time = memory.created_at
age_days = (now - memory_time).total_seconds() / 86400
recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期
# 综合分数
final_score = (
similarity_score * 0.6 +
importance_score * 0.3 +
recency_score * 0.1
)
memories_with_scores.append((memory, final_score))
# 按综合分数排序
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
memories = [mem for mem, _ in memories_with_scores[:top_k]]
# 6. 格式化结果
results = []
for memory in memories:
result = {
"memory_id": memory.id,
"importance": memory.importance,
"created_at": memory.created_at.isoformat(),
"summary": self._summarize_memory(memory),
}
results.append(result)
logger.info(
f"搜索完成: 初始{len(initial_memory_ids)}个 → "
f"扩展{len(expanded_memory_scores)}个 → "
f"最终返回{len(results)}条记忆"
)
return {
"success": True,
"results": results,
"total": len(results),
"query": query,
"strategy": "multi_query" if use_multi_query else "single_query",
"expanded_count": len(expanded_memory_scores),
"expand_depth": expand_depth,
}
except Exception as e:
logger.error(f"记忆搜索失败: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"message": "记忆搜索失败",
"results": [],
}
async def _generate_multi_queries_simple(
self, query: str, context: dict[str, Any] | None = None
) -> list[tuple[str, float]]:
"""
简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
让小模型直接生成3-5个不同角度的查询语句。
"""
try:
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.multi_query"
)
# 获取上下文信息
participants = context.get("participants", []) if context else []
chat_history = context.get("chat_history", "") if context else ""
sender = context.get("sender", "") if context else ""
# 处理聊天历史提取最近5条左右的对话
recent_chat = ""
if chat_history:
lines = chat_history.strip().split("\n")
# 取最近5条消息
recent_lines = lines[-5:] if len(lines) > 5 else lines
recent_chat = "\n".join(recent_lines)
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句JSON格式
**当前查询:** {query}
**发送者:** {sender if sender else '未知'}
**参与者:** {', '.join(participants) if participants else ''}
**最近聊天记录最近5条**
{recent_chat if recent_chat else '无聊天历史'}
**分析原则:**
1. **上下文理解**:根据聊天历史理解查询的真实意图
2. **指代消解**:识别并代换"""""""那个"等指代词
3. **话题关联**:结合最近讨论的话题生成更精准的查询
4. **查询分解**:对复杂查询分解为多个子查询
**生成策略:**
1. **完整查询**权重1.0):结合上下文的完整查询,包含指代消解
2. **关键概念查询**权重0.8):查询中的核心概念,特别是聊天中提到的实体
3. **话题扩展查询**权重0.7):基于最近聊天话题的相关查询
4. **动作/情感查询**权重0.6):如果涉及情感或动作,生成相关查询
**输出JSON格式**
```json
{{"queries": [{{"text": "查询语句", "weight": 1.0}}, {{"text": "查询语句", "weight": 0.8}}]}}
```
**示例:**
- 查询:"他怎么样了?" + 聊天中提到"小明生病了""小明身体恢复情况"
- 查询:"那个项目" + 聊天中讨论"记忆系统开发""记忆系统项目进展"
"""
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
import re
import orjson
response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip()
data = orjson.loads(response)
queries = data.get("queries", [])
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
for item in queries if item.get("text", "").strip()]
if result:
logger.info(f"生成查询: {[q for q, _ in result]}")
return result
except Exception as e:
logger.warning(f"多查询生成失败: {e}")
return [(query, 1.0)]
async def _single_query_search(
self, query: str, top_k: int
) -> list[tuple[str, float, dict[str, Any]]]:
"""
传统的单查询搜索
Args:
query: 查询字符串
top_k: 返回结果数
Returns:
相似节点列表 [(node_id, similarity, metadata), ...]
"""
# 生成查询嵌入
if self.builder.embedding_generator:
query_embedding = await self.builder.embedding_generator.generate(query)
else:
logger.warning("未配置嵌入生成器,使用随机向量")
import numpy as np
query_embedding = np.random.rand(384).astype(np.float32)
# 向量搜索
similar_nodes = await self.vector_store.search_similar_nodes(
query_embedding=query_embedding,
limit=top_k * 2, # 多取一些,后续过滤
)
return similar_nodes
async def _multi_query_search(
self, query: str, top_k: int, context: dict[str, Any] | None = None
) -> list[tuple[str, float, dict[str, Any]]]:
"""
多查询策略搜索(简化版)
直接使用小模型生成多个查询,无需复杂的分解和组合。
步骤:
1. 让小模型生成3-5个不同角度的查询
2. 为每个查询生成嵌入
3. 并行搜索并融合结果
Args:
query: 查询字符串
top_k: 返回结果数
context: 查询上下文
Returns:
融合后的相似节点列表
"""
try:
# 1. 使用小模型生成多个查询
multi_queries = await self._generate_multi_queries_simple(query, context)
logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}")
# 2. 生成所有查询的嵌入
if not self.builder.embedding_generator:
logger.warning("未配置嵌入生成器,回退到单查询模式")
return await self._single_query_search(query, top_k)
query_embeddings = []
query_weights = []
for sub_query, weight in multi_queries:
embedding = await self.builder.embedding_generator.generate(sub_query)
query_embeddings.append(embedding)
query_weights.append(weight)
# 3. 多查询融合搜索
similar_nodes = await self.vector_store.search_with_multiple_queries(
query_embeddings=query_embeddings,
query_weights=query_weights,
limit=top_k * 2, # 多取一些,后续过滤
fusion_strategy="weighted_max",
)
logger.info(f"多查询检索完成: {len(similar_nodes)} 个节点")
return similar_nodes
except Exception as e:
logger.warning(f"多查询搜索失败,回退到单查询模式: {e}", exc_info=True)
return await self._single_query_search(query, top_k)
async def _add_memory_to_stores(self, memory: Memory):
"""将记忆添加到存储"""
# 1. 添加到图存储
self.graph_store.add_memory(memory)
# 2. 添加有嵌入的节点到向量存储
for node in memory.nodes:
if node.embedding is not None:
await self.vector_store.add_node(node)
async def _find_memory_by_description(self, description: str) -> Memory | None:
"""
通过描述查找记忆
Args:
description: 记忆描述
Returns:
找到的记忆,如果没有则返回 None
"""
# 使用语义搜索查找最相关的记忆
if self.builder.embedding_generator:
query_embedding = await self.builder.embedding_generator.generate(description)
else:
import numpy as np
query_embedding = np.random.rand(384).astype(np.float32)
# 搜索相似节点
similar_nodes = await self.vector_store.search_similar_nodes(
query_embedding=query_embedding,
limit=5,
)
if not similar_nodes:
return None
# 获取最相似节点关联的记忆
_node_id, _similarity, metadata = similar_nodes[0]
if "memory_ids" not in metadata or not metadata["memory_ids"]:
return None
ids = metadata["memory_ids"]
# 确保是列表
if isinstance(ids, str):
import orjson
try:
ids = orjson.loads(ids)
except Exception as e:
logger.warning(f"JSON 解析失败: {e}")
ids = [ids]
if isinstance(ids, list) and ids:
memory_id = ids[0]
return self.graph_store.get_memory_by_id(memory_id)
return None
def _summarize_memory(self, memory: Memory) -> str:
"""生成记忆摘要"""
if not memory.metadata:
return "未知记忆"
subject = memory.metadata.get("subject", "")
topic = memory.metadata.get("topic", "")
memory_type = memory.metadata.get("memory_type", "")
return f"{subject} - {memory_type}: {topic}"
@staticmethod
def get_all_tool_schemas() -> list[dict[str, Any]]:
"""
获取所有工具的 schema
Returns:
工具 schema 列表
"""
return [
MemoryTools.get_create_memory_schema(),
MemoryTools.get_link_memories_schema(),
MemoryTools.get_search_memories_schema(),
]

View File

@@ -0,0 +1,9 @@
"""
工具模块
"""
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
from src.memory_graph.utils.similarity import cosine_similarity
from src.memory_graph.utils.time_parser import TimeParser
__all__ = ["EmbeddingGenerator", "TimeParser", "cosine_similarity", "get_embedding_generator"]

View File

@@ -0,0 +1,297 @@
"""
嵌入向量生成器:优先使用配置的 embedding APIsentence-transformers 作为备选
"""
from __future__ import annotations
import asyncio
import numpy as np
from src.common.logger import get_logger
logger = get_logger(__name__)
class EmbeddingGenerator:
"""
嵌入向量生成器
策略:
1. 优先使用配置的 embedding API通过 LLMRequest
2. 如果 API 不可用,回退到本地 sentence-transformers
3. 如果 sentence-transformers 未安装,使用随机向量(仅测试)
优点:
- 降低本地运算负载
- 即使未安装 sentence-transformers 也可正常运行
- 保持与现有系统的一致性
"""
def __init__(
self,
use_api: bool = True,
fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
):
"""
初始化嵌入生成器
Args:
use_api: 是否优先使用 API默认 True
fallback_model_name: 回退本地模型名称
"""
self.use_api = use_api
self.fallback_model_name = fallback_model_name
# API 相关
self._llm_request = None
self._api_available = False
self._api_dimension = None
# 本地模型相关
self._local_model = None
self._local_model_loaded = False
async def _initialize_api(self):
"""初始化 embedding API"""
if self._api_available:
return
try:
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
embedding_config = model_config.model_task_config.embedding
self._llm_request = LLMRequest(
model_set=embedding_config,
request_type="memory_graph.embedding"
)
# 获取嵌入维度
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
self._api_dimension = embedding_config.embedding_dimension
self._api_available = True
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
except Exception as e:
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
self._api_available = False
def _load_local_model(self):
"""延迟加载本地模型"""
if not self._local_model_loaded:
try:
from sentence_transformers import SentenceTransformer
logger.info(f"📦 加载本地嵌入模型: {self.fallback_model_name}")
self._local_model = SentenceTransformer(self.fallback_model_name)
self._local_model_loaded = True
logger.info("✅ 本地嵌入模型加载成功")
except ImportError:
logger.warning(
"⚠️ sentence-transformers 未安装,将使用随机向量(仅测试用)\n"
" 安装方法: pip install sentence-transformers"
)
self._local_model_loaded = False
except Exception as e:
logger.warning(f"⚠️ 本地模型加载失败: {e}")
self._local_model_loaded = False
async def generate(self, text: str) -> np.ndarray:
"""
生成单个文本的嵌入向量
策略:
1. 优先使用 API
2. API 失败则使用本地模型
3. 本地模型不可用则使用随机向量
Args:
text: 输入文本
Returns:
嵌入向量
"""
if not text or not text.strip():
logger.warning("输入文本为空,返回零向量")
dim = self._get_dimension()
return np.zeros(dim, dtype=np.float32)
try:
# 策略 1: 使用 API
if self.use_api:
embedding = await self._generate_with_api(text)
if embedding is not None:
return embedding
# 策略 2: 使用本地模型
embedding = await self._generate_with_local_model(text)
if embedding is not None:
return embedding
# 策略 3: 随机向量(仅测试)
logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...")
dim = self._get_dimension()
return np.random.rand(dim).astype(np.float32)
except Exception as e:
logger.error(f"❌ 嵌入生成失败: {e}", exc_info=True)
dim = self._get_dimension()
return np.random.rand(dim).astype(np.float32)
async def _generate_with_api(self, text: str) -> np.ndarray | None:
"""使用 API 生成嵌入"""
try:
# 初始化 API
if not self._api_available:
await self._initialize_api()
if not self._api_available or not self._llm_request:
return None
# 调用 API
embedding_list, model_name = await self._llm_request.get_embedding(text)
if embedding_list and len(embedding_list) > 0:
embedding = np.array(embedding_list, dtype=np.float32)
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
return embedding
return None
except Exception as e:
logger.debug(f"API 嵌入生成失败: {e}")
return None
async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
"""使用本地模型生成嵌入"""
try:
# 加载本地模型
if not self._local_model_loaded:
self._load_local_model()
if not self._local_model_loaded or not self._local_model:
return None
# 在线程池中运行
loop = asyncio.get_event_loop()
embedding = await loop.run_in_executor(None, self._encode_single_local, text)
logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}")
return embedding
except Exception as e:
logger.debug(f"本地模型嵌入生成失败: {e}")
return None
def _encode_single_local(self, text: str) -> np.ndarray:
"""同步编码单个文本(本地模型)"""
if self._local_model is None:
raise RuntimeError("本地模型未加载")
embedding = self._local_model.encode(text, convert_to_numpy=True) # type: ignore
return embedding.astype(np.float32)
def _get_dimension(self) -> int:
"""获取嵌入维度"""
# 优先使用 API 维度
if self._api_dimension:
return self._api_dimension
# 其次使用本地模型维度
if self._local_model_loaded and self._local_model:
try:
return self._local_model.get_sentence_embedding_dimension()
except Exception:
pass
# 默认 384sentence-transformers 常用维度)
return 384
async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
"""
批量生成嵌入向量
Args:
texts: 文本列表
Returns:
嵌入向量列表
"""
if not texts:
return []
try:
# 过滤空文本
valid_texts = [t for t in texts if t and t.strip()]
if not valid_texts:
logger.warning("所有文本为空,返回零向量列表")
dim = self._get_dimension()
return [np.zeros(dim, dtype=np.float32) for _ in texts]
# 使用 API 批量生成(如果可用)
if self.use_api:
results = await self._generate_batch_with_api(valid_texts)
if results:
return results
# 回退到逐个生成
results = []
for text in valid_texts:
embedding = await self.generate(text)
results.append(embedding)
logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本")
return results
except Exception as e:
logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True)
dim = self._get_dimension()
return [np.random.rand(dim).astype(np.float32) for _ in texts]
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
"""使用 API 批量生成"""
try:
# 对于大多数 API批量调用就是多次单独调用
# 这里保持简单,逐个调用
results = []
for text in texts:
embedding = await self._generate_with_api(text)
if embedding is None:
return None # 如果任何一个失败,返回 None 触发回退
results.append(embedding)
return results
except Exception as e:
logger.debug(f"API 批量生成失败: {e}")
return None
def get_embedding_dimension(self) -> int:
"""获取嵌入向量维度"""
return self._get_dimension()
# 全局单例
_global_generator: EmbeddingGenerator | None = None
def get_embedding_generator(
use_api: bool = True,
fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
) -> EmbeddingGenerator:
"""
获取全局嵌入生成器单例
Args:
use_api: 是否优先使用 API
fallback_model_name: 回退本地模型名称
Returns:
EmbeddingGenerator 实例
"""
global _global_generator
if _global_generator is None:
_global_generator = EmbeddingGenerator(
use_api=use_api,
fallback_model_name=fallback_model_name
)
return _global_generator

View File

@@ -0,0 +1,156 @@
"""
图扩展工具
提供记忆图的扩展算法,用于从初始记忆集合沿图结构扩展查找相关记忆
"""
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from src.memory_graph.utils.similarity import cosine_similarity
if TYPE_CHECKING:
import numpy as np
from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
logger = get_logger(__name__)
async def expand_memories_with_semantic_filter(
graph_store: "GraphStore",
vector_store: "VectorStore",
initial_memory_ids: list[str],
query_embedding: "np.ndarray",
max_depth: int = 2,
semantic_threshold: float = 0.5,
max_expanded: int = 20,
) -> list[tuple[str, float]]:
"""
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
这个方法解决了纯向量搜索可能遗漏的"语义相关且图结构相关"的记忆。
Args:
graph_store: 图存储
vector_store: 向量存储
initial_memory_ids: 初始记忆ID集合由向量搜索得到
query_embedding: 查询向量
max_depth: 最大扩展深度1-3推荐
semantic_threshold: 语义相似度阈值0.5推荐)
max_expanded: 最多扩展多少个记忆
Returns:
List[(memory_id, relevance_score)] 按相关度排序
"""
if not initial_memory_ids or query_embedding is None:
return []
try:
# 记录已访问的记忆,避免重复
visited_memories = set(initial_memory_ids)
# 记录扩展的记忆及其分数
expanded_memories: dict[str, float] = {}
# BFS扩展
current_level = initial_memory_ids
for depth in range(max_depth):
next_level = []
for memory_id in current_level:
memory = graph_store.get_memory_by_id(memory_id)
if not memory:
continue
# 遍历该记忆的所有节点
for node in memory.nodes:
if not node.has_embedding():
continue
# 获取邻居节点
try:
neighbors = list(graph_store.graph.neighbors(node.id))
except Exception:
continue
for neighbor_id in neighbors:
# 获取邻居节点信息
neighbor_node_data = graph_store.graph.nodes.get(neighbor_id)
if not neighbor_node_data:
continue
# 获取邻居节点的向量(从向量存储)
neighbor_vector_data = await vector_store.get_node_by_id(neighbor_id)
if not neighbor_vector_data or neighbor_vector_data.get("embedding") is None:
continue
neighbor_embedding = neighbor_vector_data["embedding"]
# 计算与查询的语义相似度
semantic_sim = cosine_similarity(query_embedding, neighbor_embedding)
# 获取边的权重
try:
edge_data = graph_store.graph.get_edge_data(node.id, neighbor_id)
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
except Exception:
edge_importance = 0.5
# 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%)
depth_decay = 1.0 / (depth + 1) # 深度越深,权重越低
relevance_score = semantic_sim * 0.7 + edge_importance * 0.2 + depth_decay * 0.1
# 只保留超过阈值的节点
if relevance_score < semantic_threshold:
continue
# 提取邻居节点所属的记忆
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
if isinstance(neighbor_memory_ids, str):
import json
try:
neighbor_memory_ids = json.loads(neighbor_memory_ids)
except Exception:
neighbor_memory_ids = [neighbor_memory_ids]
for neighbor_mem_id in neighbor_memory_ids:
if neighbor_mem_id in visited_memories:
continue
# 记录这个扩展记忆
if neighbor_mem_id not in expanded_memories:
expanded_memories[neighbor_mem_id] = relevance_score
visited_memories.add(neighbor_mem_id)
next_level.append(neighbor_mem_id)
else:
# 如果已存在,取最高分
expanded_memories[neighbor_mem_id] = max(
expanded_memories[neighbor_mem_id], relevance_score
)
# 如果没有新节点或已达到数量限制,提前终止
if not next_level or len(expanded_memories) >= max_expanded:
break
current_level = next_level[:max_expanded] # 限制每层的扩展数量
# 排序并返回
sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded]
logger.info(
f"图扩展完成: 初始{len(initial_memory_ids)}个 → "
f"扩展{len(sorted_results)}个新记忆 "
f"(深度={max_depth}, 阈值={semantic_threshold:.2f})"
)
return sorted_results
except Exception as e:
logger.error(f"语义图扩展失败: {e}", exc_info=True)
return []
__all__ = ["expand_memories_with_semantic_filter"]

View File

@@ -0,0 +1,320 @@
"""
记忆格式化工具
用于将记忆图系统的Memory对象转换为适合提示词的自然语言描述
"""
import logging
from datetime import datetime
from src.memory_graph.models import EdgeType, Memory, MemoryType, NodeType
logger = logging.getLogger(__name__)
def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> str:
"""
将记忆对象格式化为适合提示词的自然语言描述
根据记忆的图结构,构建完整的主谓宾描述,包含:
- 主语subject node
- 谓语/动作topic node
- 宾语/对象object node如果存在
- 属性信息attributes如时间、地点等
- 关系信息(记忆之间的关系)
Args:
memory: 记忆对象
include_metadata: 是否包含元数据(时间、重要性等)
Returns:
格式化后的自然语言描述
"""
try:
# 1. 获取主体节点(主语)
subject_node = memory.get_subject_node()
if not subject_node:
logger.warning(f"记忆 {memory.id} 缺少主体节点")
return "(记忆格式错误:缺少主体)"
subject_text = subject_node.content
# 2. 查找主题节点(谓语/动作)
topic_node = None
for edge in memory.edges:
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
topic_node = memory.get_node_by_id(edge.target_id)
break
if not topic_node:
logger.warning(f"记忆 {memory.id} 缺少主题节点")
return f"{subject_text}(记忆格式错误:缺少主题)"
topic_text = topic_node.content
# 3. 查找客体节点(宾语)和核心关系
object_node = None
core_relation = None
for edge in memory.edges:
if edge.edge_type == EdgeType.CORE_RELATION and edge.source_id == topic_node.id:
object_node = memory.get_node_by_id(edge.target_id)
core_relation = edge.relation if edge.relation else ""
break
# 4. 收集属性节点
attributes: dict[str, str] = {}
for edge in memory.edges:
if edge.edge_type == EdgeType.ATTRIBUTE:
# 查找属性节点和值节点
attr_node = memory.get_node_by_id(edge.target_id)
if attr_node and attr_node.node_type == NodeType.ATTRIBUTE:
# 查找这个属性的值
for value_edge in memory.edges:
if (value_edge.edge_type == EdgeType.ATTRIBUTE
and value_edge.source_id == attr_node.id):
value_node = memory.get_node_by_id(value_edge.target_id)
if value_node and value_node.node_type == NodeType.VALUE:
attributes[attr_node.content] = value_node.content
break
# 5. 构建自然语言描述
parts = []
# 主谓宾结构
if object_node is not None:
# 有完整的主谓宾
if core_relation:
parts.append(f"{subject_text}{topic_text}{core_relation}{object_node.content}")
else:
parts.append(f"{subject_text}{topic_text}{object_node.content}")
else:
# 只有主谓
parts.append(f"{subject_text}{topic_text}")
# 添加属性信息
if attributes:
attr_parts = []
# 优先显示时间和地点
if "时间" in attributes:
attr_parts.append(f"{attributes['时间']}")
if "地点" in attributes:
attr_parts.append(f"{attributes['地点']}")
# 其他属性
for key, value in attributes.items():
if key not in ["时间", "地点"]:
attr_parts.append(f"{key}{value}")
if attr_parts:
parts.append(f"{' '.join(attr_parts)}")
description = "".join(parts)
# 6. 添加元数据(可选)
if include_metadata:
metadata_parts = []
# 记忆类型
if memory.memory_type:
metadata_parts.append(f"类型:{memory.memory_type.value}")
# 重要性
if memory.importance >= 0.8:
metadata_parts.append("重要")
elif memory.importance >= 0.6:
metadata_parts.append("一般")
# 时间(如果没有在属性中)
if "时间" not in attributes:
time_str = _format_relative_time(memory.created_at)
if time_str:
metadata_parts.append(time_str)
if metadata_parts:
description += f" [{', '.join(metadata_parts)}]"
return description
except Exception as e:
logger.error(f"格式化记忆失败: {e}", exc_info=True)
return f"(记忆格式化错误: {str(e)[:50]}"
def format_memories_for_prompt(
memories: list[Memory],
max_count: int | None = None,
include_metadata: bool = False,
group_by_type: bool = False
) -> str:
"""
批量格式化多条记忆为提示词文本
Args:
memories: 记忆列表
max_count: 最大记忆数量(可选)
include_metadata: 是否包含元数据
group_by_type: 是否按类型分组
Returns:
格式化后的文本,包含标题和列表
"""
if not memories:
return ""
# 限制数量
if max_count:
memories = memories[:max_count]
# 按类型分组
if group_by_type:
type_groups: dict[MemoryType, list[Memory]] = {}
for memory in memories:
if memory.memory_type not in type_groups:
type_groups[memory.memory_type] = []
type_groups[memory.memory_type].append(memory)
# 构建分组文本
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
type_order = [MemoryType.FACT, MemoryType.EVENT, MemoryType.RELATION, MemoryType.OPINION]
for mem_type in type_order:
if mem_type in type_groups:
parts.append(f"#### {mem_type.value}")
for memory in type_groups[mem_type]:
desc = format_memory_for_prompt(memory, include_metadata)
parts.append(f"- {desc}")
parts.append("")
return "\n".join(parts)
else:
# 不分组,直接列出
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
for memory in memories:
# 获取类型标签
type_label = memory.memory_type.value if memory.memory_type else "未知"
# 格式化记忆内容
desc = format_memory_for_prompt(memory, include_metadata)
# 添加类型标签
parts.append(f"- **[{type_label}]** {desc}")
return "\n".join(parts)
def get_memory_type_label(memory_type: str) -> str:
"""
获取记忆类型的中文标签
Args:
memory_type: 记忆类型(可能是英文或中文)
Returns:
中文标签
"""
# 映射表
type_mapping = {
# 英文到中文
"event": "事件",
"fact": "事实",
"relation": "关系",
"opinion": "观点",
"preference": "偏好",
"emotion": "情绪",
"knowledge": "知识",
"skill": "技能",
"goal": "目标",
"experience": "经历",
"contextual": "情境",
# 中文(保持不变)
"事件": "事件",
"事实": "事实",
"关系": "关系",
"观点": "观点",
"偏好": "偏好",
"情绪": "情绪",
"知识": "知识",
"技能": "技能",
"目标": "目标",
"经历": "经历",
"情境": "情境",
}
# 转换为小写进行匹配
memory_type_lower = memory_type.lower() if memory_type else ""
return type_mapping.get(memory_type_lower, "未知")
def _format_relative_time(timestamp: datetime) -> str | None:
"""
格式化相对时间(如"2天前""刚才"
Args:
timestamp: 时间戳
Returns:
相对时间描述如果太久远则返回None
"""
try:
now = datetime.now()
delta = now - timestamp
if delta.total_seconds() < 60:
return "刚才"
elif delta.total_seconds() < 3600:
minutes = int(delta.total_seconds() / 60)
return f"{minutes}分钟前"
elif delta.total_seconds() < 86400:
hours = int(delta.total_seconds() / 3600)
return f"{hours}小时前"
elif delta.days < 7:
return f"{delta.days}天前"
elif delta.days < 30:
weeks = delta.days // 7
return f"{weeks}周前"
elif delta.days < 365:
months = delta.days // 30
return f"{months}个月前"
else:
# 超过一年不显示相对时间
return None
except Exception:
return None
def format_memory_summary(memory: Memory) -> str:
"""
生成记忆的简短摘要(用于日志和调试)
Args:
memory: 记忆对象
Returns:
简短摘要
"""
try:
subject_node = memory.get_subject_node()
subject_text = subject_node.content if subject_node else "?"
topic_text = "?"
for edge in memory.edges:
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
topic_node = memory.get_node_by_id(edge.target_id)
if topic_node:
topic_text = topic_node.content
break
return f"{subject_text} - {memory.memory_type.value if memory.memory_type else '?'}: {topic_text}"
except Exception:
return f"记忆 {memory.id[:8]}"
# 导出主要函数
__all__ = [
"format_memories_for_prompt",
"format_memory_for_prompt",
"format_memory_summary",
"get_memory_type_label",
]

View File

@@ -0,0 +1,50 @@
"""
相似度计算工具
提供统一的向量相似度计算函数
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
"""
计算两个向量的余弦相似度
Args:
vec1: 第一个向量
vec2: 第二个向量
Returns:
余弦相似度 (0.0-1.0)
"""
try:
import numpy as np
# 确保是numpy数组
if not isinstance(vec1, np.ndarray):
vec1 = np.array(vec1)
if not isinstance(vec2, np.ndarray):
vec2 = np.array(vec2)
# 归一化
vec1_norm = np.linalg.norm(vec1)
vec2_norm = np.linalg.norm(vec2)
if vec1_norm == 0 or vec2_norm == 0:
return 0.0
# 余弦相似度
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
# 确保在 [0, 1] 范围内(处理浮点误差)
return float(np.clip(similarity, 0.0, 1.0))
except Exception:
return 0.0
__all__ = ["cosine_similarity"]

View File

@@ -0,0 +1,493 @@
"""
时间解析器:将相对时间转换为绝对时间
支持的时间表达:
- 今天、明天、昨天、前天、后天
- X天前、X天后
- X小时前、X小时后
- 上周、上个月、去年
- 具体日期2025-11-05, 11月5日
- 时间点早上8点、下午3点、晚上9点
"""
from __future__ import annotations
import re
from datetime import datetime, timedelta
from src.common.logger import get_logger
logger = get_logger(__name__)
class TimeParser:
"""
时间解析器
负责将自然语言时间表达转换为标准化的绝对时间
"""
def __init__(self, reference_time: datetime | None = None):
"""
初始化时间解析器
Args:
reference_time: 参考时间(通常是当前时间)
"""
self.reference_time = reference_time or datetime.now()
def parse(self, time_str: str) -> datetime | None:
"""
解析时间字符串
Args:
time_str: 时间字符串
Returns:
标准化的datetime对象如果解析失败则返回None
"""
if not time_str or not isinstance(time_str, str):
return None
time_str = time_str.strip()
# 先尝试组合解析(如"今天下午"、"昨天晚上"
combined_result = self._parse_combined_time(time_str)
if combined_result:
logger.debug(f"时间解析: '{time_str}'{combined_result.isoformat()}")
return combined_result
# 尝试各种解析方法
parsers = [
self._parse_relative_day,
self._parse_days_ago,
self._parse_hours_ago,
self._parse_week_month_year,
self._parse_specific_date,
self._parse_time_of_day,
]
for parser in parsers:
try:
result = parser(time_str)
if result:
logger.debug(f"时间解析: '{time_str}'{result.isoformat()}")
return result
except Exception as e:
logger.debug(f"解析器 {parser.__name__} 失败: {e}")
continue
logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
return self.reference_time
def _parse_relative_day(self, time_str: str) -> datetime | None:
"""
解析相对日期:今天、明天、昨天、前天、后天
"""
relative_days = {
"今天": 0,
"今日": 0,
"明天": 1,
"明日": 1,
"昨天": -1,
"昨日": -1,
"前天": -2,
"前日": -2,
"后天": 2,
"后日": 2,
"大前天": -3,
"大后天": 3,
}
for keyword, days in relative_days.items():
if keyword in time_str:
result = self.reference_time + timedelta(days=days)
# 保留原有时间,只改变日期
return result.replace(hour=0, minute=0, second=0, microsecond=0)
return None
def _parse_days_ago(self, time_str: str) -> datetime | None:
"""
解析 X天前/X天后、X周前/X周后、X个月前/X个月后
"""
# 匹配3天前、5天后、一天前
pattern_day = r"([一二三四五六七八九十\d]+)天(前|后)"
match = re.search(pattern_day, time_str)
if match:
num_str, direction = match.groups()
num = self._chinese_num_to_int(num_str)
if direction == "":
num = -num
result = self.reference_time + timedelta(days=num)
return result.replace(hour=0, minute=0, second=0, microsecond=0)
# 匹配2周前、3周后、一周前
pattern_week = r"([一二三四五六七八九十\d]+)[个]?周(前|后)"
match = re.search(pattern_week, time_str)
if match:
num_str, direction = match.groups()
num = self._chinese_num_to_int(num_str)
if direction == "":
num = -num
result = self.reference_time + timedelta(weeks=num)
return result.replace(hour=0, minute=0, second=0, microsecond=0)
# 匹配2个月前、3月后
pattern_month = r"([一二三四五六七八九十\d]+)[个]?月(前|后)"
match = re.search(pattern_month, time_str)
if match:
num_str, direction = match.groups()
num = self._chinese_num_to_int(num_str)
if direction == "":
num = -num
# 简单处理1个月 = 30天
result = self.reference_time + timedelta(days=num * 30)
return result.replace(hour=0, minute=0, second=0, microsecond=0)
# 匹配2年前、3年后
pattern_year = r"([一二三四五六七八九十\d]+)[个]?年(前|后)"
match = re.search(pattern_year, time_str)
if match:
num_str, direction = match.groups()
num = self._chinese_num_to_int(num_str)
if direction == "":
num = -num
# 简单处理1年 = 365天
result = self.reference_time + timedelta(days=num * 365)
return result.replace(hour=0, minute=0, second=0, microsecond=0)
return None
def _parse_hours_ago(self, time_str: str) -> datetime | None:
"""
解析 X小时前/X小时后、X分钟前/X分钟后
"""
# 小时
pattern_hour = r"([一二三四五六七八九十\d]+)小?时(前|后)"
match = re.search(pattern_hour, time_str)
if match:
num_str, direction = match.groups()
num = self._chinese_num_to_int(num_str)
if direction == "":
num = -num
return self.reference_time + timedelta(hours=num)
# 分钟
pattern_minute = r"([一二三四五六七八九十\d]+)分钟(前|后)"
match = re.search(pattern_minute, time_str)
if match:
num_str, direction = match.groups()
num = self._chinese_num_to_int(num_str)
if direction == "":
num = -num
return self.reference_time + timedelta(minutes=num)
return None
def _parse_week_month_year(self, time_str: str) -> datetime | None:
"""
解析:上周、上个月、去年、本周、本月、今年
"""
now = self.reference_time
if "上周" in time_str or "上星期" in time_str:
return now - timedelta(days=7)
if "上个月" in time_str or "上月" in time_str:
# 简单处理减30天
return now - timedelta(days=30)
if "去年" in time_str or "上年" in time_str:
return now.replace(year=now.year - 1)
if "本周" in time_str or "这周" in time_str:
# 返回本周一
return now - timedelta(days=now.weekday())
if "本月" in time_str or "这个月" in time_str:
return now.replace(day=1)
if "今年" in time_str or "这年" in time_str:
return now.replace(month=1, day=1)
return None
def _parse_specific_date(self, time_str: str) -> datetime | None:
"""
解析具体日期:
- 2025-11-05
- 2025/11/05
- 11月5日
- 11-05
"""
# ISO 格式2025-11-05
pattern_iso = r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})"
match = re.search(pattern_iso, time_str)
if match:
year, month, day = map(int, match.groups())
return datetime(year, month, day)
# 中文格式11月5日、11月5号
pattern_cn = r"(\d{1,2})月(\d{1,2})[日号]"
match = re.search(pattern_cn, time_str)
if match:
month, day = map(int, match.groups())
# 使用参考时间的年份
year = self.reference_time.year
return datetime(year, month, day)
# 短格式11-05使用当前年份
pattern_short = r"(\d{1,2})[-/](\d{1,2})"
match = re.search(pattern_short, time_str)
if match:
month, day = map(int, match.groups())
year = self.reference_time.year
return datetime(year, month, day)
return None
def _parse_time_of_day(self, time_str: str) -> datetime | None:
"""
解析一天中的时间:
- 早上、上午、中午、下午、晚上、深夜
- 早上8点、下午3点
- 8点、15点
"""
now = self.reference_time
result = now.replace(minute=0, second=0, microsecond=0)
# 时间段映射
time_periods = {
"早上": 8,
"早晨": 8,
"上午": 10,
"中午": 12,
"下午": 15,
"傍晚": 18,
"晚上": 20,
"深夜": 23,
"凌晨": 2,
}
# 先检查是否有具体时间点早上8点、下午3点
for period in time_periods.keys():
pattern = rf"{period}(\d{{1,2}})点?"
match = re.search(pattern, time_str)
if match:
hour = int(match.group(1))
# 下午时间需要+12
if period in ["下午", "晚上"] and hour < 12:
hour += 12
return result.replace(hour=hour)
# 检查时间段关键词
for period, hour in time_periods.items():
if period in time_str:
return result.replace(hour=hour)
# 直接的时间点8点、15点
pattern = r"(\d{1,2})点"
match = re.search(pattern, time_str)
if match:
hour = int(match.group(1))
return result.replace(hour=hour)
return None
def _parse_combined_time(self, time_str: str) -> datetime | None:
"""
解析组合时间表达:今天下午、昨天晚上、明天早上
"""
# 先解析日期部分
date_result = None
# 相对日期关键词
relative_days = {
"今天": 0, "今日": 0,
"明天": 1, "明日": 1,
"昨天": -1, "昨日": -1,
"前天": -2, "前日": -2,
"后天": 2, "后日": 2,
"大前天": -3, "大后天": 3,
}
for keyword, days in relative_days.items():
if keyword in time_str:
date_result = self.reference_time + timedelta(days=days)
date_result = date_result.replace(hour=0, minute=0, second=0, microsecond=0)
break
if not date_result:
return None
# 再解析时间段部分
time_periods = {
"早上": 8, "早晨": 8,
"上午": 10,
"中午": 12,
"下午": 15,
"傍晚": 18,
"晚上": 20,
"深夜": 23,
"凌晨": 2,
}
for period, hour in time_periods.items():
if period in time_str:
# 检查是否有具体时间点
pattern = rf"{period}(\d{{1,2}})点?"
match = re.search(pattern, time_str)
if match:
hour = int(match.group(1))
# 下午时间需要+12
if period in ["下午", "晚上"] and hour < 12:
hour += 12
return date_result.replace(hour=hour)
# 如果没有时间段返回日期默认0点
return date_result
def _chinese_num_to_int(self, num_str: str) -> int:
"""
将中文数字转换为阿拉伯数字
Args:
num_str: 中文数字字符串(如:"""""3"
Returns:
整数
"""
# 如果已经是数字,直接返回
if num_str.isdigit():
return int(num_str)
# 中文数字映射
chinese_nums = {
"": 1,
"": 2,
"": 3,
"": 4,
"": 5,
"": 6,
"": 7,
"": 8,
"": 9,
"": 10,
"": 0,
}
if num_str in chinese_nums:
return chinese_nums[num_str]
# 处理 "十X" 的情况(如"十五"=15
if num_str.startswith(""):
if len(num_str) == 1:
return 10
return 10 + chinese_nums.get(num_str[1], 0)
# 处理 "X十" 的情况(如"三十"=30
if "" in num_str:
parts = num_str.split("")
tens = chinese_nums.get(parts[0], 1) * 10
ones = chinese_nums.get(parts[1], 0) if len(parts) > 1 and parts[1] else 0
return tens + ones
# 默认返回1
return 1
def format_time(self, dt: datetime, format_type: str = "iso") -> str:
"""
格式化时间
Args:
dt: datetime对象
format_type: 格式类型 ("iso", "cn", "relative")
Returns:
格式化的时间字符串
"""
if format_type == "iso":
return dt.isoformat()
elif format_type == "cn":
return dt.strftime("%Y年%m月%d%H:%M:%S")
elif format_type == "relative":
# 相对时间表达
diff = self.reference_time - dt
days = diff.days
if days == 0:
hours = diff.seconds // 3600
if hours == 0:
minutes = diff.seconds // 60
return f"{minutes}分钟前" if minutes > 0 else "刚刚"
return f"{hours}小时前"
elif days == 1:
return "昨天"
elif days == 2:
return "前天"
elif days < 7:
return f"{days}天前"
elif days < 30:
weeks = days // 7
return f"{weeks}周前"
elif days < 365:
months = days // 30
return f"{months}个月前"
else:
years = days // 365
return f"{years}年前"
return str(dt)
def parse_time_range(self, time_str: str) -> tuple[datetime | None, datetime | None]:
"""
解析时间范围最近一周、最近3天
Args:
time_str: 时间范围字符串
Returns:
(start_time, end_time)
"""
pattern = r"最近(\d+)(天|周|月|年)"
match = re.search(pattern, time_str)
if match:
num, unit = match.groups()
num = int(num)
unit_map = {"": "days", "": "weeks", "": "days", "": "days"}
if unit == "":
num *= 7
elif unit == "":
num *= 30
elif unit == "":
num *= 365
end_time = self.reference_time
start_time = end_time - timedelta(**{unit_map[unit]: num})
return (start_time, end_time)
return (None, None)

View File

@@ -7,7 +7,7 @@
"""
import atexit
import json
import orjson
import os
import threading
from typing import Any, ClassVar
@@ -100,10 +100,10 @@ class PluginStorage:
if os.path.exists(self.file_path):
with open(self.file_path, encoding="utf-8") as f:
content = f.read()
self._data = json.loads(content) if content else {}
self._data = orjson.loads(content) if content else {}
else:
self._data = {}
except (json.JSONDecodeError, Exception) as e:
except (orjson.JSONDecodeError, Exception) as e:
logger.warning(f"'{self.file_path}' 加载数据失败: {e},将初始化为空数据。")
self._data = {}
@@ -125,7 +125,7 @@ class PluginStorage:
try:
with open(self.file_path, "w", encoding="utf-8") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8'))
self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e:

View File

@@ -5,7 +5,7 @@ MCP Client Manager
"""
import asyncio
import json
import orjson
import shutil
from pathlib import Path
from typing import Any
@@ -89,7 +89,7 @@ class MCPClientManager:
try:
with open(self.config_path, encoding="utf-8") as f:
config_data = json.load(f)
config_data = orjson.loads(f.read())
servers = {}
mcp_servers = config_data.get("mcpServers", {})
@@ -106,7 +106,7 @@ class MCPClientManager:
logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置")
return servers
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件失败: {e}")
return {}
except Exception as e:

View File

@@ -0,0 +1,414 @@
"""
流式工具历史记录管理器
用于在聊天流级别管理工具调用历史,支持智能缓存和上下文感知
"""
import time
from typing import Any, Optional
from dataclasses import dataclass, asdict, field
import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("stream_tool_history")
@dataclass
class ToolCallRecord:
"""工具调用记录"""
tool_name: str
args: dict[str, Any]
result: Optional[dict[str, Any]] = None
status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息
def __post_init__(self):
"""后处理:生成结果预览"""
if self.result and not self.result_preview:
content = self.result.get("content", "")
if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)):
try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
except Exception:
self.result_preview = str(content)[:500] + "..."
else:
self.result_preview = str(content)[:500] + "..."
class StreamToolHistoryManager:
"""流式工具历史记录管理器
提供以下功能:
1. 工具调用历史的持久化管理
2. 智能缓存集成和结果去重
3. 上下文感知的历史记录检索
4. 性能监控和统计
"""
def __init__(self, chat_id: str, max_history: int = 20, enable_memory_cache: bool = True):
"""初始化历史记录管理器
Args:
chat_id: 聊天ID用于隔离不同聊天流的历史记录
max_history: 最大历史记录数量
enable_memory_cache: 是否启用内存缓存
"""
self.chat_id = chat_id
self.max_history = max_history
self.enable_memory_cache = enable_memory_cache
# 内存中的历史记录,按时间顺序排列
self._history: list[ToolCallRecord] = []
# 性能统计
self._stats = {
"total_calls": 0,
"cache_hits": 0,
"cache_misses": 0,
"total_execution_time": 0.0,
"average_execution_time": 0.0,
}
logger.info(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}")
async def add_tool_call(self, record: ToolCallRecord) -> None:
"""添加工具调用记录
Args:
record: 工具调用记录
"""
# 维护历史记录大小
if len(self._history) >= self.max_history:
# 移除最旧的记录
removed_record = self._history.pop(0)
logger.debug(f"[{self.chat_id}] 移除旧记录: {removed_record.tool_name}")
# 添加新记录
self._history.append(record)
# 更新统计
self._stats["total_calls"] += 1
if record.cache_hit:
self._stats["cache_hits"] += 1
else:
self._stats["cache_misses"] += 1
if record.execution_time is not None:
self._stats["total_execution_time"] += record.execution_time
self._stats["average_execution_time"] = self._stats["total_execution_time"] / self._stats["total_calls"]
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
"""从缓存或历史记录中获取结果
Args:
tool_name: 工具名称
args: 工具参数
Returns:
缓存的结果如果不存在则返回None
"""
# 首先检查内存中的历史记录
if self.enable_memory_cache:
memory_result = self._search_memory_cache(tool_name, args)
if memory_result:
logger.info(f"[{self.chat_id}] 内存缓存命中: {tool_name}")
return memory_result
# 然后检查全局缓存系统
try:
# 这里需要工具实例来获取文件路径,但为了解耦,我们先尝试从历史记录中推断
tool_file_path = self._infer_tool_path(tool_name)
# 尝试语义缓存(如果可以推断出语义查询参数)
semantic_query = self._extract_semantic_query(tool_name, args)
cached_result = await tool_cache.get(
tool_name=tool_name,
function_args=args,
tool_file_path=tool_file_path,
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"[{self.chat_id}] 全局缓存命中: {tool_name}")
# 将结果同步到内存缓存
if self.enable_memory_cache:
record = ToolCallRecord(
tool_name=tool_name,
args=args,
result=cached_result,
status="success",
cache_hit=True,
timestamp=time.time(),
)
await self.add_tool_call(record)
return cached_result
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存查询失败: {e}")
return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None,
tool_file_path: Optional[str] = None,
ttl: Optional[int] = None) -> None:
"""缓存工具调用结果
Args:
tool_name: 工具名称
args: 工具参数
result: 执行结果
execution_time: 执行耗时
tool_file_path: 工具文件路径
ttl: 缓存TTL
"""
# 添加到内存历史记录
record = ToolCallRecord(
tool_name=tool_name,
args=args,
result=result,
status="success",
execution_time=execution_time,
cache_hit=False,
timestamp=time.time(),
)
await self.add_tool_call(record)
# 同步到全局缓存系统
try:
if tool_file_path is None:
tool_file_path = self._infer_tool_path(tool_name)
# 尝试语义缓存
semantic_query = self._extract_semantic_query(tool_name, args)
await tool_cache.set(
tool_name=tool_name,
function_args=args,
tool_file_path=tool_file_path,
data=result,
ttl=ttl,
semantic_query=semantic_query,
)
logger.debug(f"[{self.chat_id}] 结果已缓存: {tool_name}")
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
"""获取最近的历史记录
Args:
count: 返回的记录数量
status_filter: 状态过滤器可选值success, error, pending
Returns:
历史记录列表
"""
history = self._history.copy()
# 应用状态过滤
if status_filter:
history = [record for record in history if record.status == status_filter]
# 返回最近的记录
return history[-count:] if history else []
def format_for_prompt(self, max_records: int = 5, include_results: bool = True) -> str:
"""格式化历史记录为提示词
Args:
max_records: 最大记录数量
include_results: 是否包含结果预览
Returns:
格式化的提示词字符串
"""
if not self._history:
return ""
recent_records = self._history[-max_records:]
lines = ["## 🔧 最近工具调用记录"]
for i, record in enumerate(recent_records, 1):
status_icon = "" if record.status == "success" else "" if record.status == "error" else ""
# 格式化参数
args_preview = self._format_args_preview(record.args)
# 基础信息
lines.append(f"{i}. {status_icon} **{record.tool_name}**({args_preview})")
# 添加执行时间和缓存信息
if record.execution_time is not None:
time_info = f"{record.execution_time:.2f}s"
cache_info = "🎯缓存" if record.cache_hit else "🔍执行"
lines.append(f" ⏱️ {time_info} | {cache_info}")
# 添加结果预览
if include_results and record.result_preview:
lines.append(f" 📝 结果: {record.result_preview}")
# 添加错误信息
if record.status == "error" and record.error_message:
lines.append(f" ❌ 错误: {record.error_message}")
# 添加统计信息
if self._stats["total_calls"] > 0:
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
avg_time = self._stats["average_execution_time"]
lines.append(f"\n📊 工具统计: 总计{self._stats['total_calls']}次 | 缓存命中率{cache_hit_rate:.1f}% | 平均耗时{avg_time:.2f}s")
return "\n".join(lines)
def get_stats(self) -> dict[str, Any]:
"""获取性能统计信息
Returns:
统计信息字典
"""
cache_hit_rate = 0.0
if self._stats["total_calls"] > 0:
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
return {
**self._stats,
"cache_hit_rate": cache_hit_rate,
"history_size": len(self._history),
"chat_id": self.chat_id,
}
def clear_history(self) -> None:
"""清除历史记录"""
self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
"""在内存历史记录中搜索缓存
Args:
tool_name: 工具名称
args: 工具参数
Returns:
匹配的结果如果不存在则返回None
"""
for record in reversed(self._history): # 从最新的开始搜索
if (record.tool_name == tool_name and
record.status == "success" and
record.args == args):
return record.result
return None
def _infer_tool_path(self, tool_name: str) -> str:
"""推断工具文件路径
Args:
tool_name: 工具名称
Returns:
推断的文件路径
"""
# 基于工具名称推断路径,这是一个简化的实现
# 在实际使用中,可能需要更复杂的映射逻辑
tool_path_mapping = {
"web_search": "src/plugins/built_in/web_search_tool/tools/web_search.py",
"memory_create": "src/memory_graph/tools/memory_tools.py",
"memory_search": "src/memory_graph/tools/memory_tools.py",
"user_profile_update": "src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py",
"chat_stream_impression_update": "src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py",
}
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
"""提取语义查询参数
Args:
tool_name: 工具名称
args: 工具参数
Returns:
语义查询字符串如果不存在则返回None
"""
# 为不同工具定义语义查询参数映射
semantic_query_mapping = {
"web_search": "query",
"memory_search": "query",
"knowledge_search": "query",
}
query_key = semantic_query_mapping.get(tool_name)
if query_key and query_key in args:
return str(args[query_key])
return None
def _format_args_preview(self, args: dict[str, Any], max_length: int = 100) -> str:
"""格式化参数预览
Args:
args: 参数字典
max_length: 最大长度
Returns:
格式化的参数预览字符串
"""
if not args:
return ""
try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
if len(args_str) > max_length:
args_str = args_str[:max_length] + "..."
return args_str
except Exception:
# 如果序列化失败,使用简单格式
parts = []
for k, v in list(args.items())[:3]: # 最多显示3个参数
parts.append(f"{k}={str(v)[:20]}")
result = ", ".join(parts)
if len(parts) >= 3 or len(result) > max_length:
result += "..."
return result
# 全局管理器字典按chat_id索引
_stream_managers: dict[str, StreamToolHistoryManager] = {}
def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager:
"""获取指定聊天的工具历史记录管理器
Args:
chat_id: 聊天ID
Returns:
工具历史记录管理器实例
"""
if chat_id not in _stream_managers:
_stream_managers[chat_id] = StreamToolHistoryManager(chat_id)
return _stream_managers[chat_id]
def cleanup_stream_manager(chat_id: str) -> None:
"""清理指定聊天的管理器
Args:
chat_id: 聊天ID
"""
if chat_id in _stream_managers:
del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -3,7 +3,6 @@ import time
from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.payload_content import ToolCall
@@ -11,6 +10,8 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
from dataclasses import asdict
logger = get_logger("tool_use")
@@ -18,20 +19,50 @@ logger = get_logger("tool_use")
def init_tool_executor_prompt():
"""初始化工具执行器的提示词"""
tool_executor_prompt = """
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}
群里正在进行的聊天内容:
# 工具调用系统
## 📋 你的身份
- **名字**: {bot_name}
- **核心人设**: {personality_core}
- **人格特质**: {personality_side}
- **当前时间**: {time_now}
## 💬 上下文信息
### 对话历史
{chat_history}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 内容中是否包含需要查询信息的问题
2. 是否有明确的工具使用指令
3. 之前的工具调用是否提供了有用的信息
4. 是否需要基于之前的工具结果进行进一步的查询
### 当前消息
**{sender}** 说: {target_message}
{tool_history}
If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
## 🔧 工具决策指南
**核心原则:**
- 根据上下文智能判断是否需要使用工具
- 每个工具都有详细的description说明其用途和参数
- 避免重复调用历史记录中已执行的工具(除非参数不同)
- 优先考虑使用已有的缓存结果,避免重复调用
**历史记录说明:**
- 上方显示的是**之前**的工具调用记录
- 请参考历史记录避免重复调用相同参数的工具
- 如果历史记录中已有相关结果,可以考虑直接回答而不调用工具
**⚠️ 记忆创建特别提醒:**
创建记忆时subject主体必须使用对话历史中显示的**真实发送人名字**
- ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject
- ❌ 错误:使用"用户""对方"等泛指词
**工具调用策略:**
1. **避免重复调用**:查看历史记录,如果最近已调用过相同工具且参数一致,无需重复调用
2. **智能选择工具**:根据消息内容选择最合适的工具,避免过度使用
3. **参数优化**:确保工具参数简洁有效,避免冗余信息
**执行指令:**
- 需要使用工具 → 直接调用相应的工具函数
- 不需要工具 → 输出 "No tool needed"
"""
Prompt(tool_executor_prompt, "tool_executor_prompt")
@@ -65,9 +96,8 @@ class ToolExecutor:
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
self._log_prefix_initialized = False
# 工具调用历史
self.tool_call_history: list[dict[str, Any]] = []
"""工具调用历史,包含工具名称、参数和结果"""
# 流式工具历史记录管理器
self.history_manager = get_stream_tool_history_manager(chat_id)
# logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中
@@ -109,7 +139,11 @@ class ToolExecutor:
bot_name = global_config.bot.nickname
# 构建工具调用历史文本
tool_history = self._format_tool_history()
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息
personality_core = global_config.personality.personality_core
personality_side = global_config.personality.personality_side
# 构建工具调用提示词
prompt = await global_prompt_manager.format_prompt(
@@ -120,6 +154,8 @@ class ToolExecutor:
bot_name=bot_name,
time_now=time_now,
tool_history=tool_history,
personality_core=personality_core,
personality_side=personality_side,
)
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
@@ -161,83 +197,7 @@ class ToolExecutor:
return tool_definitions
def _format_tool_history(self, max_history: int = 5) -> str:
"""格式化工具调用历史为文本
Args:
max_history: 最多显示的历史记录数量
Returns:
格式化的工具历史文本
"""
if not self.tool_call_history:
return ""
# 只取最近的几条历史
recent_history = self.tool_call_history[-max_history:]
history_lines = ["历史工具调用记录:"]
for i, record in enumerate(recent_history, 1):
tool_name = record.get("tool_name", "unknown")
args = record.get("args", {})
result_preview = record.get("result_preview", "")
status = record.get("status", "success")
# 格式化参数
args_str = ", ".join([f"{k}={v}" for k, v in args.items()])
# 格式化记录
status_emoji = "" if status == "success" else ""
history_lines.append(f"{i}. {status_emoji} {tool_name}({args_str})")
if result_preview:
# 限制结果预览长度
if len(result_preview) > 200:
result_preview = result_preview[:200] + "..."
history_lines.append(f" 结果: {result_preview}")
return "\n".join(history_lines)
def _add_tool_to_history(self, tool_name: str, args: dict, result: dict | None, status: str = "success"):
"""添加工具调用到历史记录
Args:
tool_name: 工具名称
args: 工具参数
result: 工具结果
status: 执行状态 (success/error)
"""
# 生成结果预览
result_preview = ""
if result:
content = result.get("content", "")
if isinstance(content, str):
result_preview = content
elif isinstance(content, list | dict):
import json
try:
result_preview = json.dumps(content, ensure_ascii=False)
except Exception:
result_preview = str(content)
else:
result_preview = str(content)
record = {
"tool_name": tool_name,
"args": args,
"result_preview": result_preview,
"status": status,
"timestamp": time.time(),
}
self.tool_call_history.append(record)
# 限制历史记录数量,避免内存溢出
max_history_size = 5
if len(self.tool_call_history) > max_history_size:
self.tool_call_history = self.tool_call_history[-max_history_size:]
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用
@@ -298,10 +258,20 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
# 记录到历史
self._add_tool_to_history(tool_name, tool_args, result, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=result,
status="success"
))
else:
# 工具返回空结果也记录到历史
self._add_tool_to_history(tool_name, tool_args, None, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="success"
))
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
@@ -316,62 +286,72 @@ class ToolExecutor:
tool_results.append(error_info)
# 记录失败到历史
self._add_tool_to_history(tool_name, tool_args, None, status="error")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="error",
error_message=str(e)
))
return tool_results, used_tools
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
) -> dict[str, Any] | None:
"""执行单个工具调用,并处理缓存"""
"""执行单个工具调用,集成流式历史记录管理器"""
start_time = time.time()
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
# 如果工具不存在或未启用缓存,则直接执行
if not tool_instance or not tool_instance.enable_cache:
return await self._original_execute_tool_call(tool_call, tool_instance)
# 尝试从历史记录管理器获取缓存结果
if tool_instance and tool_instance.enable_cache:
try:
cached_result = await self.history_manager.get_cached_result(
tool_name=tool_call.func_name,
args=function_args
)
if cached_result:
execution_time = time.time() - start_time
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
# --- 缓存逻辑开始 ---
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
# 记录缓存命中到历史
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_call.func_name,
args=function_args,
result=cached_result,
status="success",
execution_time=execution_time,
cache_hit=True
))
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查历史缓存时出错: {e}")
# 缓存未命中,执行原始工具调用
# 缓存未命中,执行工具调用
result = await self._original_execute_tool_call(tool_call, tool_instance)
# 将结果存入缓存
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
# 记录执行结果到历史管理器
execution_time = time.time() - start_time
if tool_instance and result and tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query,
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# --- 缓存逻辑结束 ---
await self.history_manager.cache_result(
tool_name=tool_call.func_name,
args=function_args,
result=result,
execution_time=execution_time,
tool_file_path=tool_file_path,
ttl=tool_instance.cache_ttl
)
except Exception as e:
logger.error(f"{self.log_prefix}缓存结果到历史管理器时出错: {e}")
return result
@@ -506,21 +486,31 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
# 记录到历史
self._add_tool_to_history(tool_name, tool_args, result, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=result,
status="success"
))
return tool_info
except Exception as e:
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
# 记录失败到历史
self._add_tool_to_history(tool_name, tool_args, None, status="error")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="error",
error_message=str(e)
))
return None
def clear_tool_history(self):
"""清除工具调用历史"""
self.tool_call_history.clear()
logger.debug(f"{self.log_prefix}已清除工具调用历史")
self.history_manager.clear_history()
def get_tool_history(self) -> list[dict[str, Any]]:
"""获取工具调用历史
@@ -528,7 +518,17 @@ class ToolExecutor:
Returns:
工具调用历史列表
"""
return self.tool_call_history.copy()
# 返回最近的历史记录
records = self.history_manager.get_recent_history(count=10)
return [asdict(record) for record in records]
def get_tool_stats(self) -> dict[str, Any]:
"""获取工具统计信息
Returns:
工具统计信息字典
"""
return self.history_manager.get_stats()
"""

View File

@@ -639,18 +639,20 @@ class ChatterPlanFilter:
else:
keywords.append("晚上")
# 使用新的统一记忆系统检索记忆
# 使用记忆系统检索记忆
try:
from src.chat.memory_system import get_memory_system
from src.memory_graph.manager_singleton import get_memory_manager
memory_system = get_memory_system()
memory_manager = get_memory_manager()
if not memory_manager:
return "记忆系统未初始化。"
# 将关键词转换为查询字符串
query = " ".join(keywords)
enhanced_memories = await memory_system.retrieve_relevant_memories(
query_text=query,
user_id="system", # 系统查询
scope_id="system",
limit=5,
enhanced_memories = await memory_manager.search_memories(
query=query,
top_k=5,
use_multi_query=False, # 直接使用关键词查询
)
if not enhanced_memories:
@@ -658,9 +660,14 @@ class ChatterPlanFilter:
# 转换格式以兼容现有代码
retrieved_memories = []
for memory_chunk in enhanced_memories:
content = memory_chunk.display or memory_chunk.text_content or ""
memory_type = memory_chunk.memory_type.value if memory_chunk.memory_type else "unknown"
for memory in enhanced_memories:
# 从记忆图的节点中提取内容
content_parts = []
for node in memory.nodes:
if node.content:
content_parts.append(node.content)
content = " ".join(content_parts) if content_parts else "无内容"
memory_type = memory.memory_type.value
retrieved_memories.append((memory_type, content))
memory_statements = [

View File

@@ -3,7 +3,7 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
"""
import json
import orjson
from datetime import datetime
from typing import Any, Literal

View File

@@ -3,7 +3,7 @@
负责记录和管理已回复过的评论ID避免重复回复
"""
import json
import orjson
import time
from pathlib import Path
from typing import Any
@@ -71,7 +71,7 @@ class ReplyTrackerService:
self.replied_comments = {}
return
data = json.loads(file_content)
data = orjson.loads(file_content)
if self._validate_data(data):
self.replied_comments = data
logger.info(
@@ -81,7 +81,7 @@ class ReplyTrackerService:
else:
logger.error("加载的数据格式无效,将创建新的记录")
self.replied_comments = {}
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
logger.error(f"解析回复记录文件失败: {e}")
self._backup_corrupted_file()
self.replied_comments = {}
@@ -118,7 +118,7 @@ class ReplyTrackerService:
# 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')
# 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,6 +1,6 @@
import asyncio
import inspect
import json
import orjson
from typing import ClassVar, List
import websockets as Server
@@ -44,10 +44,10 @@ async def message_recv(server_connection: Server.ServerConnection):
# 只在debug模式下记录原始消息
if logger.level <= 10: # DEBUG level
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
decoded_raw_message: dict = json.loads(raw_message)
decoded_raw_message: dict = orjson.loads(raw_message)
try:
# 首先尝试解析原始消息
decoded_raw_message: dict = json.loads(raw_message)
decoded_raw_message: dict = orjson.loads(raw_message)
# 检查是否是切片消息 (来自 MMC)
if chunker.is_chunk_message(decoded_raw_message):
@@ -71,7 +71,7 @@ async def message_recv(server_connection: Server.ServerConnection):
elif post_type is None:
await put_response(decoded_raw_message)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
logger.error(f"消息解析失败: {e}")
logger.debug(f"原始消息: {raw_message[:500]}...")
except Exception as e:

View File

@@ -5,7 +5,7 @@
"""
import asyncio
import json
import orjson
import time
import uuid
from typing import Any, Dict, List, Optional, Union
@@ -34,7 +34,7 @@ class MessageChunker:
"""判断消息是否需要切片"""
try:
if isinstance(message, dict):
message_str = json.dumps(message, ensure_ascii=False)
message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
else:
message_str = message
return len(message_str.encode("utf-8")) > self.max_chunk_size
@@ -58,7 +58,7 @@ class MessageChunker:
try:
# 统一转换为字符串
if isinstance(message, dict):
message_str = json.dumps(message, ensure_ascii=False)
message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
else:
message_str = message
@@ -116,7 +116,7 @@ class MessageChunker:
"""判断是否是切片消息"""
try:
if isinstance(message, str):
data = json.loads(message)
data = orjson.loads(message)
else:
data = message
@@ -126,7 +126,7 @@ class MessageChunker:
and "__mmc_chunk_data__" in data
and "__mmc_is_chunked__" in data
)
except (json.JSONDecodeError, TypeError):
except (orjson.JSONDecodeError, TypeError):
return False
@@ -187,7 +187,7 @@ class MessageReassembler:
try:
# 统一转换为字典
if isinstance(message, str):
chunk_data = json.loads(message)
chunk_data = orjson.loads(message)
else:
chunk_data = message
@@ -197,8 +197,8 @@ class MessageReassembler:
if "_original_message" in chunk_data:
# 这是一个被包装的非切片消息,解包返回
try:
return json.loads(chunk_data["_original_message"])
except json.JSONDecodeError:
return orjson.loads(chunk_data["_original_message"])
except orjson.JSONDecodeError:
return {"text_message": chunk_data["_original_message"]}
else:
return chunk_data
@@ -251,14 +251,14 @@ class MessageReassembler:
# 尝试反序列化重组后的消息
try:
return json.loads(reassembled_message)
except json.JSONDecodeError:
return orjson.loads(reassembled_message)
except orjson.JSONDecodeError:
# 如果不能反序列化为JSON则作为文本消息返回
return {"text_message": reassembled_message}
return None
except (json.JSONDecodeError, KeyError, TypeError) as e:
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"处理切片消息时出错: {e}")
return None

View File

@@ -1,5 +1,5 @@
import base64
import json
import orjson
import time
import uuid
from pathlib import Path
@@ -783,11 +783,11 @@ class MessageHandler:
# 检查JSON消息格式
if not message_data or "data" not in message_data:
logger.warning("JSON消息格式不正确")
return Seg(type="json", data=json.dumps(message_data))
return Seg(type="json", data=orjson.dumps(message_data).decode('utf-8'))
try:
# 尝试将json_data解析为Python对象
nested_data = json.loads(json_data)
nested_data = orjson.loads(json_data)
# 检查是否是机器人自己上传文件的回声
if self._is_file_upload_echo(nested_data):
@@ -912,7 +912,7 @@ class MessageHandler:
# 如果没有提取到关键信息返回None
return None
except json.JSONDecodeError:
except orjson.JSONDecodeError:
# 如果解析失败我们假设它不是我们关心的任何一种结构化JSON
# 而是普通的文本或者无法解析的格式。
logger.debug(f"无法将data字段解析为JSON: {json_data}")
@@ -1146,13 +1146,13 @@ class MessageHandler:
return None
forward_message_id = forward_message_data.get("id")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
payload = orjson.dumps(
{
"action": "get_forward_msg",
"params": {"message_id": forward_message_id},
"echo": request_uuid,
}
)
).decode('utf-8')
try:
connection = self.get_server_connection()
if not connection:
@@ -1167,9 +1167,9 @@ class MessageHandler:
logger.error(f"获取转发消息失败: {str(e)}")
return None
logger.debug(
f"转发消息原始格式:{json.dumps(response)[:80]}..."
if len(json.dumps(response)) > 80
else json.dumps(response)
f"转发消息原始格式:{orjson.dumps(response).decode('utf-8')[:80]}..."
if len(orjson.dumps(response).decode('utf-8')) > 80
else orjson.dumps(response).decode('utf-8')
)
response_data: Dict = response.get("data")
if not response_data:

View File

@@ -1,5 +1,5 @@
import asyncio
import json
import orjson
import time
from typing import ClassVar, Optional, Tuple
@@ -241,7 +241,7 @@ class NoticeHandler:
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=handled_message,
raw_message=json.dumps(raw_message),
raw_message=orjson.dumps(raw_message).decode('utf-8'),
)
if system_notice:
@@ -602,7 +602,7 @@ class NoticeHandler:
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=seg_message,
raw_message=json.dumps(
raw_message=orjson.dumps(
{
"post_type": "notice",
"notice_type": "group_ban",
@@ -611,7 +611,7 @@ class NoticeHandler:
"user_id": user_id,
"operator_id": None, # 自然解除禁言没有操作者
}
),
).decode('utf-8'),
)
await self.put_notice(message_base)

View File

@@ -1,4 +1,5 @@
import json
import orjson
import random
import time
import random
import websockets as Server
@@ -603,7 +604,7 @@ class SendHandler:
async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict:
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
payload = orjson.dumps({"action": action, "params": params, "echo": request_uuid}).decode('utf-8')
# 获取当前连接
connection = self.get_server_connection()

View File

@@ -1,6 +1,6 @@
import base64
import io
import json
import orjson
import ssl
import uuid
from typing import List, Optional, Tuple, Union
@@ -34,7 +34,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
"""
logger.debug("获取群聊信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
payload = orjson.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
@@ -56,7 +56,7 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in
"""
logger.debug("获取群详细信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
payload = orjson.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
@@ -78,13 +78,13 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
"""
logger.debug("获取群成员信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
payload = orjson.dumps(
{
"action": "get_group_member_info",
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
"echo": request_uuid,
}
)
).decode('utf-8')
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
@@ -146,7 +146,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
"""
logger.debug("获取自身信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
payload = orjson.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}).decode('utf-8')
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid)
@@ -183,7 +183,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) ->
"""
logger.debug("获取陌生人信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
payload = orjson.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}).decode('utf-8')
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid)
@@ -208,7 +208,7 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni
"""
logger.debug("获取消息详情中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
payload = orjson.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}).decode('utf-8')
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
@@ -236,13 +236,13 @@ async def get_record_detail(
"""
logger.debug("获取语音消息详情中")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
payload = orjson.dumps(
{
"action": "get_record",
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
"echo": request_uuid,
}
)
).decode('utf-8')
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒

View File

@@ -39,15 +39,23 @@ class ExaSearchEngine(BaseSearchEngine):
return self.api_manager.is_available()
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""执行Exa搜索"""
"""执行优化的Exa搜索使用answer模式"""
if not self.is_available():
return []
query = args["query"]
num_results = args.get("num_results", 3)
num_results = min(args.get("num_results", 5), 5) # 默认5个结果但限制最多5个
time_range = args.get("time_range", "any")
exa_args = {"num_results": num_results, "text": True, "highlights": True}
# 优化的搜索参数 - 更注重答案质量
exa_args = {
"num_results": num_results,
"text": True,
"highlights": True,
"summary": True, # 启用自动摘要
}
# 时间范围过滤
if time_range != "any":
today = datetime.now()
start_date = today - timedelta(days=7 if time_range == "week" else 30)
@@ -61,18 +69,89 @@ class ExaSearchEngine(BaseSearchEngine):
return []
loop = asyncio.get_running_loop()
# 使用search_and_contents获取完整内容优化为answer模式
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
search_response = await loop.run_in_executor(None, func)
return [
{
# 优化结果处理 - 更注重答案质量
results = []
for res in search_response.results:
# 获取最佳内容片段
highlights = getattr(res, "highlights", [])
summary = getattr(res, "summary", "")
text = getattr(res, "text", "")
# 智能内容选择:摘要 > 高亮 > 文本开头
if summary and len(summary) > 50:
snippet = summary.strip()
elif highlights:
snippet = " ".join(highlights).strip()
elif text:
snippet = text[:300] + "..." if len(text) > 300 else text
else:
snippet = "内容获取失败"
# 只保留有意义的摘要
if len(snippet) < 30:
snippet = text[:200] + "..." if text and len(text) > 200 else snippet
results.append({
"title": res.title,
"url": res.url,
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
"snippet": snippet,
"provider": "Exa",
}
for res in search_response.results
]
"answer_focused": True, # 标记为答案导向的搜索
})
return results
except Exception as e:
logger.error(f"Exa 搜索失败: {e}")
logger.error(f"Exa answer模式搜索失败: {e}")
return []
async def answer_search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""执行Exa快速答案搜索 - 最精简的搜索模式"""
if not self.is_available():
return []
query = args["query"]
num_results = min(args.get("num_results", 3), 3) # answer模式默认3个结果专注质量
# 精简的搜索参数 - 专注快速答案
exa_args = {
"num_results": num_results,
"text": False, # 不需要全文
"highlights": True, # 只要关键高亮
"summary": True, # 优先摘要
}
try:
exa_client = self.api_manager.get_next_client()
if not exa_client:
return []
loop = asyncio.get_running_loop()
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
search_response = await loop.run_in_executor(None, func)
# 极简结果处理 - 只保留最核心信息
results = []
for res in search_response.results:
summary = getattr(res, "summary", "")
highlights = getattr(res, "highlights", [])
# 优先使用摘要,否则使用高亮
answer_text = summary.strip() if summary and len(summary) > 30 else " ".join(highlights).strip()
if answer_text and len(answer_text) > 20:
results.append({
"title": res.title,
"url": res.url,
"snippet": answer_text[:400] + "..." if len(answer_text) > 400 else answer_text,
"provider": "Exa-Answer",
"answer_mode": True # 标记为纯答案模式
})
return results
except Exception as e:
logger.error(f"Exa快速答案搜索失败: {e}")
return []

View File

@@ -1,7 +1,7 @@
"""
Metaso Search Engine (Chat Completions Mode)
"""
import json
import orjson
from typing import Any
import httpx
@@ -43,12 +43,12 @@ class MetasoClient:
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
data = orjson.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content_chunk = delta.get("content")
if content_chunk:
full_response_content += content_chunk
except json.JSONDecodeError:
except orjson.JSONDecodeError:
logger.warning(f"Metaso stream: could not decode JSON line: {data_str}")
continue

View File

@@ -41,6 +41,13 @@ class WebSurfingTool(BaseTool):
False,
["any", "week", "month"],
),
(
"answer_mode",
ToolParamType.BOOLEAN,
"是否启用答案模式仅适用于Exa搜索引擎。启用后将返回更精简、直接的答案减少冗余信息。默认为False。",
False,
None,
),
] # type: ignore
def __init__(self, plugin_config=None, chat_stream=None):
@@ -97,13 +104,19 @@ class WebSurfingTool(BaseTool):
) -> dict[str, Any]:
"""并行搜索策略:同时使用所有启用的搜索引擎"""
search_tasks = []
answer_mode = function_args.get("answer_mode", False)
for engine_name in enabled_engines:
engine = self.engines.get(engine_name)
if engine and engine.is_available():
custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5)
search_tasks.append(engine.search(custom_args))
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
search_tasks.append(engine.answer_search(custom_args))
else:
search_tasks.append(engine.search(custom_args))
if not search_tasks:
@@ -137,17 +150,23 @@ class WebSurfingTool(BaseTool):
self, function_args: dict[str, Any], enabled_engines: list[str]
) -> dict[str, Any]:
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
answer_mode = function_args.get("answer_mode", False)
for engine_name in enabled_engines:
engine = self.engines.get(engine_name)
if not engine or not engine.is_available():
continue
try:
custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5)
results = await engine.search(custom_args)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
logger.info("使用Exa答案模式进行搜索fallback策略")
results = await engine.answer_search(custom_args)
else:
results = await engine.search(custom_args)
if results: # 如果有结果,直接返回
formatted_content = format_search_results(results)
@@ -164,22 +183,30 @@ class WebSurfingTool(BaseTool):
async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]:
"""单一搜索策略:只使用第一个可用的搜索引擎"""
answer_mode = function_args.get("answer_mode", False)
for engine_name in enabled_engines:
engine = self.engines.get(engine_name)
if not engine or not engine.is_available():
continue
try:
custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5)
results = await engine.search(custom_args)
formatted_content = format_search_results(results)
return {
"type": "web_search_result",
"content": formatted_content,
}
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
logger.info("使用Exa答案模式进行搜索")
results = await engine.answer_search(custom_args)
else:
results = await engine.search(custom_args)
if results:
formatted_content = format_search_results(results)
return {
"type": "web_search_result",
"content": formatted_content,
}
except Exception as e:
logger.error(f"{engine_name} 搜索失败: {e}")