Merge branch 'dev' into dev

This commit is contained in:
拾风
2025-11-07 13:14:27 +08:00
committed by GitHub
98 changed files with 16116 additions and 8718 deletions

View File

@@ -132,6 +132,56 @@ class ExpressionLearner:
self.chat_name = stream_name or self.chat_id
self._chat_name_initialized = True
async def cleanup_expired_expressions(self, expiration_days: int | None = None) -> int:
"""
清理过期的表达方式
Args:
expiration_days: 过期天数,超过此天数未激活的表达方式将被删除(不指定则从配置读取)
Returns:
int: 删除的表达方式数量
"""
# 从配置读取过期天数
if expiration_days is None:
expiration_days = global_config.expression.expiration_days
current_time = time.time()
expiration_threshold = current_time - (expiration_days * 24 * 3600)
try:
deleted_count = 0
async with get_db_session() as session:
# 查询过期的表达方式只清理当前chat_id的
query = await session.execute(
select(Expression).where(
(Expression.chat_id == self.chat_id)
& (Expression.last_active_time < expiration_threshold)
)
)
expired_expressions = list(query.scalars())
if expired_expressions:
for expr in expired_expressions:
await session.delete(expr)
deleted_count += 1
await session.commit()
logger.info(f"清理了 {deleted_count} 个过期表达方式(超过 {expiration_days} 天未使用)")
# 清除缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
await cache.delete(generate_cache_key("chat_expressions", self.chat_id))
else:
logger.debug(f"没有发现过期的表达方式(阈值:{expiration_days} 天)")
return deleted_count
except Exception as e:
logger.error(f"清理过期表达方式失败: {e}")
return 0
def can_learn_for_chat(self) -> bool:
"""
检查指定聊天流是否允许学习表达
@@ -214,6 +264,9 @@ class ExpressionLearner:
try:
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
# 🔥 改进3在学习前清理过期的表达方式
await self.cleanup_expired_expressions()
# 学习语言风格
learnt_style = await self.learn_and_store(type="style", num=25)
@@ -397,9 +450,29 @@ class ExpressionLearner:
for chat_id, expr_list in chat_dict.items():
async with get_db_session() as session:
for new_expr in expr_list:
# 查是否存在相似表达方式
# 注意: get_all_by 不支持复杂条件,这里仍需使用 session
query = await session.execute(
# 🔥 改进1查是否存在相同情景或相同表达的数据
# 情况1相同 chat_id + type + situation相同情景不同表达
query_same_situation = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
)
)
same_situation_expr = query_same_situation.scalar()
# 情况2相同 chat_id + type + style相同表达不同情景
query_same_style = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.style == new_expr["style"])
)
)
same_style_expr = query_same_style.scalar()
# 情况3完全相同相同情景+相同表达)
query_exact_match = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
@@ -407,16 +480,29 @@ class ExpressionLearner:
& (Expression.style == new_expr["style"])
)
)
existing_expr = query.scalar()
if existing_expr:
expr_obj = existing_expr
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
exact_match_expr = query_exact_match.scalar()
# 优先处理完全匹配的情况
if exact_match_expr:
# 完全相同增加count更新时间
expr_obj = exact_match_expr
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
logger.debug(f"完全匹配更新count {expr_obj.count}")
elif same_situation_expr:
# 相同情景,不同表达:覆盖旧的表达
logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{new_expr['style']}'")
same_situation_expr.style = new_expr["style"]
same_situation_expr.count = same_situation_expr.count + 1
same_situation_expr.last_active_time = current_time
elif same_style_expr:
# 相同表达,不同情景:覆盖旧的情景
logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{new_expr['situation']}'")
same_style_expr.situation = new_expr["situation"]
same_style_expr.count = same_style_expr.count + 1
same_style_expr.last_active_time = current_time
else:
# 完全新的表达方式:创建新记录
new_expression = Expression(
situation=new_expr["situation"],
style=new_expr["style"],
@@ -427,6 +513,7 @@ class ExpressionLearner:
create_date=current_time, # 手动设置创建日期
)
session.add(new_expression)
logger.debug(f"新增表达方式:{new_expr['situation']} -> {new_expr['style']}")
# 限制最大数量 - 使用 get_all_by_sorted 获取排序结果
exprs_result = await session.execute(

View File

@@ -61,6 +61,34 @@ class ExpressorModel:
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def remove_candidate(self, cid: str) -> bool:
"""
删除候选文本
Args:
cid: 候选ID
Returns:
是否删除成功
"""
removed = False
if cid in self._candidates:
del self._candidates[cid]
removed = True
if cid in self._situations:
del self._situations[cid]
# 从nb模型中删除
if cid in self.nb.cls_counts:
del self.nb.cls_counts[cid]
if cid in self.nb.token_counts:
del self.nb.token_counts[cid]
return removed
def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分

View File

@@ -36,6 +36,8 @@ class StyleLearner:
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.cleanup_threshold = 0.9 # 达到90%容量时触发清理
self.cleanup_ratio = 0.2 # 每次清理20%的风格
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
self.id_to_style: dict[str, str] = {} # style_id -> style文本
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
@@ -45,6 +47,7 @@ class StyleLearner:
self.learning_stats = {
"total_samples": 0,
"style_counts": {},
"style_last_used": {}, # 记录每个风格最后使用时间
"last_update": time.time(),
}
@@ -66,10 +69,19 @@ class StyleLearner:
if style in self.style_to_id:
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"已达到最大风格数量限制 ({self.max_styles})")
return False
# 检查是否需要清理
current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger:
if current_count >= self.max_styles:
# 已经达到最大限制,必须清理
logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理")
self._cleanup_styles()
elif current_count >= cleanup_trigger:
# 接近限制,提前清理
logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理")
self._cleanup_styles()
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
@@ -94,6 +106,80 @@ class StyleLearner:
logger.error(f"添加风格失败: {e}")
return False
def _cleanup_styles(self):
"""
清理低价值的风格,为新风格腾出空间
清理策略:
1. 综合考虑使用次数和最后使用时间
2. 删除得分最低的风格
3. 默认清理 cleanup_ratio (20%) 的风格
"""
try:
current_time = time.time()
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
# 计算每个风格的价值分数
style_scores = []
for style_id in self.style_to_id.values():
# 使用次数
usage_count = self.learning_stats["style_counts"].get(style_id, 0)
# 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float('inf')
# 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响
import math
usage_score = math.log1p(usage_count) # log(1 + count)
# 时间分数:转换为天数,使用指数衰减
days_unused = time_since_used / 86400 # 转换为天
time_score = math.exp(-days_unused / 30) # 30天衰减因子
# 综合分数80%使用频率 + 20%时间新鲜度
total_score = 0.8 * usage_score + 0.2 * time_score
style_scores.append((style_id, total_score, usage_count, days_unused))
# 按分数排序,分数低的先删除
style_scores.sort(key=lambda x: x[1])
# 删除分数最低的风格
deleted_styles = []
for style_id, score, usage, days in style_scores[:cleanup_count]:
style_text = self.id_to_style.get(style_id)
if style_text:
# 从映射中删除
del self.style_to_id[style_text]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从统计中删除
if style_id in self.learning_stats["style_counts"]:
del self.learning_stats["style_counts"][style_id]
if style_id in self.learning_stats["style_last_used"]:
del self.learning_stats["style_last_used"][style_id]
# 从expressor模型中删除
self.expressor.remove_candidate(style_id)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
logger.info(
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
f"剩余 {len(self.style_to_id)} 个风格"
)
# 记录前5个被删除的风格用于调试
if deleted_styles:
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
except Exception as e:
logger.error(f"清理风格失败: {e}", exc_info=True)
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
@@ -118,9 +204,11 @@ class StyleLearner:
self.expressor.update_positive(up_content, style_id)
# 更新统计
current_time = time.time()
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["last_update"] = time.time()
self.learning_stats["style_last_used"][style_id] = current_time # 更新最后使用时间
self.learning_stats["last_update"] = current_time
logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}")
return True
@@ -171,6 +259,10 @@ class StyleLearner:
else:
logger.warning(f"跳过无法转换的style_id: {sid}")
# 更新最后使用时间(仅针对最佳风格)
if best_style_id:
self.learning_stats["style_last_used"][best_style_id] = time.time()
logger.debug(
f"预测成功: up_content={up_content[:30]}..., "
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
@@ -208,6 +300,30 @@ class StyleLearner:
"""
return list(self.style_to_id.keys())
def cleanup_old_styles(self, ratio: float | None = None) -> int:
"""
手动清理旧风格
Args:
ratio: 清理比例如果为None则使用默认的cleanup_ratio
Returns:
清理的风格数量
"""
old_count = len(self.style_to_id)
if ratio is not None:
old_cleanup_ratio = self.cleanup_ratio
self.cleanup_ratio = ratio
self._cleanup_styles()
self.cleanup_ratio = old_cleanup_ratio
else:
self._cleanup_styles()
new_count = len(self.style_to_id)
cleaned = old_count - new_count
logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格")
return cleaned
def apply_decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -241,6 +357,11 @@ class StyleLearner:
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
# 确保 learning_stats 包含所有必要字段
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
meta_data = {
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
@@ -295,6 +416,10 @@ class StyleLearner:
self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"]
# 确保旧数据兼容:如果没有 style_last_used 字段,添加它
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
logger.info(f"StyleLearner加载成功: {save_dir}")
return True
@@ -398,6 +523,26 @@ class StyleLearnerManager:
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]:
"""
对所有学习器清理旧风格
Args:
ratio: 清理比例
Returns:
{chat_id: 清理数量}
"""
cleanup_results = {}
for chat_id, learner in self.learners.items():
cleaned = learner.cleanup_old_styles(ratio)
if cleaned > 0:
cleanup_results[chat_id] = cleaned
total_cleaned = sum(cleanup_results.values())
logger.info(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格")
return cleanup_results
def apply_decay_all(self, factor: float | None = None):
"""
对所有学习器应用知识衰减

View File

@@ -1,73 +0,0 @@
"""
简化记忆系统模块
移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制
"""
# 核心数据结构
# 激活器
from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator
from .memory_chunk import (
ConfidenceLevel,
ContentStructure,
ImportanceLevel,
MemoryChunk,
MemoryMetadata,
MemoryType,
create_memory_chunk,
)
# 兼容性别名
from .memory_chunk import MemoryChunk as Memory
# 遗忘引擎
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
from .memory_formatter import format_memories_bracket_style
# 记忆管理器
from .memory_manager import MemoryManager, MemoryResult, memory_manager
# 记忆核心系统
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
__all__ = [
"ConfidenceLevel",
"ContentStructure",
"ForgettingConfig",
"ImportanceLevel",
"Memory", # 兼容性别名
# 激活器
"MemoryActivator",
# 核心数据结构
"MemoryChunk",
# 遗忘引擎
"MemoryForgettingEngine",
# 记忆管理器
"MemoryManager",
"MemoryMetadata",
"MemoryResult",
# 记忆系统
"MemorySystem",
"MemorySystemConfig",
"MemoryType",
# Vector DB存储
"VectorMemoryStorage",
"VectorStorageConfig",
"create_memory_chunk",
"enhanced_memory_activator", # 兼容性别名
# 格式化工具
"format_memories_bracket_style",
"get_memory_forgetting_engine",
"get_memory_system",
"get_vector_memory_storage",
"initialize_memory_system",
"memory_activator",
"memory_manager",
]
# 版本信息
__version__ = "3.0.0"
__author__ = "MoFox Team"
__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制"

View File

@@ -1,240 +0,0 @@
"""
记忆激活器
记忆系统的激活器组件
"""
import difflib
from datetime import datetime
import orjson
from json_repair import repair_json
from src.chat.memory_system.memory_manager import MemoryResult
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> list:
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Memory Activator Prompt ---
memory_activator_prompt = """
你是一个记忆分析器,你需要根据以下信息来进行记忆检索
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆检索的触发词
聊天记录:
{obs_info_text}
用户想要回复的消息:
{target_message}
历史关键词(请避免重复提取这些关键词):
{cached_keywords}
请输出一个json格式包含以下字段
{{
"keywords": ["关键词1", "关键词2", "关键词3",......]
}}
不要输出其他多余内容只输出json格式就好
"""
Prompt(memory_activator_prompt, "memory_activator_prompt")
class MemoryActivator:
"""记忆激活器"""
def __init__(self):
self.key_words_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
self.running_memory = []
self.cached_keywords = set() # 用于缓存历史关键词
self.last_memory_query_time = 0 # 上次查询记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
# 将缓存的关键词转换为字符串用于prompt
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
prompt = await global_prompt_manager.format_prompt(
"memory_activator_prompt",
obs_info_text=chat_history_prompt,
target_message=target_message,
cached_keywords=cached_keywords_str,
)
# 生成关键词
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
prompt, temperature=0.5
)
keywords = list(get_keywords_from_json(response))
# 更新关键词缓存
if keywords:
# 限制缓存大小最多保留10个关键词
if len(self.cached_keywords) > 10:
# 转换为列表,移除最早的关键词
cached_list = list(self.cached_keywords)
self.cached_keywords = set(cached_list[-8:])
# 添加新的关键词到缓存
self.cached_keywords.update(keywords)
logger.debug(f"记忆关键词: {self.cached_keywords}")
# 使用记忆系统获取相关记忆
memory_results = await self._query_unified_memory(keywords, target_message)
# 处理和记忆结果
if memory_results:
for result in memory_results:
# 检查是否已存在相似内容的记忆
exists = any(
m["content"] == result.content
or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
for m in self.running_memory
)
if not exists:
memory_entry = {
"topic": result.memory_type,
"content": result.content,
"timestamp": datetime.fromtimestamp(result.timestamp).isoformat(),
"duration": 1,
"confidence": result.confidence,
"importance": result.importance,
"source": result.source,
"relevance_score": result.relevance_score, # 添加相关度评分
}
self.running_memory.append(memory_entry)
logger.debug(f"添加新记忆: {result.memory_type} - {result.content}")
# 激活时所有已有记忆的duration+1达到3则移除
for m in self.running_memory[:]:
m["duration"] = m.get("duration", 1) + 1
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
# 限制同时加载的记忆条数最多保留最后5条
if len(self.running_memory) > 5:
self.running_memory = self.running_memory[-5:]
return self.running_memory
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
"""查询统一记忆系统"""
try:
# 使用记忆系统
from src.chat.memory_system.memory_system import get_memory_system
memory_system = get_memory_system()
if not memory_system or memory_system.status.value != "ready":
logger.warning("记忆系统未就绪")
return []
# 构建查询上下文
context = {"keywords": keywords, "query_intent": "conversation_response"}
# 查询记忆
memories = await memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="global", # 使用全局作用域
context=context,
limit=5,
)
# 转换为 MemoryResult 格式
memory_results = []
for memory in memories:
result = MemoryResult(
content=memory.display,
memory_type=memory.memory_type.value,
confidence=memory.metadata.confidence.value,
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="unified_memory",
relevance_score=memory.metadata.relevance_score,
)
memory_results.append(result)
logger.debug(f"统一记忆查询返回 {len(memory_results)} 条结果")
return memory_results
except Exception as e:
logger.error(f"查询统一记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
"""
获取即时记忆 - 兼容原有接口(使用统一存储)
"""
try:
# 使用统一存储系统获取相关记忆
from src.chat.memory_system.memory_system import get_memory_system
memory_system = get_memory_system()
if not memory_system or memory_system.status.value != "ready":
return None
context = {"query_intent": "instant_response", "chat_id": chat_id}
memories = await memory_system.retrieve_relevant_memories(
query_text=target_message, user_id="global", context=context, limit=1
)
if memories:
return memories[0].display
return None
except Exception as e:
logger.error(f"获取即时记忆失败: {e}")
return None
def clear_cache(self):
"""清除缓存"""
self.cached_keywords.clear()
self.running_memory.clear()
logger.debug("记忆激活器缓存已清除")
# 创建全局实例
memory_activator = MemoryActivator()
# 兼容性别名
enhanced_memory_activator = memory_activator
init_prompt()

View File

@@ -1,721 +0,0 @@
"""
海马体双峰分布采样器
基于旧版海马体的采样策略,适配新版记忆系统
实现低消耗、高效率的记忆采样模式
"""
import asyncio
import random
import time
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
import numpy as np
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
)
from src.chat.utils.utils import translate_timestamp_to_human_readable
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
# 全局背景任务集合
_background_tasks = set()
@dataclass
class HippocampusSampleConfig:
"""海马体采样配置"""
# 双峰分布参数
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
recent_weight: float = 0.7 # 近期分布权重
distant_mean_hours: float = 48.0 # 远期分布均值(小时)
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
distant_weight: float = 0.3 # 远期分布权重
# 采样参数
total_samples: int = 50 # 总采样数
sample_interval: int = 1800 # 采样间隔(秒)
max_sample_length: int = 30 # 每次采样的最大消息数量
batch_size: int = 5 # 批处理大小
@classmethod
def from_global_config(cls) -> "HippocampusSampleConfig":
"""从全局配置创建海马体采样配置"""
config = global_config.memory.hippocampus_distribution_config
return cls(
recent_mean_hours=config[0],
recent_std_hours=config[1],
recent_weight=config[2],
distant_mean_hours=config[3],
distant_std_hours=config[4],
distant_weight=config[5],
total_samples=global_config.memory.hippocampus_sample_size,
sample_interval=global_config.memory.hippocampus_sample_interval,
max_sample_length=global_config.memory.hippocampus_batch_size,
batch_size=global_config.memory.hippocampus_batch_size,
)
class HippocampusSampler:
"""海马体双峰分布采样器"""
def __init__(self, memory_system=None):
self.memory_system = memory_system
self.config = HippocampusSampleConfig.from_global_config()
self.last_sample_time = 0
self.is_running = False
# 记忆构建模型
self.memory_builder_model: LLMRequest | None = None
# 统计信息
self.sample_count = 0
self.success_count = 0
self.last_sample_results: list[dict[str, Any]] = []
async def initialize(self):
"""初始化采样器"""
try:
# 初始化LLM模型
from src.config.config import model_config
task_config = getattr(model_config.model_task_config, "utils", None)
if task_config:
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
task = asyncio.create_task(self.start_background_sampling())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.info("✅ 海马体采样器初始化成功")
else:
raise RuntimeError("未找到记忆构建模型配置")
except Exception as e:
logger.error(f"❌ 海马体采样器初始化失败: {e}")
raise
def generate_time_samples(self) -> list[datetime]:
"""生成双峰分布的时间采样点"""
# 计算每个分布的样本数
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
distant_samples = max(1, self.config.total_samples - recent_samples)
# 生成两个正态分布的小时偏移
recent_offsets = np.random.normal(
loc=self.config.recent_mean_hours, scale=self.config.recent_std_hours, size=recent_samples
)
distant_offsets = np.random.normal(
loc=self.config.distant_mean_hours, scale=self.config.distant_std_hours, size=distant_samples
)
# 合并两个分布的偏移
all_offsets = np.concatenate([recent_offsets, distant_offsets])
# 转换为时间戳(使用绝对值确保时间点在过去)
base_time = datetime.now()
timestamps = [base_time - timedelta(hours=abs(offset)) for offset in all_offsets]
# 按时间排序(从最早到最近)
return sorted(timestamps)
async def collect_message_samples(self, target_timestamp: float) -> list[dict[str, Any]] | None:
"""收集指定时间戳附近的消息样本"""
try:
# 随机时间窗口5-30分钟
time_window_seconds = random.randint(300, 1800)
# 尝试3次获取消息
for attempt in range(3):
timestamp_start = target_timestamp
timestamp_end = target_timestamp + time_window_seconds
# 获取单条消息作为锚点
anchor_messages = await get_raw_msg_by_timestamp(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=1,
limit_mode="earliest",
)
if not anchor_messages:
target_timestamp -= 120 # 向前调整2分钟
continue
anchor_message = anchor_messages[0]
chat_id = anchor_message.get("chat_id")
if not chat_id:
continue
# 获取同聊天的多条消息
messages = await get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=self.config.max_sample_length,
limit_mode="earliest",
chat_id=chat_id,
)
if messages and len(messages) >= 2: # 至少需要2条消息
# 过滤掉已经记忆过的消息
filtered_messages = [
msg
for msg in messages
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
]
if filtered_messages:
logger.debug(f"成功收集 {len(filtered_messages)} 条消息样本")
return filtered_messages
target_timestamp -= 120 # 向前调整再试
logger.debug(f"时间戳 {target_timestamp} 附近未找到有效消息样本")
return None
except Exception as e:
logger.error(f"收集消息样本失败: {e}")
return None
async def build_memory_from_samples(self, messages: list[dict[str, Any]], target_timestamp: float) -> str | None:
"""从消息样本构建记忆"""
if not messages or not self.memory_system or not self.memory_builder_model:
return None
try:
# 构建可读消息文本
readable_text = await build_readable_messages(
messages,
merge_messages=True,
timestamp_mode="normal_no_YMD",
replace_bot_name=False,
)
if not readable_text:
logger.warning("无法从消息样本生成可读文本")
return None
# 直接使用对话文本,不添加系统标识符
input_text = readable_text
logger.debug(f"开始构建记忆,文本长度: {len(input_text)}")
# 构建上下文
context = {
"user_id": "hippocampus_sampler",
"timestamp": time.time(),
"source": "hippocampus_sampling",
"message_count": len(messages),
"sample_mode": "bimodal_distribution",
"is_hippocampus_sample": True, # 标识为海马体样本
"bypass_value_threshold": True, # 绕过价值阈值检查
"hippocampus_sample_time": target_timestamp, # 记录样本时间
}
# 使用记忆系统构建记忆(绕过构建间隔检查)
memories = await self.memory_system.build_memory_from_conversation(
conversation_text=input_text,
context=context,
timestamp=time.time(),
bypass_interval=True, # 海马体采样器绕过构建间隔限制
)
if memories:
memory_count = len(memories)
self.success_count += 1
# 记录采样结果
result = {
"timestamp": time.time(),
"memory_count": memory_count,
"message_count": len(messages),
"text_preview": readable_text[:100] + "..." if len(readable_text) > 100 else readable_text,
"memory_types": [m.memory_type.value for m in memories],
}
self.last_sample_results.append(result)
# 限制结果历史长度
if len(self.last_sample_results) > 10:
self.last_sample_results.pop(0)
logger.info(f"✅ 海马体采样成功构建 {memory_count} 条记忆")
return f"构建{memory_count}条记忆"
else:
logger.debug("海马体采样未生成有效记忆")
return None
except Exception as e:
logger.error(f"海马体采样构建记忆失败: {e}")
return None
async def perform_sampling_cycle(self) -> dict[str, Any]:
"""执行一次完整的采样周期(优化版:批量融合构建)"""
if not self.should_sample():
return {"status": "skipped", "reason": "interval_not_met"}
start_time = time.time()
self.sample_count += 1
try:
# 生成时间采样点
time_samples = self.generate_time_samples()
logger.debug(f"生成 {len(time_samples)} 个时间采样点")
# 记录时间采样点(调试用)
readable_timestamps = [
translate_timestamp_to_human_readable(int(ts.timestamp()), mode="normal")
for ts in time_samples[:5] # 只显示前5个
]
logger.debug(f"时间采样点示例: {readable_timestamps}")
# 第一步:批量收集所有消息样本
logger.debug("开始批量收集消息样本...")
collected_messages = await self._collect_all_message_samples(time_samples)
if not collected_messages:
logger.info("未收集到有效消息样本,跳过本次采样")
self.last_sample_time = time.time()
return {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": 0,
"duration": time.time() - start_time,
"samples_generated": len(time_samples),
"message": "未收集到有效消息样本",
}
logger.info(f"收集到 {len(collected_messages)} 组消息样本")
# 第二步:融合和去重消息
logger.debug("开始融合和去重消息...")
fused_messages = await self._fuse_and_deduplicate_messages(collected_messages)
if not fused_messages:
logger.info("消息融合后为空,跳过记忆构建")
self.last_sample_time = time.time()
return {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": 0,
"duration": time.time() - start_time,
"samples_generated": len(time_samples),
"message": "消息融合后为空",
}
logger.info(f"融合后得到 {len(fused_messages)} 组有效消息")
# 第三步:一次性构建记忆
logger.debug("开始批量构建记忆...")
build_result = await self._build_batch_memory(fused_messages, time_samples)
# 更新最后采样时间
self.last_sample_time = time.time()
duration = time.time() - start_time
result = {
"status": "success",
"sample_count": self.sample_count,
"success_count": self.success_count,
"processed_samples": len(time_samples),
"successful_builds": build_result.get("memory_count", 0),
"duration": duration,
"samples_generated": len(time_samples),
"messages_collected": len(collected_messages),
"messages_fused": len(fused_messages),
"optimization_mode": "batch_fusion",
}
logger.info(
f"✅ 海马体采样周期完成(批量融合模式) | "
f"采样点: {len(time_samples)} | "
f"收集消息: {len(collected_messages)} | "
f"融合消息: {len(fused_messages)} | "
f"构建记忆: {build_result.get('memory_count', 0)} | "
f"耗时: {duration:.2f}s"
)
return result
except Exception as e:
logger.error(f"❌ 海马体采样周期失败: {e}")
return {
"status": "error",
"error": str(e),
"sample_count": self.sample_count,
"duration": time.time() - start_time,
}
async def _collect_all_message_samples(self, time_samples: list[datetime]) -> list[list[dict[str, Any]]]:
"""批量收集所有时间点的消息样本"""
collected_messages = []
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
for i in range(0, len(time_samples), max_concurrent):
batch = time_samples[i : i + max_concurrent]
tasks = []
# 创建并发收集任务
for timestamp in batch:
target_ts = timestamp.timestamp()
task = self.collect_message_samples(target_ts)
tasks.append(task)
# 执行并发收集
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理收集结果
for result in results:
if isinstance(result, list) and result:
collected_messages.append(result)
elif isinstance(result, Exception):
logger.debug(f"消息收集异常: {result}")
# 批次间短暂延迟
if i + max_concurrent < len(time_samples):
await asyncio.sleep(0.5)
return collected_messages
async def _fuse_and_deduplicate_messages(
self, collected_messages: list[list[dict[str, Any]]]
) -> list[list[dict[str, Any]]]:
"""融合和去重消息样本"""
if not collected_messages:
return []
try:
# 展平所有消息
all_messages = []
for message_group in collected_messages:
all_messages.extend(message_group)
logger.debug(f"展开后总消息数: {len(all_messages)}")
# 去重逻辑:基于消息内容和时间戳
unique_messages = []
seen_hashes = set()
for message in all_messages:
# 创建消息哈希用于去重
content = message.get("processed_plain_text", "") or message.get("display_message", "")
timestamp = message.get("time", 0)
chat_id = message.get("chat_id", "")
# 简单哈希内容前50字符 + 时间戳(精确到分钟) + 聊天ID
hash_key = f"{content[:50]}_{int(timestamp // 60)}_{chat_id}"
if hash_key not in seen_hashes and len(content.strip()) > 10:
seen_hashes.add(hash_key)
unique_messages.append(message)
logger.debug(f"去重后消息数: {len(unique_messages)}")
# 按时间排序
unique_messages.sort(key=lambda x: x.get("time", 0))
# 按聊天ID分组重新组织
chat_groups = {}
for message in unique_messages:
chat_id = message.get("chat_id", "unknown")
if chat_id not in chat_groups:
chat_groups[chat_id] = []
chat_groups[chat_id].append(message)
# 合并相邻时间范围内的消息
fused_groups = []
for chat_id, messages in chat_groups.items():
fused_groups.extend(self._merge_adjacent_messages(messages))
logger.debug(f"融合后消息组数: {len(fused_groups)}")
return fused_groups
except Exception as e:
logger.error(f"消息融合失败: {e}")
# 返回原始消息组作为备选
return collected_messages[:5] # 限制返回数量
def _merge_adjacent_messages(
self, messages: list[dict[str, Any]], time_gap: int = 1800
) -> list[list[dict[str, Any]]]:
"""合并时间间隔内的消息"""
if not messages:
return []
merged_groups = []
current_group = [messages[0]]
for i in range(1, len(messages)):
current_time = messages[i].get("time", 0)
prev_time = current_group[-1].get("time", 0)
# 如果时间间隔小于阈值,合并到当前组
if current_time - prev_time <= time_gap:
current_group.append(messages[i])
else:
# 否则开始新组
merged_groups.append(current_group)
current_group = [messages[i]]
# 添加最后一组
merged_groups.append(current_group)
# 过滤掉只有一条消息的组(除非内容较长)
result_groups = [
group for group in merged_groups
if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group)
]
return result_groups
async def _build_batch_memory(
self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]
) -> dict[str, Any]:
"""批量构建记忆"""
if not fused_messages:
return {"memory_count": 0, "memories": []}
try:
total_memories = []
total_memory_count = 0
# 构建融合后的文本
batch_input_text = await self._build_fused_conversation_text(fused_messages)
if not batch_input_text:
logger.warning("无法构建融合文本,尝试单独构建")
# 备选方案:分别构建
return await self._fallback_individual_build(fused_messages)
# 创建批量上下文
batch_context = {
"user_id": "hippocampus_batch_sampler",
"timestamp": time.time(),
"source": "hippocampus_batch_sampling",
"message_groups_count": len(fused_messages),
"total_messages": sum(len(group) for group in fused_messages),
"sample_count": len(time_samples),
"is_hippocampus_sample": True,
"bypass_value_threshold": True,
"optimization_mode": "batch_fusion",
}
logger.debug(f"批量构建记忆,文本长度: {len(batch_input_text)}")
# 一次性构建记忆
memories = await self.memory_system.build_memory_from_conversation(
conversation_text=batch_input_text, context=batch_context, timestamp=time.time(), bypass_interval=True
)
if memories:
memory_count = len(memories)
self.success_count += 1
total_memory_count += memory_count
total_memories.extend(memories)
logger.info(f"✅ 批量海马体采样成功构建 {memory_count} 条记忆")
else:
logger.debug("批量海马体采样未生成有效记忆")
# 记录采样结果
result = {
"timestamp": time.time(),
"memory_count": total_memory_count,
"message_groups_count": len(fused_messages),
"total_messages": sum(len(group) for group in fused_messages),
"text_preview": batch_input_text[:200] + "..." if len(batch_input_text) > 200 else batch_input_text,
"memory_types": [m.memory_type.value for m in total_memories],
}
self.last_sample_results.append(result)
# 限制结果历史长度
if len(self.last_sample_results) > 10:
self.last_sample_results.pop(0)
return {"memory_count": total_memory_count, "memories": total_memories, "result": result}
except Exception as e:
logger.error(f"批量构建记忆失败: {e}")
return {"memory_count": 0, "error": str(e)}
async def _build_fused_conversation_text(self, fused_messages: list[list[dict[str, Any]]]) -> str:
"""构建融合后的对话文本"""
try:
conversation_parts = []
for group_idx, message_group in enumerate(fused_messages):
if not message_group:
continue
# 为每个消息组添加分隔符
group_header = f"\n=== 对话片段 {group_idx + 1} ==="
conversation_parts.append(group_header)
# 构建可读消息
group_text = await build_readable_messages(
message_group,
merge_messages=True,
timestamp_mode="normal_no_YMD",
replace_bot_name=False,
)
if group_text and len(group_text.strip()) > 10:
conversation_parts.append(group_text.strip())
return "\n".join(conversation_parts)
except Exception as e:
logger.error(f"构建融合文本失败: {e}")
return ""
async def _fallback_individual_build(self, fused_messages: list[list[dict[str, Any]]]) -> dict[str, Any]:
"""备选方案:单独构建每个消息组"""
total_memories = []
total_count = 0
for group in fused_messages[:5]: # 限制最多5组
try:
memories = await self.build_memory_from_samples(group, time.time())
if memories:
total_memories.extend(memories)
total_count += len(memories)
except Exception as e:
logger.debug(f"单独构建失败: {e}")
return {"memory_count": total_count, "memories": total_memories, "fallback_mode": True}
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
"""处理单个时间戳采样(保留作为备选方法)"""
try:
# 收集消息样本
messages = await self.collect_message_samples(target_timestamp)
if not messages:
return None
# 构建记忆
result = await self.build_memory_from_samples(messages, target_timestamp)
return result
except Exception as e:
logger.debug(f"处理时间戳采样失败 {target_timestamp}: {e}")
return None
def should_sample(self) -> bool:
"""检查是否应该进行采样"""
current_time = time.time()
# 检查时间间隔
if current_time - self.last_sample_time < self.config.sample_interval:
return False
# 检查是否已初始化
if not self.memory_builder_model:
logger.warning("海马体采样器未初始化")
return False
return True
async def start_background_sampling(self):
"""启动后台采样"""
if self.is_running:
logger.warning("海马体后台采样已在运行")
return
self.is_running = True
logger.info("🚀 启动海马体后台采样任务")
try:
while self.is_running:
try:
# 执行采样周期
result = await self.perform_sampling_cycle()
# 如果是跳过状态,短暂睡眠
if result.get("status") == "skipped":
await asyncio.sleep(60) # 1分钟后重试
else:
# 正常等待下一个采样间隔
await asyncio.sleep(self.config.sample_interval)
except Exception as e:
logger.error(f"海马体后台采样异常: {e}")
await asyncio.sleep(300) # 异常时等待5分钟
except asyncio.CancelledError:
logger.info("海马体后台采样任务被取消")
finally:
self.is_running = False
def stop_background_sampling(self):
"""停止后台采样"""
self.is_running = False
logger.info("🛑 停止海马体后台采样任务")
def get_sampling_stats(self) -> dict[str, Any]:
"""获取采样统计信息"""
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
# 计算最近的平均数据
recent_avg_messages = 0
recent_avg_memory_count = 0
if self.last_sample_results:
recent_results = self.last_sample_results[-5:] # 最近5次
recent_avg_messages = sum(r.get("total_messages", 0) for r in recent_results) / len(recent_results)
recent_avg_memory_count = sum(r.get("memory_count", 0) for r in recent_results) / len(recent_results)
return {
"is_running": self.is_running,
"sample_count": self.sample_count,
"success_count": self.success_count,
"success_rate": f"{success_rate:.1f}%",
"last_sample_time": self.last_sample_time,
"optimization_mode": "batch_fusion", # 显示优化模式
"performance_metrics": {
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
"fusion_efficiency": f"{(recent_avg_messages / max(recent_avg_memory_count, 1)):.1f}x"
if recent_avg_messages > 0
else "N/A",
},
"config": {
"sample_interval": self.config.sample_interval,
"total_samples": self.config.total_samples,
"recent_weight": f"{self.config.recent_weight:.1%}",
"distant_weight": f"{self.config.distant_weight:.1%}",
"max_concurrent": 5, # 批量模式并发数
"fusion_time_gap": "30分钟", # 消息融合时间间隔
},
"recent_results": self.last_sample_results[-5:], # 最近5次结果
}
# 全局海马体采样器实例
_hippocampus_sampler: HippocampusSampler | None = None
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
"""获取全局海马体采样器实例"""
global _hippocampus_sampler
if _hippocampus_sampler is None:
_hippocampus_sampler = HippocampusSampler(memory_system)
return _hippocampus_sampler
async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
"""初始化全局海马体采样器"""
sampler = get_hippocampus_sampler(memory_system)
await sampler.initialize()
return sampler

View File

@@ -1,238 +0,0 @@
"""
记忆激活器
记忆系统的激活器组件
"""
import difflib
from datetime import datetime
import orjson
from json_repair import repair_json
from src.chat.memory_system.memory_manager import MemoryResult
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> list:
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Memory Activator Prompt ---
memory_activator_prompt = """
你是一个记忆分析器,你需要根据以下信息来进行记忆检索
以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆检索的触发词
聊天记录:
{obs_info_text}
用户想要回复的消息:
{target_message}
历史关键词(请避免重复提取这些关键词):
{cached_keywords}
请输出一个json格式包含以下字段
{{
"keywords": ["关键词1", "关键词2", "关键词3",......]
}}
不要输出其他多余内容只输出json格式就好
"""
Prompt(memory_activator_prompt, "memory_activator_prompt")
class MemoryActivator:
"""记忆激活器"""
def __init__(self):
self.key_words_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
self.running_memory = []
self.cached_keywords = set() # 用于缓存历史关键词
self.last_memory_query_time = 0 # 上次查询记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
# 将缓存的关键词转换为字符串用于prompt
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
prompt = await global_prompt_manager.format_prompt(
"memory_activator_prompt",
obs_info_text=chat_history_prompt,
target_message=target_message,
cached_keywords=cached_keywords_str,
)
# 生成关键词
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
prompt, temperature=0.5
)
keywords = list(get_keywords_from_json(response))
# 更新关键词缓存
if keywords:
# 限制缓存大小最多保留10个关键词
if len(self.cached_keywords) > 10:
# 转换为列表,移除最早的关键词
cached_list = list(self.cached_keywords)
self.cached_keywords = set(cached_list[-8:])
# 添加新的关键词到缓存
self.cached_keywords.update(keywords)
logger.debug(f"记忆关键词: {self.cached_keywords}")
# 使用记忆系统获取相关记忆
memory_results = await self._query_unified_memory(keywords, target_message)
# 处理和记忆结果
if memory_results:
for result in memory_results:
# 检查是否已存在相似内容的记忆
exists = any(
m["content"] == result.content
or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
for m in self.running_memory
)
if not exists:
memory_entry = {
"topic": result.memory_type,
"content": result.content,
"timestamp": datetime.fromtimestamp(result.timestamp).isoformat(),
"duration": 1,
"confidence": result.confidence,
"importance": result.importance,
"source": result.source,
"relevance_score": result.relevance_score, # 添加相关度评分
}
self.running_memory.append(memory_entry)
logger.debug(f"添加新记忆: {result.memory_type} - {result.content}")
# 激活时所有已有记忆的duration+1达到3则移除
for m in self.running_memory[:]:
m["duration"] = m.get("duration", 1) + 1
self.running_memory = [m for m in self.running_memory if m["duration"] < 3]
# 限制同时加载的记忆条数最多保留最后5条
if len(self.running_memory) > 5:
self.running_memory = self.running_memory[-5:]
return self.running_memory
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
"""查询统一记忆系统"""
try:
# 使用记忆系统
from src.chat.memory_system.memory_system import get_memory_system
memory_system = get_memory_system()
if not memory_system or memory_system.status.value != "ready":
logger.warning("记忆系统未就绪")
return []
# 构建查询上下文
context = {"keywords": keywords, "query_intent": "conversation_response"}
# 查询记忆
memories = await memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="global", # 使用全局作用域
context=context,
limit=5,
)
# 转换为 MemoryResult 格式
memory_results = []
for memory in memories:
result = MemoryResult(
content=memory.display,
memory_type=memory.memory_type.value,
confidence=memory.metadata.confidence.value,
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="unified_memory",
relevance_score=memory.metadata.relevance_score,
)
memory_results.append(result)
logger.debug(f"统一记忆查询返回 {len(memory_results)} 条结果")
return memory_results
except Exception as e:
logger.error(f"查询统一记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
"""
获取即时记忆 - 兼容原有接口(使用统一存储)
"""
try:
# 使用统一存储系统获取相关记忆
from src.chat.memory_system.memory_system import get_memory_system
memory_system = get_memory_system()
if not memory_system or memory_system.status.value != "ready":
return None
context = {"query_intent": "instant_response", "chat_id": chat_id}
memories = await memory_system.retrieve_relevant_memories(
query_text=target_message, user_id="global", context=context, limit=1
)
if memories:
return memories[0].display
return None
except Exception as e:
logger.error(f"获取即时记忆失败: {e}")
return None
def clear_cache(self):
"""清除缓存"""
self.cached_keywords.clear()
self.running_memory.clear()
logger.debug("记忆激活器缓存已清除")
# 创建全局实例
memory_activator = MemoryActivator()
init_prompt()

File diff suppressed because it is too large Load Diff

View File

@@ -1,647 +0,0 @@
"""
结构化记忆单元设计
实现高质量、结构化的记忆单元,符合文档设计规范
"""
import hashlib
import time
import uuid
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import numpy as np
import orjson
from src.common.logger import get_logger
logger = get_logger(__name__)
class MemoryType(Enum):
"""记忆类型分类"""
PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等)
EVENT = "event" # 事件(重要经历、约会等)
PREFERENCE = "preference" # 偏好(喜好、习惯等)
OPINION = "opinion" # 观点(对事物的看法)
RELATIONSHIP = "relationship" # 关系(与他人的关系)
EMOTION = "emotion" # 情感状态
KNOWLEDGE = "knowledge" # 知识信息
SKILL = "skill" # 技能能力
GOAL = "goal" # 目标计划
EXPERIENCE = "experience" # 经验教训
CONTEXTUAL = "contextual" # 上下文信息
class ConfidenceLevel(Enum):
"""置信度等级"""
LOW = 1 # 低置信度,可能不准确
MEDIUM = 2 # 中等置信度,有一定依据
HIGH = 3 # 高置信度,有明确来源
VERIFIED = 4 # 已验证,非常可靠
class ImportanceLevel(Enum):
"""重要性等级"""
LOW = 1 # 低重要性,普通信息
NORMAL = 2 # 一般重要性,日常信息
HIGH = 3 # 高重要性,重要信息
CRITICAL = 4 # 关键重要性,核心信息
@dataclass
class ContentStructure:
"""主谓宾结构,包含自然语言描述"""
subject: str | list[str]
predicate: str
object: str | dict
display: str = ""
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ContentStructure":
"""从字典创建实例"""
return cls(
subject=data.get("subject", ""),
predicate=data.get("predicate", ""),
object=data.get("object", ""),
display=data.get("display", ""),
)
def to_subject_list(self) -> list[str]:
"""将主语转换为列表形式"""
if isinstance(self.subject, list):
return [s for s in self.subject if isinstance(s, str) and s.strip()]
if isinstance(self.subject, str) and self.subject.strip():
return [self.subject.strip()]
return []
def __str__(self) -> str:
"""字符串表示"""
if self.display:
return self.display
subjects = "".join(self.to_subject_list()) or str(self.subject)
object_str = self.object if isinstance(self.object, str) else str(self.object)
return f"{subjects} {self.predicate} {object_str}".strip()
@dataclass
class MemoryMetadata:
"""记忆元数据 - 简化版本"""
# 基础信息
memory_id: str # 唯一标识符
user_id: str # 用户ID
chat_id: str | None = None # 聊天ID群聊或私聊
# 时间信息
created_at: float = 0.0 # 创建时间戳
last_accessed: float = 0.0 # 最后访问时间
last_modified: float = 0.0 # 最后修改时间
# 激活频率管理
last_activation_time: float = 0.0 # 最后激活时间
activation_frequency: int = 0 # 激活频率(单位时间内的激活次数)
total_activations: int = 0 # 总激活次数
# 统计信息
access_count: int = 0 # 访问次数
relevance_score: float = 0.0 # 相关度评分
# 信心和重要性(核心字段)
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM
importance: ImportanceLevel = ImportanceLevel.NORMAL
# 遗忘机制相关
forgetting_threshold: float = 0.0 # 遗忘阈值(动态计算)
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
# 来源信息
source_context: str | None = None # 来源上下文片段
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
source: str | None = None
def __post_init__(self):
"""后初始化处理"""
if not self.memory_id:
self.memory_id = str(uuid.uuid4())
current_time = time.time()
if self.created_at == 0:
self.created_at = current_time
if self.last_accessed == 0:
self.last_accessed = current_time
if self.last_modified == 0:
self.last_modified = current_time
if self.last_activation_time == 0:
self.last_activation_time = current_time
if self.last_forgetting_check == 0:
self.last_forgetting_check = current_time
# 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步
if not getattr(self, "source", None) and getattr(self, "source_context", None):
try:
self.source = str(self.source_context)
except Exception:
self.source = None
# 如果有 source 字段但 source_context 为空,也同步回去
if not getattr(self, "source_context", None) and getattr(self, "source", None):
try:
self.source_context = str(self.source)
except Exception:
self.source_context = None
def update_access(self):
"""更新访问信息"""
current_time = time.time()
self.last_accessed = current_time
self.access_count += 1
self.total_activations += 1
# 更新激活频率
self._update_activation_frequency(current_time)
def _update_activation_frequency(self, current_time: float):
"""更新激活频率24小时内的激活次数"""
# 如果超过24小时重置激活频率
if current_time - self.last_activation_time > 86400: # 24小时 = 86400秒
self.activation_frequency = 1
else:
self.activation_frequency += 1
self.last_activation_time = current_time
def update_relevance(self, new_score: float):
"""更新相关度评分"""
self.relevance_score = max(0.0, min(1.0, new_score))
self.last_modified = time.time()
def calculate_forgetting_threshold(self) -> float:
"""计算遗忘阈值(天数)"""
# 基础天数
base_days = 30.0
# 重要性权重 (1-4 -> 0-3)
importance_weight = (self.importance.value - 1) * 15 # 0, 15, 30, 45
# 置信度权重 (1-4 -> 0-3)
confidence_weight = (self.confidence.value - 1) * 10 # 0, 10, 20, 30
# 激活频率权重每5次激活增加1天
frequency_weight = min(self.activation_frequency, 20) * 0.5 # 最多10天
# 计算最终阈值
threshold = base_days + importance_weight + confidence_weight + frequency_weight
# 设置最小和最大阈值
return max(7.0, min(threshold, 365.0)) # 7天到1年之间
def should_forget(self, current_time: float | None = None) -> bool:
"""判断是否应该遗忘"""
if current_time is None:
current_time = time.time()
# 计算遗忘阈值
self.forgetting_threshold = self.calculate_forgetting_threshold()
# 计算距离最后激活的时间
days_since_activation = (current_time - self.last_activation_time) / 86400
return days_since_activation > self.forgetting_threshold
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
"""判断是否处于休眠状态(长期未激活)"""
if current_time is None:
current_time = time.time()
days_since_last_access = (current_time - self.last_accessed) / 86400
return days_since_last_access > inactive_days
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"memory_id": self.memory_id,
"user_id": self.user_id,
"chat_id": self.chat_id,
"created_at": self.created_at,
"last_accessed": self.last_accessed,
"last_modified": self.last_modified,
"last_activation_time": self.last_activation_time,
"activation_frequency": self.activation_frequency,
"total_activations": self.total_activations,
"access_count": self.access_count,
"relevance_score": self.relevance_score,
"confidence": self.confidence.value,
"importance": self.importance.value,
"forgetting_threshold": self.forgetting_threshold,
"last_forgetting_check": self.last_forgetting_check,
"source_context": self.source_context,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata":
"""从字典创建实例"""
return cls(
memory_id=data.get("memory_id", ""),
user_id=data.get("user_id", ""),
chat_id=data.get("chat_id"),
created_at=data.get("created_at", 0),
last_accessed=data.get("last_accessed", 0),
last_modified=data.get("last_modified", 0),
last_activation_time=data.get("last_activation_time", 0),
activation_frequency=data.get("activation_frequency", 0),
total_activations=data.get("total_activations", 0),
access_count=data.get("access_count", 0),
relevance_score=data.get("relevance_score", 0.0),
confidence=ConfidenceLevel(data.get("confidence", ConfidenceLevel.MEDIUM.value)),
importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)),
forgetting_threshold=data.get("forgetting_threshold", 0.0),
last_forgetting_check=data.get("last_forgetting_check", 0),
source_context=data.get("source_context"),
)
@dataclass
class MemoryChunk:
"""结构化记忆单元 - 核心数据结构"""
# 元数据
metadata: MemoryMetadata
# 内容结构
content: ContentStructure # 主谓宾结构
memory_type: MemoryType # 记忆类型
# 扩展信息
keywords: list[str] = field(default_factory=list) # 关键词列表
tags: list[str] = field(default_factory=list) # 标签列表
categories: list[str] = field(default_factory=list) # 分类列表
# 语义信息
embedding: list[float] | None = None # 语义向量
semantic_hash: str | None = None # 语义哈希值
# 关联信息
related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表
temporal_context: dict[str, Any] | None = None # 时间上下文
def __post_init__(self):
"""后初始化处理"""
if self.embedding and len(self.embedding) > 0:
self._generate_semantic_hash()
def _generate_semantic_hash(self):
"""生成语义哈希值"""
if not self.embedding:
return
try:
# 使用向量和内容生成稳定的哈希
content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}"
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
hash_input = f"{content_str}|{embedding_str}"
hash_object = hashlib.sha256(hash_input.encode("utf-8"))
self.semantic_hash = hash_object.hexdigest()[:16]
except Exception as e:
logger.warning(f"生成语义哈希失败: {e}")
self.semantic_hash = str(uuid.uuid4())[:16]
@property
def memory_id(self) -> str:
"""获取记忆ID"""
return self.metadata.memory_id
@property
def user_id(self) -> str:
"""获取用户ID"""
return self.metadata.user_id
@property
def text_content(self) -> str:
"""获取文本内容优先使用display"""
return str(self.content)
@property
def display(self) -> str:
"""获取展示文本"""
return self.content.display or str(self.content)
@property
def subjects(self) -> list[str]:
"""获取主语列表"""
return self.content.to_subject_list()
def update_access(self):
"""更新访问信息"""
self.metadata.update_access()
def update_relevance(self, new_score: float):
"""更新相关度评分"""
self.metadata.update_relevance(new_score)
def should_forget(self, current_time: float | None = None) -> bool:
"""判断是否应该遗忘"""
return self.metadata.should_forget(current_time)
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
"""判断是否处于休眠状态(长期未激活)"""
return self.metadata.is_dormant(current_time, inactive_days)
def calculate_forgetting_threshold(self) -> float:
"""计算遗忘阈值(天数)"""
return self.metadata.calculate_forgetting_threshold()
def add_keyword(self, keyword: str):
"""添加关键词"""
if keyword and keyword not in self.keywords:
self.keywords.append(keyword.strip())
def add_tag(self, tag: str):
"""添加标签"""
if tag and tag not in self.tags:
self.tags.append(tag.strip())
def add_category(self, category: str):
"""添加分类"""
if category and category not in self.categories:
self.categories.append(category.strip())
def add_related_memory(self, memory_id: str):
"""添加关联记忆"""
if memory_id and memory_id not in self.related_memories:
self.related_memories.append(memory_id)
def set_embedding(self, embedding: list[float]):
"""设置语义向量"""
self.embedding = embedding
self._generate_semantic_hash()
def calculate_similarity(self, other: "MemoryChunk") -> float:
"""计算与另一个记忆块的相似度"""
if not self.embedding or not other.embedding:
return 0.0
try:
# 计算余弦相似度
v1 = np.array(self.embedding)
v2 = np.array(other.embedding)
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0.0
similarity = dot_product / (norm1 * norm2)
return max(0.0, min(1.0, similarity))
except Exception as e:
logger.warning(f"计算记忆相似度失败: {e}")
return 0.0
def to_dict(self) -> dict[str, Any]:
"""转换为完整的字典格式"""
return {
"metadata": self.metadata.to_dict(),
"content": self.content.to_dict(),
"memory_type": self.memory_type.value,
"keywords": self.keywords,
"tags": self.tags,
"categories": self.categories,
"embedding": self.embedding,
"semantic_hash": self.semantic_hash,
"related_memories": self.related_memories,
"temporal_context": self.temporal_context,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk":
"""从字典创建实例"""
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
content = ContentStructure.from_dict(data.get("content", {}))
chunk = cls(
metadata=metadata,
content=content,
memory_type=MemoryType(data.get("memory_type", MemoryType.CONTEXTUAL.value)),
keywords=data.get("keywords", []),
tags=data.get("tags", []),
categories=data.get("categories", []),
embedding=data.get("embedding"),
semantic_hash=data.get("semantic_hash"),
related_memories=data.get("related_memories", []),
temporal_context=data.get("temporal_context"),
)
return chunk
def to_json(self) -> str:
"""转换为JSON字符串"""
return orjson.dumps(self.to_dict()).decode("utf-8")
@classmethod
def from_json(cls, json_str: str) -> "MemoryChunk":
"""从JSON字符串创建实例"""
try:
data = orjson.loads(json_str)
return cls.from_dict(data)
except Exception as e:
logger.error(f"从JSON创建记忆块失败: {e}")
raise
def is_similar_to(self, other: "MemoryChunk", threshold: float = 0.8) -> bool:
"""判断是否与另一个记忆块相似"""
if self.semantic_hash and other.semantic_hash:
return self.semantic_hash == other.semantic_hash
return self.calculate_similarity(other) >= threshold
def merge_with(self, other: "MemoryChunk") -> bool:
"""与另一个记忆块合并(如果相似)"""
if not self.is_similar_to(other):
return False
try:
# 合并关键词
for keyword in other.keywords:
self.add_keyword(keyword)
# 合并标签
for tag in other.tags:
self.add_tag(tag)
# 合并分类
for category in other.categories:
self.add_category(category)
# 合并关联记忆
for memory_id in other.related_memories:
self.add_related_memory(memory_id)
# 更新元数据
self.metadata.last_modified = time.time()
self.metadata.access_count += other.metadata.access_count
self.metadata.relevance_score = max(self.metadata.relevance_score, other.metadata.relevance_score)
# 更新置信度
if other.metadata.confidence.value > self.metadata.confidence.value:
self.metadata.confidence = other.metadata.confidence
# 更新重要性
if other.metadata.importance.value > self.metadata.importance.value:
self.metadata.importance = other.metadata.importance
logger.debug(f"记忆块 {self.memory_id} 合并了记忆块 {other.memory_id}")
return True
except Exception as e:
logger.error(f"合并记忆块失败: {e}")
return False
def __str__(self) -> str:
"""字符串表示"""
type_emoji = {
MemoryType.PERSONAL_FACT: "👤",
MemoryType.EVENT: "📅",
MemoryType.PREFERENCE: "❤️",
MemoryType.OPINION: "💭",
MemoryType.RELATIONSHIP: "👥",
MemoryType.EMOTION: "😊",
MemoryType.KNOWLEDGE: "📚",
MemoryType.SKILL: "🛠️",
MemoryType.GOAL: "🎯",
MemoryType.EXPERIENCE: "💡",
MemoryType.CONTEXTUAL: "📝",
}
emoji = type_emoji.get(self.memory_type, "📝")
confidence_icon = "" * self.metadata.confidence.value
importance_icon = "" * self.metadata.importance.value
return f"{emoji} [{self.memory_type.value}] {self.display} {confidence_icon} {importance_icon}"
def __repr__(self) -> str:
"""调试表示"""
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str:
"""根据主谓宾生成自然语言描述"""
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
subject_part = "".join(subjects_clean) if subjects_clean else "对话参与者"
if isinstance(obj, dict):
object_candidates = []
for key, value in obj.items():
if isinstance(value, str | int | float):
object_candidates.append(f"{key}:{value}")
elif isinstance(value, list):
compact = "".join(str(item) for item in value[:3])
object_candidates.append(f"{key}:{compact}")
object_part = "".join(object_candidates) if object_candidates else str(obj)
else:
object_part = str(obj).strip()
predicate_clean = predicate.strip()
if not predicate_clean:
return f"{subject_part} {object_part}".strip()
if object_part:
return f"{subject_part}{predicate_clean}{object_part}".strip()
return f"{subject_part}{predicate_clean}".strip()
def create_memory_chunk(
user_id: str,
subject: str | list[str],
predicate: str,
obj: str | dict,
memory_type: MemoryType,
chat_id: str | None = None,
source_context: str | None = None,
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
display: str | None = None,
**kwargs,
) -> MemoryChunk:
"""便捷的内存块创建函数"""
metadata = MemoryMetadata(
memory_id="",
user_id=user_id,
chat_id=chat_id,
created_at=time.time(),
last_accessed=0,
last_modified=0,
confidence=confidence,
importance=importance,
source_context=source_context,
)
subjects: list[str]
if isinstance(subject, list):
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
subject_payload: str | list[str] = subjects
else:
cleaned = subject.strip() if isinstance(subject, str) else ""
subjects = [cleaned] if cleaned else []
subject_payload = cleaned
display_text = display or _build_display_text(subjects, predicate, obj)
content = ContentStructure(subject=subject_payload, predicate=predicate, object=obj, display=display_text)
chunk = MemoryChunk(metadata=metadata, content=content, memory_type=memory_type, **kwargs)
return chunk
@dataclass
class MessageCollection:
"""消息集合数据结构"""
collection_id: str = field(default_factory=lambda: str(uuid.uuid4()))
chat_id: str | None = None # 聊天ID群聊或私聊
messages: list[str] = field(default_factory=list)
combined_text: str = ""
created_at: float = field(default_factory=time.time)
embedding: list[float] | None = None
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"collection_id": self.collection_id,
"chat_id": self.chat_id,
"messages": self.messages,
"combined_text": self.combined_text,
"created_at": self.created_at,
"embedding": self.embedding,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MessageCollection":
"""从字典创建实例"""
return cls(
collection_id=data.get("collection_id", str(uuid.uuid4())),
chat_id=data.get("chat_id"),
messages=data.get("messages", []),
combined_text=data.get("combined_text", ""),
created_at=data.get("created_at", time.time()),
embedding=data.get("embedding"),
)

View File

@@ -1,355 +0,0 @@
"""
智能记忆遗忘引擎
基于重要程度、置信度和激活频率的智能遗忘机制
"""
import asyncio
import time
from dataclasses import dataclass
from datetime import datetime
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ForgettingStats:
"""遗忘统计信息"""
total_checked: int = 0
marked_for_forgetting: int = 0
actually_forgotten: int = 0
dormant_memories: int = 0
last_check_time: float = 0.0
check_duration: float = 0.0
@dataclass
class ForgettingConfig:
"""遗忘引擎配置"""
# 检查频率配置
check_interval_hours: int = 24 # 定期检查间隔(小时)
batch_size: int = 100 # 批处理大小
# 遗忘阈值配置
base_forgetting_days: float = 30.0 # 基础遗忘天数
min_forgetting_days: float = 7.0 # 最小遗忘天数
max_forgetting_days: float = 365.0 # 最大遗忘天数
# 重要程度权重
critical_importance_bonus: float = 45.0 # 关键重要性额外天数
high_importance_bonus: float = 30.0 # 高重要性额外天数
normal_importance_bonus: float = 15.0 # 一般重要性额外天数
low_importance_bonus: float = 0.0 # 低重要性额外天数
# 置信度权重
verified_confidence_bonus: float = 30.0 # 已验证置信度额外天数
high_confidence_bonus: float = 20.0 # 高置信度额外天数
medium_confidence_bonus: float = 10.0 # 中等置信度额外天数
low_confidence_bonus: float = 0.0 # 低置信度额外天数
# 激活频率权重
activation_frequency_weight: float = 0.5 # 每次激活增加的天数权重
max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数
# 休眠配置
dormant_threshold_days: int = 90 # 休眠状态判定天数
force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数
class MemoryForgettingEngine:
"""智能记忆遗忘引擎"""
def __init__(self, config: ForgettingConfig | None = None):
self.config = config or ForgettingConfig()
self.stats = ForgettingStats()
self._last_forgetting_check = 0.0
self._forgetting_lock = asyncio.Lock()
logger.info("MemoryForgettingEngine 初始化完成")
def calculate_forgetting_threshold(self, memory: MemoryChunk) -> float:
"""
计算记忆的遗忘阈值(天数)
Args:
memory: 记忆块
Returns:
遗忘阈值(天数)
"""
# 基础天数
threshold = self.config.base_forgetting_days
# 重要性权重
importance = memory.metadata.importance
if importance == ImportanceLevel.CRITICAL:
threshold += self.config.critical_importance_bonus
elif importance == ImportanceLevel.HIGH:
threshold += self.config.high_importance_bonus
elif importance == ImportanceLevel.NORMAL:
threshold += self.config.normal_importance_bonus
# LOW 级别不增加额外天数
# 置信度权重
confidence = memory.metadata.confidence
if confidence == ConfidenceLevel.VERIFIED:
threshold += self.config.verified_confidence_bonus
elif confidence == ConfidenceLevel.HIGH:
threshold += self.config.high_confidence_bonus
elif confidence == ConfidenceLevel.MEDIUM:
threshold += self.config.medium_confidence_bonus
# LOW 级别不增加额外天数
# 激活频率权重
frequency_bonus = min(
memory.metadata.activation_frequency * self.config.activation_frequency_weight,
self.config.max_frequency_bonus,
)
threshold += frequency_bonus
# 确保在合理范围内
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断记忆是否应该被遗忘
Args:
memory: 记忆块
current_time: 当前时间戳
Returns:
是否应该遗忘
"""
if current_time is None:
current_time = time.time()
# 关键重要性的记忆永不遗忘
if memory.metadata.importance == ImportanceLevel.CRITICAL:
return False
# 计算遗忘阈值
forgetting_threshold = self.calculate_forgetting_threshold(memory)
# 计算距离最后激活的时间
days_since_activation = (current_time - memory.metadata.last_activation_time) / 86400
# 判断是否超过阈值
should_forget = days_since_activation > forgetting_threshold
if should_forget:
logger.debug(
f"记忆 {memory.memory_id[:8]} 触发遗忘条件: "
f"重要性={memory.metadata.importance.name}, "
f"置信度={memory.metadata.confidence.name}, "
f"激活频率={memory.metadata.activation_frequency}, "
f"阈值={forgetting_threshold:.1f}天, "
f"未激活天数={days_since_activation:.1f}"
)
return should_forget
def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断记忆是否处于休眠状态
Args:
memory: 记忆块
current_time: 当前时间戳
Returns:
是否处于休眠状态
"""
return memory.is_dormant(current_time, self.config.dormant_threshold_days)
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断是否应该强制遗忘休眠记忆
Args:
memory: 记忆块
current_time: 当前时间戳
Returns:
是否应该强制遗忘
"""
if current_time is None:
current_time = time.time()
# 只有非关键重要性的记忆才会被强制遗忘
if memory.metadata.importance == ImportanceLevel.CRITICAL:
return False
days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400
return days_since_last_access > self.config.force_forget_dormant_days
async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]:
"""
检查记忆列表,识别需要遗忘的记忆
Args:
memories: 记忆块列表
Returns:
(普通遗忘列表, 强制遗忘列表)
"""
start_time = time.time()
current_time = start_time
normal_forgetting_ids = []
force_forgetting_ids = []
self.stats.total_checked = len(memories)
self.stats.last_check_time = current_time
for memory in memories:
try:
# 检查休眠状态
if self.is_dormant_memory(memory, current_time):
self.stats.dormant_memories += 1
# 检查是否应该强制遗忘休眠记忆
if self.should_force_forget_dormant(memory, current_time):
force_forgetting_ids.append(memory.memory_id)
logger.debug(f"休眠记忆 {memory.memory_id[:8]} 被标记为强制遗忘")
continue
# 检查普通遗忘条件
if self.should_forget_memory(memory, current_time):
normal_forgetting_ids.append(memory.memory_id)
self.stats.marked_for_forgetting += 1
except Exception as e:
logger.warning(f"检查记忆 {memory.memory_id[:8]} 遗忘状态失败: {e}")
continue
self.stats.check_duration = time.time() - start_time
logger.info(
f"遗忘检查完成 | 总数={self.stats.total_checked}, "
f"标记遗忘={len(normal_forgetting_ids)}, "
f"强制遗忘={len(force_forgetting_ids)}, "
f"休眠={self.stats.dormant_memories}, "
f"耗时={self.stats.check_duration:.3f}s"
)
return normal_forgetting_ids, force_forgetting_ids
async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]:
"""
执行完整的遗忘检查流程
Args:
memories: 记忆块列表
Returns:
检查结果统计
"""
async with self._forgetting_lock:
normal_forgetting, force_forgetting = await self.check_memories_for_forgetting(memories)
# 更新统计
self.stats.actually_forgotten = len(normal_forgetting) + len(force_forgetting)
return {
"normal_forgetting": normal_forgetting,
"force_forgetting": force_forgetting,
"stats": {
"total_checked": self.stats.total_checked,
"marked_for_forgetting": self.stats.marked_for_forgetting,
"actually_forgotten": self.stats.actually_forgotten,
"dormant_memories": self.stats.dormant_memories,
"check_duration": self.stats.check_duration,
"last_check_time": self.stats.last_check_time,
},
}
def is_forgetting_check_needed(self) -> bool:
"""检查是否需要进行遗忘检查"""
current_time = time.time()
hours_since_last_check = (current_time - self._last_forgetting_check) / 3600
return hours_since_last_check >= self.config.check_interval_hours
async def schedule_periodic_check(self, memories_provider, enable_auto_cleanup: bool = True):
"""
定期执行遗忘检查(可以在后台任务中调用)
Args:
memories_provider: 提供记忆列表的函数
enable_auto_cleanup: 是否启用自动清理
"""
if not self.is_forgetting_check_needed():
return
try:
logger.info("开始执行定期遗忘检查...")
# 获取记忆列表
memories = await memories_provider()
if not memories:
logger.debug("无记忆数据需要检查")
return
# 执行遗忘检查
result = await self.perform_forgetting_check(memories)
# 如果启用自动清理,执行实际的遗忘操作
if enable_auto_cleanup and (result["normal_forgetting"] or result["force_forgetting"]):
logger.info(
f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆"
)
# 这里可以调用实际的删除逻辑
# await self.cleanup_forgotten_memories(result["normal_forgetting"] + result["force_forgetting"])
self._last_forgetting_check = time.time()
except Exception as e:
logger.error(f"定期遗忘检查失败: {e}", exc_info=True)
def get_forgetting_stats(self) -> dict[str, any]:
"""获取遗忘统计信息"""
return {
"total_checked": self.stats.total_checked,
"marked_for_forgetting": self.stats.marked_for_forgetting,
"actually_forgotten": self.stats.actually_forgotten,
"dormant_memories": self.stats.dormant_memories,
"last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat()
if self.stats.last_check_time
else None,
"last_check_duration": self.stats.check_duration,
"config": {
"check_interval_hours": self.config.check_interval_hours,
"base_forgetting_days": self.config.base_forgetting_days,
"min_forgetting_days": self.config.min_forgetting_days,
"max_forgetting_days": self.config.max_forgetting_days,
},
}
def reset_stats(self):
"""重置统计信息"""
self.stats = ForgettingStats()
logger.debug("遗忘统计信息已重置")
def update_config(self, **kwargs):
"""更新配置"""
for key, value in kwargs.items():
if hasattr(self.config, key):
setattr(self.config, key, value)
logger.debug(f"遗忘配置更新: {key} = {value}")
else:
logger.warning(f"未知的配置项: {key}")
# 创建全局遗忘引擎实例
memory_forgetting_engine = MemoryForgettingEngine()
def get_memory_forgetting_engine() -> MemoryForgettingEngine:
"""获取全局遗忘引擎实例"""
return memory_forgetting_engine

View File

@@ -1,120 +0,0 @@
"""记忆格式化工具
提供统一的记忆块格式化函数,供构建 Prompt 时使用。
当前使用的函数: format_memories_bracket_style
输入: list[dict] 其中每个元素包含:
- display: str 记忆可读内容
- memory_type: str 记忆类型 (personal_fact/opinion/preference/event 等)
- metadata: dict 可选,包括
- confidence: 置信度 (str|float)
- importance: 重要度 (str|float)
- timestamp: 时间戳 (float|str)
- source: 来源 (str)
- relevance_score: 相关度 (float)
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
"""
from __future__ import annotations
import time
from collections.abc import Iterable
from typing import Any
def _format_timestamp(ts: Any) -> str:
try:
if ts in (None, ""):
return ""
if isinstance(ts, int | float) and ts > 0:
return time.strftime("%Y-%m-%d %H:%M", time.localtime(float(ts)))
return str(ts)
except Exception:
return ""
def _coerce_str(v: Any) -> str:
if v is None:
return ""
return str(v)
def format_memories_bracket_style(
memories: Iterable[dict[str, Any]] | None,
query_context: str | None = None,
max_items: int = 15,
) -> str:
"""以方括号 + 标注字段的方式格式化记忆列表。
例子输出:
## 相关记忆回顾
- [类型:personal_fact|重要:高|置信:0.83|相关:0.72] 他喜欢黑咖啡 (来源: chat, 2025-10-05 09:30)
Args:
memories: 记忆字典迭代器
query_context: 当前查询/用户的消息,用于在首行提示(可选)
max_items: 最多输出的记忆条数
Returns:
str: 格式化文本;若无内容返回空串
"""
if not memories:
return ""
lines: list[str] = ["## 相关记忆回顾"]
if query_context:
lines.append(f"(与当前消息相关:{query_context[:60]}{'...' if len(query_context) > 60 else ''}")
lines.append("")
count = 0
for mem in memories:
if count >= max_items:
break
if not isinstance(mem, dict):
continue
display = _coerce_str(mem.get("display", "")).strip()
if not display:
continue
mtype = _coerce_str(mem.get("memory_type", "fact")) or "fact"
meta = mem.get("metadata", {}) if isinstance(mem.get("metadata"), dict) else {}
confidence = _coerce_str(meta.get("confidence", ""))
importance = _coerce_str(meta.get("importance", ""))
source = _coerce_str(meta.get("source", ""))
rel = meta.get("relevance_score")
try:
rel_str = f"{float(rel):.2f}" if rel is not None else ""
except Exception:
rel_str = ""
ts = _format_timestamp(meta.get("timestamp"))
# 构建标签段
tags: list[str] = [f"类型:{mtype}"]
if importance:
tags.append(f"重要:{importance}")
if confidence:
tags.append(f"置信:{confidence}")
if rel_str:
tags.append(f"相关:{rel_str}")
tag_block = "|".join(tags)
suffix_parts = []
if source:
suffix_parts.append(source)
if ts:
suffix_parts.append(ts)
suffix = (" (" + ", ".join(suffix_parts) + ")") if suffix_parts else ""
lines.append(f"- [{tag_block}] {display}{suffix}")
count += 1
if count == 0:
return ""
if count >= max_items:
lines.append(f"\n(已截断,仅显示前 {max_items} 条相关记忆)")
return "\n".join(lines)
__all__ = ["format_memories_bracket_style"]

View File

@@ -1,505 +0,0 @@
"""
记忆融合与去重机制
避免记忆碎片化,确保长期记忆库的高质量
"""
import time
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class FusionResult:
"""融合结果"""
original_count: int
fused_count: int
removed_duplicates: int
merged_memories: list[MemoryChunk]
fusion_time: float
details: list[str]
@dataclass
class DuplicateGroup:
"""重复记忆组"""
group_id: str
memories: list[MemoryChunk]
similarity_matrix: list[list[float]]
representative_memory: MemoryChunk | None = None
class MemoryFusionEngine:
"""记忆融合引擎"""
def __init__(self, similarity_threshold: float = 0.85):
self.similarity_threshold = similarity_threshold
self.fusion_stats = {
"total_fusions": 0,
"memories_fused": 0,
"duplicates_removed": 0,
"average_similarity": 0.0,
}
# 融合策略配置
self.fusion_strategies = {
"semantic_similarity": True, # 语义相似性融合
"temporal_proximity": True, # 时间接近性融合
"logical_consistency": True, # 逻辑一致性融合
"confidence_boosting": True, # 置信度提升
"importance_preservation": True, # 重要性保持
}
async def fuse_memories(
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None
) -> list[MemoryChunk]:
"""融合记忆列表"""
start_time = time.time()
try:
if not new_memories:
return []
logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}")
# 1. 检测重复记忆组
duplicate_groups = await self._detect_duplicate_groups(new_memories, existing_memories or [])
if not duplicate_groups:
fusion_time = time.time() - start_time
self._update_fusion_stats(len(new_memories), 0, fusion_time)
logger.info("✅ 记忆融合完成: %d 条记忆,移除 0 条重复", len(new_memories))
return new_memories
# 2. 对每个重复组进行融合
fused_memories = []
removed_count = 0
for group in duplicate_groups:
if len(group.memories) == 1:
# 单个记忆,直接添加
fused_memories.append(group.memories[0])
else:
# 多个记忆,进行融合
fused_memory = await self._fuse_memory_group(group)
if fused_memory:
fused_memories.append(fused_memory)
removed_count += len(group.memories) - 1
# 3. 更新统计
fusion_time = time.time() - start_time
self._update_fusion_stats(len(new_memories), removed_count, fusion_time)
logger.info(f"✅ 记忆融合完成: {len(fused_memories)} 条记忆,移除 {removed_count} 条重复")
return fused_memories
except Exception as e:
logger.error(f"❌ 记忆融合失败: {e}", exc_info=True)
return new_memories # 失败时返回原始记忆
async def _detect_duplicate_groups(
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk]
) -> list[DuplicateGroup]:
"""检测重复记忆组"""
all_memories = new_memories + existing_memories
new_memory_ids = {memory.memory_id for memory in new_memories}
groups = []
processed_ids = set()
for i, memory1 in enumerate(all_memories):
if memory1.memory_id in processed_ids:
continue
# 创建新的重复组
group = DuplicateGroup(group_id=f"group_{len(groups)}", memories=[memory1], similarity_matrix=[[1.0]])
processed_ids.add(memory1.memory_id)
# 寻找相似记忆
for j, memory2 in enumerate(all_memories[i + 1 :], i + 1):
if memory2.memory_id in processed_ids:
continue
similarity = self._calculate_comprehensive_similarity(memory1, memory2)
if similarity >= self.similarity_threshold:
group.memories.append(memory2)
processed_ids.add(memory2.memory_id)
# 更新相似度矩阵
self._update_similarity_matrix(group, memory2, similarity)
if len(group.memories) > 1:
# 选择代表性记忆
group.representative_memory = self._select_representative_memory(group)
groups.append(group)
else:
# 仅包含单条记忆,只有当其来自新记忆列表时保留
if memory1.memory_id in new_memory_ids:
groups.append(group)
logger.debug(f"检测到 {len(groups)} 个重复记忆组")
return groups
def _calculate_comprehensive_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float:
"""计算综合相似度"""
similarity_scores = []
# 1. 语义向量相似度
if self.fusion_strategies["semantic_similarity"]:
semantic_sim = mem1.calculate_similarity(mem2)
similarity_scores.append(("semantic", semantic_sim))
# 2. 文本相似度
text_sim = self._calculate_text_similarity(mem1.text_content, mem2.text_content)
similarity_scores.append(("text", text_sim))
# 3. 关键词重叠度
keyword_sim = self._calculate_keyword_similarity(mem1.keywords, mem2.keywords)
similarity_scores.append(("keyword", keyword_sim))
# 4. 类型一致性
type_consistency = 1.0 if mem1.memory_type == mem2.memory_type else 0.0
similarity_scores.append(("type", type_consistency))
# 5. 时间接近性
if self.fusion_strategies["temporal_proximity"]:
temporal_sim = self._calculate_temporal_similarity(mem1.metadata.created_at, mem2.metadata.created_at)
similarity_scores.append(("temporal", temporal_sim))
# 6. 逻辑一致性
if self.fusion_strategies["logical_consistency"]:
logical_sim = self._calculate_logical_similarity(mem1, mem2)
similarity_scores.append(("logical", logical_sim))
# 计算加权平均相似度
weights = {"semantic": 0.35, "text": 0.25, "keyword": 0.15, "type": 0.10, "temporal": 0.10, "logical": 0.05}
weighted_sum = 0.0
total_weight = 0.0
for score_type, score in similarity_scores:
weight = weights.get(score_type, 0.1)
weighted_sum += weight * score
total_weight += weight
final_similarity = weighted_sum / total_weight if total_weight > 0 else 0.0
logger.debug(f"综合相似度计算: {final_similarity:.3f} - {[(t, f'{s:.3f}') for t, s in similarity_scores]}")
return final_similarity
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度"""
# 简单的词汇重叠度计算
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1 & words2
union = words1 | words2
jaccard_similarity = len(intersection) / len(union)
return jaccard_similarity
def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float:
"""计算关键词相似度"""
if not keywords1 or not keywords2:
return 0.0
set1 = set(k.lower() for k in keywords1) # noqa: C401
set2 = set(k.lower() for k in keywords2) # noqa: C401
intersection = set1 & set2
union = set1 | set2
return len(intersection) / len(union) if union else 0.0
def _calculate_temporal_similarity(self, time1: float, time2: float) -> float:
"""计算时间相似度"""
time_diff = abs(time1 - time2)
hours_diff = time_diff / 3600
# 24小时内相似度较高
if hours_diff <= 24:
return 1.0 - (hours_diff / 24)
elif hours_diff <= 168: # 一周内
return 0.7 - ((hours_diff - 24) / 168) * 0.5
else:
return 0.2
def _calculate_logical_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float:
"""计算逻辑一致性"""
# 检查主谓宾结构的逻辑一致性
consistency_score = 0.0
# 主语一致性
subjects1 = set(mem1.subjects)
subjects2 = set(mem2.subjects)
if subjects1 or subjects2:
overlap = len(subjects1 & subjects2)
union_count = max(len(subjects1 | subjects2), 1)
consistency_score += (overlap / union_count) * 0.4
# 谓语相似性
predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate)
consistency_score += predicate_sim * 0.3
# 宾语相似性
if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str):
object_sim = self._calculate_text_similarity(str(mem1.content.object), str(mem2.content.object))
consistency_score += object_sim * 0.3
return consistency_score
def _update_similarity_matrix(self, group: DuplicateGroup, new_memory: MemoryChunk, similarity: float):
"""更新组的相似度矩阵"""
# 为新记忆添加行和列
for i in range(len(group.similarity_matrix)):
group.similarity_matrix[i].append(similarity)
# 添加新行
new_row = [similarity] + [1.0] * len(group.similarity_matrix)
group.similarity_matrix.append(new_row)
def _select_representative_memory(self, group: DuplicateGroup) -> MemoryChunk:
"""选择代表性记忆"""
if not group.memories:
return None
# 评分标准
best_memory = None
best_score = -1.0
for memory in group.memories:
score = 0.0
# 置信度权重
score += memory.metadata.confidence.value * 0.3
# 重要性权重
score += memory.metadata.importance.value * 0.3
# 访问次数权重
score += min(memory.metadata.access_count * 0.1, 0.2)
# 相关度权重
score += memory.metadata.relevance_score * 0.2
if score > best_score:
best_score = score
best_memory = memory
return best_memory
async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None:
"""融合记忆组"""
if not group.memories:
return None
if len(group.memories) == 1:
return group.memories[0]
try:
# 选择基础记忆(通常是代表性记忆)
base_memory = group.representative_memory or group.memories[0]
# 融合其他记忆的属性
fused_memory = await self._merge_memory_attributes(base_memory, group.memories)
# 更新元数据
self._update_fused_metadata(fused_memory, group)
logger.debug(f"成功融合记忆组,包含 {len(group.memories)} 条原始记忆")
return fused_memory
except Exception as e:
logger.error(f"融合记忆组失败: {e}")
# 返回置信度最高的记忆
return max(group.memories, key=lambda m: m.metadata.confidence.value)
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk:
"""合并记忆属性"""
# 创建基础记忆的深拷贝
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
# 合并关键词
all_keywords = set()
for memory in memories:
all_keywords.update(memory.keywords)
fused_memory.keywords = sorted(all_keywords)
# 合并标签
all_tags = set()
for memory in memories:
all_tags.update(memory.tags)
fused_memory.tags = sorted(all_tags)
# 合并分类
all_categories = set()
for memory in memories:
all_categories.update(memory.categories)
fused_memory.categories = sorted(all_categories)
# 合并关联记忆
all_related = set()
for memory in memories:
all_related.update(memory.related_memories)
# 移除对自身和组内记忆的引用
all_related = {rid for rid in all_related if rid not in [m.memory_id for m in memories]}
fused_memory.related_memories = sorted(all_related)
# 合并时间上下文
if self.fusion_strategies["temporal_proximity"]:
fused_memory.temporal_context = self._merge_temporal_context(memories)
return fused_memory
def _update_fused_metadata(self, fused_memory: MemoryChunk, group: DuplicateGroup):
"""更新融合记忆的元数据"""
# 更新修改时间
fused_memory.metadata.last_modified = time.time()
# 计算平均访问次数
total_access = sum(m.metadata.access_count for m in group.memories)
fused_memory.metadata.access_count = total_access
# 提升置信度(如果有多个来源支持)
if self.fusion_strategies["confidence_boosting"] and len(group.memories) > 1:
max_confidence = max(m.metadata.confidence.value for m in group.memories)
if max_confidence < ConfidenceLevel.VERIFIED.value:
fused_memory.metadata.confidence = ConfidenceLevel(
min(max_confidence + 1, ConfidenceLevel.VERIFIED.value)
)
# 保持最高重要性
if self.fusion_strategies["importance_preservation"]:
max_importance = max(m.metadata.importance.value for m in group.memories)
fused_memory.metadata.importance = ImportanceLevel(max_importance)
# 计算平均相关度
avg_relevance = sum(m.metadata.relevance_score for m in group.memories) / len(group.memories)
fused_memory.metadata.relevance_score = min(avg_relevance * 1.1, 1.0) # 稍微提升相关度
# 设置来源信息
source_ids = [m.memory_id[:8] for m in group.memories]
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]:
"""合并时间上下文"""
contexts = [m.temporal_context for m in memories if m.temporal_context]
if not contexts:
return {}
# 计算时间范围
timestamps = [m.metadata.created_at for m in memories]
earliest_time = min(timestamps)
latest_time = max(timestamps)
merged_context = {
"earliest_timestamp": earliest_time,
"latest_timestamp": latest_time,
"time_span_hours": (latest_time - earliest_time) / 3600,
"source_memories": len(memories),
}
# 合并其他上下文信息
for context in contexts:
for key, value in context.items():
if key not in ["timestamp", "earliest_timestamp", "latest_timestamp"]:
if key not in merged_context:
merged_context[key] = value
elif merged_context[key] != value:
merged_context[key] = f"multiple: {value}"
return merged_context
async def incremental_fusion(
self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk]
) -> tuple[MemoryChunk, list[MemoryChunk]]:
"""增量融合(单个新记忆与现有记忆融合)"""
# 寻找相似记忆
similar_memories = []
for existing in existing_memories:
similarity = self._calculate_comprehensive_similarity(new_memory, existing)
if similarity >= self.similarity_threshold:
similar_memories.append((existing, similarity))
if not similar_memories:
# 没有相似记忆,直接返回
return new_memory, existing_memories
# 按相似度排序
similar_memories.sort(key=lambda x: x[1], reverse=True)
# 与最相似的记忆融合
best_match, similarity = similar_memories[0]
# 创建融合组
group = DuplicateGroup(
group_id=f"incremental_{int(time.time())}",
memories=[new_memory, best_match],
similarity_matrix=[[1.0, similarity], [similarity, 1.0]],
)
# 执行融合
fused_memory = await self._fuse_memory_group(group)
# 从现有记忆中移除被融合的记忆
updated_existing = [m for m in existing_memories if m.memory_id != best_match.memory_id]
updated_existing.append(fused_memory)
logger.debug(f"增量融合完成,相似度: {similarity:.3f}")
return fused_memory, updated_existing
def _update_fusion_stats(self, original_count: int, removed_count: int, fusion_time: float):
"""更新融合统计"""
self.fusion_stats["total_fusions"] += 1
self.fusion_stats["memories_fused"] += original_count
self.fusion_stats["duplicates_removed"] += removed_count
# 更新平均相似度(估算)
if removed_count > 0:
avg_similarity = 0.9 # 假设平均相似度较高
total_similarity = self.fusion_stats["average_similarity"] * (self.fusion_stats["total_fusions"] - 1)
total_similarity += avg_similarity
self.fusion_stats["average_similarity"] = total_similarity / self.fusion_stats["total_fusions"]
async def maintenance(self):
"""维护操作"""
try:
logger.info("开始记忆融合引擎维护...")
# 可以在这里添加定期维护任务,如:
# - 重新评估低置信度记忆
# - 清理孤立记忆引用
# - 优化融合策略参数
logger.info("✅ 记忆融合引擎维护完成")
except Exception as e:
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
def get_fusion_stats(self) -> dict[str, Any]:
"""获取融合统计信息"""
return self.fusion_stats.copy()
def reset_stats(self):
"""重置统计信息"""
self.fusion_stats = {
"total_fusions": 0,
"memories_fused": 0,
"duplicates_removed": 0,
"average_similarity": 0.0,
}

View File

@@ -1,512 +0,0 @@
"""
记忆系统管理器
替代原有的 Hippocampus 和 instant_memory 系统
"""
import re
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.message_collection_processor import MessageCollectionProcessor
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
from src.common.logger import get_logger
logger = get_logger(__name__)
@dataclass
class MemoryResult:
"""记忆查询结果"""
content: str
memory_type: str
confidence: float
importance: float
timestamp: float
source: str = "memory"
relevance_score: float = 0.0
structure: dict[str, Any] | None = None
class MemoryManager:
"""记忆系统管理器 - 替代原有的 HippocampusManager"""
def __init__(self):
self.memory_system: MemorySystem | None = None
self.message_collection_storage: MessageCollectionStorage | None = None
self.message_collection_processor: MessageCollectionProcessor | None = None
self.is_initialized = False
self.user_cache = {} # 用户记忆缓存
def _clean_text(self, text: Any) -> str:
if text is None:
return ""
cleaned = re.sub(r"[\s\u3000]+", " ", str(text)).strip()
cleaned = re.sub(r"[、,,;]+$", "", cleaned)
return cleaned
async def initialize(self):
"""初始化记忆系统"""
if self.is_initialized:
return
try:
from src.config.config import global_config
# 检查是否启用记忆系统
if not global_config.memory.enable_memory:
logger.info("记忆系统已禁用,跳过初始化")
self.is_initialized = True
return
logger.info("正在初始化记忆系统...")
# 初始化记忆系统
from src.chat.memory_system.memory_system import get_memory_system
self.memory_system = get_memory_system()
# 初始化消息集合系统
self.message_collection_storage = MessageCollectionStorage()
self.message_collection_processor = MessageCollectionProcessor(self.message_collection_storage)
self.is_initialized = True
logger.info(" 记忆系统初始化完成")
except Exception as e:
logger.error(f"记忆系统初始化失败: {e}")
# 如果系统初始化失败,创建一个空的管理器避免系统崩溃
self.memory_system = None
self.message_collection_storage = None
self.message_collection_processor = None
self.is_initialized = True # 标记为已初始化但系统不可用
def get_hippocampus(self):
"""兼容原有接口 - 返回空"""
logger.debug("get_hippocampus 调用 - 记忆系统不使用此方法")
return {}
async def build_memory(self):
"""兼容原有接口 - 构建记忆"""
if not self.is_initialized or not self.memory_system:
return
try:
# 记忆系统使用实时构建,不需要定时构建
logger.debug("build_memory 调用 - 记忆系统使用实时构建")
except Exception as e:
logger.error(f"build_memory 失败: {e}")
async def forget_memory(self, percentage: float = 0.005):
"""兼容原有接口 - 遗忘机制"""
if not self.is_initialized or not self.memory_system:
return
try:
# 增强记忆系统有内置的遗忘机制
logger.debug(f"forget_memory 调用 - 参数: {percentage}")
# 可以在这里调用增强系统的维护功能
await self.memory_system.maintenance()
except Exception as e:
logger.error(f"forget_memory 失败: {e}")
async def get_memory_from_text(
self,
text: str,
chat_id: str,
user_id: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
time_weight: float = 1.0,
keyword_weight: float = 1.0,
) -> list[tuple[str, str]]:
"""从文本获取相关记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
try:
# 使用增强记忆系统检索
context = {
"chat_id": chat_id,
"expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE],
}
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query=text, user_id=user_id, context=context, limit=max_memory_num
)
# 转换为原有格式 (topic, content)
results = []
for memory in relevant_memories:
topic = memory.memory_type.value
content = memory.text_content
results.append((topic, content))
logger.debug(f"从文本检索到 {len(results)} 条相关记忆")
# 如果检索到有效记忆,打印详细信息
if results:
logger.info(f"📚 从文本 '{text[:50]}...' 检索到 {len(results)} 条有效记忆:")
for i, (topic, content) in enumerate(results, 1):
# 处理长内容如果超过150字符则截断
display_content = content
if len(content) > 150:
display_content = content[:150] + "..."
logger.info(f" 记忆#{i} [{topic}]: {display_content}")
return results
except Exception as e:
logger.error(f"get_memory_from_text 失败: {e}")
return []
async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> list[tuple[str, str]]:
"""从关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
try:
# 将关键词转换为查询文本
query_text = " ".join(valid_keywords)
# 使用增强记忆系统检索
context = {
"keywords": valid_keywords,
"expected_memory_types": [
MemoryType.PERSONAL_FACT,
MemoryType.EVENT,
MemoryType.PREFERENCE,
MemoryType.OPINION,
],
}
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query_text=query_text,
user_id="default_user", # 可以根据实际需要传递
context=context,
limit=max_memory_num,
)
# 转换为原有格式 (topic, content)
results = []
for memory in relevant_memories:
topic = memory.memory_type.value
content = memory.text_content
results.append((topic, content))
logger.debug(f"从关键词 {valid_keywords} 检索到 {len(results)} 条相关记忆")
# 如果检索到有效记忆,打印详细信息
if results:
keywords_str = ", ".join(valid_keywords[:5]) # 最多显示5个关键词
if len(valid_keywords) > 5:
keywords_str += f" ... (共{len(valid_keywords)}个关键词)"
logger.info(f"🔍 从关键词 [{keywords_str}] 检索到 {len(results)} 条有效记忆:")
for i, (topic, content) in enumerate(results, 1):
# 处理长内容如果超过150字符则截断
display_content = content
if len(content) > 150:
display_content = content[:150] + "..."
logger.info(f" 记忆#{i} [{topic}]: {display_content}")
return results
except Exception as e:
logger.error(f"get_memory_from_topic 失败: {e}")
return []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从单个关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
try:
# 同步方法,返回空列表
logger.debug(f"get_memory_from_keyword 调用 - 关键词: {keyword}")
return []
except Exception as e:
logger.error(f"get_memory_from_keyword 失败: {e}")
return []
async def process_conversation(
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None
) -> list[MemoryChunk]:
"""处理对话并构建记忆 - 新增功能"""
if not self.is_initialized or not self.memory_system:
return []
try:
# 将消息添加到消息集合处理器
chat_id = context.get("chat_id")
if self.message_collection_processor and chat_id:
await self.message_collection_processor.add_message(conversation_text, chat_id)
payload_context = dict(context or {})
payload_context.setdefault("conversation_text", conversation_text)
if timestamp is not None:
payload_context.setdefault("timestamp", timestamp)
result = await self.memory_system.process_conversation_memory(payload_context)
# 从结果中提取记忆块
memory_chunks = []
if result.get("success"):
memory_chunks = result.get("created_memories", [])
logger.info(f"从对话构建了 {len(memory_chunks)} 条记忆")
return memory_chunks
except Exception as e:
logger.error(f"process_conversation 失败: {e}")
return []
async def get_enhanced_memory_context(
self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5
) -> list[MemoryResult]:
"""获取增强记忆上下文 - 新增功能"""
if not self.is_initialized or not self.memory_system:
return []
try:
relevant_memories = await self.memory_system.retrieve_relevant_memories(
query=query_text, user_id=None, context=context or {}, limit=limit
)
results = []
for memory in relevant_memories:
formatted_content, structure = self._format_memory_chunk(memory)
result = MemoryResult(
content=formatted_content,
memory_type=memory.memory_type.value,
confidence=memory.metadata.confidence.value,
importance=memory.metadata.importance.value,
timestamp=memory.metadata.created_at,
source="enhanced_memory",
relevance_score=memory.metadata.relevance_score,
structure=structure,
)
results.append(result)
return results
except Exception as e:
logger.error(f"get_enhanced_memory_context 失败: {e}")
return []
def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]:
"""将记忆块转换为更易读的文本描述"""
structure = memory.content.to_dict()
if memory.display:
return self._clean_text(memory.display), structure
subject = structure.get("subject")
predicate = structure.get("predicate") or ""
obj = structure.get("object")
subject_display = self._format_subject(subject, memory)
formatted = self._apply_predicate_format(subject_display, predicate, obj)
if not formatted:
predicate_display = self._format_predicate(predicate)
object_display = self._format_object(obj)
formatted = f"{subject_display}{predicate_display}{object_display}".strip()
formatted = self._clean_text(formatted)
return formatted, structure
def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str:
if not subject:
return "该用户"
if subject == memory.metadata.user_id:
return "该用户"
if memory.metadata.chat_id and subject == memory.metadata.chat_id:
return "该聊天"
return self._clean_text(subject)
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None:
predicate = (predicate or "").strip()
obj_value = obj
if predicate == "is_named":
name = self._extract_from_object(obj_value, ["name", "nickname"]) or self._format_object(obj_value)
name = self._clean_text(name)
if not name:
return None
name_display = name if (name.startswith("") and name.endswith("")) else f"{name}"
return f"{subject}的昵称是{name_display}"
if predicate == "is_age":
age = self._extract_from_object(obj_value, ["age"]) or self._format_object(obj_value)
age = self._clean_text(age)
if not age:
return None
return f"{subject}今年{age}"
if predicate == "is_profession":
profession = self._extract_from_object(obj_value, ["profession", "job"]) or self._format_object(obj_value)
profession = self._clean_text(profession)
if not profession:
return None
return f"{subject}的职业是{profession}"
if predicate == "lives_in":
location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(
obj_value
)
location = self._clean_text(location)
if not location:
return None
return f"{subject}居住在{location}"
if predicate == "has_phone":
phone = self._extract_from_object(obj_value, ["phone", "number"]) or self._format_object(obj_value)
phone = self._clean_text(phone)
if not phone:
return None
return f"{subject}的电话号码是{phone}"
if predicate == "has_email":
email = self._extract_from_object(obj_value, ["email"]) or self._format_object(obj_value)
email = self._clean_text(email)
if not email:
return None
return f"{subject}的邮箱是{email}"
if predicate == "likes":
liked = self._format_object(obj_value)
if not liked:
return None
return f"{subject}喜欢{liked}"
if predicate == "likes_food":
food = self._format_object(obj_value)
if not food:
return None
return f"{subject}爱吃{food}"
if predicate == "dislikes":
disliked = self._format_object(obj_value)
if not disliked:
return None
return f"{subject}不喜欢{disliked}"
if predicate == "hates":
hated = self._format_object(obj_value)
if not hated:
return None
return f"{subject}讨厌{hated}"
if predicate == "favorite_is":
favorite = self._format_object(obj_value)
if not favorite:
return None
return f"{subject}最喜欢{favorite}"
if predicate == "mentioned_event":
event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(
obj_value
)
event_text = self._clean_text(self._truncate(event_text))
if not event_text:
return None
return f"{subject}提到了计划或事件:{event_text}"
if predicate in {"正在", "", "正在进行"}:
action = self._format_object(obj_value)
if not action:
return None
return f"{subject}{predicate}{action}"
if predicate in {"感到", "觉得", "表示", "提到", "说道", ""}:
feeling = self._format_object(obj_value)
if not feeling:
return None
return f"{subject}{predicate}{feeling}"
if predicate in {"", "", ""}:
counterpart = self._format_object(obj_value)
if counterpart:
return f"{subject}{predicate}{counterpart}"
return f"{subject}{predicate}"
return None
def _format_predicate(self, predicate: str) -> str:
if not predicate:
return ""
predicate_map = {
"is_named": "的昵称是",
"is_profession": "的职业是",
"lives_in": "居住在",
"has_phone": "的电话是",
"has_email": "的邮箱是",
"likes": "喜欢",
"dislikes": "不喜欢",
"likes_food": "爱吃",
"hates": "讨厌",
"favorite_is": "最喜欢",
"mentioned_event": "提到的事件",
}
if predicate in predicate_map:
connector = predicate_map[predicate]
if connector.startswith(""):
return connector
return f" {connector} "
cleaned = predicate.replace("_", " ").strip()
if re.search(r"[\u4e00-\u9fff]", cleaned):
return cleaned
return f" {cleaned} "
def _format_object(self, obj: Any) -> str:
if obj is None:
return ""
if isinstance(obj, dict):
parts = []
for key, value in obj.items():
formatted_value = self._format_object(value)
if not formatted_value:
continue
pretty_key = {
"name": "名字",
"profession": "职业",
"location": "位置",
"event_text": "内容",
"timestamp": "时间",
}.get(key, key)
parts.append(f"{pretty_key}: {formatted_value}")
return self._clean_text("".join(parts))
if isinstance(obj, list):
formatted_items = [self._format_object(item) for item in obj]
filtered = [item for item in formatted_items if item]
return self._clean_text("".join(filtered)) if filtered else ""
if isinstance(obj, int | float):
return str(obj)
text = self._truncate(str(obj).strip())
return self._clean_text(text)
def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None:
if isinstance(obj, dict):
for key in keys:
if obj.get(key):
value = obj[key]
if isinstance(value, dict | list):
return self._clean_text(self._format_object(value))
return self._clean_text(value)
if isinstance(obj, list) and obj:
return self._clean_text(self._format_object(obj[0]))
if isinstance(obj, str | int | float):
return self._clean_text(obj)
return None
def _truncate(self, text: str, max_length: int = 80) -> str:
if len(text) <= max_length:
return text
return text[: max_length - 1] + ""
async def shutdown(self):
"""关闭增强记忆系统"""
if not self.is_initialized:
return
try:
if self.memory_system:
await self.memory_system.shutdown()
logger.info(" 记忆系统已关闭")
except Exception as e:
logger.error(f"关闭记忆系统失败: {e}")
# 全局记忆管理器实例
memory_manager = MemoryManager()

View File

@@ -1,122 +0,0 @@
"""
记忆元数据索引。
"""
from dataclasses import asdict, dataclass
from typing import Any
from src.common.logger import get_logger
logger = get_logger(__name__)
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
@dataclass
class MemoryMetadataIndexEntry:
memory_id: str
user_id: str
memory_type: str
subjects: list[str]
objects: list[str]
keywords: list[str]
tags: list[str]
importance: int
confidence: int
created_at: float
access_count: int
chat_id: str | None = None
content_preview: str | None = None
class MemoryMetadataIndex:
"""Rust 加速版本唯一实现。"""
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
self._rust = _RustIndex(index_file)
# 仅为向量层和调试提供最小缓存长度判断、get_entry 返回)
self.index: dict[str, MemoryMetadataIndexEntry] = {}
logger.info("✅ MemoryMetadataIndex (Rust) 初始化完成,仅支持加速实现")
# 向后代码仍调用的接口batch_add_or_update / add_or_update
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
if not entries:
return
payload = []
for e in entries:
if not e.memory_id:
continue
self.index[e.memory_id] = e
payload.append(asdict(e))
if payload:
try:
self._rust.batch_add(payload)
except Exception as ex:
logger.error(f"Rust 元数据批量添加失败: {ex}")
def add_or_update(self, entry: MemoryMetadataIndexEntry):
self.batch_add_or_update([entry])
def search(
self,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
keywords: list[str] | None = None,
tags: list[str] | None = None,
importance_min: int | None = None,
importance_max: int | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
flexible_mode: bool = True,
) -> list[str]:
params: dict[str, Any] = {
"user_id": user_id,
"memory_types": memory_types,
"subjects": subjects,
"keywords": keywords,
"tags": tags,
"importance_min": importance_min,
"importance_max": importance_max,
"created_after": created_after,
"created_before": created_before,
"limit": limit,
}
params = {k: v for k, v in params.items() if v is not None}
try:
if flexible_mode:
return list(self._rust.search_flexible(params))
return list(self._rust.search_strict(params))
except Exception as ex:
logger.error(f"Rust 搜索失败返回空: {ex}")
return []
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
return self.index.get(memory_id)
def get_stats(self) -> dict[str, Any]:
try:
raw = self._rust.stats()
return {
"total_memories": raw.get("total", 0),
"types": raw.get("types_dist", {}),
"subjects_count": raw.get("subjects_indexed", 0),
"keywords_count": raw.get("keywords_indexed", 0),
"tags_count": raw.get("tags_indexed", 0),
}
except Exception as ex:
logger.warning(f"读取 Rust stats 失败: {ex}")
return {"total_memories": 0}
def save(self): # 仅调用 rust save
try:
self._rust.save()
except Exception as ex:
logger.warning(f"Rust save 失败: {ex}")
__all__ = [
"MemoryMetadataIndex",
"MemoryMetadataIndexEntry",
]

View File

@@ -1,219 +0,0 @@
"""记忆检索查询规划器"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any
import orjson
from src.chat.memory_system.memory_chunk import MemoryType
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.utils.json_parser import extract_and_parse_json
logger = get_logger(__name__)
@dataclass
class MemoryQueryPlan:
"""查询规划结果"""
semantic_query: str
memory_types: list[MemoryType] = field(default_factory=list)
subject_includes: list[str] = field(default_factory=list)
object_includes: list[str] = field(default_factory=list)
required_keywords: list[str] = field(default_factory=list)
optional_keywords: list[str] = field(default_factory=list)
owner_filters: list[str] = field(default_factory=list)
recency_preference: str = "any"
limit: int = 10
emphasis: str | None = None
raw_plan: dict[str, Any] = field(default_factory=dict)
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
if not self.semantic_query:
self.semantic_query = fallback_query
if self.limit <= 0:
self.limit = default_limit
self.recency_preference = (self.recency_preference or "any").lower()
if self.recency_preference not in {"any", "recent", "historical"}:
self.recency_preference = "any"
self.emphasis = (self.emphasis or "balanced").lower()
class MemoryQueryPlanner:
"""基于小模型的记忆检索查询规划器"""
def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10):
self.model = planner_model
self.default_limit = default_limit
async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan:
if not self.model:
logger.debug("未提供查询规划模型,使用默认规划")
return self._default_plan(query_text)
prompt = self._build_prompt(query_text, context)
try:
response, _ = await self.model.generate_response_async(prompt, temperature=0.2)
# 使用统一的 JSON 解析工具
data = extract_and_parse_json(response, strict=False)
if not data or not isinstance(data, dict):
logger.debug("查询规划模型未返回有效的结构化结果,使用默认规划")
return self._default_plan(query_text)
plan = self._parse_plan_dict(data, query_text)
plan.ensure_defaults(query_text, self.default_limit)
return plan
except Exception as exc:
logger.error("查询规划模型调用失败: %s", exc, exc_info=True)
return self._default_plan(query_text)
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
def _collect_list(key: str) -> list[str]:
value = data.get(key)
if isinstance(value, str):
return [value]
if isinstance(value, list):
return [self._safe_str(item) for item in value if self._safe_str(item)]
return []
memory_type_values = _collect_list("memory_types")
memory_types: list[MemoryType] = []
for item in memory_type_values:
if not item:
continue
try:
memory_types.append(MemoryType(item))
except ValueError:
# 尝试匹配value值
normalized = item.lower()
for mt in MemoryType:
if mt.value == normalized:
memory_types.append(mt)
break
plan = MemoryQueryPlan(
semantic_query=semantic_query,
memory_types=memory_types,
subject_includes=_collect_list("subject_includes"),
object_includes=_collect_list("object_includes"),
required_keywords=_collect_list("required_keywords"),
optional_keywords=_collect_list("optional_keywords"),
owner_filters=_collect_list("owner_filters"),
recency_preference=self._safe_str(data.get("recency")) or "any",
limit=self._safe_int(data.get("limit"), self.default_limit),
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
raw_plan=data,
)
return plan
def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str:
participants = context.get("participants") or context.get("speaker_names") or []
if isinstance(participants, str):
participants = [participants]
participants = [p for p in participants if isinstance(p, str) and p.strip()]
participant_preview = "".join(participants[:5]) or "未知"
persona = context.get("bot_personality") or context.get("bot_identity") or "未知"
# 构建未读消息上下文信息
context_section = ""
if context.get("has_unread_context") and context.get("unread_messages_context"):
unread_context = context["unread_messages_context"]
unread_messages = unread_context.get("messages", [])
unread_keywords = unread_context.get("keywords", [])
unread_participants = unread_context.get("participants", [])
context_summary = unread_context.get("context_summary", "")
if unread_messages:
# 构建未读消息摘要
message_previews = []
for msg in unread_messages[:5]: # 最多显示5条
sender = msg.get("sender", "未知")
content = msg.get("content", "")[:100] # 限制每条消息长度
message_previews.append(f"{sender}: {content}")
context_section = f"""
## 📋 未读消息上下文 (共{unread_context.get("total_count", 0)}条未读消息)
### 最近消息预览:
{chr(10).join(message_previews)}
### 上下文关键词:
{", ".join(unread_keywords[:15]) if unread_keywords else ""}
### 对话参与者:
{", ".join(unread_participants) if unread_participants else ""}
### 上下文摘要:
{context_summary[:300] if context_summary else ""}
"""
else:
context_section = """
## 📋 未读消息上下文:
无未读消息或上下文信息不可用
"""
return f"""
你是一名记忆检索规划助手,请基于输入生成一个简洁的 JSON 检索计划。
你的任务是分析当前查询并结合未读消息的上下文,生成更精准的记忆检索策略。
仅需提供以下字段:
- semantic_query: 用于向量召回的自然语言描述,要求具体且贴合当前查询和上下文;
- memory_types: 建议检索的记忆类型列表,取值范围来自 MemoryType 枚举 (personal_fact,event,preference,opinion,relationship,emotion,knowledge,skill,goal,experience,contextual)
- subject_includes: 建议出现在记忆主语中的人物或角色;
- object_includes: 建议关注的对象、主题或关键信息;
- required_keywords: 建议必须包含的关键词(从上下文中提取);
- recency: 推荐的时间偏好,可选 recent/any/historical
- limit: 推荐的最大返回数量 (1-15)
- emphasis: 检索重点,可选 balanced/contextual/recent/comprehensive。
请不要生成谓语字段,也不要额外补充其它参数。
## 当前查询:
"{query_text}"
## 已知对话参与者:
{participant_preview}
## 机器人设定:
{persona}{context_section}
## 🎯 指导原则:
1. **上下文关联**: 优先分析与当前查询相关的未读消息内容和关键词
2. **语义理解**: 结合上下文理解查询的真实意图,而非字面意思
3. **参与者感知**: 考虑未读消息中的参与者,检索与他们相关的记忆
4. **主题延续**: 关注未读消息中讨论的主题,检索相关的历史记忆
5. **时间相关性**: 如果未读消息讨论最近的事件,偏向检索相关时期的记忆
请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。
"""
@staticmethod
def _safe_str(value: Any) -> str:
if isinstance(value, str):
return value.strip()
if value is None:
return ""
return str(value).strip()
@staticmethod
def _safe_int(value: Any, default: int) -> int:
try:
number = int(value)
if number <= 0:
return default
return number
except (TypeError, ValueError):
return default

File diff suppressed because it is too large Load Diff

View File

@@ -1,75 +0,0 @@
"""
消息集合处理器
负责收集消息、创建集合并将其存入向量存储。
"""
import asyncio
from collections import deque
from typing import Any
from src.chat.memory_system.memory_chunk import MessageCollection
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
from src.common.logger import get_logger
logger = get_logger(__name__)
class MessageCollectionProcessor:
"""处理消息集合的创建和存储"""
def __init__(self, storage: MessageCollectionStorage, buffer_size: int = 5):
self.storage = storage
self.buffer_size = buffer_size
self.message_buffers: dict[str, deque[str]] = {}
self._lock = asyncio.Lock()
async def add_message(self, message_text: str, chat_id: str):
"""添加一条新消息到指定聊天的缓冲区,并在满时触发处理"""
async with self._lock:
if not isinstance(message_text, str) or not message_text.strip():
return
if chat_id not in self.message_buffers:
self.message_buffers[chat_id] = deque(maxlen=self.buffer_size)
buffer = self.message_buffers[chat_id]
buffer.append(message_text)
logger.debug(f"消息已添加到聊天 '{chat_id}' 的缓冲区,当前数量: {len(buffer)}/{self.buffer_size}")
if len(buffer) == self.buffer_size:
await self._process_buffer(chat_id)
async def _process_buffer(self, chat_id: str):
"""处理指定聊天缓冲区中的消息,创建并存储一个集合"""
buffer = self.message_buffers.get(chat_id)
if not buffer or len(buffer) < self.buffer_size:
return
messages_to_process = list(buffer)
buffer.clear()
logger.info(f"聊天 '{chat_id}' 的消息缓冲区已满,开始创建消息集合...")
try:
combined_text = "\n".join(messages_to_process)
collection = MessageCollection(
chat_id=chat_id,
messages=messages_to_process,
combined_text=combined_text,
)
await self.storage.add_collection(collection)
logger.info(f"成功为聊天 '{chat_id}' 创建并存储了新的消息集合: {collection.collection_id}")
except Exception as e:
logger.error(f"处理聊天 '{chat_id}' 的消息缓冲区失败: {e}", exc_info=True)
def get_stats(self) -> dict[str, Any]:
"""获取处理器统计信息"""
total_buffered_messages = sum(len(buf) for buf in self.message_buffers.values())
return {
"active_buffers": len(self.message_buffers),
"total_buffered_messages": total_buffered_messages,
"buffer_capacity_per_chat": self.buffer_size,
}

View File

@@ -1,193 +0,0 @@
"""
消息集合向量存储系统
专用于存储和检索消息集合,以提供即时上下文。
"""
import time
from typing import Any
from src.chat.memory_system.memory_chunk import MessageCollection
from src.chat.utils.utils import get_embedding
from src.common.logger import get_logger
from src.common.vector_db import vector_db_service
from src.config.config import global_config
logger = get_logger(__name__)
class MessageCollectionStorage:
"""消息集合向量存储"""
def __init__(self):
self.config = global_config.memory
self.vector_db_service = vector_db_service
self.collection_name = "message_collections"
self._initialize_storage()
def _initialize_storage(self):
"""初始化存储"""
try:
self.vector_db_service.get_or_create_collection(
name=self.collection_name,
metadata={"description": "短期消息集合记忆", "hnsw:space": "cosine"},
)
logger.info(f"消息集合存储初始化完成,集合: '{self.collection_name}'")
except Exception as e:
logger.error(f"消息集合存储初始化失败: {e}", exc_info=True)
raise
async def add_collection(self, collection: MessageCollection):
"""添加一个新的消息集合,并处理容量和时间限制"""
try:
# 清理过期和超额的集合
await self._cleanup_collections()
# 向量化并存储
embedding = await get_embedding(collection.combined_text)
if not embedding:
logger.warning(f"无法为消息集合 {collection.collection_id} 生成向量,跳过存储。")
return
collection.embedding = embedding
self.vector_db_service.add(
collection_name=self.collection_name,
embeddings=[embedding],
ids=[collection.collection_id],
documents=[collection.combined_text],
metadatas=[collection.to_dict()],
)
logger.debug(f"成功存储消息集合: {collection.collection_id}")
except Exception as e:
logger.error(f"存储消息集合失败: {e}", exc_info=True)
async def _cleanup_collections(self):
"""清理超额和过期的消息集合"""
try:
# 基于时间清理
if self.config.instant_memory_retention_hours > 0:
expiration_time = time.time() - self.config.instant_memory_retention_hours * 3600
expired_docs = self.vector_db_service.get(
collection_name=self.collection_name,
where={"created_at": {"$lt": expiration_time}},
include=[], # 只获取ID
)
if expired_docs and expired_docs.get("ids"):
self.vector_db_service.delete(collection_name=self.collection_name, ids=expired_docs["ids"])
logger.info(f"删除了 {len(expired_docs['ids'])} 个过期的瞬时记忆")
# 基于数量清理
current_count = self.vector_db_service.count(self.collection_name)
if current_count > self.config.instant_memory_max_collections:
num_to_delete = current_count - self.config.instant_memory_max_collections
# 获取所有文档的元数据以进行排序
all_docs = self.vector_db_service.get(
collection_name=self.collection_name,
include=["metadatas"]
)
if all_docs and all_docs.get("ids"):
# 在内存中排序找到最旧的文档
sorted_docs = sorted(
zip(all_docs["ids"], all_docs["metadatas"]),
key=lambda item: item[1].get("created_at", 0),
)
ids_to_delete = [doc[0] for doc in sorted_docs[:num_to_delete]]
if ids_to_delete:
self.vector_db_service.delete(collection_name=self.collection_name, ids=ids_to_delete)
logger.info(f"消息集合已满,删除最旧的 {len(ids_to_delete)} 个集合")
except Exception as e:
logger.error(f"清理消息集合失败: {e}", exc_info=True)
async def get_relevant_collection(self, query_text: str, n_results: int = 1) -> list[MessageCollection]:
"""根据查询文本检索最相关的消息集合"""
if not query_text.strip():
return []
try:
query_embedding = await get_embedding(query_text)
if not query_embedding:
return []
results = self.vector_db_service.query(
collection_name=self.collection_name,
query_embeddings=[query_embedding],
n_results=n_results,
)
collections = []
if results and results.get("ids") and results["ids"][0]:
collections.extend(MessageCollection.from_dict(metadata) for metadata in results["metadatas"][0])
return collections
except Exception as e:
logger.error(f"检索相关消息集合失败: {e}", exc_info=True)
return []
async def get_message_collection_context(self, query_text: str, chat_id: str) -> str:
"""获取消息集合上下文,用于添加到 prompt 中。优先展示当前聊天的上下文。"""
try:
collections = await self.get_relevant_collection(query_text, n_results=5)
if not collections:
return ""
# 根据传入的 chat_id 对集合进行排序
collections.sort(key=lambda c: c.chat_id == chat_id, reverse=True)
context_parts = []
for collection in collections:
if not collection.combined_text:
continue
header = "## 📝 相关对话上下文\n"
if collection.chat_id == chat_id:
# 匹配的ID使用更明显的标识
context_parts.append(
f"{header} [🔥 来自当前聊天的上下文]\n```\n{collection.combined_text}\n```"
)
else:
# 不匹配的ID
context_parts.append(
f"{header} [💡 来自其他聊天的相关上下文 (ID: {collection.chat_id})]\n```\n{collection.combined_text}\n```"
)
if not context_parts:
return ""
# 格式化消息集合为 prompt 上下文
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
return f"\n{final_context}\n"
except Exception as e:
logger.error(f"get_message_collection_context 失败: {e}")
return ""
def clear_all(self):
"""清空所有消息集合"""
try:
# In ChromaDB, the easiest way to clear a collection is to delete and recreate it.
self.vector_db_service.delete_collection(name=self.collection_name)
self._initialize_storage()
logger.info(f"已清空所有消息集合: '{self.collection_name}'")
except Exception as e:
logger.error(f"清空消息集合失败: {e}", exc_info=True)
def get_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
try:
count = self.vector_db_service.count(self.collection_name)
return {
"collection_name": self.collection_name,
"total_collections": count,
"storage_limit": self.config.instant_memory_max_collections,
}
except Exception as e:
logger.error(f"获取消息集合存储统计失败: {e}")
return {}

File diff suppressed because it is too large Load Diff

View File

@@ -69,7 +69,11 @@ class SingleStreamContextManager:
try:
from .message_manager import message_manager as mm
message_manager = mm
use_cache_system = message_manager.is_running
# 检查配置是否启用消息缓存系统
cache_enabled = global_config.chat.enable_message_cache
use_cache_system = message_manager.is_running and cache_enabled
if not cache_enabled:
logger.debug(f"消息缓存系统已在配置中禁用")
except Exception as e:
logger.debug(f"MessageManager不可用使用直接添加: {e}")
use_cache_system = False

View File

@@ -323,8 +323,8 @@ class GlobalNoticeManager:
return message.additional_config.get("is_notice", False)
elif isinstance(message.additional_config, str):
# 兼容JSON字符串格式
import json
config = json.loads(message.additional_config)
import orjson
config = orjson.loads(message.additional_config)
return config.get("is_notice", False)
# 检查消息类型或其他标识
@@ -349,8 +349,8 @@ class GlobalNoticeManager:
if isinstance(message.additional_config, dict):
return message.additional_config.get("notice_type")
elif isinstance(message.additional_config, str):
import json
config = json.loads(message.additional_config)
import orjson
config = orjson.loads(message.additional_config)
return config.get("notice_type")
return None
except Exception:

View File

@@ -12,6 +12,7 @@ from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.core import get_db_session
from src.common.database.core.models import Images, Messages
from src.common.logger import get_logger
from src.config.config import global_config
from .chat_stream import ChatStream
from .message import MessageSending
@@ -181,12 +182,14 @@ class MessageStorageBatcher:
is_command = message.is_command or False
is_public_notice = message.is_public_notice or False
notice_type = message.notice_type
actions = message.actions
# 序列化actions列表为JSON字符串
actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None
should_reply = message.should_reply
should_act = message.should_act
additional_config = message.additional_config
key_words = ""
key_words_lite = ""
# 确保关键词字段是字符串格式(如果不是,则序列化)
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
memorized_times = 0
user_platform = message.user_info.platform if message.user_info else ""
@@ -253,7 +256,8 @@ class MessageStorageBatcher:
is_command = message.is_command
is_public_notice = getattr(message, "is_public_notice", False)
notice_type = getattr(message, "notice_type", None)
actions = getattr(message, "actions", None)
# 序列化actions列表为JSON字符串
actions = orjson.dumps(getattr(message, "actions", None)).decode("utf-8") if getattr(message, "actions", None) else None
should_reply = getattr(message, "should_reply", None)
should_act = getattr(message, "should_act", None)
additional_config = getattr(message, "additional_config", None)
@@ -275,6 +279,9 @@ class MessageStorageBatcher:
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
if user_id == global_config.bot.qq_account:
user_id = "SELF"
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
@@ -576,6 +583,11 @@ class MessageStorage:
is_picid = False
is_notify = False
is_command = False
is_public_notice = False
notice_type = None
actions = None
should_reply = False
should_act = False
key_words = ""
key_words_lite = ""
else:
@@ -589,6 +601,12 @@ class MessageStorage:
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
is_public_notice = getattr(message, "is_public_notice", False)
notice_type = getattr(message, "notice_type", None)
# 序列化actions列表为JSON字符串
actions = orjson.dumps(getattr(message, "actions", None)).decode("utf-8") if getattr(message, "actions", None) else None
should_reply = getattr(message, "should_reply", False)
should_act = getattr(message, "should_act", False)
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
@@ -612,6 +630,9 @@ class MessageStorage:
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
# 将机器人自己的user_id标记为"SELF",增强对自我身份的识别
if user_id == global_config.bot.qq_account:
user_id = "SELF"
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
@@ -659,6 +680,11 @@ class MessageStorage:
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
is_public_notice=is_public_notice,
notice_type=notice_type,
actions=actions,
should_reply=should_reply,
should_act=should_act,
key_words=key_words,
key_words_lite=key_words_lite,
)

View File

@@ -255,8 +255,6 @@ class DefaultReplyer:
self._chat_info_initialized = False
self.heart_fc_sender = HeartFCSender()
# 使用新的增强记忆系统
# from src.chat.memory_system.enhanced_memory_activator import EnhancedMemoryActivator
self._chat_info_initialized = False
async def _initialize_chat_info(self):
@@ -393,19 +391,9 @@ class DefaultReplyer:
f"插件{result.get_summary().get('stopped_handlers', '')}于请求后取消了内容生成"
)
# 回复生成成功后,异步存储聊天记忆(不阻塞返回)
try:
# 将记忆存储作为子任务创建,可以被取消
memory_task = asyncio.create_task(
self._store_chat_memory_async(reply_to, reply_message),
name=f"store_memory_{self.chat_stream.stream_id}"
)
# 不等待完成,让它在后台运行
# 如果父任务被取消,这个子任务也会被垃圾回收
logger.debug(f"创建记忆存储子任务: {memory_task.get_name()}")
except Exception as memory_e:
# 记忆存储失败不应该影响回复生成的成功返回
logger.warning(f"记忆存储失败,但不影响回复生成: {memory_e}")
# 旧的自动记忆存储已移除,现在使用记忆图系统通过工具创建记忆
# 记忆由LLM在对话过程中通过CreateMemoryTool主动创建而非自动存储
pass
return True, llm_response, prompt
@@ -550,178 +538,116 @@ class DefaultReplyer:
Returns:
str: 记忆信息字符串
"""
if not global_config.memory.enable_memory:
return ""
# 使用新的记忆图系统检索记忆(带智能查询优化)
all_memories = []
try:
from src.memory_graph.manager_singleton import get_memory_manager, is_initialized
if is_initialized():
manager = get_memory_manager()
if manager:
# 构建查询上下文
stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None)
sender_name = ""
if user_info_obj:
sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "")
# 获取参与者信息
participants = []
try:
# 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'):
history_manager = stream.chat_history_manager
# 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history(
user_id=getattr(stream, "user_id", ""),
count=10,
memory_types=["chat_message", "system_message"]
)
# 提取唯一的参与者名称
for record in recent_records[:5]: # 最近5条记录
content = record.get("content", {})
participant = content.get("participant_name")
if participant and participant not in participants:
participants.append(participant)
instant_memory = None
# 如果消息包含发送者信息,也添加到参与者列表
if content.get("sender_name") and content.get("sender_name") not in participants:
participants.append(content.get("sender_name"))
except Exception as e:
logger.debug(f"获取参与者信息失败: {e}")
# 使用新的增强记忆系统检索记忆
running_memories = []
instant_memory = None
# 如果发送者不在参与者列表中,添加进去
if sender_name and sender_name not in participants:
participants.insert(0, sender_name)
if global_config.memory.enable_memory:
try:
# 使用新的统一记忆系统
from src.chat.memory_system import get_memory_system
# 格式化聊天历史为更友好的格式
formatted_history = ""
if chat_history:
# 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n')
recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines)
stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None)
group_info_obj = getattr(stream, "group_info", None)
memory_user_id = str(stream.stream_id)
memory_user_display = None
memory_aliases = []
user_info_dict = {}
if user_info_obj is not None:
raw_user_id = getattr(user_info_obj, "user_id", None)
if raw_user_id:
memory_user_id = str(raw_user_id)
if hasattr(user_info_obj, "to_dict"):
try:
user_info_dict = user_info_obj.to_dict() # type: ignore[attr-defined]
except Exception:
user_info_dict = {}
candidate_keys = [
"user_cardname",
"user_nickname",
"nickname",
"remark",
"display_name",
"user_name",
]
for key in candidate_keys:
value = user_info_dict.get(key)
if isinstance(value, str) and value.strip():
stripped = value.strip()
if memory_user_display is None:
memory_user_display = stripped
elif stripped not in memory_aliases:
memory_aliases.append(stripped)
attr_keys = [
"user_cardname",
"user_nickname",
"nickname",
"remark",
"display_name",
"name",
]
for attr in attr_keys:
value = getattr(user_info_obj, attr, None)
if isinstance(value, str) and value.strip():
stripped = value.strip()
if memory_user_display is None:
memory_user_display = stripped
elif stripped not in memory_aliases:
memory_aliases.append(stripped)
alias_values = (
user_info_dict.get("aliases")
or user_info_dict.get("alias_names")
or user_info_dict.get("alias")
query_context = {
"chat_history": formatted_history,
"sender": sender_name,
"participants": participants,
}
# 使用记忆管理器的智能检索(多查询策略)
memories = await manager.search_memories(
query=target,
top_k=10,
min_importance=0.3,
include_forgotten=False,
use_multi_query=True,
context=query_context,
)
if isinstance(alias_values, list | tuple | set):
for alias in alias_values:
if isinstance(alias, str) and alias.strip():
stripped = alias.strip()
if stripped not in memory_aliases and stripped != memory_user_display:
memory_aliases.append(stripped)
memory_context = {
"user_id": memory_user_id,
"user_display_name": memory_user_display or "",
"user_name": memory_user_display or "",
"nickname": memory_user_display or "",
"sender_name": memory_user_display or "",
"platform": getattr(stream, "platform", None),
"chat_id": stream.stream_id,
"stream_id": stream.stream_id,
}
if memory_aliases:
memory_context["user_aliases"] = memory_aliases
if group_info_obj is not None:
group_name = getattr(group_info_obj, "group_name", None) or getattr(
group_info_obj, "group_nickname", None
)
if group_name:
memory_context["group_name"] = str(group_name)
group_id = getattr(group_info_obj, "group_id", None)
if group_id:
memory_context["group_id"] = str(group_id)
memory_context = {key: value for key, value in memory_context.items() if value}
# 获取记忆系统实例
memory_system = get_memory_system()
# 使用统一记忆系统检索相关记忆
enhanced_memories = await memory_system.retrieve_relevant_memories(
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
)
# 注意:记忆存储已迁移到回复生成完成后进行,不在查询阶段执行
# 转换格式以兼容现有代码
running_memories = []
if enhanced_memories:
logger.debug(f"[记忆转换] 收到 {len(enhanced_memories)} 条原始记忆")
for idx, memory_chunk in enumerate(enhanced_memories, 1):
# 获取结构化内容的字符串表示
structure_display = str(memory_chunk.content) if hasattr(memory_chunk, "content") else "unknown"
# 获取记忆内容优先使用display
content = memory_chunk.display or memory_chunk.text_content or ""
# 调试:记录每条记忆的内容获取情况
logger.debug(
f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}"
if memories:
logger.info(f"[记忆图] 检索到 {len(memories)} 条相关记忆")
# 使用新的格式化工具构建完整的记忆描述
from src.memory_graph.utils.memory_formatter import (
format_memory_for_prompt,
get_memory_type_label,
)
running_memories.append(
{
"content": content,
"memory_type": memory_chunk.memory_type.value,
"confidence": memory_chunk.metadata.confidence.value,
"importance": memory_chunk.metadata.importance.value,
"relevance": getattr(memory_chunk.metadata, "relevance_score", 0.5),
"source": memory_chunk.metadata.source,
"structure": structure_display,
}
)
# 构建瞬时记忆字符串
if running_memories:
top_memory = running_memories[:1]
if top_memory:
instant_memory = top_memory[0].get("content", "")
logger.info(
f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆"
)
except Exception as e:
logger.warning(f"增强记忆系统检索失败: {e}")
running_memories = []
instant_memory = ""
for memory in memories:
# 使用格式化工具生成完整的主谓宾描述
content = format_memory_for_prompt(memory, include_metadata=False)
# 获取记忆类型
mem_type = memory.memory_type.value if memory.memory_type else "未知"
if content:
all_memories.append({
"content": content,
"memory_type": mem_type,
"importance": memory.importance,
"relevance": 0.7,
"source": "memory_graph",
})
logger.debug(f"[记忆构建] 格式化记忆: [{mem_type}] {content[:50]}...")
else:
logger.debug("[记忆图] 未找到相关记忆")
except Exception as e:
logger.debug(f"[记忆图] 检索失败: {e}")
all_memories = []
# 构建记忆字符串,使用方括号格式
memory_str = ""
has_any_memory = False
# 添加长期记忆(来自增强记忆系统)
if running_memories:
# 添加长期记忆(来自记忆系统)
if all_memories:
# 使用方括号格式
memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
# 按相关度排序,并记录相关度信息用于调试
sorted_memories = sorted(running_memories, key=lambda x: x.get("relevance", 0.0), reverse=True)
sorted_memories = sorted(all_memories, key=lambda x: x.get("relevance", 0.0), reverse=True)
# 调试相关度信息
relevance_info = [(m.get("memory_type", "unknown"), m.get("relevance", 0.0)) for m in sorted_memories]
@@ -738,8 +664,13 @@ class DefaultReplyer:
logger.debug(f"[记忆构建] 空记忆详情: {running_memory}")
continue
# 使用全局记忆类型映射表
chinese_type = get_memory_type_chinese_label(memory_type)
# 使用记忆图的类型映射(优先)或全局映射
try:
from src.memory_graph.utils.memory_formatter import get_memory_type_label
chinese_type = get_memory_type_label(memory_type)
except ImportError:
# 回退到全局映射
chinese_type = get_memory_type_chinese_label(memory_type)
# 提取纯净内容(如果包含旧格式的元数据)
clean_content = content
@@ -753,13 +684,7 @@ class DefaultReplyer:
has_any_memory = True
logger.debug(f"[记忆构建] 成功构建记忆字符串,包含 {len(memory_parts) - 2} 条记忆")
# 添加瞬时记忆
if instant_memory:
if not any(rm["content"] == instant_memory for rm in running_memories):
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
memory_str += f"- 最相关记忆:{instant_memory}\n"
has_any_memory = True
# 瞬时记忆由另一套系统处理,这里不再添加
# 只有当完全没有任何记忆时才返回空字符串
return memory_str if has_any_memory else ""
@@ -780,32 +705,46 @@ class DefaultReplyer:
return ""
try:
# 使用工具执行器获取信息
# 首先获取当前的历史记录(在执行新工具调用之前)
tool_history_str = self.tool_executor.history_manager.format_for_prompt(max_records=3, include_results=True)
# 然后执行工具调用
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
sender=sender, target_message=target, chat_history=chat_history, return_details=False
)
info_parts = []
# 显示之前的工具调用历史(不包括当前这次调用)
if tool_history_str:
info_parts.append(tool_history_str)
# 显示当前工具调用的结果(简要信息)
if tool_results:
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
current_results_parts = ["## 🔧 刚获取的工具信息"]
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
# 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}")
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
info_parts.append("\n".join(current_results_parts))
logger.info(f"获取到 {len(tool_results)} 个工具结果")
return tool_info_str
else:
logger.debug("未获取到任何工具结果")
# 如果没有任何信息,返回空字符串
if not info_parts:
logger.debug("未获取到任何工具结果或历史记录")
return ""
return "\n\n".join(info_parts)
except Exception as e:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt
@@ -1145,29 +1084,6 @@ class DefaultReplyer:
return read_history_prompt, unread_history_prompt
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
"""为消息获取兴趣度评分(使用预计算的兴趣值)"""
interest_scores = {}
try:
# 直接使用消息中的预计算兴趣值
for msg_dict in messages:
message_id = msg_dict.get("message_id", "")
interest_value = msg_dict.get("interest_value")
if interest_value is not None:
interest_scores[message_id] = float(interest_value)
logger.debug(f"使用预计算兴趣度 - 消息 {message_id}: {interest_value:.3f}")
else:
interest_scores[message_id] = 0.5 # 默认值
logger.debug(f"消息 {message_id} 无预计算兴趣值,使用默认值 0.5")
except Exception as e:
logger.warning(f"处理预计算兴趣值失败: {e}")
return interest_scores
async def build_prompt_reply_context(
self,
reply_to: str,
@@ -1976,14 +1892,22 @@ class DefaultReplyer:
return f"你与{sender}是普通朋友关系。"
# 已废弃:旧的自动记忆存储逻辑
# 新的记忆图系统通过LLM工具(CreateMemoryTool)主动创建记忆,而非自动存储
async def _store_chat_memory_async(self, reply_to: str, reply_message: DatabaseMessages | dict[str, Any] | None = None):
"""
异步存储聊天记忆从build_memory_block迁移而来
[已废弃] 异步存储聊天记忆从build_memory_block迁移而来
此函数已被记忆图系统的工具调用方式替代。
记忆现在由LLM在对话过程中通过CreateMemoryTool主动创建。
Args:
reply_to: 回复对象
reply_message: 回复的原始消息
"""
return # 已禁用,保留函数签名以防其他地方有引用
# 以下代码已废弃,不再执行
try:
if not global_config.memory.enable_memory:
return
@@ -2121,23 +2045,9 @@ class DefaultReplyer:
show_actions=True,
)
# 异步存储聊天历史(完全非阻塞)
memory_system = get_memory_system()
task = asyncio.create_task(
memory_system.process_conversation_memory(
context={
"conversation_text": chat_history,
"user_id": memory_user_id,
"scope_id": stream.stream_id,
**memory_context,
}
)
)
# 将任务添加到全局集合以防止被垃圾回收
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.debug(f"已启动记忆存储任务,用户: {memory_user_display or memory_user_id}")
# 旧记忆系统的自动存储已禁用
# 新记忆系统通过 LLM 工具调用create_memory来创建记忆
logger.debug(f"记忆创建通过 LLM 工具调用进行,用户: {memory_user_display or memory_user_id}")
except asyncio.CancelledError:
logger.debug("记忆存储任务被取消")

View File

@@ -44,8 +44,8 @@ def replace_user_references_sync(
if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
return f"{global_config.bot.nickname}(你)"
# 同步函数中无法使用异步的 get_value直接返回 user_id
# 建议调用方使用 replace_user_references_async 以获取完整的用户名
@@ -60,8 +60,8 @@ def replace_user_references_sync(
aaa = match[1]
bbb = match[2]
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
reply_person_name = f"{global_config.bot.nickname}(你)"
else:
reply_person_name = name_resolver(platform, bbb) or aaa
@@ -81,8 +81,8 @@ def replace_user_references_sync(
aaa = m.group(1)
bbb = m.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
at_person_name = f"{global_config.bot.nickname}(你)"
else:
at_person_name = name_resolver(platform, bbb) or aaa
@@ -120,8 +120,8 @@ async def replace_user_references_async(
person_info_manager = get_person_info_manager()
async def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id)
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
@@ -135,8 +135,8 @@ async def replace_user_references_async(
aaa = match.group(1)
bbb = match.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
reply_person_name = f"{global_config.bot.nickname}(你)"
else:
reply_person_name = await name_resolver(platform, bbb) or aaa
@@ -156,8 +156,8 @@ async def replace_user_references_async(
aaa = m.group(1)
bbb = m.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (bbb == "SELF" or bbb == global_config.bot.qq_account):
at_person_name = f"{global_config.bot.nickname}(你)"
else:
at_person_name = await name_resolver(platform, bbb) or aaa
@@ -638,13 +638,14 @@ async def _build_readable_messages_internal(
if not all([platform, user_id, timestamp is not None]):
continue
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager()
# 根据 replace_bot_name 参数决定是否替换机器人名称
person_name: str
if replace_bot_name and user_id == global_config.bot.qq_account:
# 检查是否是机器人自己支持SELF标记或直接比对QQ号
if replace_bot_name and (user_id == "SELF" or user_id == global_config.bot.qq_account):
person_name = f"{global_config.bot.nickname}(你)"
else:
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
@@ -656,8 +657,8 @@ async def _build_readable_messages_internal(
else:
person_name = "某人"
# 在用户名后面添加 QQ 号, 但机器人本体不用
if user_id != global_config.bot.qq_account:
# 在用户名后面添加 QQ 号, 但机器人本体不用包括SELF标记
if user_id != global_config.bot.qq_account and user_id != "SELF":
person_name = f"{person_name}({user_id})"
# 使用独立函数处理用户引用格式

View File

@@ -398,6 +398,9 @@ class Prompt:
"""
start_time = time.time()
# 初始化预构建参数字典
pre_built_params = {}
try:
# --- 步骤 1: 准备构建任务 ---
tasks = []
@@ -406,7 +409,6 @@ class Prompt:
# --- 步骤 1.1: 优先使用预构建的参数 ---
# 如果参数对象中已经包含了某些block说明它们是外部预构建的
# 我们将它们存起来,并跳过对应的实时构建任务。
pre_built_params = {}
if self.parameters.expression_habits_block:
pre_built_params["expression_habits_block"] = self.parameters.expression_habits_block
if self.parameters.relation_info_block:
@@ -428,11 +430,9 @@ class Prompt:
tasks.append(self._build_expression_habits())
task_names.append("expression_habits")
# 记忆块构建非常耗时,强烈建议预构建。如果没有预构建,这里会运行一个快速的后备版本。
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
logger.debug("memory_block未预构建执行快速构建作为后备方案")
tasks.append(self._build_memory_block_fast())
task_names.append("memory_block")
# 记忆块构建已移至 default_generator.py 的 build_memory_block 方法
# 使用新的记忆图系统,不再在 prompt.py 中构建记忆
# 如果需要记忆,必须通过 pre_built_params 传入
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info())
@@ -637,146 +637,6 @@ class Prompt:
logger.error(f"构建表达习惯失败: {e}")
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> dict[str, Any]:
"""构建与当前对话相关的记忆上下文块(完整版)."""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
try:
from src.chat.memory_system.enhanced_memory_activator import enhanced_memory_activator
# 准备用于记忆激活的聊天历史
chat_history = ""
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-20:]
chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 并行查询长期记忆和即时记忆以提高性能
import asyncio
memory_tasks = [
enhanced_memory_activator.activate_memory_with_chat_history(
target_message=self.parameters.target, chat_history_prompt=chat_history
),
enhanced_memory_activator.get_instant_memory(
target_message=self.parameters.target, chat_id=self.parameters.chat_id
),
]
try:
# 使用 `return_exceptions=True` 来防止一个任务的失败导致所有任务失败
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
# 单独处理每个任务的结果,如果是异常则记录并使用默认值
if isinstance(running_memories, BaseException):
logger.warning(f"长期记忆查询失败: {running_memories}")
running_memories = []
if isinstance(instant_memory, BaseException):
logger.warning(f"即时记忆查询失败: {instant_memory}")
instant_memory = None
except asyncio.TimeoutError:
logger.warning("记忆查询超时,使用部分结果")
running_memories = []
instant_memory = None
# 将检索到的记忆格式化为提示词
if running_memories:
try:
from src.chat.memory_system.memory_formatter import format_memories_bracket_style
# 将原始记忆数据转换为格式化器所需的标准格式
formatted_memories = []
for memory in running_memories:
content = memory.get("content", "")
display_text = content
# 清理内容,移除元数据括号
if "(类型:" in content and "" in content:
display_text = content.split("(类型:")[0].strip()
# 映射记忆主题到标准类型
topic = memory.get("topic", "personal_fact")
memory_type_mapping = {
"relationship": "personal_fact",
"opinion": "opinion",
"personal_fact": "personal_fact",
"preference": "preference",
"event": "event",
}
mapped_type = memory_type_mapping.get(topic, "personal_fact")
formatted_memories.append(
{
"display": display_text,
"memory_type": mapped_type,
"metadata": {
"confidence": memory.get("confidence", "未知"),
"importance": memory.get("importance", "一般"),
"timestamp": memory.get("timestamp", ""),
"source": memory.get("source", "unknown"),
"relevance_score": memory.get("relevance_score", 0.0),
},
}
)
# 使用指定的风格进行格式化
memory_block = format_memories_bracket_style(
formatted_memories, query_context=self.parameters.target
)
except Exception as e:
# 如果格式化失败,提供一个简化的、健壮的备用格式
logger.warning(f"记忆格式化失败,使用简化格式: {e}")
memory_parts = ["## 相关记忆回顾", ""]
for memory in running_memories:
content = memory.get("content", "")
if "(类型:" in content and "" in content:
clean_content = content.split("(类型:")[0].strip()
memory_parts.append(f"- {clean_content}")
else:
memory_parts.append(f"- {content}")
memory_block = "\n".join(memory_parts)
else:
memory_block = ""
# 将即时记忆附加到记忆块的末尾
if instant_memory:
if memory_block:
memory_block += f"\n- 最相关记忆:{instant_memory}"
else:
memory_block = f"- 最相关记忆:{instant_memory}"
return {"memory_block": memory_block}
except Exception as e:
logger.error(f"构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_memory_block_fast(self) -> dict[str, Any]:
"""快速构建记忆块(简化版),作为未预构建时的后备方案."""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
try:
from src.chat.memory_system.enhanced_memory_activator import enhanced_memory_activator
# 这个快速版本只查询最高优先级的“即时记忆”,速度更快
instant_memory = await enhanced_memory_activator.get_instant_memory(
target_message=self.parameters.target, chat_id=self.parameters.chat_id
)
if instant_memory:
memory_block = f"- 相关记忆:{instant_memory}"
else:
memory_block = ""
return {"memory_block": memory_block}
except Exception as e:
logger.warning(f"快速构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_relation_info(self) -> dict[str, Any]:
"""构建与对话目标相关的关系信息."""
try: