From bb4ff48e26725bb47afa265626db751d1d291ce5 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:44:31 +0800 Subject: [PATCH 1/5] =?UTF-8?q?refactor(proactive=5Fthinker):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E4=B8=BB=E5=8A=A8=E6=80=9D=E8=80=83=E7=9A=84=E5=86=B3?= =?UTF-8?q?=E7=AD=96=E4=B8=8E=E4=BA=A4=E4=BA=92=EF=BC=8C=E4=BD=BF=E5=85=B6?= =?UTF-8?q?=E6=9B=B4=E8=87=AA=E7=84=B6=E4=B8=94=E9=81=BF=E5=85=8D=E6=89=93?= =?UTF-8?q?=E6=89=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次提交对主动思考插件进行了多项核心优化,旨在提升其交互的自然度和人性化,并引入了关键的防打扰机制。 主要变更包括: 1. **重构冷启动任务 (`ColdStartTask`)**: - 任务逻辑从一个长期运行的周期性任务,重构为在机器人启动时执行一次的“唤醒”任务。 - 新逻辑不仅能为白名单中的全新用户发起首次问候,还能智能地“唤醒”那些因机器人重启而“沉睡”的聊天流,确保了主动思考功能的连续性。 2. **增强决策提示词 (`_build_plan_prompt`)**: - 引入了更精细的决策原则,核心是增加了防打扰机制。现在模型在决策时会检查上一条消息是否为自己发送,如果对方尚未回复,则倾向于不发起新对话,以表现出耐心和体贴。 - 优化了示例,引导模型优先利用上下文信息,并在无切入点时使用简单的问候,避免创造生硬抽象的话题。 3. **改善回复生成逻辑 (`_build_*_reply_prompt`)**: - 在生成回复的指令中,明确要求模型必须先用一句通用的礼貌问候语(如“在吗?”、“下午好!”)作为开场白,然后再衔接具体话题。这使得主动发起的对话更加自然、流畅,符合人类的沟通习惯。 4. **模型调整**: - 将决策规划阶段的 LLM 模型从 `utils` 调整为 `replyer`,以更好地适应生成对话策略的任务。 --- .../proacive_thinker_event.py | 97 +++++++++---------- .../proactive_thinker_executor.py | 38 ++++++-- 2 files changed, 75 insertions(+), 60 deletions(-) diff --git a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py index c310e5c45..ffd663d18 100644 --- a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py +++ b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py @@ -21,76 +21,69 @@ logger = get_logger(__name__) class ColdStartTask(AsyncTask): """ - 冷启动任务,专门用于处理那些在白名单里,但从未与机器人发生过交互的用户。 - 它的核心职责是“破冰”,主动创建聊天流并发起第一次问候。 + “冷启动”任务,在机器人启动时执行一次。 + 它的核心职责是“唤醒”那些因重启而“沉睡”的聊天流,确保它们能够接收主动思考。 + 对于在白名单中但从未有过记录的全新用户,它也会发起第一次“破冰”问候。 """ - def __init__(self): + def __init__(self, bot_start_time: float): super().__init__(task_name="ColdStartTask") self.chat_manager = get_chat_manager() self.executor = ProactiveThinkerExecutor() + self.bot_start_time = bot_start_time async def run(self): - """任务主循环,周期性地检查是否有需要“破冰”的新用户。""" - logger.info("冷启动任务已启动,将周期性检查白名单中的新朋友。") - # 初始等待一段时间,确保其他服务(如数据库)完全启动 - await asyncio.sleep(100) + """任务主逻辑,在启动后执行一次白名单扫描。""" + logger.info("冷启动任务已启动,将在短暂延迟后开始唤醒沉睡的聊天流...") + await asyncio.sleep(30) # 延迟以确保所有服务和聊天流已从数据库加载完毕 - while True: - try: - #开始就先暂停一小时,等bot聊一会再说() - await asyncio.sleep(3600) - logger.info("【冷启动】开始扫描白名单,寻找从未聊过的用户...") + try: + logger.info("【冷启动】开始扫描白名单,唤醒沉睡的聊天流...") - # 从全局配置中获取私聊白名单 - enabled_private_chats = global_config.proactive_thinking.enabled_private_chats - if not enabled_private_chats: - logger.debug("【冷启动】私聊白名单为空,任务暂停一小时。") - await asyncio.sleep(3600) # 白名单为空时,没必要频繁检查 - continue + enabled_private_chats = global_config.proactive_thinking.enabled_private_chats + if not enabled_private_chats: + logger.debug("【冷启动】私聊白名单为空,任务结束。") + return - # 遍历白名单中的每一个用户 - for chat_id in enabled_private_chats: - try: - platform, user_id_str = chat_id.split(":") - user_id = int(user_id_str) + for chat_id in enabled_private_chats: + try: + platform, user_id_str = chat_id.split(":") + user_id = int(user_id_str) - # 【核心逻辑】使用 chat_api 检查该用户是否已经存在聊天流(ChatStream) - # 如果返回了 ChatStream 对象,说明已经聊过天了,不是本次任务的目标 - if chat_api.get_stream_by_user_id(user_id_str, platform): - continue # 跳过已存在的用户 + should_wake_up = False + stream = chat_api.get_stream_by_user_id(user_id_str, platform) - logger.info(f"【冷启动】发现白名单新用户 {chat_id},准备发起第一次问候。") + if not stream: + should_wake_up = True + logger.info(f"【冷启动】发现全新用户 {chat_id},准备发起第一次问候。") + elif stream.last_active_time < self.bot_start_time: + should_wake_up = True + logger.info(f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。") - # 【增强体验】尝试从关系数据库中获取该用户的昵称 - # 这样打招呼时可以更亲切,而不是只知道一个冷冰冰的ID + if should_wake_up: person_id = person_api.get_person_id(platform, user_id) nickname = await person_api.get_person_value(person_id, "nickname") - - # 如果数据库里有昵称,就用数据库里的;如果没有,就用 "用户+ID" 作为备用 user_nickname = nickname or f"用户{user_id}" - - # 创建 UserInfo 对象,这是创建聊天流的必要信息 user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname) - - # 【关键步骤】主动创建聊天流。 - # 创建后,该用户就进入了机器人的“好友列表”,后续将由 ProactiveThinkingTask 接管 + + # 使用 get_or_create_stream 来安全地获取或创建流 stream = await self.chat_manager.get_or_create_stream(platform, user_info) + + formatted_stream_id = f"{stream.user_info.platform}:{stream.user_info.user_id}:private" + await self.executor.execute(stream_id=formatted_stream_id, start_mode="cold_start") + logger.info(f"【冷启动】已为用户 {chat_id} (昵称: {user_nickname}) 发送唤醒/问候消息。") - await self.executor.execute(stream_id=stream.stream_id, start_mode="cold_start") - logger.info(f"【冷启动】已为新用户 {chat_id} (昵称: {user_nickname}) 创建聊天流并发送问候。") + except ValueError: + logger.warning(f"【冷启动】白名单条目格式错误或用户ID无效,已跳过: {chat_id}") + except Exception as e: + logger.error(f"【冷启动】处理用户 {chat_id} 时发生未知错误: {e}", exc_info=True) - except ValueError: - logger.warning(f"【冷启动】白名单条目格式错误或用户ID无效,已跳过: {chat_id}") - except Exception as e: - logger.error(f"【冷启动】处理用户 {chat_id} 时发生未知错误: {e}", exc_info=True) - - except asyncio.CancelledError: - logger.info("冷启动任务被正常取消。") - break - except Exception as e: - logger.error(f"【冷启动】任务出现严重错误,将在5分钟后重试: {e}", exc_info=True) - await asyncio.sleep(300) + except asyncio.CancelledError: + logger.info("冷启动任务被正常取消。") + except Exception as e: + logger.error(f"【冷启动】任务出现严重错误: {e}", exc_info=True) + finally: + logger.info("【冷启动】任务执行完毕。") class ProactiveThinkingTask(AsyncTask): @@ -222,13 +215,15 @@ class ProactiveThinkerEventHandler(BaseEventHandler): logger.info("检测到插件启动事件,正在初始化【主动思考】") # 检查总开关 if global_config.proactive_thinking.enable: + bot_start_time = time.time() # 记录“诞生时刻” + # 启动负责“日常唤醒”的核心任务 proactive_task = ProactiveThinkingTask() await async_task_manager.add_task(proactive_task) # 检查“冷启动”功能的独立开关 if global_config.proactive_thinking.enable_cold_start: - cold_start_task = ColdStartTask() + cold_start_task = ColdStartTask(bot_start_time) await async_task_manager.add_task(cold_start_task) else: diff --git a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py index ea5187f1f..96377c800 100644 --- a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py +++ b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py @@ -80,7 +80,7 @@ class ProactiveThinkerExecutor: plan_prompt = self._build_plan_prompt(context, start_mode, topic, reason) is_success, response, _, _ = await llm_api.generate_with_model( - prompt=plan_prompt, model_config=model_config.model_task_config.utils + prompt=plan_prompt, model_config=model_config.model_task_config.replyer ) if is_success and response: @@ -158,12 +158,12 @@ class ProactiveThinkerExecutor: ) # 2. 构建基础上下文 + mood_state = "暂时没有" if global_config.mood.enable_mood: try: mood_state = mood_manager.get_mood_by_chat_id(stream.stream_id).mood_state except Exception as e: logger.error(f"获取情绪失败,原因:{e}") - mood_state = "暂时没有" base_context = { "schedule_context": schedule_context, "recent_chat_history": recent_chat_history, @@ -281,29 +281,47 @@ class ProactiveThinkerExecutor: # 构建通用尾部 prompt += """ # 决策指令 -请综合以上所有信息,做出决策。你的决策需要以JSON格式输出,包含以下字段: +请综合以上所有信息,以稳定、真实、拟人的方式做出决策。你的决策需要以JSON格式输出,包含以下字段: - `should_reply`: bool, 是否应该发起对话。 -- `topic`: str, 如果 `should_reply` 为 true,你打算聊什么话题?(例如:问候一下今天的日程、关心一下昨天的某件事、分享一个你自己的趣事等) +- `topic`: str, 如果 `should_reply` 为 true,你打算聊什么话题? - `reason`: str, 做出此决策的简要理由。 # 决策原则 -- **避免打扰**: 如果你最近(尤其是在最近的几次决策中)已经主动发起过对话,请倾向于选择“不回复”,除非有非常重要和紧急的事情。 +- **谨慎对待未回复的对话**: 在发起新话题前,请检查【最近的聊天摘要】。如果最后一条消息是你自己发送的,请仔细评估等待的时间和上下文,判断再次主动发起对话是否礼貌和自然。如果等待时间很短(例如几分钟或半小时内),通常应该选择“不回复”。 +- **优先利用上下文**: 优先从【情境分析】中已有的信息(如最近的聊天摘要、你的日程、你对Ta的关系印象)寻找自然的话题切入点。 +- **简单问候作为备选**: 如果上下文中没有合适的话题,可以生成一个简单、真诚的日常问候(例如“在忙吗?”,“下午好呀~”)。 +- **避免抽象**: 避免创造过于复杂、抽象或需要对方思考很久才能明白的话题。目标是轻松、自然地开启对话。 +- **避免过于频繁**: 如果你最近(尤其是在最近的几次决策中)已经主动发起过对话,请倾向于选择“不回复”,除非有非常重要和紧急的事情。 --- -示例1 (应该回复): +示例1 (基于上下文): {{ "should_reply": true, - "topic": "提醒大家今天下午有'项目会议'的日程", - "reason": "现在是上午,下午有个重要会议,我觉得应该主动提醒一下大家,这会显得我很贴心。" + "topic": "关心一下Ta昨天提到的那个项目进展如何了", + "reason": "用户昨天在聊天中提到了一个重要的项目,现在主动关心一下进展,会显得很体贴,也能自然地开启对话。" }} -示例2 (不应回复): +示例2 (简单问候): +{{ + "should_reply": true, + "topic": "打个招呼,问问Ta现在在忙些什么", + "reason": "最近没有聊天记录,日程也很常规,没有特别的切入点。一个简单的日常问候是最安全和自然的方式来重新连接。" +}} + +示例3 (不应回复 - 过于频繁): {{ "should_reply": false, "topic": null, "reason": "虽然群里很活跃,但现在是深夜,而且最近的聊天话题我也不熟悉,没有合适的理由去打扰大家。" }} + +示例4 (不应回复 - 等待回应): +{{ + "should_reply": false, + "topic": null, + "reason": "我注意到上一条消息是我几分钟前主动发送的,对方可能正在忙。为了表现出耐心和体贴,我现在最好保持安静,等待对方的回应。" +}} --- 请输出你的决策: @@ -399,6 +417,7 @@ class ProactiveThinkerExecutor: # 对话指引 - 你决定和Ta聊聊关于“{topic}”的话题。 +- **重要**: 在开始你的话题前,必须先用一句通用的、礼貌的开场白进行问候(例如:“在吗?”、“上午好!”、“晚上好呀~”),然后再自然地衔接你的话题,确保整个回复在一条消息内流畅、自然、像人类的说话方式。 - 请结合以上所有情境信息,自然地开启对话。 - 你的语气应该符合你的人设({context["mood_state"]})以及你对Ta的好感度。 """ @@ -436,6 +455,7 @@ class ProactiveThinkerExecutor: # 对话指引 - 你决定和大家聊聊关于“{topic}”的话题。 +- **重要**: 在开始你的话题前,必须先用一句通用的、礼貌的开场白进行问候(例如:“哈喽,大家好呀~”、“下午好!”),然后再自然地衔接你的话题,确保整个回复在一条消息内流畅、自然、像人类的说话方式。 - 你的语气应该更活泼、更具包容性,以吸引更多群成员参与讨论。你的语气应该符合你的人设)。 - 请结合以上所有情境信息,自然地开启对话。 - 可以分享你的看法、提出相关问题,或者开个合适的玩笑。 From 1eb41f8372176c8bcba349014e54e372deeaaf96 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 3 Oct 2025 22:00:53 +0800 Subject: [PATCH 2/5] =?UTF-8?q?refactor(memory):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E8=AE=B0=E5=BF=86=E7=B3=BB=E7=BB=9F=E6=9E=B6=E6=9E=84=EF=BC=8C?= =?UTF-8?q?=E5=BC=95=E5=85=A5=E5=8F=AF=E9=85=8D=E7=BD=AE=E7=9A=84=E9=87=87?= =?UTF-8?q?=E6=A0=B7=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将记忆系统从单一构建模式重构为多策略可配置架构,支持灵活的采样行为控制: 核心架构改进: - 重构记忆构建流程,分离采样策略与核心逻辑 - 引入MemorySamplingMode枚举,标准化采样模式定义 - 设计插件化采样器接口,支持海马体后台定时采样 - 优化记忆处理管道,添加bypass_interval机制支持后台采样 配置系统增强: - 新增memory_sampling_mode配置项,支持hippocampus/immediate/all三种模式 - 添加海马体双峰采样参数配置,支持自定义采样间隔和分布 - 引入自适应采样阈值控制,动态调整记忆构建频率 - 完善精准记忆配置,支持基于价值评分的触发机制 兼容性与性能优化: - 修复Python 3.9兼容性问题,替换strict=False参数 - 优化记忆检索逻辑,统一使用filters参数替代scope_id - 改进错误处理机制,增强系统稳定性 BREAKING CHANGE: 新增memory_sampling_mode配置项,默认值从adaptive改为immediate --- scripts/log_viewer_optimized.py | 6 +- src/chat/memory_system/hippocampus_sampler.py | 731 ++++++++++++++++++ src/chat/memory_system/memory_system.py | 199 ++++- src/config/official_configs.py | 35 + template/bot_config_template.toml | 15 +- 5 files changed, 969 insertions(+), 17 deletions(-) create mode 100644 src/chat/memory_system/hippocampus_sampler.py diff --git a/scripts/log_viewer_optimized.py b/scripts/log_viewer_optimized.py index f38dafa64..65cf579c0 100644 --- a/scripts/log_viewer_optimized.py +++ b/scripts/log_viewer_optimized.py @@ -373,7 +373,11 @@ class VirtualLogDisplay: # 为每个部分应用正确的标签 current_len = 0 - for part, tag_name in zip(parts, tags, strict=False): + # Python 3.9 兼容性:不使用 strict=False 参数 + min_len = min(len(parts), len(tags)) + for i in range(min_len): + part = parts[i] + tag_name = tags[i] start_index = f"{start_pos}+{current_len}c" end_index = f"{start_pos}+{current_len + len(part)}c" self.text_widget.tag_add(tag_name, start_index, end_index) diff --git a/src/chat/memory_system/hippocampus_sampler.py b/src/chat/memory_system/hippocampus_sampler.py new file mode 100644 index 000000000..0cc6b61d5 --- /dev/null +++ b/src/chat/memory_system/hippocampus_sampler.py @@ -0,0 +1,731 @@ +# -*- coding: utf-8 -*- +""" +海马体双峰分布采样器 +基于旧版海马体的采样策略,适配新版记忆系统 +实现低消耗、高效率的记忆采样模式 +""" + +import asyncio +import random +import time +from datetime import datetime, timedelta +from typing import List, Optional, Tuple, Dict, Any +from dataclasses import dataclass + +import numpy as np +import orjson + +from src.chat.utils.chat_message_builder import ( + get_raw_msg_by_timestamp, + build_readable_messages, + get_raw_msg_by_timestamp_with_chat, +) +from src.chat.utils.utils import translate_timestamp_to_human_readable +from src.common.logger import get_logger +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest + +logger = get_logger(__name__) + + +@dataclass +class HippocampusSampleConfig: + """海马体采样配置""" + # 双峰分布参数 + recent_mean_hours: float = 12.0 # 近期分布均值(小时) + recent_std_hours: float = 8.0 # 近期分布标准差(小时) + recent_weight: float = 0.7 # 近期分布权重 + + distant_mean_hours: float = 48.0 # 远期分布均值(小时) + distant_std_hours: float = 24.0 # 远期分布标准差(小时) + distant_weight: float = 0.3 # 远期分布权重 + + # 采样参数 + total_samples: int = 50 # 总采样数 + sample_interval: int = 1800 # 采样间隔(秒) + max_sample_length: int = 30 # 每次采样的最大消息数量 + batch_size: int = 5 # 批处理大小 + + @classmethod + def from_global_config(cls) -> 'HippocampusSampleConfig': + """从全局配置创建海马体采样配置""" + config = global_config.memory.hippocampus_distribution_config + return cls( + recent_mean_hours=config[0], + recent_std_hours=config[1], + recent_weight=config[2], + distant_mean_hours=config[3], + distant_std_hours=config[4], + distant_weight=config[5], + total_samples=global_config.memory.hippocampus_sample_size, + sample_interval=global_config.memory.hippocampus_sample_interval, + max_sample_length=global_config.memory.hippocampus_batch_size, + batch_size=global_config.memory.hippocampus_batch_size, + ) + + +class HippocampusSampler: + """海马体双峰分布采样器""" + + def __init__(self, memory_system=None): + self.memory_system = memory_system + self.config = HippocampusSampleConfig.from_global_config() + self.last_sample_time = 0 + self.is_running = False + + # 记忆构建模型 + self.memory_builder_model: Optional[LLMRequest] = None + + # 统计信息 + self.sample_count = 0 + self.success_count = 0 + self.last_sample_results: List[Dict[str, Any]] = [] + + async def initialize(self): + """初始化采样器""" + try: + # 初始化LLM模型 + from src.config.config import model_config + task_config = getattr(model_config.model_task_config, "utils", None) + if task_config: + self.memory_builder_model = LLMRequest( + model_set=task_config, + request_type="memory.hippocampus_build" + ) + asyncio.create_task(self.start_background_sampling()) + logger.info("✅ 海马体采样器初始化成功") + else: + raise RuntimeError("未找到记忆构建模型配置") + + except Exception as e: + logger.error(f"❌ 海马体采样器初始化失败: {e}") + raise + + def generate_time_samples(self) -> List[datetime]: + """生成双峰分布的时间采样点""" + # 计算每个分布的样本数 + recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight)) + distant_samples = max(1, self.config.total_samples - recent_samples) + + # 生成两个正态分布的小时偏移 + recent_offsets = np.random.normal( + loc=self.config.recent_mean_hours, + scale=self.config.recent_std_hours, + size=recent_samples + ) + distant_offsets = np.random.normal( + loc=self.config.distant_mean_hours, + scale=self.config.distant_std_hours, + size=distant_samples + ) + + # 合并两个分布的偏移 + all_offsets = np.concatenate([recent_offsets, distant_offsets]) + + # 转换为时间戳(使用绝对值确保时间点在过去) + base_time = datetime.now() + timestamps = [ + base_time - timedelta(hours=abs(offset)) + for offset in all_offsets + ] + + # 按时间排序(从最早到最近) + return sorted(timestamps) + + async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]: + """收集指定时间戳附近的消息样本""" + try: + # 随机时间窗口:5-30分钟 + time_window_seconds = random.randint(300, 1800) + + # 尝试3次获取消息 + for attempt in range(3): + timestamp_start = target_timestamp + timestamp_end = target_timestamp + time_window_seconds + + # 获取单条消息作为锚点 + anchor_messages = await get_raw_msg_by_timestamp( + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + limit=1, + limit_mode="earliest", + ) + + if not anchor_messages: + target_timestamp -= 120 # 向前调整2分钟 + continue + + anchor_message = anchor_messages[0] + chat_id = anchor_message.get("chat_id") + + if not chat_id: + continue + + # 获取同聊天的多条消息 + messages = await get_raw_msg_by_timestamp_with_chat( + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + limit=self.config.max_sample_length, + limit_mode="earliest", + chat_id=chat_id, + ) + + if messages and len(messages) >= 2: # 至少需要2条消息 + # 过滤掉已经记忆过的消息 + filtered_messages = [ + msg for msg in messages + if msg.get("memorized_times", 0) < 2 # 最多记忆2次 + ] + + if filtered_messages: + logger.debug(f"成功收集 {len(filtered_messages)} 条消息样本") + return filtered_messages + + target_timestamp -= 120 # 向前调整再试 + + logger.debug(f"时间戳 {target_timestamp} 附近未找到有效消息样本") + return None + + except Exception as e: + logger.error(f"收集消息样本失败: {e}") + return None + + async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]: + """从消息样本构建记忆""" + if not messages or not self.memory_system or not self.memory_builder_model: + return None + + try: + # 构建可读消息文本 + readable_text = await build_readable_messages( + messages, + merge_messages=True, + timestamp_mode="normal_no_YMD", + replace_bot_name=False, + ) + + if not readable_text: + logger.warning("无法从消息样本生成可读文本") + return None + + # 添加当前日期信息 + current_date = f"当前日期: {datetime.now().isoformat()}" + input_text = f"{current_date}\n{readable_text}" + + logger.debug(f"开始构建记忆,文本长度: {len(input_text)}") + + # 构建上下文 + context = { + "user_id": "hippocampus_sampler", + "timestamp": time.time(), + "source": "hippocampus_sampling", + "message_count": len(messages), + "sample_mode": "bimodal_distribution", + "is_hippocampus_sample": True, # 标识为海马体样本 + "bypass_value_threshold": True, # 绕过价值阈值检查 + "hippocampus_sample_time": target_timestamp, # 记录样本时间 + } + + # 使用记忆系统构建记忆(绕过构建间隔检查) + memories = await self.memory_system.build_memory_from_conversation( + conversation_text=input_text, + context=context, + timestamp=time.time(), + bypass_interval=True # 海马体采样器绕过构建间隔限制 + ) + + if memories: + memory_count = len(memories) + self.success_count += 1 + + # 记录采样结果 + result = { + "timestamp": time.time(), + "memory_count": memory_count, + "message_count": len(messages), + "text_preview": readable_text[:100] + "..." if len(readable_text) > 100 else readable_text, + "memory_types": [m.memory_type.value for m in memories], + } + self.last_sample_results.append(result) + + # 限制结果历史长度 + if len(self.last_sample_results) > 10: + self.last_sample_results.pop(0) + + logger.info(f"✅ 海马体采样成功构建 {memory_count} 条记忆") + return f"构建{memory_count}条记忆" + else: + logger.debug("海马体采样未生成有效记忆") + return None + + except Exception as e: + logger.error(f"海马体采样构建记忆失败: {e}") + return None + + async def perform_sampling_cycle(self) -> Dict[str, Any]: + """执行一次完整的采样周期(优化版:批量融合构建)""" + if not self.should_sample(): + return {"status": "skipped", "reason": "interval_not_met"} + + start_time = time.time() + self.sample_count += 1 + + try: + # 生成时间采样点 + time_samples = self.generate_time_samples() + logger.debug(f"生成 {len(time_samples)} 个时间采样点") + + # 记录时间采样点(调试用) + readable_timestamps = [ + translate_timestamp_to_human_readable(int(ts.timestamp()), mode="normal") + for ts in time_samples[:5] # 只显示前5个 + ] + logger.debug(f"时间采样点示例: {readable_timestamps}") + + # 第一步:批量收集所有消息样本 + logger.debug("开始批量收集消息样本...") + collected_messages = await self._collect_all_message_samples(time_samples) + + if not collected_messages: + logger.info("未收集到有效消息样本,跳过本次采样") + self.last_sample_time = time.time() + return { + "status": "success", + "sample_count": self.sample_count, + "success_count": self.success_count, + "processed_samples": len(time_samples), + "successful_builds": 0, + "duration": time.time() - start_time, + "samples_generated": len(time_samples), + "message": "未收集到有效消息样本", + } + + logger.info(f"收集到 {len(collected_messages)} 组消息样本") + + # 第二步:融合和去重消息 + logger.debug("开始融合和去重消息...") + fused_messages = await self._fuse_and_deduplicate_messages(collected_messages) + + if not fused_messages: + logger.info("消息融合后为空,跳过记忆构建") + self.last_sample_time = time.time() + return { + "status": "success", + "sample_count": self.sample_count, + "success_count": self.success_count, + "processed_samples": len(time_samples), + "successful_builds": 0, + "duration": time.time() - start_time, + "samples_generated": len(time_samples), + "message": "消息融合后为空", + } + + logger.info(f"融合后得到 {len(fused_messages)} 组有效消息") + + # 第三步:一次性构建记忆 + logger.debug("开始批量构建记忆...") + build_result = await self._build_batch_memory(fused_messages, time_samples) + + # 更新最后采样时间 + self.last_sample_time = time.time() + + duration = time.time() - start_time + result = { + "status": "success", + "sample_count": self.sample_count, + "success_count": self.success_count, + "processed_samples": len(time_samples), + "successful_builds": build_result.get("memory_count", 0), + "duration": duration, + "samples_generated": len(time_samples), + "messages_collected": len(collected_messages), + "messages_fused": len(fused_messages), + "optimization_mode": "batch_fusion", + } + + logger.info( + f"✅ 海马体采样周期完成(批量融合模式) | " + f"采样点: {len(time_samples)} | " + f"收集消息: {len(collected_messages)} | " + f"融合消息: {len(fused_messages)} | " + f"构建记忆: {build_result.get('memory_count', 0)} | " + f"耗时: {duration:.2f}s" + ) + + return result + + except Exception as e: + logger.error(f"❌ 海马体采样周期失败: {e}") + return { + "status": "error", + "error": str(e), + "sample_count": self.sample_count, + "duration": time.time() - start_time, + } + + async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]: + """批量收集所有时间点的消息样本""" + collected_messages = [] + max_concurrent = min(5, len(time_samples)) # 提高并发数到5 + + for i in range(0, len(time_samples), max_concurrent): + batch = time_samples[i:i + max_concurrent] + tasks = [] + + # 创建并发收集任务 + for timestamp in batch: + target_ts = timestamp.timestamp() + task = self.collect_message_samples(target_ts) + tasks.append(task) + + # 执行并发收集 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理收集结果 + for result in results: + if isinstance(result, list) and result: + collected_messages.append(result) + elif isinstance(result, Exception): + logger.debug(f"消息收集异常: {result}") + + # 批次间短暂延迟 + if i + max_concurrent < len(time_samples): + await asyncio.sleep(0.5) + + return collected_messages + + async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]: + """融合和去重消息样本""" + if not collected_messages: + return [] + + try: + # 展平所有消息 + all_messages = [] + for message_group in collected_messages: + all_messages.extend(message_group) + + logger.debug(f"展开后总消息数: {len(all_messages)}") + + # 去重逻辑:基于消息内容和时间戳 + unique_messages = [] + seen_hashes = set() + + for message in all_messages: + # 创建消息哈希用于去重 + content = message.get("processed_plain_text", "") or message.get("display_message", "") + timestamp = message.get("time", 0) + chat_id = message.get("chat_id", "") + + # 简单哈希:内容前50字符 + 时间戳(精确到分钟) + 聊天ID + hash_key = f"{content[:50]}_{int(timestamp//60)}_{chat_id}" + + if hash_key not in seen_hashes and len(content.strip()) > 10: + seen_hashes.add(hash_key) + unique_messages.append(message) + + logger.debug(f"去重后消息数: {len(unique_messages)}") + + # 按时间排序 + unique_messages.sort(key=lambda x: x.get("time", 0)) + + # 按聊天ID分组重新组织 + chat_groups = {} + for message in unique_messages: + chat_id = message.get("chat_id", "unknown") + if chat_id not in chat_groups: + chat_groups[chat_id] = [] + chat_groups[chat_id].append(message) + + # 合并相邻时间范围内的消息 + fused_groups = [] + for chat_id, messages in chat_groups.items(): + fused_groups.extend(self._merge_adjacent_messages(messages)) + + logger.debug(f"融合后消息组数: {len(fused_groups)}") + return fused_groups + + except Exception as e: + logger.error(f"消息融合失败: {e}") + # 返回原始消息组作为备选 + return collected_messages[:5] # 限制返回数量 + + def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]: + """合并时间间隔内的消息""" + if not messages: + return [] + + merged_groups = [] + current_group = [messages[0]] + + for i in range(1, len(messages)): + current_time = messages[i].get("time", 0) + prev_time = current_group[-1].get("time", 0) + + # 如果时间间隔小于阈值,合并到当前组 + if current_time - prev_time <= time_gap: + current_group.append(messages[i]) + else: + # 否则开始新组 + merged_groups.append(current_group) + current_group = [messages[i]] + + # 添加最后一组 + merged_groups.append(current_group) + + # 过滤掉只有一条消息的组(除非内容较长) + result_groups = [] + for group in merged_groups: + if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group): + result_groups.append(group) + + return result_groups + + async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]: + """批量构建记忆""" + if not fused_messages: + return {"memory_count": 0, "memories": []} + + try: + total_memories = [] + total_memory_count = 0 + + # 构建融合后的文本 + batch_input_text = await self._build_fused_conversation_text(fused_messages) + + if not batch_input_text: + logger.warning("无法构建融合文本,尝试单独构建") + # 备选方案:分别构建 + return await self._fallback_individual_build(fused_messages) + + # 创建批量上下文 + batch_context = { + "user_id": "hippocampus_batch_sampler", + "timestamp": time.time(), + "source": "hippocampus_batch_sampling", + "message_groups_count": len(fused_messages), + "total_messages": sum(len(group) for group in fused_messages), + "sample_count": len(time_samples), + "is_hippocampus_sample": True, + "bypass_value_threshold": True, + "optimization_mode": "batch_fusion", + } + + logger.debug(f"批量构建记忆,文本长度: {len(batch_input_text)}") + + # 一次性构建记忆 + memories = await self.memory_system.build_memory_from_conversation( + conversation_text=batch_input_text, + context=batch_context, + timestamp=time.time(), + bypass_interval=True + ) + + if memories: + memory_count = len(memories) + self.success_count += 1 + total_memory_count += memory_count + total_memories.extend(memories) + + logger.info(f"✅ 批量海马体采样成功构建 {memory_count} 条记忆") + else: + logger.debug("批量海马体采样未生成有效记忆") + + # 记录采样结果 + result = { + "timestamp": time.time(), + "memory_count": total_memory_count, + "message_groups_count": len(fused_messages), + "total_messages": sum(len(group) for group in fused_messages), + "text_preview": batch_input_text[:200] + "..." if len(batch_input_text) > 200 else batch_input_text, + "memory_types": [m.memory_type.value for m in total_memories], + } + + self.last_sample_results.append(result) + + # 限制结果历史长度 + if len(self.last_sample_results) > 10: + self.last_sample_results.pop(0) + + return { + "memory_count": total_memory_count, + "memories": total_memories, + "result": result + } + + except Exception as e: + logger.error(f"批量构建记忆失败: {e}") + return {"memory_count": 0, "error": str(e)} + + async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str: + """构建融合后的对话文本""" + try: + # 添加批次标识 + current_date = f"海马体批量采样 - {datetime.now().isoformat()}\n" + conversation_parts = [current_date] + + for group_idx, message_group in enumerate(fused_messages): + if not message_group: + continue + + # 为每个消息组添加分隔符 + group_header = f"\n=== 对话片段 {group_idx + 1} ===" + conversation_parts.append(group_header) + + # 构建可读消息 + group_text = await build_readable_messages( + message_group, + merge_messages=True, + timestamp_mode="normal_no_YMD", + replace_bot_name=False, + ) + + if group_text and len(group_text.strip()) > 10: + conversation_parts.append(group_text.strip()) + + return "\n".join(conversation_parts) + + except Exception as e: + logger.error(f"构建融合文本失败: {e}") + return "" + + async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + """备选方案:单独构建每个消息组""" + total_memories = [] + total_count = 0 + + for group in fused_messages[:5]: # 限制最多5组 + try: + memories = await self.build_memory_from_samples(group, time.time()) + if memories: + total_memories.extend(memories) + total_count += len(memories) + except Exception as e: + logger.debug(f"单独构建失败: {e}") + + return { + "memory_count": total_count, + "memories": total_memories, + "fallback_mode": True + } + + async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]: + """处理单个时间戳采样(保留作为备选方法)""" + try: + # 收集消息样本 + messages = await self.collect_message_samples(target_timestamp) + if not messages: + return None + + # 构建记忆 + result = await self.build_memory_from_samples(messages, target_timestamp) + return result + + except Exception as e: + logger.debug(f"处理时间戳采样失败 {target_timestamp}: {e}") + return None + + def should_sample(self) -> bool: + """检查是否应该进行采样""" + current_time = time.time() + + # 检查时间间隔 + if current_time - self.last_sample_time < self.config.sample_interval: + return False + + # 检查是否已初始化 + if not self.memory_builder_model: + logger.warning("海马体采样器未初始化") + return False + + return True + + async def start_background_sampling(self): + """启动后台采样""" + if self.is_running: + logger.warning("海马体后台采样已在运行") + return + + self.is_running = True + logger.info("🚀 启动海马体后台采样任务") + + try: + while self.is_running: + try: + # 执行采样周期 + result = await self.perform_sampling_cycle() + + # 如果是跳过状态,短暂睡眠 + if result.get("status") == "skipped": + await asyncio.sleep(60) # 1分钟后重试 + else: + # 正常等待下一个采样间隔 + await asyncio.sleep(self.config.sample_interval) + + except Exception as e: + logger.error(f"海马体后台采样异常: {e}") + await asyncio.sleep(300) # 异常时等待5分钟 + + except asyncio.CancelledError: + logger.info("海马体后台采样任务被取消") + finally: + self.is_running = False + + def stop_background_sampling(self): + """停止后台采样""" + self.is_running = False + logger.info("🛑 停止海马体后台采样任务") + + def get_sampling_stats(self) -> Dict[str, Any]: + """获取采样统计信息""" + success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0 + + # 计算最近的平均数据 + recent_avg_messages = 0 + recent_avg_memory_count = 0 + if self.last_sample_results: + recent_results = self.last_sample_results[-5:] # 最近5次 + recent_avg_messages = sum(r.get("total_messages", 0) for r in recent_results) / len(recent_results) + recent_avg_memory_count = sum(r.get("memory_count", 0) for r in recent_results) / len(recent_results) + + return { + "is_running": self.is_running, + "sample_count": self.sample_count, + "success_count": self.success_count, + "success_rate": f"{success_rate:.1f}%", + "last_sample_time": self.last_sample_time, + "optimization_mode": "batch_fusion", # 显示优化模式 + "performance_metrics": { + "avg_messages_per_sample": f"{recent_avg_messages:.1f}", + "avg_memories_per_sample": f"{recent_avg_memory_count:.1f}", + "fusion_efficiency": f"{(recent_avg_messages/max(recent_avg_memory_count, 1)):.1f}x" if recent_avg_messages > 0 else "N/A" + }, + "config": { + "sample_interval": self.config.sample_interval, + "total_samples": self.config.total_samples, + "recent_weight": f"{self.config.recent_weight:.1%}", + "distant_weight": f"{self.config.distant_weight:.1%}", + "max_concurrent": 5, # 批量模式并发数 + "fusion_time_gap": "30分钟", # 消息融合时间间隔 + }, + "recent_results": self.last_sample_results[-5:], # 最近5次结果 + } + + +# 全局海马体采样器实例 +_hippocampus_sampler: Optional[HippocampusSampler] = None + + +def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler: + """获取全局海马体采样器实例""" + global _hippocampus_sampler + if _hippocampus_sampler is None: + _hippocampus_sampler = HippocampusSampler(memory_system) + return _hippocampus_sampler + + +async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampler: + """初始化全局海马体采样器""" + sampler = get_hippocampus_sampler(memory_system) + await sampler.initialize() + return sampler \ No newline at end of file diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 5236da62a..fc802c5d2 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -19,6 +19,12 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_fusion import MemoryFusionEngine from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner +# 简化的记忆采样模式枚举 +class MemorySamplingMode(Enum): + """记忆采样模式""" + HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样 + IMMEDIATE = "immediate" # 即时模式:回复后立即采样 + ALL = "all" # 所有模式:同时使用海马体和即时采样 from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest @@ -148,6 +154,9 @@ class MemorySystem: # 记忆指纹缓存,用于快速检测重复记忆 self._memory_fingerprints: dict[str, str] = {} + # 海马体采样器 + self.hippocampus_sampler = None + logger.info("MemorySystem 初始化开始") async def initialize(self): @@ -249,6 +258,16 @@ class MemorySystem: self.query_planner = MemoryQueryPlanner(planner_model, default_limit=self.config.final_recall_limit) + # 初始化海马体采样器 + if global_config.memory.enable_hippocampus_sampling: + try: + from .hippocampus_sampler import initialize_hippocampus_sampler + self.hippocampus_sampler = await initialize_hippocampus_sampler(self) + logger.info("✅ 海马体采样器初始化成功") + except Exception as e: + logger.warning(f"海马体采样器初始化失败: {e}") + self.hippocampus_sampler = None + # 统一存储已经自动加载数据,无需额外加载 logger.info("✅ 简化版记忆系统初始化完成") @@ -283,14 +302,14 @@ class MemorySystem: try: # 使用统一存储检索相似记忆 + filters = {"user_id": user_id} if user_id else None search_results = await self.unified_storage.search_similar_memories( - query_text=query_text, limit=limit, scope_id=user_id + query_text=query_text, limit=limit, filters=filters ) # 转换为记忆对象 memories = [] - for memory_id, similarity_score in search_results: - memory = self.unified_storage.get_memory_by_id(memory_id) + for memory, similarity_score in search_results: if memory: memory.update_access() # 更新访问信息 memories.append(memory) @@ -302,7 +321,7 @@ class MemorySystem: return [] async def build_memory_from_conversation( - self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None + self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None, bypass_interval: bool = False ) -> list[MemoryChunk]: """从对话中构建记忆 @@ -310,6 +329,7 @@ class MemorySystem: conversation_text: 对话文本 context: 上下文信息 timestamp: 时间戳,默认为当前时间 + bypass_interval: 是否绕过构建间隔检查(海马体采样器专用) Returns: 构建的记忆块列表 @@ -328,7 +348,8 @@ class MemorySystem: min_interval = max(0.0, getattr(self.config, "min_build_interval_seconds", 0.0)) current_time = time.time() - if build_scope_key and min_interval > 0: + # 构建间隔检查(海马体采样器可以绕过) + if build_scope_key and min_interval > 0 and not bypass_interval: last_time = self._last_memory_build_times.get(build_scope_key) if last_time and (current_time - last_time) < min_interval: remaining = min_interval - (current_time - last_time) @@ -340,18 +361,35 @@ class MemorySystem: build_marker_time = current_time self._last_memory_build_times[build_scope_key] = current_time + elif bypass_interval: + # 海马体采样模式:不更新构建时间记录,避免影响即时模式 + logger.debug("海马体采样模式:绕过构建间隔检查") conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context) logger.debug("开始构建记忆,文本长度: %d", len(conversation_text)) - # 1. 信息价值评估 - value_score = await self._assess_information_value(conversation_text, normalized_context) + # 1. 信息价值评估(海马体采样器可以绕过) + if not bypass_interval and not context.get("bypass_value_threshold", False): + value_score = await self._assess_information_value(conversation_text, normalized_context) - if value_score < self.config.memory_value_threshold: - logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建") - self.status = original_status - return [] + if value_score < self.config.memory_value_threshold: + logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建") + self.status = original_status + return [] + else: + # 海马体采样器:使用默认价值分数或简单评估 + value_score = 0.6 # 默认中等价值 + if context.get("is_hippocampus_sample", False): + # 对海马体样本进行简单价值评估 + if len(conversation_text) > 100: # 长文本可能有更多信息 + value_score = 0.7 + elif len(conversation_text) > 50: + value_score = 0.6 + else: + value_score = 0.5 + + logger.debug(f"海马体采样模式:使用价值评分 {value_score:.2f}") # 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享) memory_chunks = await self.memory_builder.build_memories( @@ -469,7 +507,7 @@ class MemorySystem: continue search_tasks.append( self.unified_storage.search_similar_memories( - query_text=display_text, limit=8, scope_id=GLOBAL_MEMORY_SCOPE + query_text=display_text, limit=8, filters={"user_id": GLOBAL_MEMORY_SCOPE} ) ) @@ -512,12 +550,70 @@ class MemorySystem: return existing_candidates async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]: - """对外暴露的对话记忆处理接口,仅依赖上下文信息""" + """对外暴露的对话记忆处理接口,支持海马体、即时、所有三种采样模式""" start_time = time.time() try: context = dict(context or {}) + # 获取配置的采样模式 + sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate') + current_mode = MemorySamplingMode(sampling_mode) + + logger.debug(f"使用记忆采样模式: {current_mode.value}") + + # 根据采样模式处理记忆 + if current_mode == MemorySamplingMode.HIPPOCAMPUS: + # 海马体模式:仅后台定时采样,不立即处理 + return { + "success": True, + "created_memories": [], + "memory_count": 0, + "processing_time": time.time() - start_time, + "status": self.status.value, + "processing_mode": "hippocampus", + "message": "海马体模式:记忆将由后台定时任务采样处理", + } + + elif current_mode == MemorySamplingMode.IMMEDIATE: + # 即时模式:立即处理记忆构建 + return await self._process_immediate_memory(context, start_time) + + elif current_mode == MemorySamplingMode.ALL: + # 所有模式:同时进行即时处理和海马体采样 + immediate_result = await self._process_immediate_memory(context, start_time) + + # 海马体采样器会在后台继续处理,这里只是记录 + if self.hippocampus_sampler: + immediate_result["processing_mode"] = "all_modes" + immediate_result["hippocampus_status"] = "background_sampling_enabled" + immediate_result["message"] = "所有模式:即时处理已完成,海马体采样将在后台继续" + else: + immediate_result["processing_mode"] = "immediate_fallback" + immediate_result["hippocampus_status"] = "not_available" + immediate_result["message"] = "海马体采样器不可用,回退到即时模式" + + return immediate_result + + else: + # 默认回退到即时模式 + logger.warning(f"未知的采样模式 {sampling_mode},回退到即时模式") + return await self._process_immediate_memory(context, start_time) + + except Exception as e: + processing_time = time.time() - start_time + logger.error(f"对话记忆处理失败: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "processing_time": processing_time, + "status": self.status.value, + "processing_mode": "error", + } + + async def _process_immediate_memory(self, context: dict[str, Any], start_time: float) -> dict[str, Any]: + """即时记忆处理的辅助方法""" + try: conversation_candidate = ( context.get("conversation_text") or context.get("message_content") @@ -537,6 +633,23 @@ class MemorySystem: normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp) normalized_context.setdefault("conversation_text", conversation_text) + # 检查信息价值阈值 + value_score = await self._assess_information_value(conversation_text, normalized_context) + threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5) + + if value_score < threshold: + logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建") + return { + "success": True, + "created_memories": [], + "memory_count": 0, + "processing_time": time.time() - start_time, + "status": self.status.value, + "processing_mode": "immediate", + "skip_reason": f"value_score_{value_score:.2f}_below_threshold_{threshold}", + "value_score": value_score, + } + memories = await self.build_memory_from_conversation( conversation_text=conversation_text, context=normalized_context, timestamp=timestamp ) @@ -550,12 +663,20 @@ class MemorySystem: "memory_count": memory_count, "processing_time": processing_time, "status": self.status.value, + "processing_mode": "immediate", + "value_score": value_score, } except Exception as e: processing_time = time.time() - start_time - logger.error(f"对话记忆处理失败: {e}", exc_info=True) - return {"success": False, "error": str(e), "processing_time": processing_time, "status": self.status.value} + logger.error(f"即时记忆处理失败: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "processing_time": processing_time, + "status": self.status.value, + "processing_mode": "immediate_error", + } async def retrieve_relevant_memories( self, @@ -1372,11 +1493,53 @@ class MemorySystem: except Exception as e: logger.error(f"❌ 记忆系统维护失败: {e}", exc_info=True) + def start_hippocampus_sampling(self): + """启动海马体采样""" + if self.hippocampus_sampler: + asyncio.create_task(self.hippocampus_sampler.start_background_sampling()) + logger.info("🚀 海马体后台采样已启动") + else: + logger.warning("海马体采样器未初始化,无法启动采样") + + def stop_hippocampus_sampling(self): + """停止海马体采样""" + if self.hippocampus_sampler: + self.hippocampus_sampler.stop_background_sampling() + logger.info("🛑 海马体后台采样已停止") + + def get_system_stats(self) -> dict[str, Any]: + """获取系统统计信息""" + base_stats = { + "status": self.status.value, + "total_memories": self.total_memories, + "last_build_time": self.last_build_time, + "last_retrieval_time": self.last_retrieval_time, + "config": asdict(self.config), + } + + # 添加海马体采样器统计 + if self.hippocampus_sampler: + base_stats["hippocampus_sampler"] = self.hippocampus_sampler.get_sampling_stats() + + # 添加存储统计 + if self.unified_storage: + try: + storage_stats = self.unified_storage.get_storage_stats() + base_stats["storage_stats"] = storage_stats + except Exception as e: + logger.debug(f"获取存储统计失败: {e}") + + return base_stats + async def shutdown(self): """关闭系统(简化版)""" try: logger.info("正在关闭简化记忆系统...") + # 停止海马体采样 + if self.hippocampus_sampler: + self.hippocampus_sampler.stop_background_sampling() + # 保存统一存储数据 if self.unified_storage: await self.unified_storage.cleanup() @@ -1456,4 +1619,10 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None): if memory_system is None: memory_system = MemorySystem(llm_model=llm_model) await memory_system.initialize() + + # 根据配置启动海马体采样 + sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate') + if sampling_mode in ['hippocampus', 'all']: + memory_system.start_hippocampus_sampling() + return memory_system diff --git a/src/config/official_configs.py b/src/config/official_configs.py index ecdb5d5b5..07fa87091 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -337,6 +337,41 @@ class MemoryConfig(ValidatedConfigBase): # 休眠机制 dormant_threshold_days: int = Field(default=90, description="休眠状态判定天数") + # === 混合记忆系统配置 === + # 采样模式配置 + memory_sampling_mode: Literal["adaptive", "hippocampus", "precision"] = Field( + default="adaptive", description="记忆采样模式:adaptive(自适应),hippocampus(海马体双峰采样),precision(精准记忆)" + ) + + # 海马体双峰采样配置 + enable_hippocampus_sampling: bool = Field(default=True, description="启用海马体双峰采样策略") + hippocampus_sample_interval: int = Field(default=1800, description="海马体采样间隔(秒,默认30分钟)") + hippocampus_sample_size: int = Field(default=30, description="海马体每次采样的消息数量") + hippocampus_batch_size: int = Field(default=5, description="海马体每批处理的记忆数量") + + # 双峰分布配置 [近期均值, 近期标准差, 近期权重, 远期均值, 远期标准差, 远期权重] + hippocampus_distribution_config: list[float] = Field( + default=[12.0, 8.0, 0.7, 48.0, 24.0, 0.3], + description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]" + ) + + # 自适应采样配置 + adaptive_sampling_enabled: bool = Field(default=True, description="启用自适应采样策略") + adaptive_sampling_threshold: float = Field(default=0.8, description="自适应采样负载阈值(0-1)") + adaptive_sampling_check_interval: int = Field(default=300, description="自适应采样检查间隔(秒)") + adaptive_sampling_max_concurrent_builds: int = Field(default=3, description="自适应采样最大并发记忆构建数") + + # 精准记忆配置(现有系统的增强版本) + precision_memory_reply_threshold: float = Field( + default=0.6, description="精准记忆回复触发阈值(对话价值评分超过此值时触发记忆构建)" + ) + precision_memory_max_builds_per_hour: int = Field(default=10, description="精准记忆每小时最大构建数量") + + # 混合系统优化配置 + memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡") + memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流") + memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列") + class MoodConfig(ValidatedConfigBase): """情绪配置类""" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index e0097e1ad..41a95b6e7 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.1.5" +version = "7.1.6" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -208,6 +208,19 @@ max_context_emojis = 30 # 每次随机传递给LLM的表情包详细描述的最 enable_memory = true # 是否启用记忆系统 memory_build_interval = 600 # 记忆构建间隔(秒)。间隔越低,学习越频繁,但可能产生更多冗余信息 +# === 记忆采样系统配置 === +memory_sampling_mode = "immediate" # 记忆采样模式:hippocampus(海马体定时采样),immediate(即时采样),all(所有模式) + +# 海马体双峰采样配置 +enable_hippocampus_sampling = true # 启用海马体双峰采样策略 +hippocampus_sample_interval = 1800 # 海马体采样间隔(秒,默认30分钟) +hippocampus_sample_size = 30 # 海马体采样样本数量 +hippocampus_batch_size = 10 # 海马体批量处理大小 +hippocampus_distribution_config = [12.0, 8.0, 0.7, 48.0, 24.0, 0.3] # 海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重] + +# 即时采样配置 +precision_memory_reply_threshold = 0.5 # 精准记忆回复阈值(0-1),高于此值的对话将立即构建记忆 + min_memory_length = 10 # 最小记忆长度 max_memory_length = 500 # 最大记忆长度 memory_value_threshold = 0.5 # 记忆价值阈值,低于该值的记忆会被丢弃 From a8a42694f552c3042743318eea394ce3e2671400 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Fri, 3 Oct 2025 22:11:49 +0800 Subject: [PATCH 3/5] =?UTF-8?q?refactor(proactive=5Fthinker):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=86=B3=E7=AD=96=E6=8F=90=E7=A4=BA=E8=AF=8D=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=97=A5=E7=A8=8B=E4=B8=8E=E5=8E=86=E5=8F=B2?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为了让主动思考的决策更加贴近当前情境,对决策提示词(Prompt)进行了重构。 - **增强情境感知**:在提示词中增加了日程安排、最近聊天摘要和近期动作历史,为 AI 提供更全面的决策依据。 - **优化结构**:将所有上下文信息整合到“情境分析”部分,使提示词结构更清晰,便于模型理解。 - 修复了获取最近消息时参数传递的错误。 --- .../proactive_thinker/proactive_thinker_executor.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py index 96377c800..8af6bbc35 100644 --- a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py +++ b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py @@ -140,7 +140,7 @@ class ProactiveThinkerExecutor: else "今天没有日程安排。" ) - recent_messages = await message_api.get_recent_messages(stream_id, limit=10) + recent_messages = await message_api.get_recent_messages(stream.stream_id) recent_chat_history = ( await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无" ) @@ -386,10 +386,18 @@ class ProactiveThinkerExecutor: # 决策上下文 - **决策理由**: {reason} -- **你和Ta的关系**: + +# 情境分析 +1. **你的日程**: +{context["schedule_context"]} +2. **你和Ta的关系**: - 简短印象: {relationship["short_impression"]} - 详细印象: {relationship["impression"]} - 好感度: {relationship["attitude"]}/100 +3. **最近的聊天摘要**: +{context["recent_chat_history"]} +4. **你最近的相关动作**: +{context["action_history_context"]} # 对话指引 - 你的目标是“破冰”,让对话自然地开始。 From df5a4c717b386c89447d670fe5e8bef8c202a4f3 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Fri, 3 Oct 2025 22:14:58 +0800 Subject: [PATCH 4/5] =?UTF-8?q?refactor(proactive=5Fthinker):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=86=B3=E7=AD=96=E6=8F=90=E7=A4=BA=E8=AF=8D=EF=BC=8C?= =?UTF-8?q?=E9=81=BF=E5=85=8D=E5=9C=A8=E4=BB=85=E6=9C=89=E8=87=AA=E8=BA=AB?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E6=97=B6=E5=88=B7=E5=B1=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为了防止在没有其他人回复的情况下出现自言自语或刷屏的现象,为主动思考模块的决策提示词增加了一条新规则。 该规则指示模型在判断是否主动发言时,如果上下文中仅存在自身发送的消息,则倾向于不回复,以提升交互的自然性和用户体验。 --- .../built_in/proactive_thinker/proactive_thinker_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py index 78d2d6462..e651ecfd5 100644 --- a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py +++ b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py @@ -292,7 +292,8 @@ class ProactiveThinkerExecutor: - **简单问候作为备选**: 如果上下文中没有合适的话题,可以生成一个简单、真诚的日常问候(例如“在忙吗?”,“下午好呀~”)。 - **避免抽象**: 避免创造过于复杂、抽象或需要对方思考很久才能明白的话题。目标是轻松、自然地开启对话。 - **避免过于频繁**: 如果你最近(尤其是在最近的几次决策中)已经主动发起过对话,请倾向于选择“不回复”,除非有非常重要和紧急的事情。 -- **避免打扰**: 如果你最近(尤其是在最近的几次决策中)已经主动发起过对话,请倾向于选择“不回复”,除非有非常重要和紧急的事情。 +- **如果上下文中只有你的消息而没有别人的消息**:选择不回复,以防刷屏或者打扰到别人 + --- 示例1 (基于上下文): From 04e7776a4518944020a50e2900058b70463829a8 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 4 Oct 2025 01:38:41 +0800 Subject: [PATCH 5/5] =?UTF-8?q?refactor(memory):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E5=BA=9F=E5=BC=83=E7=9A=84=E8=AE=B0=E5=BF=86=E7=B3=BB=E7=BB=9F?= =?UTF-8?q?=E5=A4=87=E4=BB=BD=E6=96=87=E4=BB=B6=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E7=AE=A1=E7=90=86=E5=99=A8=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 移除了deprecated_backup目录下的所有废弃记忆系统文件,包括增强记忆适配器、钩子、集成层、重排序器、元数据索引、多阶段检索和向量存储等模块。同时优化了消息管理器,集成了批量数据库写入器、流缓存管理器和自适应流管理器,提升了系统性能和可维护性。 --- .../enhanced_memory_adapter.py | 363 ----- .../enhanced_memory_hooks.py | 194 --- .../enhanced_memory_integration.py | 177 -- .../deprecated_backup/enhanced_reranker.py | 356 ---- .../deprecated_backup/integration_layer.py | 245 --- .../memory_integration_hooks.py | 526 ------ .../deprecated_backup/metadata_index.py | 1027 ------------ .../multi_stage_retrieval.py | 1432 ----------------- .../deprecated_backup/vector_storage.py | 875 ---------- .../adaptive_stream_manager.py | 489 ++++++ .../message_manager/batch_database_writer.py | 348 ++++ .../message_manager/distribution_manager.py | 127 +- src/chat/message_manager/message_manager.py | 48 + .../message_manager/stream_cache_manager.py | 381 +++++ src/chat/message_receive/chat_stream.py | 96 +- .../message_receive/optimized_chat_stream.py | 494 ++++++ 16 files changed, 1975 insertions(+), 5203 deletions(-) delete mode 100644 src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py delete mode 100644 src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py delete mode 100644 src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py delete mode 100644 src/chat/memory_system/deprecated_backup/enhanced_reranker.py delete mode 100644 src/chat/memory_system/deprecated_backup/integration_layer.py delete mode 100644 src/chat/memory_system/deprecated_backup/memory_integration_hooks.py delete mode 100644 src/chat/memory_system/deprecated_backup/metadata_index.py delete mode 100644 src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py delete mode 100644 src/chat/memory_system/deprecated_backup/vector_storage.py create mode 100644 src/chat/message_manager/adaptive_stream_manager.py create mode 100644 src/chat/message_manager/batch_database_writer.py create mode 100644 src/chat/message_manager/stream_cache_manager.py create mode 100644 src/chat/message_receive/optimized_chat_stream.py diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py deleted file mode 100644 index 9f35d2d82..000000000 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_adapter.py +++ /dev/null @@ -1,363 +0,0 @@ -""" -增强记忆系统适配器 -将增强记忆系统集成到现有MoFox Bot架构中 -""" - -import time -from dataclasses import dataclass -from typing import Any - -from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest - -logger = get_logger(__name__) - - -MEMORY_TYPE_LABELS = { - MemoryType.PERSONAL_FACT: "个人事实", - MemoryType.EVENT: "事件", - MemoryType.PREFERENCE: "偏好", - MemoryType.OPINION: "观点", - MemoryType.RELATIONSHIP: "关系", - MemoryType.EMOTION: "情感", - MemoryType.KNOWLEDGE: "知识", - MemoryType.SKILL: "技能", - MemoryType.GOAL: "目标", - MemoryType.EXPERIENCE: "经验", - MemoryType.CONTEXTUAL: "上下文", -} - - -@dataclass -class AdapterConfig: - """适配器配置""" - - enable_enhanced_memory: bool = True - integration_mode: str = "enhanced_only" # replace, enhanced_only - auto_migration: bool = True - memory_value_threshold: float = 0.6 - fusion_threshold: float = 0.85 - max_retrieval_results: int = 10 - - -class EnhancedMemoryAdapter: - """增强记忆系统适配器""" - - def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None): - self.llm_model = llm_model - self.config = config or AdapterConfig() - self.integration_layer: MemoryIntegrationLayer | None = None - self._initialized = False - - # 统计信息 - self.adapter_stats = { - "total_processed": 0, - "enhanced_used": 0, - "legacy_used": 0, - "hybrid_used": 0, - "memories_created": 0, - "memories_retrieved": 0, - "average_processing_time": 0.0, - } - - async def initialize(self): - """初始化适配器""" - if self._initialized: - return - - try: - logger.info("🚀 初始化增强记忆系统适配器...") - - # 转换配置格式 - integration_config = IntegrationConfig( - mode=IntegrationMode(self.config.integration_mode), - enable_enhanced_memory=self.config.enable_enhanced_memory, - memory_value_threshold=self.config.memory_value_threshold, - fusion_threshold=self.config.fusion_threshold, - max_retrieval_results=self.config.max_retrieval_results, - enable_learning=True, # 启用学习功能 - ) - - # 创建集成层 - self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config) - - # 初始化集成层 - await self.integration_layer.initialize() - - self._initialized = True - logger.info("✅ 增强记忆系统适配器初始化完成") - - except Exception as e: - logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True) - # 如果初始化失败,禁用增强记忆功能 - self.config.enable_enhanced_memory = False - - async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]: - """处理对话记忆,以上下文为唯一输入""" - if not self._initialized or not self.config.enable_enhanced_memory: - return {"success": False, "error": "Enhanced memory not available"} - - start_time = time.time() - self.adapter_stats["total_processed"] += 1 - - try: - payload_context: dict[str, Any] = dict(context or {}) - - conversation_text = payload_context.get("conversation_text") - if not conversation_text: - conversation_candidate = ( - payload_context.get("message_content") - or payload_context.get("latest_message") - or payload_context.get("raw_text") - ) - if conversation_candidate is not None: - conversation_text = str(conversation_candidate) - payload_context["conversation_text"] = conversation_text - else: - conversation_text = "" - else: - conversation_text = str(conversation_text) - - if "timestamp" not in payload_context: - payload_context["timestamp"] = time.time() - - logger.debug("适配器收到记忆构建请求,文本长度=%d", len(conversation_text)) - - # 使用集成层处理对话 - result = await self.integration_layer.process_conversation(payload_context) - - # 更新统计 - processing_time = time.time() - start_time - self._update_processing_stats(processing_time) - - if result["success"]: - created_count = len(result.get("created_memories", [])) - self.adapter_stats["memories_created"] += created_count - logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆") - - return result - - except Exception as e: - logger.error(f"处理对话记忆失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - async def retrieve_relevant_memories( - self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None - ) -> list[MemoryChunk]: - """检索相关记忆""" - if not self._initialized or not self.config.enable_enhanced_memory: - return [] - - try: - limit = limit or self.config.max_retrieval_results - memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit) - - self.adapter_stats["memories_retrieved"] += len(memories) - logger.debug(f"检索到 {len(memories)} 条相关记忆") - - return memories - - except Exception as e: - logger.error(f"检索相关记忆失败: {e}", exc_info=True) - return [] - - async def get_memory_context_for_prompt( - self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5 - ) -> str: - """获取用于提示词的记忆上下文""" - memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories) - - if not memories: - return "" - - # 使用新的记忆格式化器 - formatter_config = FormatterConfig( - include_timestamps=True, - include_memory_types=True, - include_confidence=False, - use_emoji_icons=True, - group_by_type=False, - max_display_length=150, - ) - - return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config) - - async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]: - """获取增强记忆系统摘要""" - if not self._initialized or not self.config.enable_enhanced_memory: - return {"available": False, "reason": "Not initialized or disabled"} - - try: - # 获取系统状态 - status = await self.integration_layer.get_system_status() - - # 获取适配器统计 - adapter_stats = self.adapter_stats.copy() - - # 获取集成统计 - integration_stats = self.integration_layer.get_integration_stats() - - return { - "available": True, - "system_status": status, - "adapter_stats": adapter_stats, - "integration_stats": integration_stats, - "total_memories_created": adapter_stats["memories_created"], - "total_memories_retrieved": adapter_stats["memories_retrieved"], - } - - except Exception as e: - logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True) - return {"available": False, "error": str(e)} - - def _update_processing_stats(self, processing_time: float): - """更新处理统计""" - total_processed = self.adapter_stats["total_processed"] - if total_processed > 0: - current_avg = self.adapter_stats["average_processing_time"] - new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed - self.adapter_stats["average_processing_time"] = new_avg - - def get_adapter_stats(self) -> dict[str, Any]: - """获取适配器统计信息""" - return self.adapter_stats.copy() - - async def maintenance(self): - """维护操作""" - if not self._initialized: - return - - try: - logger.info("🔧 增强记忆系统适配器维护...") - await self.integration_layer.maintenance() - logger.info("✅ 增强记忆系统适配器维护完成") - except Exception as e: - logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True) - - async def shutdown(self): - """关闭适配器""" - if not self._initialized: - return - - try: - logger.info("🔄 关闭增强记忆系统适配器...") - await self.integration_layer.shutdown() - self._initialized = False - logger.info("✅ 增强记忆系统适配器已关闭") - except Exception as e: - logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True) - - -# 全局适配器实例 -_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None - - -async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter: - """获取全局增强记忆适配器实例""" - global _enhanced_memory_adapter - - if _enhanced_memory_adapter is None: - # 从配置中获取适配器配置 - from src.config.config import global_config - - adapter_config = AdapterConfig( - enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True), - integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"), - auto_migration=getattr(global_config.memory, "enable_memory_migration", True), - memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6), - fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85), - max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10), - ) - - _enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config) - await _enhanced_memory_adapter.initialize() - - return _enhanced_memory_adapter - - -async def initialize_enhanced_memory_system(llm_model: LLMRequest): - """初始化增强记忆系统""" - try: - logger.info("🚀 初始化增强记忆系统...") - adapter = await get_enhanced_memory_adapter(llm_model) - logger.info("✅ 增强记忆系统初始化完成") - return adapter - except Exception as e: - logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True) - return None - - -async def process_conversation_with_enhanced_memory( - context: dict[str, Any], llm_model: LLMRequest | None = None -) -> dict[str, Any]: - """使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息""" - if not llm_model: - # 获取默认的LLM模型 - from src.llm_models.utils_model import get_global_llm_model - - llm_model = get_global_llm_model() - - try: - adapter = await get_enhanced_memory_adapter(llm_model) - payload_context = dict(context or {}) - - if "conversation_text" not in payload_context: - conversation_candidate = ( - payload_context.get("message_content") - or payload_context.get("latest_message") - or payload_context.get("raw_text") - ) - if conversation_candidate is not None: - payload_context["conversation_text"] = str(conversation_candidate) - - return await adapter.process_conversation_memory(payload_context) - except Exception as e: - logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - -async def retrieve_memories_with_enhanced_system( - query: str, - user_id: str, - context: dict[str, Any] | None = None, - limit: int = 10, - llm_model: LLMRequest | None = None, -) -> list[MemoryChunk]: - """使用增强记忆系统检索记忆""" - if not llm_model: - # 获取默认的LLM模型 - from src.llm_models.utils_model import get_global_llm_model - - llm_model = get_global_llm_model() - - try: - adapter = await get_enhanced_memory_adapter(llm_model) - return await adapter.retrieve_relevant_memories(query, user_id, context, limit) - except Exception as e: - logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True) - return [] - - -async def get_memory_context_for_prompt( - query: str, - user_id: str, - context: dict[str, Any] | None = None, - max_memories: int = 5, - llm_model: LLMRequest | None = None, -) -> str: - """获取用于提示词的记忆上下文""" - if not llm_model: - # 获取默认的LLM模型 - from src.llm_models.utils_model import get_global_llm_model - - llm_model = get_global_llm_model() - - try: - adapter = await get_enhanced_memory_adapter(llm_model) - return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories) - except Exception as e: - logger.error(f"获取记忆上下文失败: {e}", exc_info=True) - return "" diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py deleted file mode 100644 index 1d6e65396..000000000 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_hooks.py +++ /dev/null @@ -1,194 +0,0 @@ -""" -增强记忆系统钩子 -用于在消息处理过程中自动构建和检索记忆 -""" - -from datetime import datetime -from typing import Any - -from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager -from src.common.logger import get_logger -from src.config.config import global_config - -logger = get_logger(__name__) - - -class EnhancedMemoryHooks: - """增强记忆系统钩子 - 自动处理消息的记忆构建和检索""" - - def __init__(self): - self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory - self.processed_messages = set() # 避免重复处理 - - async def process_message_for_memory( - self, - message_content: str, - user_id: str, - chat_id: str, - message_id: str, - context: dict[str, Any] | None = None, - ) -> bool: - """ - 处理消息并构建记忆 - - Args: - message_content: 消息内容 - user_id: 用户ID - chat_id: 聊天ID - message_id: 消息ID - context: 上下文信息 - - Returns: - bool: 是否成功处理 - """ - if not self.enabled: - return False - - if message_id in self.processed_messages: - return False - - try: - # 确保增强记忆管理器已初始化 - if not enhanced_memory_manager.is_initialized: - await enhanced_memory_manager.initialize() - - # 注入机器人基础人设,帮助记忆构建时避免记录自身信息 - bot_config = getattr(global_config, "bot", None) - personality_config = getattr(global_config, "personality", None) - bot_context = {} - if bot_config is not None: - bot_context["bot_name"] = getattr(bot_config, "nickname", None) - bot_context["bot_aliases"] = list(getattr(bot_config, "alias_names", []) or []) - bot_context["bot_account"] = getattr(bot_config, "qq_account", None) - - if personality_config is not None: - bot_context["bot_identity"] = getattr(personality_config, "identity", None) - bot_context["bot_personality"] = getattr(personality_config, "personality_core", None) - bot_context["bot_personality_side"] = getattr(personality_config, "personality_side", None) - - # 构建上下文 - memory_context = { - "chat_id": chat_id, - "message_id": message_id, - "timestamp": datetime.now().timestamp(), - "message_type": "user_message", - **bot_context, - **(context or {}), - } - - # 处理对话并构建记忆 - memory_chunks = await enhanced_memory_manager.process_conversation( - conversation_text=message_content, - context=memory_context, - user_id=user_id, - timestamp=memory_context["timestamp"], - ) - - # 标记消息已处理 - self.processed_messages.add(message_id) - - # 限制处理历史大小 - if len(self.processed_messages) > 1000: - # 移除最旧的500个记录 - self.processed_messages = set(list(self.processed_messages)[-500:]) - - logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆") - return len(memory_chunks) > 0 - - except Exception as e: - logger.error(f"处理消息记忆失败: {e}") - return False - - async def get_memory_for_response( - self, - query_text: str, - user_id: str, - chat_id: str, - limit: int = 5, - extra_context: dict[str, Any] | None = None, - ) -> list[dict[str, Any]]: - """ - 为回复获取相关记忆 - - Args: - query_text: 查询文本 - user_id: 用户ID - chat_id: 聊天ID - limit: 返回记忆数量限制 - - Returns: - List[Dict]: 相关记忆列表 - """ - if not self.enabled: - return [] - - try: - # 确保增强记忆管理器已初始化 - if not enhanced_memory_manager.is_initialized: - await enhanced_memory_manager.initialize() - - # 构建查询上下文 - context = { - "chat_id": chat_id, - "query_intent": "response_generation", - "expected_memory_types": ["personal_fact", "event", "preference", "opinion"], - } - - if extra_context: - context.update(extra_context) - - # 获取相关记忆 - enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context( - query_text=query_text, user_id=user_id, context=context, limit=limit - ) - - # 转换为字典格式 - results = [] - for result in enhanced_results: - memory_dict = { - "content": result.content, - "type": result.memory_type, - "confidence": result.confidence, - "importance": result.importance, - "timestamp": result.timestamp, - "source": result.source, - "relevance": result.relevance_score, - "structure": result.structure, - } - results.append(memory_dict) - - logger.debug(f"为回复查询到 {len(results)} 条相关记忆") - return results - - except Exception as e: - logger.error(f"获取回复记忆失败: {e}") - return [] - - async def cleanup_old_memories(self): - """清理旧记忆""" - try: - if enhanced_memory_manager.is_initialized: - # 调用增强记忆系统的维护功能 - await enhanced_memory_manager.enhanced_system.maintenance() - logger.debug("增强记忆系统维护完成") - except Exception as e: - logger.error(f"清理旧记忆失败: {e}") - - def clear_processed_cache(self): - """清除已处理消息的缓存""" - self.processed_messages.clear() - logger.debug("已清除消息处理缓存") - - def enable(self): - """启用记忆钩子""" - self.enabled = True - logger.info("增强记忆钩子已启用") - - def disable(self): - """禁用记忆钩子""" - self.enabled = False - logger.info("增强记忆钩子已禁用") - - -# 创建全局实例 -enhanced_memory_hooks = EnhancedMemoryHooks() diff --git a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py b/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py deleted file mode 100644 index 068326113..000000000 --- a/src/chat/memory_system/deprecated_backup/enhanced_memory_integration.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -增强记忆系统集成脚本 -用于在现有系统中无缝集成增强记忆功能 -""" - -from typing import Any - -from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks -from src.common.logger import get_logger - -logger = get_logger(__name__) - - -async def process_user_message_memory( - message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None -) -> bool: - """ - 处理用户消息并构建记忆 - - Args: - message_content: 消息内容 - user_id: 用户ID - chat_id: 聊天ID - message_id: 消息ID - context: 额外的上下文信息 - - Returns: - bool: 是否成功构建记忆 - """ - try: - success = await enhanced_memory_hooks.process_message_for_memory( - message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context - ) - - if success: - logger.debug(f"成功为消息 {message_id} 构建记忆") - - return success - - except Exception as e: - logger.error(f"处理用户消息记忆失败: {e}") - return False - - -async def get_relevant_memories_for_response( - query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None -) -> dict[str, Any]: - """ - 为回复获取相关记忆 - - Args: - query_text: 查询文本(通常是用户的当前消息) - user_id: 用户ID - chat_id: 聊天ID - limit: 返回记忆数量限制 - extra_context: 额外上下文信息 - - Returns: - Dict: 包含记忆信息的字典 - """ - try: - memories = await enhanced_memory_hooks.get_memory_for_response( - query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context - ) - - result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)} - - logger.debug(f"为回复获取到 {len(memories)} 条相关记忆") - return result - - except Exception as e: - logger.error(f"获取回复记忆失败: {e}") - return {"has_memories": False, "memories": [], "memory_count": 0} - - -def format_memories_for_prompt(memories: dict[str, Any]) -> str: - """ - 格式化记忆信息用于Prompt - - Args: - memories: 记忆信息字典 - - Returns: - str: 格式化后的记忆文本 - """ - if not memories["has_memories"]: - return "" - - memory_lines = ["以下是相关的记忆信息:"] - - for memory in memories["memories"]: - content = memory["content"] - memory_type = memory["type"] - confidence = memory["confidence"] - importance = memory["importance"] - - # 根据重要性添加不同的标记 - importance_marker = "🔥" if importance >= 3 else "⭐" if importance >= 2 else "📝" - confidence_marker = "✅" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭" - - memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)" - memory_lines.append(memory_line) - - return "\n".join(memory_lines) - - -async def cleanup_memory_system(): - """清理记忆系统""" - try: - await enhanced_memory_hooks.cleanup_old_memories() - logger.info("记忆系统清理完成") - except Exception as e: - logger.error(f"记忆系统清理失败: {e}") - - -def get_memory_system_status() -> dict[str, Any]: - """ - 获取记忆系统状态 - - Returns: - Dict: 系统状态信息 - """ - from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager - - return { - "enabled": enhanced_memory_hooks.enabled, - "enhanced_system_initialized": enhanced_memory_manager.is_initialized, - "processed_messages_count": len(enhanced_memory_hooks.processed_messages), - "system_type": "enhanced_memory_system", - } - - -# 便捷函数 -async def remember_message( - message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None -) -> bool: - """ - 便捷的记忆构建函数 - - Args: - message: 要记住的消息 - user_id: 用户ID - chat_id: 聊天ID - - Returns: - bool: 是否成功 - """ - import uuid - - message_id = str(uuid.uuid4()) - return await process_user_message_memory( - message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context - ) - - -async def recall_memories( - query: str, - user_id: str = "default_user", - chat_id: str = "default_chat", - limit: int = 5, - context: dict[str, Any] | None = None, -) -> dict[str, Any]: - """ - 便捷的记忆检索函数 - - Args: - query: 查询文本 - user_id: 用户ID - chat_id: 聊天ID - limit: 返回数量限制 - - Returns: - Dict: 记忆信息 - """ - return await get_relevant_memories_for_response( - query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context - ) diff --git a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py b/src/chat/memory_system/deprecated_backup/enhanced_reranker.py deleted file mode 100644 index c35b9de53..000000000 --- a/src/chat/memory_system/deprecated_backup/enhanced_reranker.py +++ /dev/null @@ -1,356 +0,0 @@ -""" -增强重排序器 -实现文档设计的多维度评分模型 -""" - -import math -import time -from dataclasses import dataclass -from enum import Enum -from typing import Any - -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.common.logger import get_logger - -logger = get_logger(__name__) - - -class IntentType(Enum): - """对话意图类型""" - - FACT_QUERY = "fact_query" # 事实查询 - EVENT_RECALL = "event_recall" # 事件回忆 - PREFERENCE_CHECK = "preference_check" # 偏好检查 - GENERAL_CHAT = "general_chat" # 一般对话 - UNKNOWN = "unknown" # 未知意图 - - -@dataclass -class ReRankingConfig: - """重排序配置""" - - # 权重配置 (w1 + w2 + w3 + w4 = 1.0) - semantic_weight: float = 0.5 # 语义相似度权重 - recency_weight: float = 0.2 # 时效性权重 - usage_freq_weight: float = 0.2 # 使用频率权重 - type_match_weight: float = 0.1 # 类型匹配权重 - - # 时效性衰减参数 - recency_decay_rate: float = 0.1 # 时效性衰减率 (天) - - # 使用频率计算参数 - freq_log_base: float = 2.0 # 对数底数 - freq_max_score: float = 5.0 # 最大频率得分 - - # 类型匹配权重映射 - type_match_weights: dict[str, dict[str, float]] = None - - def __post_init__(self): - """初始化类型匹配权重""" - if self.type_match_weights is None: - self.type_match_weights = { - IntentType.FACT_QUERY.value: { - MemoryType.PERSONAL_FACT.value: 1.0, - MemoryType.KNOWLEDGE.value: 0.8, - MemoryType.PREFERENCE.value: 0.5, - MemoryType.EVENT.value: 0.3, - "default": 0.3, - }, - IntentType.EVENT_RECALL.value: { - MemoryType.EVENT.value: 1.0, - MemoryType.EXPERIENCE.value: 0.8, - MemoryType.EMOTION.value: 0.6, - MemoryType.PERSONAL_FACT.value: 0.5, - "default": 0.5, - }, - IntentType.PREFERENCE_CHECK.value: { - MemoryType.PREFERENCE.value: 1.0, - MemoryType.OPINION.value: 0.8, - MemoryType.GOAL.value: 0.6, - MemoryType.PERSONAL_FACT.value: 0.4, - "default": 0.4, - }, - IntentType.GENERAL_CHAT.value: {"default": 0.8}, - IntentType.UNKNOWN.value: {"default": 0.8}, - } - - -class IntentClassifier: - """轻量级意图识别器""" - - def __init__(self): - # 关键词模式匹配规则 - self.patterns = { - IntentType.FACT_QUERY: [ - # 中文模式 - "我是", - "我的", - "我叫", - "我在", - "我住在", - "我的职业", - "我的工作", - "什么时候", - "在哪里", - "是什么", - "多少", - "几岁", - "年龄", - # 英文模式 - "what is", - "where is", - "when is", - "how old", - "my name", - "i am", - "i live", - ], - IntentType.EVENT_RECALL: [ - # 中文模式 - "记得", - "想起", - "还记得", - "那次", - "上次", - "之前", - "以前", - "曾经", - "发生过", - "经历", - "做过", - "去过", - "见过", - # 英文模式 - "remember", - "recall", - "last time", - "before", - "previously", - "happened", - "experience", - ], - IntentType.PREFERENCE_CHECK: [ - # 中文模式 - "喜欢", - "不喜欢", - "偏好", - "爱好", - "兴趣", - "讨厌", - "最爱", - "最喜欢", - "习惯", - "通常", - "一般", - "倾向于", - "更喜欢", - # 英文模式 - "like", - "love", - "hate", - "prefer", - "favorite", - "usually", - "tend to", - "interest", - ], - } - - def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType: - """识别对话意图""" - if not query: - return IntentType.UNKNOWN - - query_lower = query.lower() - - # 统计各意图的匹配分数 - intent_scores = dict.fromkeys(IntentType, 0) - - for intent, patterns in self.patterns.items(): - for pattern in patterns: - if pattern in query_lower: - intent_scores[intent] += 1 - - # 返回得分最高的意图 - max_score = max(intent_scores.values()) - if max_score == 0: - return IntentType.GENERAL_CHAT - - for intent, score in intent_scores.items(): - if score == max_score: - return intent - - return IntentType.GENERAL_CHAT - - -class EnhancedReRanker: - """增强重排序器 - 实现文档设计的多维度评分模型""" - - def __init__(self, config: ReRankingConfig | None = None): - self.config = config or ReRankingConfig() - self.intent_classifier = IntentClassifier() - - # 验证权重和为1.0 - total_weight = ( - self.config.semantic_weight - + self.config.recency_weight - + self.config.usage_freq_weight - + self.config.type_match_weight - ) - - if abs(total_weight - 1.0) > 0.01: - logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化") - # 归一化权重 - self.config.semantic_weight /= total_weight - self.config.recency_weight /= total_weight - self.config.usage_freq_weight /= total_weight - self.config.type_match_weight /= total_weight - - def rerank_memories( - self, - query: str, - candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity) - context: dict[str, Any], - limit: int = 10, - ) -> list[tuple[str, MemoryChunk, float]]: - """ - 对候选记忆进行重排序 - - Args: - query: 查询文本 - candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)] - context: 上下文信息 - limit: 返回数量限制 - - Returns: - 重排序后的记忆列表 [(memory_id, memory, final_score)] - """ - if not candidate_memories: - return [] - - # 识别查询意图 - intent = self.intent_classifier.classify_intent(query, context) - logger.debug(f"识别到查询意图: {intent.value}") - - # 计算每个候选记忆的最终得分 - scored_memories = [] - current_time = time.time() - - for memory_id, memory, vector_sim in candidate_memories: - try: - # 1. 语义相似度得分 (已归一化到[0,1]) - semantic_score = self._normalize_similarity(vector_sim) - - # 2. 时效性得分 - recency_score = self._calculate_recency_score(memory, current_time) - - # 3. 使用频率得分 - usage_freq_score = self._calculate_usage_frequency_score(memory) - - # 4. 类型匹配得分 - type_match_score = self._calculate_type_match_score(memory, intent) - - # 计算最终得分 - final_score = ( - self.config.semantic_weight * semantic_score - + self.config.recency_weight * recency_score - + self.config.usage_freq_weight * usage_freq_score - + self.config.type_match_weight * type_match_score - ) - - scored_memories.append((memory_id, memory, final_score)) - - # 记录调试信息 - logger.debug( - f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, " - f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, " - f"type={type_match_score:.3f}, final={final_score:.3f}" - ) - - except Exception as e: - logger.error(f"计算记忆 {memory_id} 得分时出错: {e}") - # 使用向量相似度作为后备得分 - scored_memories.append((memory_id, memory, vector_sim)) - - # 按最终得分降序排序 - scored_memories.sort(key=lambda x: x[2], reverse=True) - - # 返回前N个结果 - result = scored_memories[:limit] - - highest_score = result[0][2] if result else 0.0 - logger.info( - f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, " - f"意图={intent.value}, 最高分={highest_score:.3f}" - ) - - return result - - def _normalize_similarity(self, raw_similarity: float) -> float: - """归一化相似度到[0,1]区间""" - # 假设原始相似度已经在[-1,1]或[0,1]区间 - if raw_similarity < 0: - return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1] - return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间 - - def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float: - """ - 计算时效性得分 - 公式: Recency = 1 / (1 + decay_rate * days_old) - """ - last_accessed = memory.metadata.last_accessed or memory.metadata.created_at - days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数 - - if days_old < 0: - days_old = 0 # 处理时间异常 - - score = 1 / (1 + self.config.recency_decay_rate * days_old) - return min(1.0, max(0.0, score)) - - def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float: - """ - 计算使用频率得分 - 公式: Usage_Freq = min(1.0, log2(access_count + 1) / max_score) - """ - access_count = memory.metadata.access_count - if access_count <= 0: - return 0.0 - - log_count = math.log2(access_count + 1) - score = log_count / self.config.freq_max_score - return min(1.0, max(0.0, score)) - - def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float: - """计算类型匹配得分""" - memory_type = memory.memory_type.value - intent_value = intent.value - - # 获取对应意图的类型权重映射 - type_weights = self.config.type_match_weights.get(intent_value, {}) - - # 查找具体类型的权重,如果没有则使用默认权重 - score = type_weights.get(memory_type, type_weights.get("default", 0.8)) - - return min(1.0, max(0.0, score)) - - -# 创建默认的重排序器实例 -default_reranker = EnhancedReRanker() - - -def rerank_candidate_memories( - query: str, - candidate_memories: list[tuple[str, MemoryChunk, float]], - context: dict[str, Any], - limit: int = 10, - config: ReRankingConfig | None = None, -) -> list[tuple[str, MemoryChunk, float]]: - """ - 便捷函数:对候选记忆进行重排序 - """ - if config: - reranker = EnhancedReRanker(config) - else: - reranker = default_reranker - - return reranker.rerank_memories(query, candidate_memories, context, limit) diff --git a/src/chat/memory_system/deprecated_backup/integration_layer.py b/src/chat/memory_system/deprecated_backup/integration_layer.py deleted file mode 100644 index 220ec00f5..000000000 --- a/src/chat/memory_system/deprecated_backup/integration_layer.py +++ /dev/null @@ -1,245 +0,0 @@ -""" -增强记忆系统集成层 -现在只管理新的增强记忆系统,旧系统已被完全移除 -""" - -import asyncio -import time -from dataclasses import dataclass -from enum import Enum -from typing import Any - -from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem -from src.chat.memory_system.memory_chunk import MemoryChunk -from src.common.logger import get_logger -from src.llm_models.utils_model import LLMRequest - -logger = get_logger(__name__) - - -class IntegrationMode(Enum): - """集成模式""" - - REPLACE = "replace" # 完全替换现有记忆系统 - ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统 - - -@dataclass -class IntegrationConfig: - """集成配置""" - - mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY - enable_enhanced_memory: bool = True - memory_value_threshold: float = 0.6 - fusion_threshold: float = 0.85 - max_retrieval_results: int = 10 - enable_learning: bool = True - - -class MemoryIntegrationLayer: - """记忆系统集成层 - 现在只管理增强记忆系统""" - - def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None): - self.llm_model = llm_model - self.config = config or IntegrationConfig() - - # 只初始化增强记忆系统 - self.enhanced_memory: EnhancedMemorySystem | None = None - - # 集成统计 - self.integration_stats = { - "total_queries": 0, - "enhanced_queries": 0, - "memory_creations": 0, - "average_response_time": 0.0, - "success_rate": 0.0, - } - - # 初始化锁 - self._initialization_lock = asyncio.Lock() - self._initialized = False - - async def initialize(self): - """初始化集成层""" - if self._initialized: - return - - async with self._initialization_lock: - if self._initialized: - return - - logger.info("🚀 开始初始化增强记忆系统集成层...") - - try: - # 初始化增强记忆系统 - if self.config.enable_enhanced_memory: - await self._initialize_enhanced_memory() - - self._initialized = True - logger.info("✅ 增强记忆系统集成层初始化完成") - - except Exception as e: - logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True) - raise - - async def _initialize_enhanced_memory(self): - """初始化增强记忆系统""" - try: - logger.debug("初始化增强记忆系统...") - - # 创建增强记忆系统配置 - from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig - - memory_config = MemorySystemConfig.from_global_config() - - # 使用集成配置覆盖部分值 - memory_config.memory_value_threshold = self.config.memory_value_threshold - memory_config.fusion_similarity_threshold = self.config.fusion_threshold - memory_config.final_recall_limit = self.config.max_retrieval_results - - # 创建增强记忆系统 - self.enhanced_memory = EnhancedMemorySystem(config=memory_config) - - # 如果外部提供了LLM模型,注入到系统中 - if self.llm_model is not None: - self.enhanced_memory.llm_model = self.llm_model - - # 初始化系统 - await self.enhanced_memory.initialize() - logger.info("✅ 增强记忆系统初始化完成") - - except Exception as e: - logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True) - raise - - async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]: - """处理对话记忆,仅使用上下文信息""" - if not self._initialized or not self.enhanced_memory: - return {"success": False, "error": "Memory system not available"} - - start_time = time.time() - self.integration_stats["total_queries"] += 1 - self.integration_stats["enhanced_queries"] += 1 - - try: - payload_context = dict(context or {}) - conversation_text = payload_context.get("conversation_text") or payload_context.get("message_content") or "" - logger.debug("集成层收到记忆构建请求,文本长度=%d", len(conversation_text)) - - # 直接使用增强记忆系统处理 - result = await self.enhanced_memory.process_conversation_memory(payload_context) - - # 更新统计 - processing_time = time.time() - start_time - self._update_response_stats(processing_time, result.get("success", False)) - - if result.get("success"): - created_count = len(result.get("created_memories", [])) - self.integration_stats["memory_creations"] += created_count - logger.debug(f"对话处理完成,创建 {created_count} 条记忆") - - return result - - except Exception as e: - processing_time = time.time() - start_time - self._update_response_stats(processing_time, False) - logger.error(f"处理对话记忆失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} - - async def retrieve_relevant_memories( - self, - query: str, - user_id: str | None = None, - context: dict[str, Any] | None = None, - limit: int | None = None, - ) -> list[MemoryChunk]: - """检索相关记忆""" - if not self._initialized or not self.enhanced_memory: - return [] - - try: - limit = limit or self.config.max_retrieval_results - memories = await self.enhanced_memory.retrieve_relevant_memories( - query=query, user_id=None, context=context or {}, limit=limit - ) - - memory_count = len(memories) - logger.debug(f"检索到 {memory_count} 条相关记忆") - return memories - - except Exception as e: - logger.error(f"检索相关记忆失败: {e}", exc_info=True) - return [] - - async def get_system_status(self) -> dict[str, Any]: - """获取系统状态""" - if not self._initialized: - return {"status": "not_initialized"} - - try: - enhanced_status = {} - if self.enhanced_memory: - enhanced_status = await self.enhanced_memory.get_system_status() - - return { - "status": "initialized", - "mode": self.config.mode.value, - "enhanced_memory": enhanced_status, - "integration_stats": self.integration_stats.copy(), - } - - except Exception as e: - logger.error(f"获取系统状态失败: {e}", exc_info=True) - return {"status": "error", "error": str(e)} - - def get_integration_stats(self) -> dict[str, Any]: - """获取集成统计信息""" - return self.integration_stats.copy() - - def _update_response_stats(self, processing_time: float, success: bool): - """更新响应统计""" - total_queries = self.integration_stats["total_queries"] - if total_queries > 0: - # 更新平均响应时间 - current_avg = self.integration_stats["average_response_time"] - new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries - self.integration_stats["average_response_time"] = new_avg - - # 更新成功率 - if success: - current_success_rate = self.integration_stats["success_rate"] - new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries - self.integration_stats["success_rate"] = new_success_rate - - async def maintenance(self): - """执行维护操作""" - if not self._initialized: - return - - try: - logger.info("🔧 执行记忆系统集成层维护...") - - if self.enhanced_memory: - await self.enhanced_memory.maintenance() - - logger.info("✅ 记忆系统集成层维护完成") - - except Exception as e: - logger.error(f"❌ 集成层维护失败: {e}", exc_info=True) - - async def shutdown(self): - """关闭集成层""" - if not self._initialized: - return - - try: - logger.info("🔄 关闭记忆系统集成层...") - - if self.enhanced_memory: - await self.enhanced_memory.shutdown() - - self._initialized = False - logger.info("✅ 记忆系统集成层已关闭") - - except Exception as e: - logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True) diff --git a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py b/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py deleted file mode 100644 index 5dfd52c38..000000000 --- a/src/chat/memory_system/deprecated_backup/memory_integration_hooks.py +++ /dev/null @@ -1,526 +0,0 @@ -""" -记忆系统集成钩子 -提供与现有MoFox Bot系统的无缝集成点 -""" - -import time -from dataclasses import dataclass -from typing import Any - -from src.chat.memory_system.enhanced_memory_adapter import ( - get_memory_context_for_prompt, - process_conversation_with_enhanced_memory, - retrieve_memories_with_enhanced_system, -) -from src.common.logger import get_logger - -logger = get_logger(__name__) - - -@dataclass -class HookResult: - """钩子执行结果""" - - success: bool - data: Any = None - error: str | None = None - processing_time: float = 0.0 - - -class MemoryIntegrationHooks: - """记忆系统集成钩子""" - - def __init__(self): - self.hooks_registered = False - self.hook_stats = { - "message_processing_hooks": 0, - "memory_retrieval_hooks": 0, - "prompt_enhancement_hooks": 0, - "total_hook_executions": 0, - "average_hook_time": 0.0, - } - - async def register_hooks(self): - """注册所有集成钩子""" - if self.hooks_registered: - return - - try: - logger.info("🔗 注册记忆系统集成钩子...") - - # 注册消息处理钩子 - await self._register_message_processing_hooks() - - # 注册记忆检索钩子 - await self._register_memory_retrieval_hooks() - - # 注册提示词增强钩子 - await self._register_prompt_enhancement_hooks() - - # 注册系统维护钩子 - await self._register_maintenance_hooks() - - self.hooks_registered = True - logger.info("✅ 记忆系统集成钩子注册完成") - - except Exception as e: - logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True) - - async def _register_message_processing_hooks(self): - """注册消息处理钩子""" - try: - # 钩子1: 在消息处理后创建记忆 - await self._register_post_message_hook() - - # 钩子2: 在聊天流保存时处理记忆 - await self._register_chat_stream_hook() - - logger.debug("消息处理钩子注册完成") - - except Exception as e: - logger.error(f"注册消息处理钩子失败: {e}") - - async def _register_memory_retrieval_hooks(self): - """注册记忆检索钩子""" - try: - # 钩子1: 在生成回复前检索相关记忆 - await self._register_pre_response_hook() - - # 钩子2: 在知识库查询前增强上下文 - await self._register_knowledge_query_hook() - - logger.debug("记忆检索钩子注册完成") - - except Exception as e: - logger.error(f"注册记忆检索钩子失败: {e}") - - async def _register_prompt_enhancement_hooks(self): - """注册提示词增强钩子""" - try: - # 钩子1: 增强提示词构建 - await self._register_prompt_building_hook() - - logger.debug("提示词增强钩子注册完成") - - except Exception as e: - logger.error(f"注册提示词增强钩子失败: {e}") - - async def _register_maintenance_hooks(self): - """注册系统维护钩子""" - try: - # 钩子1: 系统维护时的记忆系统维护 - await self._register_system_maintenance_hook() - - logger.debug("系统维护钩子注册完成") - - except Exception as e: - logger.error(f"注册系统维护钩子失败: {e}") - - async def _register_post_message_hook(self): - """注册消息后处理钩子""" - try: - # 这里需要根据实际的系统架构来注册钩子 - # 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整 - - # 尝试注册到事件系统 - try: - from src.plugin_system.base.component_types import EventType - from src.plugin_system.core.event_manager import event_manager - - # 注册消息后处理事件 - event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler) - logger.debug("已注册到事件系统的消息处理钩子") - - except ImportError: - logger.debug("事件系统不可用,跳过事件钩子注册") - - # 尝试注册到消息管理器 - try: - from src.chat.message_manager import message_manager - - # 如果消息管理器支持钩子注册 - if hasattr(message_manager, "register_post_process_hook"): - message_manager.register_post_process_hook(self._on_message_processed_hook) - logger.debug("已注册到消息管理器的处理钩子") - - except ImportError: - logger.debug("消息管理器不可用,跳过消息管理器钩子注册") - - except Exception as e: - logger.error(f"注册消息后处理钩子失败: {e}") - - async def _register_chat_stream_hook(self): - """注册聊天流钩子""" - try: - # 尝试注册到聊天流管理器 - try: - from src.chat.message_receive.chat_stream import get_chat_manager - - chat_manager = get_chat_manager() - if hasattr(chat_manager, "register_save_hook"): - chat_manager.register_save_hook(self._on_chat_stream_save_hook) - logger.debug("已注册到聊天流管理器的保存钩子") - - except ImportError: - logger.debug("聊天流管理器不可用,跳过聊天流钩子注册") - - except Exception as e: - logger.error(f"注册聊天流钩子失败: {e}") - - async def _register_pre_response_hook(self): - """注册回复前钩子""" - try: - # 尝试注册到回复生成器 - try: - from src.chat.replyer.default_generator import default_generator - - if hasattr(default_generator, "register_pre_generation_hook"): - default_generator.register_pre_generation_hook(self._on_pre_response_hook) - logger.debug("已注册到回复生成器的前置钩子") - - except ImportError: - logger.debug("回复生成器不可用,跳过回复前钩子注册") - - except Exception as e: - logger.error(f"注册回复前钩子失败: {e}") - - async def _register_knowledge_query_hook(self): - """注册知识库查询钩子""" - try: - # 尝试注册到知识库系统 - try: - from src.chat.knowledge.knowledge_lib import knowledge_manager - - if hasattr(knowledge_manager, "register_query_enhancer"): - knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook) - logger.debug("已注册到知识库的查询增强钩子") - - except ImportError: - logger.debug("知识库系统不可用,跳过知识库钩子注册") - - except Exception as e: - logger.error(f"注册知识库查询钩子失败: {e}") - - async def _register_prompt_building_hook(self): - """注册提示词构建钩子""" - try: - # 尝试注册到提示词系统 - try: - from src.chat.utils.prompt import prompt_manager - - if hasattr(prompt_manager, "register_enhancer"): - prompt_manager.register_enhancer(self._on_prompt_building_hook) - logger.debug("已注册到提示词管理器的增强钩子") - - except ImportError: - logger.debug("提示词系统不可用,跳过提示词钩子注册") - - except Exception as e: - logger.error(f"注册提示词构建钩子失败: {e}") - - async def _register_system_maintenance_hook(self): - """注册系统维护钩子""" - try: - # 尝试注册到系统维护器 - try: - from src.manager.async_task_manager import async_task_manager - - # 注册定期维护任务 - async_task_manager.add_task(MemoryMaintenanceTask()) - logger.debug("已注册到系统维护器的定期任务") - - except ImportError: - logger.debug("异步任务管理器不可用,跳过系统维护钩子注册") - - except Exception as e: - logger.error(f"注册系统维护钩子失败: {e}") - - # 钩子处理器方法 - - async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult: - """事件系统的消息处理处理器""" - return await self._on_message_processed_hook(event_data) - - async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult: - """消息后处理钩子""" - start_time = time.time() - - try: - self.hook_stats["message_processing_hooks"] += 1 - - # 提取必要的信息 - message_info = message_data.get("message_info", {}) - user_info = message_info.get("user_info", {}) - conversation_text = message_data.get("processed_plain_text", "") - - if not conversation_text: - return HookResult(success=True, data="No conversation text") - - user_id = str(user_info.get("user_id", "unknown")) - context = { - "chat_id": message_data.get("chat_id"), - "message_type": message_data.get("message_type", "normal"), - "platform": message_info.get("platform", "unknown"), - "interest_value": message_data.get("interest_value", 0.0), - "keywords": message_data.get("key_words", []), - "timestamp": message_data.get("time", time.time()), - } - - # 使用增强记忆系统处理对话 - memory_context = dict(context) - memory_context["conversation_text"] = conversation_text - memory_context["user_id"] = user_id - - result = await process_conversation_with_enhanced_memory(memory_context) - - processing_time = time.time() - start_time - self._update_hook_stats(processing_time) - - if result["success"]: - logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆") - return HookResult(success=True, data=result, processing_time=processing_time) - else: - logger.warning(f"消息处理钩子执行失败: {result.get('error')}") - return HookResult(success=False, error=result.get("error"), processing_time=processing_time) - - except Exception as e: - processing_time = time.time() - start_time - logger.error(f"消息处理钩子执行异常: {e}", exc_info=True) - return HookResult(success=False, error=str(e), processing_time=processing_time) - - async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult: - """聊天流保存钩子""" - start_time = time.time() - - try: - self.hook_stats["message_processing_hooks"] += 1 - - # 从聊天流数据中提取对话信息 - stream_context = chat_stream_data.get("stream_context", {}) - user_id = stream_context.get("user_id", "unknown") - messages = stream_context.get("messages", []) - - if not messages: - return HookResult(success=True, data="No messages to process") - - # 构建对话文本 - conversation_parts = [] - for msg in messages[-10:]: # 只处理最近10条消息 - text = msg.get("processed_plain_text", "") - if text: - conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}") - - conversation_text = "\n".join(conversation_parts) - if not conversation_text: - return HookResult(success=True, data="No conversation text") - - context = { - "chat_id": chat_stream_data.get("chat_id"), - "stream_id": chat_stream_data.get("stream_id"), - "platform": chat_stream_data.get("platform", "unknown"), - "message_count": len(messages), - "timestamp": time.time(), - } - - # 使用增强记忆系统处理对话 - memory_context = dict(context) - memory_context["conversation_text"] = conversation_text - memory_context["user_id"] = user_id - - result = await process_conversation_with_enhanced_memory(memory_context) - - processing_time = time.time() - start_time - self._update_hook_stats(processing_time) - - if result["success"]: - logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆") - return HookResult(success=True, data=result, processing_time=processing_time) - else: - logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}") - return HookResult(success=False, error=result.get("error"), processing_time=processing_time) - - except Exception as e: - processing_time = time.time() - start_time - logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True) - return HookResult(success=False, error=str(e), processing_time=processing_time) - - async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult: - """回复前钩子""" - start_time = time.time() - - try: - self.hook_stats["memory_retrieval_hooks"] += 1 - - # 提取查询信息 - query = response_data.get("query", "") - user_id = response_data.get("user_id", "unknown") - context = response_data.get("context", {}) - - if not query: - return HookResult(success=True, data="No query provided") - - # 检索相关记忆 - memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5) - - processing_time = time.time() - start_time - self._update_hook_stats(processing_time) - - # 将记忆添加到响应数据中 - response_data["enhanced_memories"] = memories - response_data["enhanced_memory_context"] = await get_memory_context_for_prompt( - query, user_id, context, max_memories=5 - ) - - logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆") - return HookResult(success=True, data=memories, processing_time=processing_time) - - except Exception as e: - processing_time = time.time() - start_time - logger.error(f"回复前钩子执行异常: {e}", exc_info=True) - return HookResult(success=False, error=str(e), processing_time=processing_time) - - async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult: - """知识库查询钩子""" - start_time = time.time() - - try: - self.hook_stats["memory_retrieval_hooks"] += 1 - - query = query_data.get("query", "") - user_id = query_data.get("user_id", "unknown") - context = query_data.get("context", {}) - - if not query: - return HookResult(success=True, data="No query provided") - - # 获取记忆上下文并增强查询 - memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3) - - processing_time = time.time() - start_time - self._update_hook_stats(processing_time) - - # 将记忆上下文添加到查询数据中 - query_data["enhanced_memory_context"] = memory_context - - logger.debug("知识库查询钩子执行成功") - return HookResult(success=True, data=memory_context, processing_time=processing_time) - - except Exception as e: - processing_time = time.time() - start_time - logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True) - return HookResult(success=False, error=str(e), processing_time=processing_time) - - async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult: - """提示词构建钩子""" - start_time = time.time() - - try: - self.hook_stats["prompt_enhancement_hooks"] += 1 - - query = prompt_data.get("query", "") - user_id = prompt_data.get("user_id", "unknown") - context = prompt_data.get("context", {}) - base_prompt = prompt_data.get("base_prompt", "") - - if not query: - return HookResult(success=True, data="No query provided") - - # 获取记忆上下文 - memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5) - - processing_time = time.time() - start_time - self._update_hook_stats(processing_time) - - # 构建增强的提示词 - enhanced_prompt = base_prompt - if memory_context: - enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n" - - # 将增强的提示词添加到数据中 - prompt_data["enhanced_prompt"] = enhanced_prompt - prompt_data["memory_context"] = memory_context - - logger.debug("提示词构建钩子执行成功") - return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time) - - except Exception as e: - processing_time = time.time() - start_time - logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True) - return HookResult(success=False, error=str(e), processing_time=processing_time) - - def _update_hook_stats(self, processing_time: float): - """更新钩子统计""" - self.hook_stats["total_hook_executions"] += 1 - - total_executions = self.hook_stats["total_hook_executions"] - if total_executions > 0: - current_avg = self.hook_stats["average_hook_time"] - new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions - self.hook_stats["average_hook_time"] = new_avg - - def get_hook_stats(self) -> dict[str, Any]: - """获取钩子统计信息""" - return self.hook_stats.copy() - - -class MemoryMaintenanceTask: - """记忆系统维护任务""" - - def __init__(self): - self.task_name = "enhanced_memory_maintenance" - self.interval = 3600 # 1小时执行一次 - - async def execute(self): - """执行维护任务""" - try: - logger.info("🔧 执行增强记忆系统维护任务...") - - # 获取适配器实例 - try: - from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter - - if _enhanced_memory_adapter: - await _enhanced_memory_adapter.maintenance() - logger.info("✅ 增强记忆系统维护任务完成") - else: - logger.debug("增强记忆适配器未初始化,跳过维护") - except Exception as e: - logger.error(f"增强记忆系统维护失败: {e}") - - except Exception as e: - logger.error(f"执行维护任务时发生异常: {e}", exc_info=True) - - def get_interval(self) -> int: - """获取执行间隔""" - return self.interval - - def get_task_name(self) -> str: - """获取任务名称""" - return self.task_name - - -# 全局钩子实例 -_memory_hooks: MemoryIntegrationHooks | None = None - - -async def get_memory_integration_hooks() -> MemoryIntegrationHooks: - """获取全局记忆集成钩子实例""" - global _memory_hooks - - if _memory_hooks is None: - _memory_hooks = MemoryIntegrationHooks() - await _memory_hooks.register_hooks() - - return _memory_hooks - - -async def initialize_memory_integration_hooks(): - """初始化记忆集成钩子""" - try: - logger.info("🚀 初始化记忆集成钩子...") - hooks = await get_memory_integration_hooks() - logger.info("✅ 记忆集成钩子初始化完成") - return hooks - except Exception as e: - logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True) - return None diff --git a/src/chat/memory_system/deprecated_backup/metadata_index.py b/src/chat/memory_system/deprecated_backup/metadata_index.py deleted file mode 100644 index 8c89e5c34..000000000 --- a/src/chat/memory_system/deprecated_backup/metadata_index.py +++ /dev/null @@ -1,1027 +0,0 @@ -""" -元数据索引系统 -为记忆系统提供多维度的精准过滤和查询能力 -""" - -import threading -import time -from collections import defaultdict -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Any - -import orjson - -from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType -from src.common.logger import get_logger - -logger = get_logger(__name__) - - -class IndexType(Enum): - """索引类型""" - - MEMORY_TYPE = "memory_type" # 记忆类型索引 - USER_ID = "user_id" # 用户ID索引 - SUBJECT = "subject" # 主体索引 - KEYWORD = "keyword" # 关键词索引 - TAG = "tag" # 标签索引 - CATEGORY = "category" # 分类索引 - TIMESTAMP = "timestamp" # 时间索引 - CONFIDENCE = "confidence" # 置信度索引 - IMPORTANCE = "importance" # 重要性索引 - RELATIONSHIP_SCORE = "relationship_score" # 关系分索引 - ACCESS_FREQUENCY = "access_frequency" # 访问频率索引 - SEMANTIC_HASH = "semantic_hash" # 语义哈希索引 - - -@dataclass -class IndexQuery: - """索引查询条件""" - - user_ids: list[str] | None = None - memory_types: list[MemoryType] | None = None - subjects: list[str] | None = None - keywords: list[str] | None = None - tags: list[str] | None = None - categories: list[str] | None = None - time_range: tuple[float, float] | None = None - confidence_levels: list[ConfidenceLevel] | None = None - importance_levels: list[ImportanceLevel] | None = None - min_relationship_score: float | None = None - max_relationship_score: float | None = None - min_access_count: int | None = None - semantic_hashes: list[str] | None = None - limit: int | None = None - sort_by: str | None = None # "created_at", "access_count", "relevance_score" - sort_order: str = "desc" # "asc", "desc" - - -@dataclass -class IndexResult: - """索引结果""" - - memory_ids: list[str] - total_count: int - query_time: float - filtered_by: list[str] - - -class MetadataIndexManager: - """元数据索引管理器""" - - def __init__(self, index_path: str = "data/memory_metadata"): - self.index_path = Path(index_path) - self.index_path.mkdir(parents=True, exist_ok=True) - - # 各类索引 - self.indices = { - IndexType.MEMORY_TYPE: defaultdict(set), - IndexType.USER_ID: defaultdict(set), - IndexType.SUBJECT: defaultdict(set), - IndexType.KEYWORD: defaultdict(set), - IndexType.TAG: defaultdict(set), - IndexType.CATEGORY: defaultdict(set), - IndexType.CONFIDENCE: defaultdict(set), - IndexType.IMPORTANCE: defaultdict(set), - IndexType.SEMANTIC_HASH: defaultdict(set), - } - - # 时间索引(特殊处理) - self.time_index = [] # [(timestamp, memory_id), ...] - self.relationship_index = [] # [(relationship_score, memory_id), ...] - self.access_frequency_index = [] # [(access_count, memory_id), ...] - - # 内存缓存 - self.memory_metadata_cache: dict[str, dict[str, Any]] = {} - - # 统计信息 - self.index_stats = { - "total_memories": 0, - "index_build_time": 0.0, - "average_query_time": 0.0, - "total_queries": 0, - "cache_hit_rate": 0.0, - "cache_hits": 0, - } - - # 线程锁 - self._lock = threading.RLock() - self._dirty = False # 标记索引是否有未保存的更改 - - # 自动保存配置 - self.auto_save_interval = 500 # 每500次操作自动保存 - self._operation_count = 0 - - @staticmethod - def _serialize_index_key(index_type: IndexType, key: Any) -> str: - """将索引键序列化为字符串以便存储""" - if isinstance(key, Enum): - value = key.value - else: - value = key - return str(value) - - @staticmethod - def _deserialize_index_key(index_type: IndexType, key: str) -> Any: - """根据索引类型反序列化索引键""" - try: - if index_type == IndexType.MEMORY_TYPE: - return MemoryType(key) - if index_type == IndexType.CONFIDENCE: - return ConfidenceLevel(int(key)) - if index_type == IndexType.IMPORTANCE: - return ImportanceLevel(int(key)) - # 其他索引键默认使用原始字符串(可能已经是lower后的字符串) - return key - except Exception: - logger.warning("无法反序列化索引键 %s 在索引 %s 中,使用原始字符串", key, index_type.value) - return key - - @staticmethod - def _serialize_metadata_entry(metadata: dict[str, Any]) -> dict[str, Any]: - serialized = {} - for field_name, value in metadata.items(): - if isinstance(value, Enum): - serialized[field_name] = value.value - else: - serialized[field_name] = value - return serialized - - async def index_memories(self, memories: list[MemoryChunk]): - """为记忆建立索引""" - if not memories: - return - - start_time = time.time() - - try: - with self._lock: - for memory in memories: - self._index_single_memory(memory) - - # 标记为需要保存 - self._dirty = True - self._operation_count += len(memories) - - # 自动保存检查 - if self._operation_count >= self.auto_save_interval: - await self.save_index() - self._operation_count = 0 - - index_time = time.time() - start_time - self.index_stats["index_build_time"] = ( - self.index_stats["index_build_time"] * (len(memories) - 1) + index_time - ) / len(memories) - - logger.debug(f"元数据索引完成,{len(memories)} 条记忆,耗时 {index_time:.3f}秒") - - except Exception as e: - logger.error(f"❌ 元数据索引失败: {e}", exc_info=True) - - async def update_memory_entry(self, memory: MemoryChunk): - """更新已存在记忆的索引信息""" - if not memory: - return - - with self._lock: - entry = self.memory_metadata_cache.get(memory.memory_id) - if entry is None: - # 若不存在则作为新记忆索引 - self._index_single_memory(memory) - return - - old_confidence = entry.get("confidence") - old_importance = entry.get("importance") - old_semantic_hash = entry.get("semantic_hash") - - entry.update( - { - "user_id": memory.user_id, - "memory_type": memory.memory_type, - "created_at": memory.metadata.created_at, - "last_accessed": memory.metadata.last_accessed, - "access_count": memory.metadata.access_count, - "confidence": memory.metadata.confidence, - "importance": memory.metadata.importance, - "relationship_score": memory.metadata.relationship_score, - "relevance_score": memory.metadata.relevance_score, - "semantic_hash": memory.semantic_hash, - "subjects": memory.subjects, - } - ) - - # 更新置信度/重要性索引 - if isinstance(old_confidence, ConfidenceLevel): - self.indices[IndexType.CONFIDENCE][old_confidence].discard(memory.memory_id) - if isinstance(old_importance, ImportanceLevel): - self.indices[IndexType.IMPORTANCE][old_importance].discard(memory.memory_id) - if isinstance(old_semantic_hash, str): - self.indices[IndexType.SEMANTIC_HASH][old_semantic_hash].discard(memory.memory_id) - - self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory.memory_id) - self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory.memory_id) - if memory.semantic_hash: - self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory.memory_id) - - # 同步关键词/标签/分类索引 - for keyword in memory.keywords: - if keyword: - self.indices[IndexType.KEYWORD][keyword.lower()].add(memory.memory_id) - - for tag in memory.tags: - if tag: - self.indices[IndexType.TAG][tag.lower()].add(memory.memory_id) - - for category in memory.categories: - if category: - self.indices[IndexType.CATEGORY][category.lower()].add(memory.memory_id) - - for subject in memory.subjects: - if subject: - self.indices[IndexType.SUBJECT][subject.strip().lower()].add(memory.memory_id) - - def _index_single_memory(self, memory: MemoryChunk): - """为单个记忆建立索引""" - memory_id = memory.memory_id - - # 更新内存缓存 - self.memory_metadata_cache[memory_id] = { - "user_id": memory.user_id, - "memory_type": memory.memory_type, - "created_at": memory.metadata.created_at, - "last_accessed": memory.metadata.last_accessed, - "access_count": memory.metadata.access_count, - "confidence": memory.metadata.confidence, - "importance": memory.metadata.importance, - "relationship_score": memory.metadata.relationship_score, - "relevance_score": memory.metadata.relevance_score, - "semantic_hash": memory.semantic_hash, - "subjects": memory.subjects, - } - - # 记忆类型索引 - self.indices[IndexType.MEMORY_TYPE][memory.memory_type].add(memory_id) - - # 用户ID索引 - self.indices[IndexType.USER_ID][memory.user_id].add(memory_id) - - # 主体索引 - for subject in memory.subjects: - normalized = subject.strip().lower() - if normalized: - self.indices[IndexType.SUBJECT][normalized].add(memory_id) - - # 关键词索引 - for keyword in memory.keywords: - self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id) - - # 标签索引 - for tag in memory.tags: - self.indices[IndexType.TAG][tag.lower()].add(memory_id) - - # 分类索引 - for category in memory.categories: - self.indices[IndexType.CATEGORY][category.lower()].add(memory_id) - - # 置信度索引 - self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory_id) - - # 重要性索引 - self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory_id) - - # 语义哈希索引 - if memory.semantic_hash: - self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory_id) - - # 时间索引(插入排序保持有序) - self._insert_into_time_index(memory.metadata.created_at, memory_id) - - # 关系分索引(插入排序保持有序) - self._insert_into_relationship_index(memory.metadata.relationship_score, memory_id) - - # 访问频率索引(插入排序保持有序) - self._insert_into_access_frequency_index(memory.metadata.access_count, memory_id) - - # 更新统计 - self.index_stats["total_memories"] += 1 - - def _insert_into_time_index(self, timestamp: float, memory_id: str): - """插入时间索引(保持降序)""" - insert_pos = len(self.time_index) - for i, (ts, _) in enumerate(self.time_index): - if timestamp >= ts: - insert_pos = i - break - - self.time_index.insert(insert_pos, (timestamp, memory_id)) - - def _insert_into_relationship_index(self, relationship_score: float, memory_id: str): - """插入关系分索引(保持降序)""" - insert_pos = len(self.relationship_index) - for i, (score, _) in enumerate(self.relationship_index): - if relationship_score >= score: - insert_pos = i - break - - self.relationship_index.insert(insert_pos, (relationship_score, memory_id)) - - def _insert_into_access_frequency_index(self, access_count: int, memory_id: str): - """插入访问频率索引(保持降序)""" - insert_pos = len(self.access_frequency_index) - for i, (count, _) in enumerate(self.access_frequency_index): - if access_count >= count: - insert_pos = i - break - - self.access_frequency_index.insert(insert_pos, (access_count, memory_id)) - - async def query_memories(self, query: IndexQuery) -> IndexResult: - """查询记忆""" - start_time = time.time() - - try: - with self._lock: - # 获取候选记忆ID集合 - candidate_ids = self._get_candidate_memories(query) - - # 应用过滤条件 - filtered_ids = self._apply_filters(candidate_ids, query) - - # 排序 - if query.sort_by: - filtered_ids = self._sort_memories(filtered_ids, query.sort_by, query.sort_order) - - # 限制数量 - if query.limit and len(filtered_ids) > query.limit: - filtered_ids = filtered_ids[: query.limit] - - # 记录查询统计 - query_time = time.time() - start_time - self.index_stats["total_queries"] += 1 - self.index_stats["average_query_time"] = ( - self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time - ) / self.index_stats["total_queries"] - - return IndexResult( - memory_ids=filtered_ids, - total_count=len(filtered_ids), - query_time=query_time, - filtered_by=self._get_applied_filters(query), - ) - - except Exception as e: - logger.error(f"❌ 元数据查询失败: {e}", exc_info=True) - return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[]) - - def _get_candidate_memories(self, query: IndexQuery) -> set[str]: - """获取候选记忆ID集合""" - candidate_ids = set() - - # 获取所有记忆ID作为起点 - all_memory_ids = set(self.memory_metadata_cache.keys()) - - if not all_memory_ids: - return candidate_ids - - # 应用最严格的过滤条件 - applied_filters = [] - - if query.memory_types: - memory_types_set = set() - for memory_type in query.memory_types: - memory_types_set.update(self.indices[IndexType.MEMORY_TYPE].get(memory_type, set())) - if applied_filters: - candidate_ids &= memory_types_set - else: - candidate_ids.update(memory_types_set) - applied_filters.append("memory_types") - - if query.keywords: - keywords_set = set() - for keyword in query.keywords: - keywords_set.update(self._collect_index_matches(IndexType.KEYWORD, keyword)) - if applied_filters: - candidate_ids &= keywords_set - else: - candidate_ids.update(keywords_set) - applied_filters.append("keywords") - - if query.tags: - tags_set = set() - for tag in query.tags: - tags_set.update(self.indices[IndexType.TAG].get(tag.lower(), set())) - if applied_filters: - candidate_ids &= tags_set - else: - candidate_ids.update(tags_set) - applied_filters.append("tags") - - if query.categories: - categories_set = set() - for category in query.categories: - categories_set.update(self.indices[IndexType.CATEGORY].get(category.lower(), set())) - if applied_filters: - candidate_ids &= categories_set - else: - candidate_ids.update(categories_set) - applied_filters.append("categories") - - if query.subjects: - subjects_set = set() - for subject in query.subjects: - subjects_set.update(self._collect_index_matches(IndexType.SUBJECT, subject)) - if applied_filters: - candidate_ids &= subjects_set - else: - candidate_ids.update(subjects_set) - applied_filters.append("subjects") - - # 如果没有应用任何过滤条件,返回所有记忆 - if not applied_filters: - return all_memory_ids - - return candidate_ids - - def _collect_index_matches(self, index_type: IndexType, token: str | Enum | None) -> set[str]: - """根据给定token收集索引匹配,支持部分匹配""" - mapping = self.indices.get(index_type) - if mapping is None: - return set() - - key = "" - if isinstance(token, Enum): - key = str(token.value).strip().lower() - elif isinstance(token, str): - key = token.strip().lower() - elif token is not None: - key = str(token).strip().lower() - - if not key: - return set() - - matches: set[str] = set(mapping.get(key, set())) - - if matches: - return set(matches) - - for existing_key, ids in mapping.items(): - if not existing_key or not isinstance(existing_key, str): - continue - normalized = existing_key.strip().lower() - if not normalized: - continue - if key in normalized or normalized in key: - matches.update(ids) - - return matches - - def _apply_filters(self, candidate_ids: set[str], query: IndexQuery) -> list[str]: - """应用过滤条件""" - filtered_ids = list(candidate_ids) - - # 时间范围过滤 - if query.time_range: - start_time, end_time = query.time_range - filtered_ids = [ - memory_id for memory_id in filtered_ids if self._is_in_time_range(memory_id, start_time, end_time) - ] - - # 置信度过滤 - if query.confidence_levels: - confidence_set = set(query.confidence_levels) - filtered_ids = [ - memory_id - for memory_id in filtered_ids - if self.memory_metadata_cache[memory_id]["confidence"] in confidence_set - ] - - # 重要性过滤 - if query.importance_levels: - importance_set = set(query.importance_levels) - filtered_ids = [ - memory_id - for memory_id in filtered_ids - if self.memory_metadata_cache[memory_id]["importance"] in importance_set - ] - - # 关系分范围过滤 - if query.min_relationship_score is not None: - filtered_ids = [ - memory_id - for memory_id in filtered_ids - if self.memory_metadata_cache[memory_id]["relationship_score"] >= query.min_relationship_score - ] - - if query.max_relationship_score is not None: - filtered_ids = [ - memory_id - for memory_id in filtered_ids - if self.memory_metadata_cache[memory_id]["relationship_score"] <= query.max_relationship_score - ] - - # 最小访问次数过滤 - if query.min_access_count is not None: - filtered_ids = [ - memory_id - for memory_id in filtered_ids - if self.memory_metadata_cache[memory_id]["access_count"] >= query.min_access_count - ] - - # 语义哈希过滤 - if query.semantic_hashes: - hash_set = set(query.semantic_hashes) - filtered_ids = [ - memory_id - for memory_id in filtered_ids - if self.memory_metadata_cache[memory_id]["semantic_hash"] in hash_set - ] - - return filtered_ids - - def _is_in_time_range(self, memory_id: str, start_time: float, end_time: float) -> bool: - """检查记忆是否在时间范围内""" - created_at = self.memory_metadata_cache[memory_id]["created_at"] - return start_time <= created_at <= end_time - - def _sort_memories(self, memory_ids: list[str], sort_by: str, sort_order: str) -> list[str]: - """对记忆进行排序""" - if sort_by == "created_at": - # 使用时间索引(已经有序) - if sort_order == "desc": - return memory_ids # 时间索引已经是降序 - else: - return memory_ids[::-1] # 反转为升序 - - elif sort_by == "access_count": - # 使用访问频率索引(已经有序) - if sort_order == "desc": - return memory_ids # 访问频率索引已经是降序 - else: - return memory_ids[::-1] # 反转为升序 - - elif sort_by == "relevance_score": - # 按相关度排序 - memory_ids.sort( - key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"], reverse=(sort_order == "desc") - ) - - elif sort_by == "relationship_score": - # 使用关系分索引(已经有序) - if sort_order == "desc": - return memory_ids # 关系分索引已经是降序 - else: - return memory_ids[::-1] # 反转为升序 - - elif sort_by == "last_accessed": - # 按最后访问时间排序 - memory_ids.sort( - key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"], reverse=(sort_order == "desc") - ) - - return memory_ids - - def _get_applied_filters(self, query: IndexQuery) -> list[str]: - """获取应用的过滤器列表""" - filters = [] - if query.memory_types: - filters.append("memory_types") - if query.subjects: - filters.append("subjects") - if query.keywords: - filters.append("keywords") - if query.tags: - filters.append("tags") - if query.categories: - filters.append("categories") - if query.time_range: - filters.append("time_range") - if query.confidence_levels: - filters.append("confidence_levels") - if query.importance_levels: - filters.append("importance_levels") - if query.min_relationship_score is not None or query.max_relationship_score is not None: - filters.append("relationship_score_range") - if query.min_access_count is not None: - filters.append("min_access_count") - if query.semantic_hashes: - filters.append("semantic_hashes") - return filters - - async def update_memory_index(self, memory: MemoryChunk): - """更新记忆索引""" - with self._lock: - try: - memory_id = memory.memory_id - - # 如果记忆已存在,先删除旧索引 - if memory_id in self.memory_metadata_cache: - await self.remove_memory_index(memory_id) - - # 重新建立索引 - self._index_single_memory(memory) - self._dirty = True - self._operation_count += 1 - - # 自动保存检查 - if self._operation_count >= self.auto_save_interval: - await self.save_index() - self._operation_count = 0 - - logger.debug(f"更新记忆索引完成: {memory_id}") - - except Exception as e: - logger.error(f"❌ 更新记忆索引失败: {e}") - - async def remove_memory_index(self, memory_id: str): - """移除记忆索引""" - with self._lock: - try: - if memory_id not in self.memory_metadata_cache: - return - - # 获取记忆元数据 - metadata = self.memory_metadata_cache[memory_id] - - # 从各类索引中移除 - self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id) - self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id) - subjects = metadata.get("subjects") or [] - for subject in subjects: - if not isinstance(subject, str): - continue - normalized = subject.strip().lower() - if not normalized: - continue - subject_bucket = self.indices[IndexType.SUBJECT].get(normalized) - if subject_bucket is not None: - subject_bucket.discard(memory_id) - if not subject_bucket: - self.indices[IndexType.SUBJECT].pop(normalized, None) - - # 从时间索引中移除 - self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id] - - # 从关系分索引中移除 - self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid != memory_id] - - # 从访问频率索引中移除 - self.access_frequency_index = [ - (count, mid) for count, mid in self.access_frequency_index if mid != memory_id - ] - - # 注意:关键词、标签、分类索引需要从原始记忆中获取,这里简化处理 - # 实际实现中可能需要重新加载记忆或维护反向索引 - - # 从缓存中移除 - del self.memory_metadata_cache[memory_id] - - # 更新统计 - self.index_stats["total_memories"] = max(0, self.index_stats["total_memories"] - 1) - self._dirty = True - - logger.debug(f"移除记忆索引完成: {memory_id}") - - except Exception as e: - logger.error(f"❌ 移除记忆索引失败: {e}") - - async def get_memory_metadata(self, memory_id: str) -> dict[str, Any] | None: - """获取记忆元数据""" - return self.memory_metadata_cache.get(memory_id) - - async def get_user_memory_ids(self, user_id: str, limit: int | None = None) -> list[str]: - """获取用户的所有记忆ID""" - user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set())) - - if limit and len(user_memory_ids) > limit: - user_memory_ids = user_memory_ids[:limit] - - return user_memory_ids - - async def get_memory_statistics(self, user_id: str | None = None) -> dict[str, Any]: - """获取记忆统计信息""" - stats = { - "total_memories": self.index_stats["total_memories"], - "memory_types": {}, - "average_confidence": 0.0, - "average_importance": 0.0, - "average_relationship_score": 0.0, - "top_keywords": [], - "top_tags": [], - } - - if user_id: - # 限定用户统计 - user_memory_ids = self.indices[IndexType.USER_ID].get(user_id, set()) - stats["user_total_memories"] = len(user_memory_ids) - - if not user_memory_ids: - return stats - - # 用户记忆类型分布 - user_types = {} - for memory_type, memory_ids in self.indices[IndexType.MEMORY_TYPE].items(): - user_count = len(user_memory_ids & memory_ids) - if user_count > 0: - user_types[memory_type.value] = user_count - stats["memory_types"] = user_types - - # 计算用户平均值 - user_confidences = [] - user_importances = [] - user_relationship_scores = [] - - for memory_id in user_memory_ids: - metadata = self.memory_metadata_cache.get(memory_id, {}) - if metadata: - user_confidences.append(metadata["confidence"].value) - user_importances.append(metadata["importance"].value) - user_relationship_scores.append(metadata["relationship_score"]) - - if user_confidences: - stats["average_confidence"] = sum(user_confidences) / len(user_confidences) - if user_importances: - stats["average_importance"] = sum(user_importances) / len(user_importances) - if user_relationship_scores: - stats["average_relationship_score"] = sum(user_relationship_scores) / len(user_relationship_scores) - - else: - # 全局统计 - for memory_type, memory_ids in self.indices[IndexType.MEMORY_TYPE].items(): - stats["memory_types"][memory_type.value] = len(memory_ids) - - # 计算全局平均值 - if self.memory_metadata_cache: - all_confidences = [m["confidence"].value for m in self.memory_metadata_cache.values()] - all_importances = [m["importance"].value for m in self.memory_metadata_cache.values()] - all_relationship_scores = [m["relationship_score"] for m in self.memory_metadata_cache.values()] - - if all_confidences: - stats["average_confidence"] = sum(all_confidences) / len(all_confidences) - if all_importances: - stats["average_importance"] = sum(all_importances) / len(all_importances) - if all_relationship_scores: - stats["average_relationship_score"] = sum(all_relationship_scores) / len(all_relationship_scores) - - # 统计热门关键词和标签 - keyword_counts = [(keyword, len(memory_ids)) for keyword, memory_ids in self.indices[IndexType.KEYWORD].items()] - keyword_counts.sort(key=lambda x: x[1], reverse=True) - stats["top_keywords"] = keyword_counts[:10] - - tag_counts = [(tag, len(memory_ids)) for tag, memory_ids in self.indices[IndexType.TAG].items()] - tag_counts.sort(key=lambda x: x[1], reverse=True) - stats["top_tags"] = tag_counts[:10] - - return stats - - async def save_index(self): - """保存索引到文件""" - if not self._dirty: - return - - try: - logger.info("正在保存元数据索引...") - - # 保存各类索引 - indices_data: dict[str, dict[str, list[str]]] = {} - for index_type, index_data in self.indices.items(): - serialized_index = {} - for key, values in index_data.items(): - serialized_key = self._serialize_index_key(index_type, key) - serialized_index[serialized_key] = list(values) - indices_data[index_type.value] = serialized_index - - indices_file = self.index_path / "indices.json" - with open(indices_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode("utf-8")) - - # 保存时间索引 - time_index_file = self.index_path / "time_index.json" - with open(time_index_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode("utf-8")) - - # 保存关系分索引 - relationship_index_file = self.index_path / "relationship_index.json" - with open(relationship_index_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode("utf-8")) - - # 保存访问频率索引 - access_frequency_index_file = self.index_path / "access_frequency_index.json" - with open(access_frequency_index_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode("utf-8")) - - # 保存元数据缓存 - metadata_cache_file = self.index_path / "metadata_cache.json" - metadata_serialized = { - memory_id: self._serialize_metadata_entry(metadata) - for memory_id, metadata in self.memory_metadata_cache.items() - } - with open(metadata_cache_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode("utf-8")) - - # 保存统计信息 - stats_file = self.index_path / "index_stats.json" - with open(stats_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode("utf-8")) - - self._dirty = False - logger.info("✅ 元数据索引保存完成") - - except Exception as e: - logger.error(f"❌ 保存元数据索引失败: {e}") - - async def load_index(self): - """从文件加载索引""" - try: - logger.info("正在加载元数据索引...") - - # 加载各类索引 - indices_file = self.index_path / "indices.json" - if indices_file.exists(): - with open(indices_file, encoding="utf-8") as f: - indices_data = orjson.loads(f.read()) - - for index_type_value, index_data in indices_data.items(): - index_type = IndexType(index_type_value) - restored_index = defaultdict(set) - for key_str, values in index_data.items(): - restored_key = self._deserialize_index_key(index_type, key_str) - restored_index[restored_key] = set(values) - self.indices[index_type] = restored_index - - # 加载时间索引 - time_index_file = self.index_path / "time_index.json" - if time_index_file.exists(): - with open(time_index_file, encoding="utf-8") as f: - self.time_index = orjson.loads(f.read()) - - # 加载关系分索引 - relationship_index_file = self.index_path / "relationship_index.json" - if relationship_index_file.exists(): - with open(relationship_index_file, encoding="utf-8") as f: - self.relationship_index = orjson.loads(f.read()) - - # 加载访问频率索引 - access_frequency_index_file = self.index_path / "access_frequency_index.json" - if access_frequency_index_file.exists(): - with open(access_frequency_index_file, encoding="utf-8") as f: - self.access_frequency_index = orjson.loads(f.read()) - - # 加载元数据缓存 - metadata_cache_file = self.index_path / "metadata_cache.json" - if metadata_cache_file.exists(): - with open(metadata_cache_file, encoding="utf-8") as f: - cache_data = orjson.loads(f.read()) - - # 转换置信度和重要性为枚举类型 - for memory_id, metadata in cache_data.items(): - memory_type_value = metadata.get("memory_type") - if isinstance(memory_type_value, str): - try: - metadata["memory_type"] = MemoryType(memory_type_value) - except ValueError: - logger.warning("无法解析memory_type %s", memory_type_value) - - confidence_value = metadata.get("confidence") - if isinstance(confidence_value, (str, int)): - try: - metadata["confidence"] = ConfidenceLevel(int(confidence_value)) - except ValueError: - logger.warning("无法解析confidence %s", confidence_value) - - importance_value = metadata.get("importance") - if isinstance(importance_value, (str, int)): - try: - metadata["importance"] = ImportanceLevel(int(importance_value)) - except ValueError: - logger.warning("无法解析importance %s", importance_value) - - subjects_value = metadata.get("subjects") - if isinstance(subjects_value, str): - metadata["subjects"] = [subjects_value] - elif isinstance(subjects_value, list): - cleaned_subjects = [] - for item in subjects_value: - if isinstance(item, str) and item.strip(): - cleaned_subjects.append(item.strip()) - metadata["subjects"] = cleaned_subjects - else: - metadata["subjects"] = [] - - self.memory_metadata_cache = cache_data - - # 加载统计信息 - stats_file = self.index_path / "index_stats.json" - if stats_file.exists(): - with open(stats_file, encoding="utf-8") as f: - self.index_stats = orjson.loads(f.read()) - - # 更新记忆计数 - self.index_stats["total_memories"] = len(self.memory_metadata_cache) - - logger.info(f"✅ 元数据索引加载完成,{self.index_stats['total_memories']} 个记忆") - - except Exception as e: - logger.error(f"❌ 加载元数据索引失败: {e}") - - async def optimize_index(self): - """优化索引""" - try: - logger.info("开始元数据索引优化...") - - # 清理无效引用 - self._cleanup_invalid_references() - - # 重建有序索引 - self._rebuild_ordered_indices() - - # 清理低频关键词和标签 - self._cleanup_low_frequency_terms() - - # 更新统计信息 - if self.index_stats["total_queries"] > 0: - self.index_stats["cache_hit_rate"] = self.index_stats["cache_hits"] / self.index_stats["total_queries"] - - logger.info("✅ 元数据索引优化完成") - - except Exception as e: - logger.error(f"❌ 元数据索引优化失败: {e}") - - def _cleanup_invalid_references(self): - """清理无效引用""" - valid_memory_ids = set(self.memory_metadata_cache.keys()) - - # 清理各类索引中的无效引用 - for index_type in self.indices: - for key in list(self.indices[index_type].keys()): - valid_ids = self.indices[index_type][key] & valid_memory_ids - self.indices[index_type][key] = valid_ids - - # 如果某类别下没有记忆了,删除该类别 - if not valid_ids: - del self.indices[index_type][key] - - # 清理时间索引中的无效引用 - self.time_index = [(ts, mid) for ts, mid in self.time_index if mid in valid_memory_ids] - - # 清理关系分索引中的无效引用 - self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid in valid_memory_ids] - - # 清理访问频率索引中的无效引用 - self.access_frequency_index = [ - (count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids - ] - - # 更新总记忆数 - self.index_stats["total_memories"] = len(valid_memory_ids) - - def _rebuild_ordered_indices(self): - """重建有序索引""" - # 重建时间索引 - self.time_index.sort(key=lambda x: x[0], reverse=True) - - # 重建关系分索引 - self.relationship_index.sort(key=lambda x: x[0], reverse=True) - - # 重建访问频率索引 - self.access_frequency_index.sort(key=lambda x: x[0], reverse=True) - - def _cleanup_low_frequency_terms(self, min_frequency: int = 2): - """清理低频术语""" - # 清理低频关键词 - for keyword in list(self.indices[IndexType.KEYWORD].keys()): - if len(self.indices[IndexType.KEYWORD][keyword]) < min_frequency: - del self.indices[IndexType.KEYWORD][keyword] - - # 清理低频标签 - for tag in list(self.indices[IndexType.TAG].keys()): - if len(self.indices[IndexType.TAG][tag]) < min_frequency: - del self.indices[IndexType.TAG][tag] - - # 清理低频分类 - for category in list(self.indices[IndexType.CATEGORY].keys()): - if len(self.indices[IndexType.CATEGORY][category]) < min_frequency: - del self.indices[IndexType.CATEGORY][category] - - def get_index_stats(self) -> dict[str, Any]: - """获取索引统计信息""" - stats = self.index_stats.copy() - if stats["total_queries"] > 0: - stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_queries"] - else: - stats["cache_hit_rate"] = 0.0 - - # 添加索引详细信息 - stats["index_details"] = { - "memory_types": len(self.indices[IndexType.MEMORY_TYPE]), - "user_ids": len(self.indices[IndexType.USER_ID]), - "keywords": len(self.indices[IndexType.KEYWORD]), - "tags": len(self.indices[IndexType.TAG]), - "categories": len(self.indices[IndexType.CATEGORY]), - "confidence_levels": len(self.indices[IndexType.CONFIDENCE]), - "importance_levels": len(self.indices[IndexType.IMPORTANCE]), - "semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH]), - } - - return stats diff --git a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py b/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py deleted file mode 100644 index 529d9db99..000000000 --- a/src/chat/memory_system/deprecated_backup/multi_stage_retrieval.py +++ /dev/null @@ -1,1432 +0,0 @@ -""" -多阶段召回机制 -实现粗粒度到细粒度的记忆检索优化 -""" - -import time -from dataclasses import dataclass, field -from enum import Enum -from typing import Any - -import orjson - -from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig -from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType -from src.common.logger import get_logger - -logger = get_logger(__name__) - - -class RetrievalStage(Enum): - """检索阶段""" - - METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段 - VECTOR_SEARCH = "vector_search" # 向量搜索阶段 - SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段 - CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段 - - -@dataclass -class RetrievalConfig: - """检索配置""" - - # 各阶段配置 - 优化召回率 - metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加) - vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加) - semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加) - final_result_limit: int = 10 # 最终结果数量 - - # 相似度阈值 - 优化召回率 - vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率) - semantic_similarity_threshold: float = 0.05 # 语义相似度阈值(保持较低以获得更多相关记忆) - - # 权重配置 - vector_weight: float = 0.4 # 向量相似度权重 - semantic_weight: float = 0.3 # 语义相似度权重 - context_weight: float = 0.2 # 上下文权重 - recency_weight: float = 0.1 # 时效性权重 - - @classmethod - def from_global_config(cls): - """从全局配置创建配置实例""" - from src.config.config import global_config - - return cls( - # 各阶段配置 - 优化召回率 - metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池 - vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果 - semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选 - final_result_limit=global_config.memory.final_result_limit, - # 相似度阈值 - 优化召回率 - vector_similarity_threshold=max(0.5, global_config.memory.vector_similarity_threshold), # 确保不低于0.5 - semantic_similarity_threshold=0.05, # 进一步降低以提升召回率 - # 权重配置 - vector_weight=global_config.memory.vector_weight, - semantic_weight=global_config.memory.semantic_weight, - context_weight=global_config.memory.context_weight, - recency_weight=global_config.memory.recency_weight, - ) - - -@dataclass -class StageResult: - """阶段结果""" - - stage: RetrievalStage - memory_ids: list[str] - processing_time: float - filtered_count: int - score_threshold: float - details: list[dict[str, Any]] = field(default_factory=list) - - -@dataclass -class RetrievalResult: - """检索结果""" - - query: str - user_id: str - final_memories: list[MemoryChunk] - stage_results: list[StageResult] - total_processing_time: float - total_filtered: int - retrieval_stats: dict[str, Any] - - -class MultiStageRetrieval: - """多阶段召回系统""" - - def __init__(self, config: RetrievalConfig | None = None): - self.config = config or RetrievalConfig.from_global_config() - - # 初始化增强重排序器 - reranker_config = ReRankingConfig( - semantic_weight=self.config.vector_weight, - recency_weight=self.config.recency_weight, - usage_freq_weight=0.2, # 新增的使用频率权重 - type_match_weight=0.1, # 新增的类型匹配权重 - ) - self.reranker = EnhancedReRanker(reranker_config) - - self.retrieval_stats = { - "total_queries": 0, - "average_retrieval_time": 0.0, - "stage_stats": { - "metadata_filtering": {"calls": 0, "avg_time": 0.0}, - "vector_search": {"calls": 0, "avg_time": 0.0}, - "semantic_reranking": {"calls": 0, "avg_time": 0.0}, - "contextual_filtering": {"calls": 0, "avg_time": 0.0}, - "enhanced_reranking": {"calls": 0, "avg_time": 0.0}, # 新增统计 - }, - } - - async def retrieve_memories( - self, - query: str, - user_id: str, - context: dict[str, Any], - metadata_index, - vector_storage, - all_memories_cache: dict[str, MemoryChunk], - limit: int | None = None, - ) -> RetrievalResult: - """多阶段记忆检索""" - start_time = time.time() - limit = limit or self.config.final_result_limit - - stage_results = [] - current_memory_ids = set() - memory_debug_info: dict[str, dict[str, Any]] = {} - - try: - logger.debug(f"开始多阶段检索:query='{query}', user_id='{user_id}'") - - # 阶段1:元数据过滤 - stage1_result = await self._metadata_filtering_stage( - query, user_id, context, metadata_index, all_memories_cache, debug_log=memory_debug_info - ) - stage_results.append(stage1_result) - current_memory_ids.update(stage1_result.memory_ids) - - # 阶段2:向量搜索 - stage2_result = await self._vector_search_stage( - query, - user_id, - context, - vector_storage, - current_memory_ids, - all_memories_cache, - debug_log=memory_debug_info, - ) - stage_results.append(stage2_result) - current_memory_ids.update(stage2_result.memory_ids) - - # 阶段3:语义重排序 - stage3_result = await self._semantic_reranking_stage( - query, user_id, context, current_memory_ids, all_memories_cache, debug_log=memory_debug_info - ) - stage_results.append(stage3_result) - - # 阶段4:上下文过滤 - stage4_result = await self._contextual_filtering_stage( - query, - user_id, - context, - stage3_result.memory_ids, - all_memories_cache, - limit, - debug_log=memory_debug_info, - ) - stage_results.append(stage4_result) - - # 检查是否需要回退机制 - if len(stage4_result.memory_ids) < min(3, limit): - logger.debug(f"上下文过滤结果过少({len(stage4_result.memory_ids)}),启用回退机制") - # 回退到更宽松的检索策略 - fallback_result = await self._fallback_retrieval_stage( - query, - user_id, - context, - all_memories_cache, - limit, - excluded_ids=set(stage4_result.memory_ids), - debug_log=memory_debug_info, - ) - if fallback_result.memory_ids: - stage4_result.memory_ids.extend(fallback_result.memory_ids[: limit - len(stage4_result.memory_ids)]) - logger.debug(f"回退机制补充了 {len(fallback_result.memory_ids)} 条记忆") - - # 阶段5:增强重排序 (新增) - stage5_result = await self._enhanced_reranking_stage( - query, - user_id, - context, - stage4_result.memory_ids, - all_memories_cache, - limit, - debug_log=memory_debug_info, - ) - stage_results.append(stage5_result) - - # 获取最终记忆对象 - final_memories = [] - for memory_id in stage5_result.memory_ids: # 使用重排序后的结果 - if memory_id in all_memories_cache: - memory = all_memories_cache[memory_id] - memory.update_access() # 更新访问统计 - final_memories.append(memory) - - # 更新统计 - total_time = time.time() - start_time - self._update_retrieval_stats(total_time, stage_results) - - total_filtered = sum(result.filtered_count for result in stage_results) - - logger.debug(f"多阶段检索完成:返回 {len(final_memories)} 条记忆,耗时 {total_time:.3f}s") - - if memory_debug_info: - final_ids_set = set(stage5_result.memory_ids) # 使用重排序后的结果 - debug_entries = [] - for memory_id, trace in memory_debug_info.items(): - memory_obj = all_memories_cache.get(memory_id) - display_text = "" - if memory_obj: - display_text = (memory_obj.display or memory_obj.text_content or "").strip() - if len(display_text) > 80: - display_text = display_text[:77] + "..." - - entry = { - "memory_id": memory_id, - "display": display_text, - "memory_type": memory_obj.memory_type.value if memory_obj else None, - "vector_similarity": trace.get("vector_stage", {}).get("similarity"), - "semantic_score": trace.get("semantic_stage", {}).get("score"), - "context_score": trace.get("context_stage", {}).get("context_score"), - "final_score": trace.get("context_stage", {}).get("final_score"), - "status": trace.get("context_stage", {}).get("status") - or trace.get("vector_stage", {}).get("status") - or trace.get("semantic_stage", {}).get("status"), - "is_final": memory_id in final_ids_set, - } - debug_entries.append(entry) - - # 限制日志输出数量 - debug_entries.sort( - key=lambda item: ( - item.get("is_final", False), - item.get("final_score") or item.get("vector_similarity") or 0.0, - ), - reverse=True, - ) - debug_payload = { - "query": query, - "semantic_query": context.get("resolved_query_text", query), - "user_id": user_id, - "stage_summaries": [ - { - "stage": result.stage.value, - "returned": len(result.memory_ids), - "filtered": result.filtered_count, - "duration": round(result.processing_time, 4), - "details": result.details, - } - for result in stage_results - ], - "candidates": debug_entries[:20], - } - try: - logger.info( - f"🧭 记忆检索调试 | query='{query}' | final={len(stage5_result.memory_ids)}", - extra={"memory_debug": debug_payload}, - ) - except Exception: - logger.info( - f"🧭 记忆检索调试详情: {orjson.dumps(debug_payload, ensure_ascii=False).decode('utf-8')}", - ) - - return RetrievalResult( - query=query, - user_id=user_id, - final_memories=final_memories, - stage_results=stage_results, - total_processing_time=total_time, - total_filtered=total_filtered, - retrieval_stats=self.retrieval_stats.copy(), - ) - - except Exception as e: - logger.error(f"多阶段检索失败: {e}", exc_info=True) - # 返回空结果 - return RetrievalResult( - query=query, - user_id=user_id, - final_memories=[], - stage_results=stage_results, - total_processing_time=time.time() - start_time, - total_filtered=0, - retrieval_stats=self.retrieval_stats.copy(), - ) - - async def _metadata_filtering_stage( - self, - query: str, - user_id: str, - context: dict[str, Any], - metadata_index, - all_memories_cache: dict[str, MemoryChunk], - *, - debug_log: dict[str, dict[str, Any]] | None = None, - ) -> StageResult: - """阶段1:元数据过滤""" - start_time = time.time() - - try: - from .metadata_index import IndexQuery - - query_plan = context.get("query_plan") - - memory_types = self._extract_memory_types_from_context(context) - keywords = self._extract_keywords_from_query(query, query_plan) - subjects = ( - query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None - ) - - index_query = IndexQuery( - user_ids=None, - memory_types=memory_types, - subjects=subjects, - keywords=keywords, - limit=self.config.metadata_filter_limit, - sort_by="last_accessed", - sort_order="desc", - ) - - # 执行查询 - result = await metadata_index.query_memories(index_query) - result_ids = list(result.memory_ids) - filtered_count = max(0, len(all_memories_cache) - len(result_ids)) - details: list[dict[str, Any]] = [] - - # 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆 - if not result_ids: - sorted_ids = sorted( - (memory.memory_id for memory in all_memories_cache.values()), - key=lambda mid: all_memories_cache[mid].metadata.last_accessed if mid in all_memories_cache else 0, - reverse=True, - ) - if memory_types: - type_filtered = [mid for mid in sorted_ids if all_memories_cache[mid].memory_type in memory_types] - sorted_ids = type_filtered or sorted_ids - if subjects: - subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()] - if subject_candidates: - subject_filtered = [ - mid - for mid in sorted_ids - if any( - subj.strip().lower() in subject_candidates for subj in all_memories_cache[mid].subjects - ) - ] - sorted_ids = subject_filtered or sorted_ids - - if keywords: - keyword_pool = {kw.lower() for kw in keywords if isinstance(kw, str) and kw.strip()} - if keyword_pool: - keyword_filtered = [] - for mid in sorted_ids: - memory_text = ( - (all_memories_cache[mid].display or "") - + "\n" - + (all_memories_cache[mid].text_content or "") - ).lower() - if any(kw in memory_text for kw in keyword_pool): - keyword_filtered.append(mid) - sorted_ids = keyword_filtered or sorted_ids - - result_ids = sorted_ids[: self.config.metadata_filter_limit] - filtered_count = max(0, len(all_memories_cache) - len(result_ids)) - logger.debug( - "元数据过滤未命中索引,使用近似回退: types=%s, subjects=%s, keywords=%s", - bool(memory_types), - bool(subjects), - bool(keywords), - ) - details.append( - { - "note": "fallback_recent", - "requested_types": [mt.value for mt in memory_types] if memory_types else [], - "subjects": subjects or [], - "keywords": keywords or [], - } - ) - - logger.debug( - "元数据过滤:候选=%d, 返回=%d", - len(all_memories_cache), - len(result_ids), - ) - - for memory_id in result_ids[:20]: - detail_entry = { - "memory_id": memory_id, - "status": "candidate", - } - details.append(detail_entry) - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("metadata_stage", {}) - stage_entry["status"] = "candidate" - - return StageResult( - stage=RetrievalStage.METADATA_FILTERING, - memory_ids=result_ids, - processing_time=time.time() - start_time, - filtered_count=filtered_count, - score_threshold=0.0, - details=details, - ) - - except Exception as e: - logger.error(f"元数据过滤阶段失败: {e}") - return StageResult( - stage=RetrievalStage.METADATA_FILTERING, - memory_ids=[], - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=0.0, - details=[{"error": str(e)}], - ) - - async def _vector_search_stage( - self, - query: str, - user_id: str, - context: dict[str, Any], - vector_storage, - candidate_ids: set[str], - all_memories_cache: dict[str, MemoryChunk], - *, - debug_log: dict[str, dict[str, Any]] | None = None, - ) -> StageResult: - """阶段2:向量搜索""" - start_time = time.time() - - try: - # 生成查询向量 - query_embedding = await self._generate_query_embedding(query, context, vector_storage) - - if not query_embedding: - logger.warning("向量搜索阶段:查询向量生成失败") - return StageResult( - stage=RetrievalStage.VECTOR_SEARCH, - memory_ids=[], - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=self.config.vector_similarity_threshold, - details=[{"note": "query_embedding_unavailable"}], - ) - - # 执行向量搜索 - search_result = await vector_storage.search_similar_memories( - query_vector=query_embedding, limit=self.config.vector_search_limit - ) - - if not search_result: - logger.warning("向量搜索阶段:搜索返回空结果,尝试回退到文本匹配") - # 向量搜索失败时的回退策略 - return self._create_text_search_fallback(candidate_ids, all_memories_cache, query, start_time) - - candidate_pool = candidate_ids or set(all_memories_cache.keys()) - - # 过滤候选记忆 - filtered_memories = [] - details: list[dict[str, Any]] = [] - raw_details: list[dict[str, Any]] = [] - threshold = self.config.vector_similarity_threshold - - for memory_id, similarity in search_result: - in_metadata_candidates = memory_id in candidate_pool - above_threshold = similarity >= threshold - if in_metadata_candidates and above_threshold: - filtered_memories.append((memory_id, similarity)) - - raw_details.append( - { - "memory_id": memory_id, - "similarity": similarity, - "in_metadata": in_metadata_candidates, - "above_threshold": above_threshold, - } - ) - - # 按相似度排序 - filtered_memories.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in filtered_memories[: self.config.vector_search_limit]] - kept_ids = set(result_ids) - - for entry in raw_details: - memory_id = entry["memory_id"] - similarity = entry["similarity"] - in_metadata = entry["in_metadata"] - above_threshold = entry["above_threshold"] - - status = "kept" - reason = None - if not in_metadata: - status = "excluded" - reason = "not_in_metadata_candidates" - elif not above_threshold: - status = "excluded" - reason = "below_threshold" - elif memory_id not in kept_ids: - status = "excluded" - reason = "limit_pruned" - - detail_entry = { - "memory_id": memory_id, - "similarity": round(similarity, 4), - "status": status, - "reason": reason, - } - details.append(detail_entry) - - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("vector_stage", {}) - stage_entry["similarity"] = round(similarity, 4) - stage_entry["status"] = status - if reason: - stage_entry["reason"] = reason - - filtered_count = max(0, len(candidate_pool) - len(result_ids)) - - logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆") - - return StageResult( - stage=RetrievalStage.VECTOR_SEARCH, - memory_ids=result_ids, - processing_time=time.time() - start_time, - filtered_count=filtered_count, - score_threshold=self.config.vector_similarity_threshold, - details=details, - ) - - except Exception as e: - logger.error(f"向量搜索阶段失败: {e}") - return StageResult( - stage=RetrievalStage.VECTOR_SEARCH, - memory_ids=[], - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=self.config.vector_similarity_threshold, - details=[{"error": str(e)}], - ) - - def _create_text_search_fallback( - self, candidate_ids: set[str], all_memories_cache: dict[str, MemoryChunk], query_text: str, start_time: float - ) -> StageResult: - """当向量搜索失败时,使用文本搜索作为回退策略""" - try: - query_lower = query_text.lower() - query_words = set(query_lower.split()) - - text_matches = [] - for memory_id in candidate_ids: - if memory_id not in all_memories_cache: - continue - - memory = all_memories_cache[memory_id] - memory_text = (memory.display or memory.text_content or "").lower() - - # 简单的文本匹配评分 - word_matches = sum(1 for word in query_words if word in memory_text) - if word_matches > 0: - score = word_matches / len(query_words) - text_matches.append((memory_id, score)) - - # 按匹配度排序 - text_matches.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in text_matches[: self.config.vector_search_limit]] - - details = [] - for memory_id, score in text_matches[: self.config.vector_search_limit]: - details.append( - {"memory_id": memory_id, "text_match_score": round(score, 4), "status": "text_match_fallback"} - ) - - logger.debug(f"向量搜索回退到文本匹配:找到 {len(result_ids)} 条匹配记忆") - - return StageResult( - stage=RetrievalStage.VECTOR_SEARCH, - memory_ids=result_ids, - processing_time=time.time() - start_time, - filtered_count=len(candidate_ids) - len(result_ids), - score_threshold=0.0, # 文本匹配无严格阈值 - details=details, - ) - - except Exception as e: - logger.error(f"文本搜索回退失败: {e}") - return StageResult( - stage=RetrievalStage.VECTOR_SEARCH, - memory_ids=list(candidate_ids)[: self.config.vector_search_limit], - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=0.0, - details=[{"error": str(e), "note": "text_fallback_failed"}], - ) - - async def _semantic_reranking_stage( - self, - query: str, - user_id: str, - context: dict[str, Any], - candidate_ids: set[str], - all_memories_cache: dict[str, MemoryChunk], - *, - debug_log: dict[str, dict[str, Any]] | None = None, - ) -> StageResult: - """阶段3:语义重排序""" - start_time = time.time() - - try: - reranked_memories = [] - details: list[dict[str, Any]] = [] - threshold = self.config.semantic_similarity_threshold - - for memory_id in candidate_ids: - if memory_id not in all_memories_cache: - continue - - memory = all_memories_cache[memory_id] - - # 计算综合语义相似度 - semantic_score = await self._calculate_semantic_similarity(query, memory, context) - - if semantic_score >= threshold: - reranked_memories.append((memory_id, semantic_score)) - - status = "kept" if semantic_score >= threshold else "excluded" - reason = None if status == "kept" else "below_threshold" - - detail_entry = { - "memory_id": memory_id, - "score": round(semantic_score, 4), - "status": status, - "reason": reason, - } - details.append(detail_entry) - - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("semantic_stage", {}) - stage_entry["score"] = round(semantic_score, 4) - stage_entry["status"] = status - if reason: - stage_entry["reason"] = reason - - # 按语义相似度排序 - reranked_memories.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in reranked_memories[: self.config.semantic_rerank_limit]] - kept_ids = set(result_ids) - - filtered_count = len(candidate_ids) - len(result_ids) - - for detail in details: - if detail["status"] == "kept" and detail["memory_id"] not in kept_ids: - detail["status"] = "excluded" - detail["reason"] = "limit_pruned" - if debug_log is not None: - stage_entry = debug_log.setdefault(detail["memory_id"], {}).setdefault("semantic_stage", {}) - stage_entry["status"] = "excluded" - stage_entry["reason"] = "limit_pruned" - - logger.debug(f"语义重排序:{len(candidate_ids)} -> {len(result_ids)} 条记忆") - - return StageResult( - stage=RetrievalStage.SEMANTIC_RERANKING, - memory_ids=result_ids, - processing_time=time.time() - start_time, - filtered_count=filtered_count, - score_threshold=self.config.semantic_similarity_threshold, - details=details, - ) - - except Exception as e: - logger.error(f"语义重排序阶段失败: {e}") - return StageResult( - stage=RetrievalStage.SEMANTIC_RERANKING, - memory_ids=list(candidate_ids), # 失败时返回原候选集 - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=self.config.semantic_similarity_threshold, - details=[{"error": str(e)}], - ) - - async def _contextual_filtering_stage( - self, - query: str, - user_id: str, - context: dict[str, Any], - candidate_ids: list[str], - all_memories_cache: dict[str, MemoryChunk], - limit: int, - *, - debug_log: dict[str, dict[str, Any]] | None = None, - ) -> StageResult: - """阶段4:上下文过滤""" - start_time = time.time() - - try: - final_memories = [] - details: list[dict[str, Any]] = [] - - for memory_id in candidate_ids: - if memory_id not in all_memories_cache: - continue - - memory = all_memories_cache[memory_id] - - # 计算上下文相关度评分 - context_score = await self._calculate_context_relevance(query, memory, context) - - # 结合多因子评分 - final_score = await self._calculate_final_score(query, memory, context, context_score) - - final_memories.append((memory_id, final_score)) - - detail_entry = { - "memory_id": memory_id, - "context_score": round(context_score, 4), - "final_score": round(final_score, 4), - "status": "candidate", - } - details.append(detail_entry) - - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("context_stage", {}) - stage_entry["context_score"] = round(context_score, 4) - stage_entry["final_score"] = round(final_score, 4) - - # 按最终评分排序 - final_memories.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in final_memories[:limit]] - kept_ids = set(result_ids) - - for detail in details: - memory_id = detail["memory_id"] - if memory_id in kept_ids: - detail["status"] = "final" - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("context_stage", {}) - stage_entry["status"] = "final" - else: - detail["status"] = "excluded" - detail["reason"] = "ranked_out" - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("context_stage", {}) - stage_entry["status"] = "excluded" - stage_entry["reason"] = "ranked_out" - - filtered_count = len(candidate_ids) - len(result_ids) - - logger.debug(f"上下文过滤:{len(candidate_ids)} -> {len(result_ids)} 条记忆") - - return StageResult( - stage=RetrievalStage.CONTEXTUAL_FILTERING, - memory_ids=result_ids, - processing_time=time.time() - start_time, - filtered_count=filtered_count, - score_threshold=0.0, # 动态阈值 - details=details, - ) - - except Exception as e: - logger.error(f"上下文过滤阶段失败: {e}") - return StageResult( - stage=RetrievalStage.CONTEXTUAL_FILTERING, - memory_ids=candidate_ids[:limit], # 失败时返回前limit个 - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=0.0, - details=[{"error": str(e)}], - ) - - async def _fallback_retrieval_stage( - self, - query: str, - user_id: str, - context: dict[str, Any], - all_memories_cache: dict[str, MemoryChunk], - limit: int, - *, - excluded_ids: set[str] | None = None, - debug_log: dict[str, dict[str, Any]] | None = None, - ) -> StageResult: - """回退检索阶段 - 当主检索失败时使用更宽松的策略""" - start_time = time.time() - - try: - excluded_ids = excluded_ids or set() - fallback_candidates = [] - - # 策略1:基于关键词的简单匹配 - query_lower = query.lower() - query_words = set(query_lower.split()) - - for memory_id, memory in all_memories_cache.items(): - if memory_id in excluded_ids: - continue - - memory_text = (memory.display or memory.text_content or "").lower() - - # 简单的关键词匹配 - word_matches = sum(1 for word in query_words if word in memory_text) - if word_matches > 0: - score = word_matches / len(query_words) - fallback_candidates.append((memory_id, score)) - - # 策略2:如果没有关键词匹配,使用时序最近的原则 - if not fallback_candidates: - logger.debug("关键词匹配无结果,使用时序最近策略") - recent_memories = sorted( - [ - (mid, mem.metadata.last_accessed or mem.metadata.created_at) - for mid, mem in all_memories_cache.items() - if mid not in excluded_ids - ], - key=lambda x: x[1], - reverse=True, - ) - fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[: limit * 2]] - - # 按分数排序 - fallback_candidates.sort(key=lambda x: x[1], reverse=True) - result_ids = [memory_id for memory_id, _ in fallback_candidates[:limit]] - - # 记录调试信息 - details = [] - for memory_id, score in fallback_candidates[:limit]: - detail_entry = { - "memory_id": memory_id, - "fallback_score": round(score, 4), - "status": "fallback_candidate", - } - details.append(detail_entry) - - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("fallback_stage", {}) - stage_entry["score"] = round(score, 4) - stage_entry["status"] = "fallback_candidate" - - filtered_count = len(all_memories_cache) - len(result_ids) - - logger.debug(f"回退检索完成:返回 {len(result_ids)} 条记忆") - - return StageResult( - stage=RetrievalStage.CONTEXTUAL_FILTERING, # 复用现有枚举 - memory_ids=result_ids, - processing_time=time.time() - start_time, - filtered_count=filtered_count, - score_threshold=0.0, # 回退机制无阈值 - details=details, - ) - - except Exception as e: - logger.error(f"回退检索阶段失败: {e}") - return StageResult( - stage=RetrievalStage.CONTEXTUAL_FILTERING, - memory_ids=[], - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=0.0, - details=[{"error": str(e)}], - ) - - async def _generate_query_embedding( - self, query: str, context: dict[str, Any], vector_storage - ) -> list[float] | None: - """生成查询向量""" - try: - query_plan = context.get("query_plan") - query_text = query - if query_plan and getattr(query_plan, "semantic_query", None): - query_text = query_plan.semantic_query - - if not query_text: - logger.debug("查询文本为空,无法生成查询向量") - return None - - if not hasattr(vector_storage, "generate_query_embedding"): - logger.warning("向量存储对象缺少 generate_query_embedding 方法") - return None - - logger.debug(f"正在生成查询向量,文本: '{query_text[:100]}'") - embedding = await vector_storage.generate_query_embedding(query_text) - - if embedding is None: - logger.warning("向量存储返回空的查询向量") - return None - - if len(embedding) == 0: - logger.warning("向量存储返回空列表作为查询向量") - return None - - logger.debug(f"查询向量生成成功,维度: {len(embedding)}") - return embedding - - except Exception as e: - logger.error(f"生成查询向量时发生异常: {e}", exc_info=True) - return None - - async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float: - """计算语义相似度 - 简化优化版本,提升召回率""" - try: - query_plan = context.get("query_plan") - query_text = query - if query_plan and getattr(query_plan, "semantic_query", None): - query_text = query_plan.semantic_query - - # 预处理:清理和标准化文本 - memory_text = (memory.display or memory.text_content or "").strip() - query_text = query_text.strip() - - if not query_text or not memory_text: - return 0.0 - - # 创建小写版本用于匹配 - query_lower = query_text.lower() - memory_lower = memory_text.lower() - - # 核心匹配策略1:精确子串匹配(最重要) - exact_score = 0.0 - if query_text in memory_text: - exact_score = 1.0 - elif query_lower in memory_lower: - exact_score = 0.9 - elif any(word in memory_lower for word in query_lower.split() if len(word) > 1): - exact_score = 0.4 - - # 核心匹配策略2:词汇匹配 - word_score = 0.0 - try: - import re - - import jieba - - # 分词处理 - query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text) - memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text) - - # 清理和标准化 - query_words = [w.strip().lower() for w in query_words if w.strip() and len(w.strip()) > 1] - memory_words = [w.strip().lower() for w in memory_words if w.strip() and len(w.strip()) > 1] - - if query_words and memory_words: - query_set = set(query_words) - memory_set = set(memory_words) - - # 精确匹配 - exact_matches = query_set & memory_set - exact_ratio = len(exact_matches) / len(query_set) if query_set else 0 - - # 部分匹配(包含关系) - partial_matches = 0 - for q_word in query_set: - if any(q_word in m_word or m_word in q_word for m_word in memory_set if len(q_word) >= 2): - partial_matches += 1 - - partial_ratio = partial_matches / len(query_set) if query_set else 0 - word_score = exact_ratio * 0.8 + partial_ratio * 0.3 - - except ImportError: - # 如果jieba不可用,使用简单分词 - import re - - query_words = re.findall(r"[\w\u4e00-\u9fa5]+", query_lower) - memory_words = re.findall(r"[\w\u4e00-\u9fa5]+", memory_lower) - - if query_words and memory_words: - query_set = set(w for w in query_words if len(w) > 1) - memory_set = set(w for w in memory_words if len(w) > 1) - - if query_set: - intersection = query_set & memory_set - word_score = len(intersection) / len(query_set) - - # 核心匹配策略3:语义概念匹配 - concept_score = 0.0 - concept_groups = { - "饮食": ["吃", "饭", "菜", "餐", "饿", "饱", "食", "dinner", "eat", "food", "meal"], - "天气": ["天气", "阳光", "雨", "晴", "阴", "温度", "weather", "sunny", "rain"], - "编程": ["编程", "代码", "程序", "开发", "语言", "programming", "code", "develop", "python"], - "时间": ["今天", "昨天", "明天", "现在", "时间", "today", "yesterday", "tomorrow", "time"], - "情感": ["好", "坏", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"], - } - - query_concepts = { - concept - for concept, keywords in concept_groups.items() - if any(keyword in query_lower for keyword in keywords) - } - memory_concepts = { - concept - for concept, keywords in concept_groups.items() - if any(keyword in memory_lower for keyword in keywords) - } - - if query_concepts and memory_concepts: - concept_overlap = query_concepts & memory_concepts - concept_score = len(concept_overlap) / len(query_concepts) * 0.5 - - # 核心匹配策略4:查询计划增强 - plan_bonus = 0.0 - if query_plan: - # 主体匹配 - if hasattr(query_plan, "subjects") and query_plan.subjects: - for subject in query_plan.subjects: - if subject.lower() in memory_lower: - plan_bonus += 0.15 - - # 对象匹配 - if hasattr(query_plan, "objects") and query_plan.objects: - for obj in query_plan.objects: - if obj.lower() in memory_lower: - plan_bonus += 0.1 - - # 记忆类型匹配 - if hasattr(query_plan, "memory_types") and query_plan.memory_types: - if memory.memory_type in query_plan.memory_types: - plan_bonus += 0.1 - - # 综合评分计算 - 简化权重分配 - if exact_score >= 0.9: - # 精确匹配为主 - final_score = exact_score * 0.6 + word_score * 0.2 + concept_score + plan_bonus - else: - # 综合评分 - final_score = exact_score * 0.3 + word_score * 0.3 + concept_score + plan_bonus - - # 基础分数保障:避免过低分数 - if final_score > 0: - if exact_score > 0 or word_score > 0.1: - final_score = max(final_score, 0.1) # 有实际匹配的最小分数 - else: - final_score = max(final_score, 0.05) # 仅概念匹配的最小分数 - - # 确保分数在合理范围 - final_score = min(1.0, max(0.0, final_score)) - - return final_score - - except Exception as e: - logger.warning(f"计算语义相似度失败: {e}") - return 0.0 - - async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float: - """计算上下文相关度""" - try: - score = 0.0 - - query_plan = context.get("query_plan") - - # 检查记忆类型是否匹配上下文 - if context.get("expected_memory_types"): - if memory.memory_type in context["expected_memory_types"]: - score += 0.3 - elif query_plan and getattr(query_plan, "memory_types", None): - if memory.memory_type in query_plan.memory_types: - score += 0.3 - - # 检查关键词匹配 - if context.get("keywords"): - memory_keywords = set(memory.keywords) - context_keywords = set(context["keywords"]) - overlap = memory_keywords & context_keywords - if overlap: - score += len(overlap) / max(len(context_keywords), 1) * 0.4 - - if query_plan: - # 主体匹配 - subject_score = self._calculate_subject_overlap(memory, getattr(query_plan, "subject_includes", [])) - score += subject_score * 0.3 - - # 对象/描述匹配 - object_keywords = getattr(query_plan, "object_includes", []) or [] - if object_keywords: - display_text = (memory.display or memory.text_content or "").lower() - hits = sum( - 1 - for kw in object_keywords - if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text - ) - if hits: - score += min(0.3, hits * 0.1) - - optional_keywords = getattr(query_plan, "optional_keywords", []) or [] - if optional_keywords: - display_text = (memory.display or memory.text_content or "").lower() - hits = sum( - 1 - for kw in optional_keywords - if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text - ) - if hits: - score += min(0.2, hits * 0.05) - - # 时间偏好 - recency_pref = getattr(query_plan, "recency_preference", "") - if recency_pref: - memory_age = time.time() - memory.metadata.created_at - if recency_pref == "recent" and memory_age < 7 * 24 * 3600: - score += 0.2 - elif recency_pref == "historical" and memory_age > 30 * 24 * 3600: - score += 0.1 - - # 检查时效性 - if context.get("recent_only", False): - memory_age = time.time() - memory.metadata.created_at - if memory_age < 7 * 24 * 3600: # 7天内 - score += 0.3 - - return min(score, 1.0) - - except Exception as e: - logger.warning(f"计算上下文相关度失败: {e}") - return 0.0 - - async def _calculate_final_score( - self, query: str, memory: MemoryChunk, context: dict[str, Any], context_score: float - ) -> float: - """计算最终评分""" - try: - query_plan = context.get("query_plan") - - # 语义相似度 - semantic_score = await self._calculate_semantic_similarity(query, memory, context) - - # 向量相似度(如果有) - vector_score = 0.0 - if memory.embedding: - # 这里应该有向量相似度计算,简化处理 - vector_score = 0.5 - - # 时效性评分 - recency_score = self._calculate_recency_score(memory.metadata.created_at) - if query_plan: - recency_pref = getattr(query_plan, "recency_preference", "") - if recency_pref == "recent": - recency_score = max(recency_score, 0.8) - elif recency_pref == "historical": - recency_score = min(recency_score, 0.5) - - # 权重组合 - vector_weight = self.config.vector_weight - semantic_weight = self.config.semantic_weight - context_weight = self.config.context_weight - recency_weight = self.config.recency_weight - - if query_plan and getattr(query_plan, "emphasis", None) == "precision": - semantic_weight += 0.05 - elif query_plan and getattr(query_plan, "emphasis", None) == "recall": - context_weight += 0.05 - - final_score = ( - semantic_score * semantic_weight - + vector_score * vector_weight - + context_score * context_weight - + recency_score * recency_weight - ) - - # 加入记忆重要性权重 - importance_weight = memory.metadata.importance.value / 4.0 # 标准化到0-1 - final_score = final_score * (0.7 + importance_weight * 0.3) # 重要性影响30% - - return final_score - - except Exception as e: - logger.warning(f"计算最终评分失败: {e}") - return 0.0 - - def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: list[str] | None) -> float: - if not required_subjects: - return 0.0 - - memory_subjects = {subject.lower() for subject in memory.subjects if isinstance(subject, str)} - if not memory_subjects: - return 0.0 - - hit = 0 - total = 0 - for subject in required_subjects: - if not isinstance(subject, str): - continue - total += 1 - normalized = subject.strip().lower() - if not normalized: - continue - if any(normalized in mem_subject for mem_subject in memory_subjects): - hit += 1 - - if total == 0: - return 0.0 - - return hit / total - - def _calculate_recency_score(self, timestamp: float) -> float: - """计算时效性评分""" - try: - age = time.time() - timestamp - age_days = age / (24 * 3600) - - if age_days < 1: - return 1.0 - elif age_days < 7: - return 0.8 - elif age_days < 30: - return 0.6 - elif age_days < 90: - return 0.4 - else: - return 0.2 - - except Exception: - return 0.5 - - def _extract_memory_types_from_context(self, context: dict[str, Any]) -> list[MemoryType]: - """从上下文中提取记忆类型""" - try: - query_plan = context.get("query_plan") - if query_plan and getattr(query_plan, "memory_types", None): - return query_plan.memory_types - - if "expected_memory_types" in context: - return context["expected_memory_types"] - - # 根据上下文推断记忆类型 - if "message_type" in context: - message_type = context["message_type"] - if message_type in ["personal_info", "fact"]: - return [MemoryType.PERSONAL_FACT] - elif message_type in ["event", "activity"]: - return [MemoryType.EVENT] - elif message_type in ["preference", "like"]: - return [MemoryType.PREFERENCE] - elif message_type in ["opinion", "view"]: - return [MemoryType.OPINION] - - return [] - - except Exception: - return [] - - def _extract_keywords_from_query(self, query: str, query_plan: Any | None = None) -> list[str]: - """从查询中提取关键词""" - try: - extracted: list[str] = [] - - if query_plan and getattr(query_plan, "required_keywords", None): - extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)]) - - # 简单的关键词提取 - words = query.lower().split() - # 过滤停用词 - stopwords = {"的", "是", "在", "有", "我", "你", "他", "她", "它", "这", "那", "了", "吗", "呢"} - extracted.extend(word for word in words if len(word) > 1 and word not in stopwords) - - # 去重并保留顺序 - seen = set() - deduplicated = [] - for word in extracted: - if word in seen or not word: - continue - seen.add(word) - deduplicated.append(word) - - return deduplicated[:10] - except Exception: - return [] - - def _update_retrieval_stats(self, total_time: float, stage_results: list[StageResult]): - """更新检索统计""" - self.retrieval_stats["total_queries"] += 1 - - # 更新平均检索时间 - current_avg = self.retrieval_stats["average_retrieval_time"] - total_queries = self.retrieval_stats["total_queries"] - new_avg = (current_avg * (total_queries - 1) + total_time) / total_queries - self.retrieval_stats["average_retrieval_time"] = new_avg - - # 更新各阶段统计 - for result in stage_results: - stage_name = result.stage.value - if stage_name in self.retrieval_stats["stage_stats"]: - stage_stat = self.retrieval_stats["stage_stats"][stage_name] - stage_stat["calls"] += 1 - - current_stage_avg = stage_stat["avg_time"] - new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat[ - "calls" - ] - stage_stat["avg_time"] = new_stage_avg - - def get_retrieval_stats(self) -> dict[str, Any]: - """获取检索统计信息""" - return self.retrieval_stats.copy() - - def reset_stats(self): - """重置统计信息""" - self.retrieval_stats = { - "total_queries": 0, - "average_retrieval_time": 0.0, - "stage_stats": { - "metadata_filtering": {"calls": 0, "avg_time": 0.0}, - "vector_search": {"calls": 0, "avg_time": 0.0}, - "semantic_reranking": {"calls": 0, "avg_time": 0.0}, - "contextual_filtering": {"calls": 0, "avg_time": 0.0}, - "enhanced_reranking": {"calls": 0, "avg_time": 0.0}, - }, - } - - async def _enhanced_reranking_stage( - self, - query: str, - user_id: str, - context: dict[str, Any], - candidate_ids: list[str], - all_memories_cache: dict[str, MemoryChunk], - limit: int, - *, - debug_log: dict[str, dict[str, Any]] | None = None, - ) -> StageResult: - """阶段5:增强重排序 - 使用多维度评分模型""" - start_time = time.time() - - try: - if not candidate_ids: - return StageResult( - stage=RetrievalStage.CONTEXTUAL_FILTERING, # 保持与原有枚举兼容 - memory_ids=[], - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=0.0, - details=[{"note": "no_candidates"}], - ) - - # 准备候选记忆数据 - candidate_memories = [] - for memory_id in candidate_ids: - memory = all_memories_cache.get(memory_id) - if memory: - # 使用原始向量相似度作为基础分数 - vector_similarity = 0.8 # 默认分数,实际应该从前面阶段传递 - candidate_memories.append((memory_id, memory, vector_similarity)) - - if not candidate_memories: - return StageResult( - stage=RetrievalStage.CONTEXTUAL_FILTERING, - memory_ids=[], - processing_time=time.time() - start_time, - filtered_count=len(candidate_ids), - score_threshold=0.0, - details=[{"note": "candidates_not_found_in_cache"}], - ) - - # 使用增强重排序器 - reranked_memories = self.reranker.rerank_memories( - query=query, candidate_memories=candidate_memories, context=context, limit=limit - ) - - # 提取重排序后的记忆ID - result_ids = [memory_id for memory_id, _, _ in reranked_memories] - - # 生成调试详情 - details = [] - for memory_id, memory, final_score in reranked_memories: - detail_entry = { - "memory_id": memory_id, - "final_score": round(final_score, 4), - "status": "reranked", - "memory_type": memory.memory_type.value, - "access_count": memory.metadata.access_count, - } - details.append(detail_entry) - - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {}) - stage_entry["final_score"] = round(final_score, 4) - stage_entry["status"] = "reranked" - stage_entry["rank"] = len(details) - - # 记录被过滤的记忆 - kept_ids = set(result_ids) - for memory_id in candidate_ids: - if memory_id not in kept_ids: - detail_entry = {"memory_id": memory_id, "status": "filtered_out", "reason": "ranked_below_limit"} - details.append(detail_entry) - - if debug_log is not None: - stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {}) - stage_entry["status"] = "filtered_out" - stage_entry["reason"] = "ranked_below_limit" - - filtered_count = len(candidate_ids) - len(result_ids) - - logger.debug(f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, 过滤={filtered_count}") - - return StageResult( - stage=RetrievalStage.CONTEXTUAL_FILTERING, # 保持与原有枚举兼容 - memory_ids=result_ids, - processing_time=time.time() - start_time, - filtered_count=filtered_count, - score_threshold=0.0, # 动态阈值,由重排序器决定 - details=details, - ) - - except Exception as e: - logger.error(f"增强重排序阶段失败: {e}", exc_info=True) - return StageResult( - stage=RetrievalStage.CONTEXTUAL_FILTERING, - memory_ids=candidate_ids[:limit], # 失败时返回前limit个 - processing_time=time.time() - start_time, - filtered_count=0, - score_threshold=0.0, - details=[{"error": str(e)}], - ) diff --git a/src/chat/memory_system/deprecated_backup/vector_storage.py b/src/chat/memory_system/deprecated_backup/vector_storage.py deleted file mode 100644 index d5d974486..000000000 --- a/src/chat/memory_system/deprecated_backup/vector_storage.py +++ /dev/null @@ -1,875 +0,0 @@ -""" -向量数据库存储接口 -为记忆系统提供高效的向量存储和语义搜索能力 -""" - -import asyncio -import threading -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -import numpy as np -import orjson - -from src.chat.memory_system.memory_chunk import MemoryChunk -from src.common.config_helpers import resolve_embedding_dimension -from src.common.logger import get_logger -from src.config.config import model_config -from src.llm_models.utils_model import LLMRequest - -logger = get_logger(__name__) - -# 尝试导入FAISS,如果不可用则使用简单替代 -try: - import faiss - - FAISS_AVAILABLE = True -except ImportError: - FAISS_AVAILABLE = False - logger.warning("FAISS not available, using simple vector storage") - - -@dataclass -class VectorStorageConfig: - """向量存储配置""" - - dimension: int = 1024 - similarity_threshold: float = 0.8 - index_type: str = "flat" # flat, ivf, hnsw - max_index_size: int = 100000 - storage_path: str = "data/memory_vectors" - auto_save_interval: int = 10 # 每N次操作自动保存 - enable_compression: bool = True - - -class VectorStorageManager: - """向量存储管理器""" - - def __init__(self, config: VectorStorageConfig | None = None): - self.config = config or VectorStorageConfig() - - resolved_dimension = resolve_embedding_dimension(self.config.dimension) - if resolved_dimension and resolved_dimension != self.config.dimension: - logger.info( - "向量存储维度调整: 使用嵌入模型配置的维度 %d (原始配置: %d)", - resolved_dimension, - self.config.dimension, - ) - self.config.dimension = resolved_dimension - self.storage_path = Path(self.config.storage_path) - self.storage_path.mkdir(parents=True, exist_ok=True) - - # 向量索引 - self.vector_index = None - self.memory_id_to_index = {} # memory_id -> vector index - self.index_to_memory_id = {} # vector index -> memory_id - - # 内存缓存 - self.memory_cache: dict[str, MemoryChunk] = {} - self.vector_cache: dict[str, list[float]] = {} - - # 统计信息 - self.storage_stats = { - "total_vectors": 0, - "index_build_time": 0.0, - "average_search_time": 0.0, - "cache_hit_rate": 0.0, - "total_searches": 0, - "cache_hits": 0, - } - - # 线程锁 - self._lock = threading.RLock() - self._operation_count = 0 - - # 初始化索引 - self._initialize_index() - - # 嵌入模型 - self.embedding_model: LLMRequest = None - - def _initialize_index(self): - """初始化向量索引""" - try: - if FAISS_AVAILABLE: - if self.config.index_type == "flat": - self.vector_index = faiss.IndexFlatIP(self.config.dimension) - elif self.config.index_type == "ivf": - quantizer = faiss.IndexFlatIP(self.config.dimension) - nlist = min(100, max(1, self.config.max_index_size // 1000)) - self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist) - elif self.config.index_type == "hnsw": - self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32) - self.vector_index.hnsw.efConstruction = 40 - else: - self.vector_index = faiss.IndexFlatIP(self.config.dimension) - else: - # 简单的向量存储实现 - self.vector_index = SimpleVectorIndex(self.config.dimension) - - logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}") - - except Exception as e: - logger.error(f"❌ 向量索引初始化失败: {e}") - # 回退到简单实现 - self.vector_index = SimpleVectorIndex(self.config.dimension) - - async def initialize_embedding_model(self): - """初始化嵌入模型""" - if self.embedding_model is None: - self.embedding_model = LLMRequest( - model_set=model_config.model_task_config.embedding, request_type="memory.embedding" - ) - logger.info("✅ 嵌入模型初始化完成") - - async def generate_query_embedding(self, query_text: str) -> list[float] | None: - """生成查询向量,用于记忆召回""" - if not query_text: - logger.warning("查询文本为空,无法生成向量") - return None - - try: - await self.initialize_embedding_model() - - logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'") - - embedding, _ = await self.embedding_model.get_embedding(query_text) - if not embedding: - logger.warning("嵌入模型返回空向量") - return None - - logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}") - - if len(embedding) != self.config.dimension: - logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding)) - return None - - normalized_vector = self._normalize_vector(embedding) - logger.debug(f"查询向量生成成功,向量范围: [{min(normalized_vector):.4f}, {max(normalized_vector):.4f}]") - return normalized_vector - - except Exception as exc: - logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True) - return None - - async def store_memories(self, memories: list[MemoryChunk]): - """存储记忆向量""" - if not memories: - return - - start_time = time.time() - - try: - # 确保嵌入模型已初始化 - await self.initialize_embedding_model() - - # 批量获取嵌入向量 - memory_texts = [] - - for memory in memories: - # 预先缓存记忆,确保后续流程可访问 - self.memory_cache[memory.memory_id] = memory - if memory.embedding is None: - # 如果没有嵌入向量,需要生成 - text = self._prepare_embedding_text(memory) - memory_texts.append((memory.memory_id, text)) - else: - # 已有嵌入向量,直接使用 - await self._add_single_memory(memory, memory.embedding) - - # 批量生成缺失的嵌入向量 - if memory_texts: - await self._batch_generate_and_store_embeddings(memory_texts) - - # 自动保存检查 - self._operation_count += len(memories) - if self._operation_count >= self.config.auto_save_interval: - await self.save_storage() - self._operation_count = 0 - - storage_time = time.time() - start_time - logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}秒") - - except Exception as e: - logger.error(f"❌ 向量存储失败: {e}", exc_info=True) - - def _prepare_embedding_text(self, memory: MemoryChunk) -> str: - """准备用于嵌入的文本,仅使用自然语言展示内容""" - display_text = (memory.display or "").strip() - if display_text: - return display_text - - fallback_text = (memory.text_content or "").strip() - if fallback_text: - return fallback_text - - subjects = "、".join(s.strip() for s in memory.subjects if s and isinstance(s, str)) - predicate = (memory.content.predicate or "").strip() - - obj = memory.content.object - if isinstance(obj, dict): - object_parts = [] - for key, value in obj.items(): - if value is None: - continue - if isinstance(value, (list, tuple)): - preview = "、".join(str(item) for item in value[:3]) - object_parts.append(f"{key}:{preview}") - else: - object_parts.append(f"{key}:{value}") - object_text = ", ".join(object_parts) - else: - object_text = str(obj or "").strip() - - composite_parts = [part for part in [subjects, predicate, object_text] if part] - if composite_parts: - return " ".join(composite_parts) - - logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id) - return memory.memory_id - - async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]): - """批量生成和存储嵌入向量""" - if not memory_texts: - return - - try: - texts = [text for _, text in memory_texts] - memory_ids = [memory_id for memory_id, _ in memory_texts] - - # 批量生成嵌入向量 - embeddings = await self._batch_generate_embeddings(memory_ids, texts) - - # 存储向量和记忆 - for memory_id, embedding in embeddings.items(): - if embedding and len(embedding) == self.config.dimension: - memory = self.memory_cache.get(memory_id) - if memory: - await self._add_single_memory(memory, embedding) - - except Exception as e: - logger.error(f"❌ 批量生成嵌入向量失败: {e}") - - async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]: - """批量生成嵌入向量""" - if not texts: - return {} - - results: dict[str, list[float]] = {} - - try: - semaphore = asyncio.Semaphore(min(4, max(1, len(texts)))) - - async def generate_embedding(memory_id: str, text: str) -> None: - async with semaphore: - try: - embedding, _ = await self.embedding_model.get_embedding(text) - if embedding and len(embedding) == self.config.dimension: - results[memory_id] = embedding - else: - logger.warning( - "嵌入向量维度不匹配: 期望 %d, 实际 %d (memory_id=%s)。请检查模型嵌入配置 model_config.model_task_config.embedding.embedding_dimension 或 LPMM 任务定义。", - self.config.dimension, - len(embedding) if embedding else 0, - memory_id, - ) - results[memory_id] = [] - except Exception as exc: - logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc) - results[memory_id] = [] - - tasks = [ - asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False) - ] - await asyncio.gather(*tasks, return_exceptions=True) - - except Exception as e: - logger.error(f"❌ 批量生成嵌入向量失败: {e}") - for memory_id in memory_ids: - results.setdefault(memory_id, []) - - return results - - async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]): - """添加单个记忆到向量存储""" - with self._lock: - try: - # 规范化向量 - if embedding: - embedding = self._normalize_vector(embedding) - - # 添加到缓存 - self.memory_cache[memory.memory_id] = memory - self.vector_cache[memory.memory_id] = embedding - - # 更新记忆的嵌入向量 - memory.set_embedding(embedding) - - # 添加到向量索引 - if hasattr(self.vector_index, "add"): - # FAISS索引 - if isinstance(embedding, np.ndarray): - vector_array = embedding.reshape(1, -1).astype("float32") - else: - vector_array = np.array([embedding], dtype="float32") - - # 特殊处理IVF索引 - if self.config.index_type == "ivf" and self.vector_index.ntotal == 0: - # IVF索引需要先训练 - logger.debug("训练IVF索引...") - self.vector_index.train(vector_array) - - self.vector_index.add(vector_array) - index_id = self.vector_index.ntotal - 1 - - else: - # 简单索引 - index_id = self.vector_index.add_vector(embedding) - - # 更新映射关系 - self.memory_id_to_index[memory.memory_id] = index_id - self.index_to_memory_id[index_id] = memory.memory_id - - # 更新统计 - self.storage_stats["total_vectors"] += 1 - - except Exception as e: - logger.error(f"❌ 添加记忆到向量存储失败: {e}") - - def _normalize_vector(self, vector: list[float]) -> list[float]: - """L2归一化向量""" - if not vector: - return vector - - try: - vector_array = np.array(vector, dtype=np.float32) - norm = np.linalg.norm(vector_array) - if norm == 0: - return vector - - normalized = vector_array / norm - return normalized.tolist() - - except Exception as e: - logger.warning(f"向量归一化失败: {e}") - return vector - - async def search_similar_memories( - self, - query_vector: list[float] | None = None, - *, - query_text: str | None = None, - limit: int = 10, - scope_id: str | None = None, - ) -> list[tuple[str, float]]: - """搜索相似记忆""" - start_time = time.time() - - try: - logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}") - - if query_vector is None: - if not query_text: - logger.warning("查询向量和查询文本都为空") - return [] - - query_vector = await self.generate_query_embedding(query_text) - if not query_vector: - logger.warning("查询向量生成失败") - return [] - - scope_filter: str | None = None - if isinstance(scope_id, str): - normalized_scope = scope_id.strip().lower() - if normalized_scope and normalized_scope not in {"global", "global_memory"}: - scope_filter = scope_id - elif scope_id: - scope_filter = str(scope_id) - - # 规范化查询向量 - query_vector = self._normalize_vector(query_vector) - - logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}") - - # 检查向量索引状态 - if not self.vector_index: - logger.error("向量索引未初始化") - return [] - - total_vectors = 0 - if hasattr(self.vector_index, "ntotal"): - total_vectors = self.vector_index.ntotal - elif hasattr(self.vector_index, "vectors"): - total_vectors = len(self.vector_index.vectors) - - logger.debug(f"向量索引中实际向量数: {total_vectors}") - - if total_vectors == 0: - logger.warning("向量索引为空,无法执行搜索") - return [] - - # 执行向量搜索 - with self._lock: - if hasattr(self.vector_index, "search"): - # FAISS索引 - if isinstance(query_vector, np.ndarray): - query_array = query_vector.reshape(1, -1).astype("float32") - else: - query_array = np.array([query_vector], dtype="float32") - - if self.config.index_type == "ivf" and self.vector_index.ntotal > 0: - # 设置IVF搜索参数 - nprobe = min(self.vector_index.nlist, 10) - self.vector_index.nprobe = nprobe - logger.debug(f"IVF搜索参数: nprobe={nprobe}") - - search_limit = min(limit, total_vectors) - logger.debug(f"执行FAISS搜索,搜索限制: {search_limit}") - - distances, indices = self.vector_index.search(query_array, search_limit) - distances = distances.flatten().tolist() - indices = indices.flatten().tolist() - - logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引") - else: - # 简单索引 - logger.debug("使用简单向量索引执行搜索") - results = self.vector_index.search(query_vector, limit) - distances = [score for _, score in results] - indices = [idx for idx, _ in results] - logger.debug(f"简单索引搜索结果: {len(results)} 个结果") - - # 处理搜索结果 - results = [] - valid_results = 0 - invalid_indices = 0 - filtered_by_scope = 0 - - for distance, index in zip(distances, indices, strict=False): - if index == -1: # FAISS的无效索引标记 - invalid_indices += 1 - continue - - memory_id = self.index_to_memory_id.get(index) - if not memory_id: - logger.debug(f"索引 {index} 没有对应的记忆ID") - invalid_indices += 1 - continue - - if scope_filter: - memory = self.memory_cache.get(memory_id) - if memory and str(memory.user_id) != scope_filter: - filtered_by_scope += 1 - continue - - similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内 - results.append((memory_id, similarity)) - valid_results += 1 - - logger.debug( - f"搜索结果处理: 总距离={len(distances)}, 有效结果={valid_results}, " - f"无效索引={invalid_indices}, 作用域过滤={filtered_by_scope}" - ) - - # 更新统计 - search_time = time.time() - start_time - self.storage_stats["total_searches"] += 1 - self.storage_stats["average_search_time"] = ( - self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time - ) / self.storage_stats["total_searches"] - - final_results = results[:limit] - logger.info( - f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' " - f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果" - ) - - return final_results - - except Exception as e: - logger.error(f"❌ 向量搜索失败: {e}", exc_info=True) - return [] - - async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None: - """根据ID获取记忆""" - # 先检查缓存 - if memory_id in self.memory_cache: - self.storage_stats["cache_hits"] += 1 - return self.memory_cache[memory_id] - - self.storage_stats["total_searches"] += 1 - return None - - async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]): - """更新记忆的嵌入向量""" - with self._lock: - try: - if memory_id not in self.memory_id_to_index: - logger.warning(f"记忆 {memory_id} 不存在于向量索引中") - return - - # 获取旧索引 - old_index = self.memory_id_to_index[memory_id] - - # 删除旧向量(如果支持) - if hasattr(self.vector_index, "remove_ids"): - try: - self.vector_index.remove_ids(np.array([old_index])) - except: - logger.warning("无法删除旧向量,将直接添加新向量") - - # 规范化新向量 - new_embedding = self._normalize_vector(new_embedding) - - # 添加新向量 - if hasattr(self.vector_index, "add"): - if isinstance(new_embedding, np.ndarray): - vector_array = new_embedding.reshape(1, -1).astype("float32") - else: - vector_array = np.array([new_embedding], dtype="float32") - - self.vector_index.add(vector_array) - new_index = self.vector_index.ntotal - 1 - else: - new_index = self.vector_index.add_vector(new_embedding) - - # 更新映射关系 - self.memory_id_to_index[memory_id] = new_index - self.index_to_memory_id[new_index] = memory_id - - # 更新缓存 - self.vector_cache[memory_id] = new_embedding - - # 更新记忆对象 - memory = self.memory_cache.get(memory_id) - if memory: - memory.set_embedding(new_embedding) - - logger.debug(f"更新记忆 {memory_id} 的嵌入向量") - - except Exception as e: - logger.error(f"❌ 更新记忆嵌入向量失败: {e}") - - async def delete_memory(self, memory_id: str): - """删除记忆""" - with self._lock: - try: - if memory_id not in self.memory_id_to_index: - return - - # 获取索引 - index = self.memory_id_to_index[memory_id] - - # 从向量索引中删除(如果支持) - if hasattr(self.vector_index, "remove_ids"): - try: - self.vector_index.remove_ids(np.array([index])) - except: - logger.warning("无法从向量索引中删除,仅从缓存中移除") - - # 删除映射关系 - del self.memory_id_to_index[memory_id] - if index in self.index_to_memory_id: - del self.index_to_memory_id[index] - - # 从缓存中删除 - self.memory_cache.pop(memory_id, None) - self.vector_cache.pop(memory_id, None) - - # 更新统计 - self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1) - - logger.debug(f"删除记忆 {memory_id}") - - except Exception as e: - logger.error(f"❌ 删除记忆失败: {e}") - - async def save_storage(self): - """保存向量存储到文件""" - try: - logger.info("正在保存向量存储...") - - # 保存记忆缓存 - cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()} - - cache_file = self.storage_path / "memory_cache.json" - with open(cache_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8")) - - # 保存向量缓存 - vector_cache_file = self.storage_path / "vector_cache.json" - with open(vector_cache_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8")) - - # 保存映射关系 - mapping_file = self.storage_path / "id_mapping.json" - mapping_data = { - "memory_id_to_index": { - str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items() - }, - "index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()}, - } - with open(mapping_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8")) - - # 保存FAISS索引(如果可用) - if FAISS_AVAILABLE and hasattr(self.vector_index, "save"): - index_file = self.storage_path / "vector_index.faiss" - faiss.write_index(self.vector_index, str(index_file)) - - # 保存统计信息 - stats_file = self.storage_path / "storage_stats.json" - with open(stats_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8")) - - logger.info("✅ 向量存储保存完成") - - except Exception as e: - logger.error(f"❌ 保存向量存储失败: {e}") - - async def load_storage(self): - """从文件加载向量存储""" - try: - logger.info("正在加载向量存储...") - - # 加载记忆缓存 - cache_file = self.storage_path / "memory_cache.json" - if cache_file.exists(): - with open(cache_file, encoding="utf-8") as f: - cache_data = orjson.loads(f.read()) - - self.memory_cache = { - memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items() - } - - # 加载向量缓存 - vector_cache_file = self.storage_path / "vector_cache.json" - if vector_cache_file.exists(): - with open(vector_cache_file, encoding="utf-8") as f: - self.vector_cache = orjson.loads(f.read()) - - # 加载映射关系 - mapping_file = self.storage_path / "id_mapping.json" - if mapping_file.exists(): - with open(mapping_file, encoding="utf-8") as f: - mapping_data = orjson.loads(f.read()) - raw_memory_to_index = mapping_data.get("memory_id_to_index", {}) - self.memory_id_to_index = { - str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items() - } - - raw_index_to_memory = mapping_data.get("index_to_memory_id", {}) - self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()} - - # 加载FAISS索引(如果可用) - index_loaded = False - if FAISS_AVAILABLE: - index_file = self.storage_path / "vector_index.faiss" - if index_file.exists(): - try: - loaded_index = faiss.read_index(str(index_file)) - # 如果索引类型匹配,则替换 - if type(loaded_index) == type(self.vector_index): - self.vector_index = loaded_index - index_loaded = True - logger.info("✅ FAISS索引文件加载完成") - else: - logger.warning("索引类型不匹配,重新构建索引") - except Exception as e: - logger.warning(f"加载FAISS索引失败: {e},重新构建") - else: - logger.info("FAISS索引文件不存在,将重新构建") - - # 如果索引没有成功加载且有向量数据,则重建索引 - if not index_loaded and self.vector_cache: - logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引") - await self._rebuild_index() - - # 加载统计信息 - stats_file = self.storage_path / "storage_stats.json" - if stats_file.exists(): - with open(stats_file, encoding="utf-8") as f: - self.storage_stats = orjson.loads(f.read()) - - # 更新向量计数 - self.storage_stats["total_vectors"] = len(self.memory_id_to_index) - - logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量") - - except Exception as e: - logger.error(f"❌ 加载向量存储失败: {e}") - - async def _rebuild_index(self): - """重建向量索引""" - try: - logger.info(f"正在重建向量索引...向量数量: {len(self.vector_cache)}") - - # 重新初始化索引 - self._initialize_index() - - # 清空映射关系 - self.memory_id_to_index.clear() - self.index_to_memory_id.clear() - - if not self.vector_cache: - logger.warning("没有向量缓存数据,跳过重建") - return - - # 准备向量数据 - memory_ids = [] - vectors = [] - - for memory_id, embedding in self.vector_cache.items(): - if embedding and len(embedding) == self.config.dimension: - memory_ids.append(memory_id) - vectors.append(self._normalize_vector(embedding)) - else: - logger.debug(f"跳过无效向量: {memory_id}, 维度: {len(embedding) if embedding else 0}") - - if not vectors: - logger.warning("没有有效的向量数据") - return - - logger.info(f"准备重建 {len(vectors)} 个向量到索引") - - # 批量添加向量到FAISS索引 - if hasattr(self.vector_index, "add"): - # FAISS索引 - vector_array = np.array(vectors, dtype="float32") - - # 特殊处理IVF索引 - if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"): - logger.info("训练IVF索引...") - self.vector_index.train(vector_array) - - # 添加向量 - self.vector_index.add(vector_array) - - # 重建映射关系 - for i, memory_id in enumerate(memory_ids): - self.memory_id_to_index[memory_id] = i - self.index_to_memory_id[i] = memory_id - - else: - # 简单索引 - for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)): - index_id = self.vector_index.add_vector(vector) - self.memory_id_to_index[memory_id] = index_id - self.index_to_memory_id[index_id] = memory_id - - # 更新统计 - self.storage_stats["total_vectors"] = len(self.memory_id_to_index) - - final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index)) - logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}") - - except Exception as e: - logger.error(f"❌ 重建向量索引失败: {e}", exc_info=True) - - async def optimize_storage(self): - """优化存储""" - try: - logger.info("开始向量存储优化...") - - # 清理无效引用 - self._cleanup_invalid_references() - - # 重新构建索引(如果碎片化严重) - if self.storage_stats["total_vectors"] > 1000: - await self._rebuild_index() - - # 更新缓存命中率 - if self.storage_stats["total_searches"] > 0: - self.storage_stats["cache_hit_rate"] = ( - self.storage_stats["cache_hits"] / self.storage_stats["total_searches"] - ) - - logger.info("✅ 向量存储优化完成") - - except Exception as e: - logger.error(f"❌ 向量存储优化失败: {e}") - - def _cleanup_invalid_references(self): - """清理无效引用""" - with self._lock: - # 清理无效的memory_id到index的映射 - valid_memory_ids = set(self.memory_cache.keys()) - invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids - - for memory_id in invalid_memory_ids: - index = self.memory_id_to_index[memory_id] - del self.memory_id_to_index[memory_id] - if index in self.index_to_memory_id: - del self.index_to_memory_id[index] - - if invalid_memory_ids: - logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用") - - def get_storage_stats(self) -> dict[str, Any]: - """获取存储统计信息""" - stats = self.storage_stats.copy() - if stats["total_searches"] > 0: - stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"] - else: - stats["cache_hit_rate"] = 0.0 - return stats - - -class SimpleVectorIndex: - """简单的向量索引实现(当FAISS不可用时的替代方案)""" - - def __init__(self, dimension: int): - self.dimension = dimension - self.vectors: list[list[float]] = [] - self.vector_ids: list[int] = [] - self.next_id = 0 - - def add_vector(self, vector: list[float]) -> int: - """添加向量""" - if len(vector) != self.dimension: - raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}") - - vector_id = self.next_id - self.vectors.append(vector.copy()) - self.vector_ids.append(vector_id) - self.next_id += 1 - - return vector_id - - def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]: - """搜索相似向量""" - if len(query_vector) != self.dimension: - raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}") - - results = [] - - for i, vector in enumerate(self.vectors): - similarity = self._calculate_cosine_similarity(query_vector, vector) - results.append((self.vector_ids[i], similarity)) - - # 按相似度排序 - results.sort(key=lambda x: x[1], reverse=True) - - return results[:limit] - - def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float: - """计算余弦相似度""" - try: - dot_product = sum(x * y for x, y in zip(v1, v2, strict=False)) - norm1 = sum(x * x for x in v1) ** 0.5 - norm2 = sum(x * x for x in v2) ** 0.5 - - if norm1 == 0 or norm2 == 0: - return 0.0 - - return dot_product / (norm1 * norm2) - - except Exception: - return 0.0 - - @property - def ntotal(self) -> int: - """向量总数""" - return len(self.vectors) diff --git a/src/chat/message_manager/adaptive_stream_manager.py b/src/chat/message_manager/adaptive_stream_manager.py new file mode 100644 index 000000000..c48b2d300 --- /dev/null +++ b/src/chat/message_manager/adaptive_stream_manager.py @@ -0,0 +1,489 @@ +""" +自适应流管理器 - 动态并发限制和异步流池管理 +根据系统负载和流优先级动态调整并发限制 +""" + +import asyncio +import psutil +import time +from typing import Dict, List, Optional, Set, Tuple +from dataclasses import dataclass, field +from enum import Enum + +from src.common.logger import get_logger +from src.chat.message_receive.chat_stream import ChatStream + +logger = get_logger("adaptive_stream_manager") + + +class StreamPriority(Enum): + """流优先级""" + LOW = 1 + NORMAL = 2 + HIGH = 3 + CRITICAL = 4 + + +@dataclass +class SystemMetrics: + """系统指标""" + cpu_usage: float = 0.0 + memory_usage: float = 0.0 + active_coroutines: int = 0 + event_loop_lag: float = 0.0 + timestamp: float = field(default_factory=time.time) + + +@dataclass +class StreamMetrics: + """流指标""" + stream_id: str + priority: StreamPriority + message_rate: float = 0.0 # 消息速率(消息/分钟) + response_time: float = 0.0 # 平均响应时间 + last_activity: float = field(default_factory=time.time) + consecutive_failures: int = 0 + is_active: bool = True + + +class AdaptiveStreamManager: + """自适应流管理器""" + + def __init__( + self, + base_concurrent_limit: int = 50, + max_concurrent_limit: int = 200, + min_concurrent_limit: int = 10, + metrics_window: float = 60.0, # 指标窗口时间 + adjustment_interval: float = 30.0, # 调整间隔 + cpu_threshold_high: float = 0.8, # CPU高负载阈值 + cpu_threshold_low: float = 0.3, # CPU低负载阈值 + memory_threshold_high: float = 0.85, # 内存高负载阈值 + ): + self.base_concurrent_limit = base_concurrent_limit + self.max_concurrent_limit = max_concurrent_limit + self.min_concurrent_limit = min_concurrent_limit + self.metrics_window = metrics_window + self.adjustment_interval = adjustment_interval + self.cpu_threshold_high = cpu_threshold_high + self.cpu_threshold_low = cpu_threshold_low + self.memory_threshold_high = memory_threshold_high + + # 当前状态 + self.current_limit = base_concurrent_limit + self.active_streams: Set[str] = set() + self.pending_streams: Set[str] = set() + self.stream_metrics: Dict[str, StreamMetrics] = {} + + # 异步信号量 + self.semaphore = asyncio.Semaphore(base_concurrent_limit) + self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量 + + # 系统监控 + self.system_metrics: List[SystemMetrics] = [] + self.last_adjustment_time = 0.0 + + # 统计信息 + self.stats = { + "total_requests": 0, + "accepted_requests": 0, + "rejected_requests": 0, + "priority_accepts": 0, + "limit_adjustments": 0, + "avg_concurrent_streams": 0, + "peak_concurrent_streams": 0, + } + + # 监控任务 + self.monitor_task: Optional[asyncio.Task] = None + self.adjustment_task: Optional[asyncio.Task] = None + self.is_running = False + + logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})") + + async def start(self): + """启动自适应管理器""" + if self.is_running: + logger.warning("自适应流管理器已经在运行") + return + + self.is_running = True + self.monitor_task = asyncio.create_task(self._system_monitor_loop(), name="system_monitor") + self.adjustment_task = asyncio.create_task(self._adjustment_loop(), name="limit_adjustment") + logger.info("自适应流管理器已启动") + + async def stop(self): + """停止自适应管理器""" + if not self.is_running: + return + + self.is_running = False + + # 停止监控任务 + if self.monitor_task and not self.monitor_task.done(): + self.monitor_task.cancel() + try: + await asyncio.wait_for(self.monitor_task, timeout=10.0) + except asyncio.TimeoutError: + logger.warning("系统监控任务停止超时") + except Exception as e: + logger.error(f"停止系统监控任务时出错: {e}") + + if self.adjustment_task and not self.adjustment_task.done(): + self.adjustment_task.cancel() + try: + await asyncio.wait_for(self.adjustment_task, timeout=10.0) + except asyncio.TimeoutError: + logger.warning("限制调整任务停止超时") + except Exception as e: + logger.error(f"停止限制调整任务时出错: {e}") + + logger.info("自适应流管理器已停止") + + async def acquire_stream_slot( + self, + stream_id: str, + priority: StreamPriority = StreamPriority.NORMAL, + force: bool = False + ) -> bool: + """ + 获取流处理槽位 + + Args: + stream_id: 流ID + priority: 优先级 + force: 是否强制获取(突破限制) + + Returns: + bool: 是否成功获取槽位 + """ + # 检查管理器是否已启动 + if not self.is_running: + logger.warning(f"自适应流管理器未运行,直接允许流 {stream_id}") + return True + + self.stats["total_requests"] += 1 + current_time = time.time() + + # 更新流指标 + if stream_id not in self.stream_metrics: + self.stream_metrics[stream_id] = StreamMetrics( + stream_id=stream_id, + priority=priority + ) + self.stream_metrics[stream_id].last_activity = current_time + + # 检查是否已经活跃 + if stream_id in self.active_streams: + logger.debug(f"流 {stream_id} 已经在活跃列表中") + return True + + # 优先级处理 + if priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]: + return await self._acquire_priority_slot(stream_id, priority, force) + + # 检查是否需要强制分发(消息积压) + if not force and self._should_force_dispatch(stream_id): + force = True + logger.info(f"流 {stream_id} 消息积压严重,强制分发") + + # 尝试获取常规信号量 + try: + # 使用wait_for实现非阻塞获取 + acquired = await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001) + if acquired: + self.active_streams.add(stream_id) + self.stats["accepted_requests"] += 1 + logger.debug(f"流 {stream_id} 获取常规槽位成功 (当前活跃: {len(self.active_streams)})") + return True + except asyncio.TimeoutError: + logger.debug(f"常规信号量已满: {stream_id}") + except Exception as e: + logger.warning(f"获取常规槽位时出错: {e}") + + # 如果强制分发,尝试突破限制 + if force: + return await self._force_acquire_slot(stream_id) + + # 无法获取槽位 + self.stats["rejected_requests"] += 1 + logger.debug(f"流 {stream_id} 获取槽位失败,当前限制: {self.current_limit}, 活跃流: {len(self.active_streams)}") + return False + + async def _acquire_priority_slot(self, stream_id: str, priority: StreamPriority, force: bool) -> bool: + """获取优先级槽位""" + try: + # 优先级信号量有少量槽位 + acquired = await asyncio.wait_for(self.priority_semaphore.acquire(), timeout=0.001) + if acquired: + self.active_streams.add(stream_id) + self.stats["priority_accepts"] += 1 + self.stats["accepted_requests"] += 1 + logger.debug(f"流 {stream_id} 获取优先级槽位成功 (优先级: {priority.name})") + return True + except asyncio.TimeoutError: + logger.debug(f"优先级信号量已满: {stream_id}") + except Exception as e: + logger.warning(f"获取优先级槽位时出错: {e}") + + # 如果优先级槽位也满了,检查是否强制 + if force or priority == StreamPriority.CRITICAL: + return await self._force_acquire_slot(stream_id) + + return False + + async def _force_acquire_slot(self, stream_id: str) -> bool: + """强制获取槽位(突破限制)""" + # 检查是否超过最大限制 + if len(self.active_streams) >= self.max_concurrent_limit: + logger.warning(f"达到最大并发限制 {self.max_concurrent_limit},无法为流 {stream_id} 强制分发") + return False + + # 强制添加到活跃列表 + self.active_streams.add(stream_id) + self.stats["accepted_requests"] += 1 + logger.warning(f"流 {stream_id} 突破并发限制强制分发 (当前活跃: {len(self.active_streams)})") + return True + + def release_stream_slot(self, stream_id: str): + """释放流处理槽位""" + if stream_id in self.active_streams: + self.active_streams.remove(stream_id) + + # 释放相应的信号量 + metrics = self.stream_metrics.get(stream_id) + if metrics and metrics.priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]: + self.priority_semaphore.release() + else: + self.semaphore.release() + + logger.debug(f"流 {stream_id} 释放槽位 (当前活跃: {len(self.active_streams)})") + + def _should_force_dispatch(self, stream_id: str) -> bool: + """判断是否应该强制分发""" + # 这里可以实现基于消息积压的判断逻辑 + # 简化版本:基于流的历史活跃度和优先级 + metrics = self.stream_metrics.get(stream_id) + if not metrics: + return False + + # 如果是高优先级流,更容易强制分发 + if metrics.priority == StreamPriority.HIGH: + return True + + # 如果最近有活跃且响应时间较长,可能需要强制分发 + current_time = time.time() + if (current_time - metrics.last_activity < 300 and # 5分钟内有活动 + metrics.response_time > 5.0): # 响应时间超过5秒 + return True + + return False + + async def _system_monitor_loop(self): + """系统监控循环""" + logger.info("系统监控循环启动") + + while self.is_running: + try: + await asyncio.sleep(5.0) # 每5秒监控一次 + await self._collect_system_metrics() + except asyncio.CancelledError: + logger.info("系统监控循环被取消") + break + except Exception as e: + logger.error(f"系统监控出错: {e}") + + logger.info("系统监控循环结束") + + async def _collect_system_metrics(self): + """收集系统指标""" + try: + # CPU使用率 + cpu_usage = psutil.cpu_percent(interval=None) / 100.0 + + # 内存使用率 + memory = psutil.virtual_memory() + memory_usage = memory.percent / 100.0 + + # 活跃协程数量 + try: + active_coroutines = len(asyncio.all_tasks()) + except: + active_coroutines = 0 + + # 事件循环延迟 + event_loop_lag = 0.0 + try: + loop = asyncio.get_running_loop() + start_time = time.time() + await asyncio.sleep(0) + event_loop_lag = time.time() - start_time + except: + pass + + metrics = SystemMetrics( + cpu_usage=cpu_usage, + memory_usage=memory_usage, + active_coroutines=active_coroutines, + event_loop_lag=event_loop_lag, + timestamp=time.time() + ) + + self.system_metrics.append(metrics) + + # 保持指标窗口大小 + cutoff_time = time.time() - self.metrics_window + self.system_metrics = [ + m for m in self.system_metrics + if m.timestamp > cutoff_time + ] + + # 更新统计信息 + self.stats["avg_concurrent_streams"] = ( + self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1 + ) + self.stats["peak_concurrent_streams"] = max( + self.stats["peak_concurrent_streams"], + len(self.active_streams) + ) + + except Exception as e: + logger.error(f"收集系统指标失败: {e}") + + async def _adjustment_loop(self): + """限制调整循环""" + logger.info("限制调整循环启动") + + while self.is_running: + try: + await asyncio.sleep(self.adjustment_interval) + await self._adjust_concurrent_limit() + except asyncio.CancelledError: + logger.info("限制调整循环被取消") + break + except Exception as e: + logger.error(f"限制调整出错: {e}") + + logger.info("限制调整循环结束") + + async def _adjust_concurrent_limit(self): + """调整并发限制""" + if not self.system_metrics: + return + + current_time = time.time() + if current_time - self.last_adjustment_time < self.adjustment_interval: + return + + # 计算平均系统指标 + recent_metrics = self.system_metrics[-10:] if len(self.system_metrics) >= 10 else self.system_metrics + if not recent_metrics: + return + + avg_cpu = sum(m.cpu_usage for m in recent_metrics) / len(recent_metrics) + avg_memory = sum(m.memory_usage for m in recent_metrics) / len(recent_metrics) + avg_coroutines = sum(m.active_coroutines for m in recent_metrics) / len(recent_metrics) + + # 调整策略 + old_limit = self.current_limit + adjustment_factor = 1.0 + + # CPU负载调整 + if avg_cpu > self.cpu_threshold_high: + adjustment_factor *= 0.8 # 减少20% + elif avg_cpu < self.cpu_threshold_low: + adjustment_factor *= 1.2 # 增加20% + + # 内存负载调整 + if avg_memory > self.memory_threshold_high: + adjustment_factor *= 0.7 # 减少30% + + # 协程数量调整 + if avg_coroutines > 1000: + adjustment_factor *= 0.9 # 减少10% + + # 应用调整 + new_limit = int(self.current_limit * adjustment_factor) + new_limit = max(self.min_concurrent_limit, min(self.max_concurrent_limit, new_limit)) + + # 检查是否需要调整信号量 + if new_limit != self.current_limit: + await self._adjust_semaphore(self.current_limit, new_limit) + self.current_limit = new_limit + self.stats["limit_adjustments"] += 1 + self.last_adjustment_time = current_time + + logger.info( + f"并发限制调整: {old_limit} -> {new_limit} " + f"(CPU: {avg_cpu:.2f}, 内存: {avg_memory:.2f}, 协程: {avg_coroutines:.0f})" + ) + + async def _adjust_semaphore(self, old_limit: int, new_limit: int): + """调整信号量大小""" + if new_limit > old_limit: + # 增加信号量槽位 + for _ in range(new_limit - old_limit): + self.semaphore.release() + elif new_limit < old_limit: + # 减少信号量槽位(通过等待槽位被释放) + reduction = old_limit - new_limit + for _ in range(reduction): + try: + await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001) + except: + # 如果无法立即获取,说明当前使用量接近限制 + break + + def update_stream_metrics(self, stream_id: str, **kwargs): + """更新流指标""" + if stream_id not in self.stream_metrics: + return + + metrics = self.stream_metrics[stream_id] + for key, value in kwargs.items(): + if hasattr(metrics, key): + setattr(metrics, key, value) + + def get_stats(self) -> Dict: + """获取统计信息""" + stats = self.stats.copy() + stats.update({ + "current_limit": self.current_limit, + "active_streams": len(self.active_streams), + "pending_streams": len(self.pending_streams), + "is_running": self.is_running, + "system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0, + "system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0, + }) + + # 计算接受率 + if stats["total_requests"] > 0: + stats["acceptance_rate"] = stats["accepted_requests"] / stats["total_requests"] + else: + stats["acceptance_rate"] = 0 + + return stats + + +# 全局自适应管理器实例 +_adaptive_manager: Optional[AdaptiveStreamManager] = None + + +def get_adaptive_stream_manager() -> AdaptiveStreamManager: + """获取自适应流管理器实例""" + global _adaptive_manager + if _adaptive_manager is None: + _adaptive_manager = AdaptiveStreamManager() + return _adaptive_manager + + +async def init_adaptive_stream_manager(): + """初始化自适应流管理器""" + manager = get_adaptive_stream_manager() + await manager.start() + + +async def shutdown_adaptive_stream_manager(): + """关闭自适应流管理器""" + manager = get_adaptive_stream_manager() + await manager.stop() \ No newline at end of file diff --git a/src/chat/message_manager/batch_database_writer.py b/src/chat/message_manager/batch_database_writer.py new file mode 100644 index 000000000..4b35b01c1 --- /dev/null +++ b/src/chat/message_manager/batch_database_writer.py @@ -0,0 +1,348 @@ +""" +异步批量数据库写入器 +优化频繁的数据库写入操作,减少I/O阻塞 +""" + +import asyncio +import time +from typing import Any, Dict, List, Optional +from dataclasses import dataclass, field +from collections import defaultdict + +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import ChatStreams +from src.common.logger import get_logger +from src.config.config import global_config + +logger = get_logger("batch_database_writer") + + +@dataclass +class StreamUpdatePayload: + """流更新数据结构""" + stream_id: str + update_data: Dict[str, Any] + priority: int = 0 # 优先级,数字越大优先级越高 + timestamp: float = field(default_factory=time.time) + + +class BatchDatabaseWriter: + """异步批量数据库写入器""" + + def __init__(self, batch_size: int = 50, flush_interval: float = 5.0, max_queue_size: int = 1000): + """ + 初始化批量写入器 + + Args: + batch_size: 批量写入的大小 + flush_interval: 刷新间隔(秒) + max_queue_size: 最大队列大小 + """ + self.batch_size = batch_size + self.flush_interval = flush_interval + self.max_queue_size = max_queue_size + + # 异步队列 + self.write_queue: asyncio.Queue[StreamUpdatePayload] = asyncio.Queue(maxsize=max_queue_size) + + # 运行状态 + self.is_running = False + self.writer_task: Optional[asyncio.Task] = None + + # 统计信息 + self.stats = { + "total_writes": 0, + "batch_writes": 0, + "failed_writes": 0, + "queue_size": 0, + "avg_batch_size": 0, + "last_flush_time": 0, + } + + # 按优先级分类的批次 + self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list) + + logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)") + + async def start(self): + """启动批量写入器""" + if self.is_running: + logger.warning("批量写入器已经在运行") + return + + self.is_running = True + self.writer_task = asyncio.create_task(self._batch_writer_loop(), name="batch_database_writer") + logger.info("批量数据库写入器已启动") + + async def stop(self): + """停止批量写入器""" + if not self.is_running: + return + + self.is_running = False + + # 等待当前批次写入完成 + if self.writer_task and not self.writer_task.done(): + try: + # 先处理剩余的数据 + await self._flush_all_batches() + # 取消任务 + self.writer_task.cancel() + await asyncio.wait_for(self.writer_task, timeout=10.0) + except asyncio.TimeoutError: + logger.warning("批量写入器停止超时") + except Exception as e: + logger.error(f"停止批量写入器时出错: {e}") + + logger.info("批量数据库写入器已停止") + + async def schedule_stream_update( + self, + stream_id: str, + update_data: Dict[str, Any], + priority: int = 0 + ) -> bool: + """ + 调度流更新 + + Args: + stream_id: 流ID + update_data: 更新数据 + priority: 优先级 + + Returns: + bool: 是否成功加入队列 + """ + try: + if not self.is_running: + logger.warning("批量写入器未运行,直接写入数据库") + await self._direct_write(stream_id, update_data) + return True + + # 创建更新载荷 + payload = StreamUpdatePayload( + stream_id=stream_id, + update_data=update_data, + priority=priority + ) + + # 非阻塞方式加入队列 + try: + self.write_queue.put_nowait(payload) + self.stats["total_writes"] += 1 + self.stats["queue_size"] = self.write_queue.qsize() + return True + except asyncio.QueueFull: + logger.warning(f"写入队列已满,丢弃低优先级更新: stream_id={stream_id}") + return False + + except Exception as e: + logger.error(f"调度流更新失败: {e}") + return False + + async def _batch_writer_loop(self): + """批量写入主循环""" + logger.info("批量写入循环启动") + + while self.is_running: + try: + # 等待批次填满或超时 + batch = await self._collect_batch() + + if batch: + await self._write_batch(batch) + + # 更新统计信息 + self.stats["queue_size"] = self.write_queue.qsize() + + except asyncio.CancelledError: + logger.info("批量写入循环被取消") + break + except Exception as e: + logger.error(f"批量写入循环出错: {e}") + # 短暂等待后继续 + await asyncio.sleep(1.0) + + # 循环结束前处理剩余数据 + await self._flush_all_batches() + logger.info("批量写入循环结束") + + async def _collect_batch(self) -> List[StreamUpdatePayload]: + """收集一个批次的数据""" + batch = [] + deadline = time.time() + self.flush_interval + + while len(batch) < self.batch_size and time.time() < deadline: + try: + # 计算剩余等待时间 + remaining_time = max(0, deadline - time.time()) + if remaining_time == 0: + break + + payload = await asyncio.wait_for( + self.write_queue.get(), + timeout=remaining_time + ) + batch.append(payload) + + except asyncio.TimeoutError: + break + + return batch + + async def _write_batch(self, batch: List[StreamUpdatePayload]): + """批量写入数据库""" + if not batch: + return + + start_time = time.time() + + try: + # 按优先级排序 + batch.sort(key=lambda x: (-x.priority, x.timestamp)) + + # 合并同一流ID的更新(保留最新的) + merged_updates = {} + for payload in batch: + if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp: + merged_updates[payload.stream_id] = payload + + # 批量写入 + await self._batch_write_to_database(list(merged_updates.values())) + + # 更新统计 + self.stats["batch_writes"] += 1 + self.stats["avg_batch_size"] = ( + self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1 + ) # 滑动平均 + self.stats["last_flush_time"] = start_time + + logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s") + + except Exception as e: + self.stats["failed_writes"] += 1 + logger.error(f"批量写入失败: {e}") + # 降级到单个写入 + for payload in batch: + try: + await self._direct_write(payload.stream_id, payload.update_data) + except Exception as single_e: + logger.error(f"单个写入也失败: {single_e}") + + async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]): + """批量写入数据库""" + async with get_db_session() as session: + for payload in payloads: + stream_id = payload.stream_id + update_data = payload.update_data + + # 根据数据库类型选择不同的插入/更新策略 + if global_config.database.database_type == "sqlite": + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + stmt = sqlite_insert(ChatStreams).values( + stream_id=stream_id, **update_data + ) + stmt = stmt.on_conflict_do_update( + index_elements=["stream_id"], + set_=update_data + ) + elif global_config.database.database_type == "mysql": + from sqlalchemy.dialects.mysql import insert as mysql_insert + stmt = mysql_insert(ChatStreams).values( + stream_id=stream_id, **update_data + ) + stmt = stmt.on_duplicate_key_update( + **{key: value for key, value in update_data.items() if key != "stream_id"} + ) + else: + # 默认使用SQLite语法 + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + stmt = sqlite_insert(ChatStreams).values( + stream_id=stream_id, **update_data + ) + stmt = stmt.on_conflict_do_update( + index_elements=["stream_id"], + set_=update_data + ) + + await session.execute(stmt) + + await session.commit() + + async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]): + """直接写入数据库(降级方案)""" + async with get_db_session() as session: + if global_config.database.database_type == "sqlite": + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + stmt = sqlite_insert(ChatStreams).values( + stream_id=stream_id, **update_data + ) + stmt = stmt.on_conflict_do_update( + index_elements=["stream_id"], + set_=update_data + ) + elif global_config.database.database_type == "mysql": + from sqlalchemy.dialects.mysql import insert as mysql_insert + stmt = mysql_insert(ChatStreams).values( + stream_id=stream_id, **update_data + ) + stmt = stmt.on_duplicate_key_update( + **{key: value for key, value in update_data.items() if key != "stream_id"} + ) + else: + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + stmt = sqlite_insert(ChatStreams).values( + stream_id=stream_id, **update_data + ) + stmt = stmt.on_conflict_do_update( + index_elements=["stream_id"], + set_=update_data + ) + + await session.execute(stmt) + await session.commit() + + async def _flush_all_batches(self): + """刷新所有剩余批次""" + # 收集所有剩余数据 + remaining_batch = [] + while not self.write_queue.empty(): + try: + payload = self.write_queue.get_nowait() + remaining_batch.append(payload) + except asyncio.QueueEmpty: + break + + if remaining_batch: + await self._write_batch(remaining_batch) + + def get_stats(self) -> Dict[str, Any]: + """获取统计信息""" + stats = self.stats.copy() + stats["is_running"] = self.is_running + stats["current_queue_size"] = self.write_queue.qsize() if self.is_running else 0 + return stats + + +# 全局批量写入器实例 +_batch_writer: Optional[BatchDatabaseWriter] = None + + +def get_batch_writer() -> BatchDatabaseWriter: + """获取批量写入器实例""" + global _batch_writer + if _batch_writer is None: + _batch_writer = BatchDatabaseWriter() + return _batch_writer + + +async def init_batch_writer(): + """初始化批量写入器""" + writer = get_batch_writer() + await writer.start() + + +async def shutdown_batch_writer(): + """关闭批量写入器""" + writer = get_batch_writer() + await writer.stop() \ No newline at end of file diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index 75a0033e0..25b4dd6b1 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -23,6 +23,8 @@ class StreamLoopManager: def __init__(self, max_concurrent_streams: int | None = None): # 流循环任务管理 self.stream_loops: dict[str, asyncio.Task] = {} + # 跟踪流使用的管理器类型 + self.stream_management_type: dict[str, str] = {} # stream_id -> "adaptive" or "fallback" # 统计信息 self.stats: dict[str, Any] = { @@ -99,7 +101,7 @@ class StreamLoopManager: logger.info("流循环管理器已停止") async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool: - """启动指定流的循环任务 + """启动指定流的循环任务 - 优化版本使用自适应管理器 Args: stream_id: 流ID @@ -113,6 +115,71 @@ class StreamLoopManager: logger.debug(f"流 {stream_id} 循环已在运行") return True + # 使用自适应流管理器获取槽位 + use_adaptive = False + try: + from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority + adaptive_manager = get_adaptive_stream_manager() + + if adaptive_manager.is_running: + # 确定流优先级 + priority = self._determine_stream_priority(stream_id) + + # 获取处理槽位 + slot_acquired = await adaptive_manager.acquire_stream_slot( + stream_id=stream_id, + priority=priority, + force=force + ) + + if slot_acquired: + use_adaptive = True + logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})") + else: + logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案") + else: + logger.debug(f"自适应管理器未运行,使用原始方法") + + except Exception as e: + logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}") + + # 如果自适应管理器失败或未运行,使用回退方案 + if not use_adaptive: + if not await self._fallback_acquire_slot(stream_id, force): + logger.debug(f"回退方案也失败: {stream_id}") + return False + + # 创建流循环任务 + try: + loop_task = asyncio.create_task( + self._stream_loop_worker(stream_id), + name=f"stream_loop_{stream_id}" + ) + self.stream_loops[stream_id] = loop_task + # 记录管理器类型 + self.stream_management_type[stream_id] = "adaptive" if use_adaptive else "fallback" + + # 更新统计信息 + self.stats["active_streams"] += 1 + self.stats["total_loops"] += 1 + + logger.info(f"启动流循环任务: {stream_id} (管理器: {'adaptive' if use_adaptive else 'fallback'})") + return True + + except Exception as e: + logger.error(f"启动流循环任务失败 {stream_id}: {e}") + # 释放槽位 + if use_adaptive: + try: + from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager + adaptive_manager = get_adaptive_stream_manager() + adaptive_manager.release_stream_slot(stream_id) + except: + pass + return False + + async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool: + """回退方案:获取槽位(原始方法)""" # 判断是否需要强制分发 should_force = force or self._should_force_dispatch_for_stream(stream_id) @@ -149,6 +216,28 @@ class StreamLoopManager: del self.stream_loops[stream_id] current_streams -= 1 # 更新当前流数量 + return True + + def _determine_stream_priority(self, stream_id: str) -> "StreamPriority": + """确定流优先级""" + try: + from src.chat.message_manager.adaptive_stream_manager import StreamPriority + + # 这里可以基于流的历史数据、用户身份等确定优先级 + # 简化版本:基于流ID的哈希值分配优先级 + hash_value = hash(stream_id) % 10 + + if hash_value >= 8: # 20% 高优先级 + return StreamPriority.HIGH + elif hash_value >= 5: # 30% 中等优先级 + return StreamPriority.NORMAL + else: # 50% 低优先级 + return StreamPriority.LOW + + except Exception: + from src.chat.message_manager.adaptive_stream_manager import StreamPriority + return StreamPriority.NORMAL + # 创建流循环任务 try: task = asyncio.create_task( @@ -201,13 +290,13 @@ class StreamLoopManager: logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})") return True - async def _stream_loop(self, stream_id: str) -> None: - """单个流的无限循环 + async def _stream_loop_worker(self, stream_id: str) -> None: + """单个流的工作循环 - 优化版本 Args: stream_id: 流ID """ - logger.info(f"流循环开始: {stream_id}") + logger.info(f"流循环工作器启动: {stream_id}") try: while self.is_running: @@ -223,6 +312,18 @@ class StreamLoopManager: unread_count = self._get_unread_count(context) force_dispatch = self._needs_force_dispatch_for_context(context, unread_count) + # 3. 更新自适应管理器指标 + try: + from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager + adaptive_manager = get_adaptive_stream_manager() + adaptive_manager.update_stream_metrics( + stream_id, + message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算 + last_activity=time.time() + ) + except Exception as e: + logger.debug(f"更新流指标失败: {e}") + has_messages = force_dispatch or await self._has_messages_to_process(context) if has_messages: @@ -278,6 +379,24 @@ class StreamLoopManager: del self.stream_loops[stream_id] logger.debug(f"清理流循环标记: {stream_id}") + # 根据管理器类型释放相应的槽位 + management_type = self.stream_management_type.get(stream_id, "fallback") + if management_type == "adaptive": + # 释放自适应管理器的槽位 + try: + from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager + adaptive_manager = get_adaptive_stream_manager() + adaptive_manager.release_stream_slot(stream_id) + logger.debug(f"释放自适应流处理槽位: {stream_id}") + except Exception as e: + logger.debug(f"释放自适应流处理槽位失败: {e}") + else: + logger.debug(f"流 {stream_id} 使用回退方案,无需释放自适应槽位") + + # 清理管理器类型记录 + if stream_id in self.stream_management_type: + del self.stream_management_type[stream_id] + logger.info(f"流循环结束: {stream_id}") async def _get_stream_context(self, stream_id: str) -> Any | None: diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index a714fd957..d6d56294b 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -56,6 +56,30 @@ class MessageManager: self.is_running = True + # 启动批量数据库写入器 + try: + from src.chat.message_manager.batch_database_writer import init_batch_writer + await init_batch_writer() + logger.info("📦 批量数据库写入器已启动") + except Exception as e: + logger.error(f"启动批量数据库写入器失败: {e}") + + # 启动流缓存管理器 + try: + from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager + await init_stream_cache_manager() + logger.info("🗄️ 流缓存管理器已启动") + except Exception as e: + logger.error(f"启动流缓存管理器失败: {e}") + + # 启动自适应流管理器 + try: + from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager + await init_adaptive_stream_manager() + logger.info("🎯 自适应流管理器已启动") + except Exception as e: + logger.error(f"启动自适应流管理器失败: {e}") + # 启动睡眠和唤醒管理器 await self.wakeup_manager.start() @@ -72,6 +96,30 @@ class MessageManager: self.is_running = False + # 停止批量数据库写入器 + try: + from src.chat.message_manager.batch_database_writer import shutdown_batch_writer + await shutdown_batch_writer() + logger.info("📦 批量数据库写入器已停止") + except Exception as e: + logger.error(f"停止批量数据库写入器失败: {e}") + + # 停止流缓存管理器 + try: + from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager + await shutdown_stream_cache_manager() + logger.info("🗄️ 流缓存管理器已停止") + except Exception as e: + logger.error(f"停止流缓存管理器失败: {e}") + + # 停止自适应流管理器 + try: + from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager + await shutdown_adaptive_stream_manager() + logger.info("🎯 自适应流管理器已停止") + except Exception as e: + logger.error(f"停止自适应流管理器失败: {e}") + # 停止睡眠和唤醒管理器 await self.wakeup_manager.stop() diff --git a/src/chat/message_manager/stream_cache_manager.py b/src/chat/message_manager/stream_cache_manager.py new file mode 100644 index 000000000..fcbfd7bd4 --- /dev/null +++ b/src/chat/message_manager/stream_cache_manager.py @@ -0,0 +1,381 @@ +""" +流缓存管理器 - 使用优化版聊天流和智能缓存策略 +提供分层缓存和自动清理功能 +""" + +import asyncio +import time +from typing import Dict, List, Optional, Set +from dataclasses import dataclass +from collections import OrderedDict + +from maim_message import GroupInfo, UserInfo +from src.common.logger import get_logger +from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream + +logger = get_logger("stream_cache_manager") + + +@dataclass +class StreamCacheStats: + """缓存统计信息""" + hot_cache_size: int = 0 + warm_storage_size: int = 0 + cold_storage_size: int = 0 + total_memory_usage: int = 0 # 估算的内存使用(字节) + cache_hits: int = 0 + cache_misses: int = 0 + evictions: int = 0 + last_cleanup_time: float = 0 + + +class TieredStreamCache: + """分层流缓存管理器""" + + def __init__( + self, + max_hot_size: int = 100, + max_warm_size: int = 500, + max_cold_size: int = 2000, + cleanup_interval: float = 300.0, # 5分钟清理一次 + hot_timeout: float = 1800.0, # 30分钟未访问降级到warm + warm_timeout: float = 7200.0, # 2小时未访问降级到cold + cold_timeout: float = 86400.0, # 24小时未访问删除 + ): + self.max_hot_size = max_hot_size + self.max_warm_size = max_warm_size + self.max_cold_size = max_cold_size + self.cleanup_interval = cleanup_interval + self.hot_timeout = hot_timeout + self.warm_timeout = warm_timeout + self.cold_timeout = cold_timeout + + # 三层缓存存储 + self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU) + self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间) + self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间) + + # 统计信息 + self.stats = StreamCacheStats() + + # 清理任务 + self.cleanup_task: Optional[asyncio.Task] = None + self.is_running = False + + logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})") + + async def start(self): + """启动缓存管理器""" + if self.is_running: + logger.warning("缓存管理器已经在运行") + return + + self.is_running = True + self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup") + logger.info("分层流缓存管理器已启动") + + async def stop(self): + """停止缓存管理器""" + if not self.is_running: + return + + self.is_running = False + + if self.cleanup_task and not self.cleanup_task.done(): + self.cleanup_task.cancel() + try: + await asyncio.wait_for(self.cleanup_task, timeout=10.0) + except asyncio.TimeoutError: + logger.warning("缓存清理任务停止超时") + except Exception as e: + logger.error(f"停止缓存清理任务时出错: {e}") + + logger.info("分层流缓存管理器已停止") + + async def get_or_create_stream( + self, + stream_id: str, + platform: str, + user_info: UserInfo, + group_info: Optional[GroupInfo] = None, + data: Optional[Dict] = None, + ) -> OptimizedChatStream: + """获取或创建流 - 优化版本""" + current_time = time.time() + + # 1. 检查热缓存 + if stream_id in self.hot_cache: + stream = self.hot_cache[stream_id] + # 移动到末尾(LRU更新) + self.hot_cache.move_to_end(stream_id) + self.stats.cache_hits += 1 + logger.debug(f"热缓存命中: {stream_id}") + return stream.create_snapshot() + + # 2. 检查温存储 + if stream_id in self.warm_storage: + stream, last_access = self.warm_storage[stream_id] + self.warm_storage[stream_id] = (stream, current_time) + self.stats.cache_hits += 1 + logger.debug(f"温缓存命中: {stream_id}") + # 提升到热缓存 + await self._promote_to_hot(stream_id, stream) + return stream.create_snapshot() + + # 3. 检查冷存储 + if stream_id in self.cold_storage: + stream, last_access = self.cold_storage[stream_id] + self.cold_storage[stream_id] = (stream, current_time) + self.stats.cache_hits += 1 + logger.debug(f"冷缓存命中: {stream_id}") + # 提升到温缓存 + await self._promote_to_warm(stream_id, stream) + return stream.create_snapshot() + + # 4. 缓存未命中,创建新流 + self.stats.cache_misses += 1 + stream = create_optimized_chat_stream( + stream_id=stream_id, + platform=platform, + user_info=user_info, + group_info=group_info, + data=data + ) + logger.debug(f"缓存未命中,创建新流: {stream_id}") + + # 添加到热缓存 + await self._add_to_hot(stream_id, stream) + + return stream + + async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream): + """添加到热缓存""" + # 检查是否需要驱逐 + if len(self.hot_cache) >= self.max_hot_size: + await self._evict_from_hot() + + self.hot_cache[stream_id] = stream + self.stats.hot_cache_size = len(self.hot_cache) + + async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream): + """提升到热缓存""" + # 从温存储中移除 + if stream_id in self.warm_storage: + del self.warm_storage[stream_id] + self.stats.warm_storage_size = len(self.warm_storage) + + # 添加到热缓存 + await self._add_to_hot(stream_id, stream) + logger.debug(f"流 {stream_id} 提升到热缓存") + + async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream): + """提升到温缓存""" + # 从冷存储中移除 + if stream_id in self.cold_storage: + del self.cold_storage[stream_id] + self.stats.cold_storage_size = len(self.cold_storage) + + # 添加到温存储 + if len(self.warm_storage) >= self.max_warm_size: + await self._evict_from_warm() + + current_time = time.time() + self.warm_storage[stream_id] = (stream, current_time) + self.stats.warm_storage_size = len(self.warm_storage) + logger.debug(f"流 {stream_id} 提升到温缓存") + + async def _evict_from_hot(self): + """从热缓存驱逐最久未使用的流""" + if not self.hot_cache: + return + + # LRU驱逐 + stream_id, stream = self.hot_cache.popitem(last=False) + self.stats.evictions += 1 + logger.debug(f"从热缓存驱逐: {stream_id}") + + # 移动到温存储 + if len(self.warm_storage) < self.max_warm_size: + current_time = time.time() + self.warm_storage[stream_id] = (stream, current_time) + self.stats.warm_storage_size = len(self.warm_storage) + else: + # 温存储也满了,直接删除 + logger.debug(f"温存储已满,删除流: {stream_id}") + + self.stats.hot_cache_size = len(self.hot_cache) + + async def _evict_from_warm(self): + """从温存储驱逐最久未使用的流""" + if not self.warm_storage: + return + + # 找到最久未访问的流 + oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1]) + stream, last_access = self.warm_storage.pop(oldest_stream_id) + self.stats.evictions += 1 + logger.debug(f"从温存储驱逐: {oldest_stream_id}") + + # 移动到冷存储 + if len(self.cold_storage) < self.max_cold_size: + current_time = time.time() + self.cold_storage[oldest_stream_id] = (stream, current_time) + self.stats.cold_storage_size = len(self.cold_storage) + else: + # 冷存储也满了,直接删除 + logger.debug(f"冷存储已满,删除流: {oldest_stream_id}") + + self.stats.warm_storage_size = len(self.warm_storage) + + async def _cleanup_loop(self): + """清理循环""" + logger.info("流缓存清理循环启动") + + while self.is_running: + try: + await asyncio.sleep(self.cleanup_interval) + await self._perform_cleanup() + except asyncio.CancelledError: + logger.info("流缓存清理循环被取消") + break + except Exception as e: + logger.error(f"流缓存清理出错: {e}") + + logger.info("流缓存清理循环结束") + + async def _perform_cleanup(self): + """执行清理操作""" + current_time = time.time() + cleanup_stats = { + "hot_to_warm": 0, + "warm_to_cold": 0, + "cold_removed": 0, + } + + # 1. 检查热缓存超时 + hot_to_demote = [] + for stream_id, stream in self.hot_cache.items(): + # 获取最后访问时间(简化:使用创建时间作为近似) + last_access = getattr(stream, 'last_active_time', stream.create_time) + if current_time - last_access > self.hot_timeout: + hot_to_demote.append(stream_id) + + for stream_id in hot_to_demote: + stream = self.hot_cache.pop(stream_id) + current_time_local = time.time() + self.warm_storage[stream_id] = (stream, current_time_local) + cleanup_stats["hot_to_warm"] += 1 + + # 2. 检查温存储超时 + warm_to_demote = [] + for stream_id, (stream, last_access) in self.warm_storage.items(): + if current_time - last_access > self.warm_timeout: + warm_to_demote.append(stream_id) + + for stream_id in warm_to_demote: + stream, last_access = self.warm_storage.pop(stream_id) + self.cold_storage[stream_id] = (stream, last_access) + cleanup_stats["warm_to_cold"] += 1 + + # 3. 检查冷存储超时 + cold_to_remove = [] + for stream_id, (stream, last_access) in self.cold_storage.items(): + if current_time - last_access > self.cold_timeout: + cold_to_remove.append(stream_id) + + for stream_id in cold_to_remove: + self.cold_storage.pop(stream_id) + cleanup_stats["cold_removed"] += 1 + + # 更新统计信息 + self.stats.hot_cache_size = len(self.hot_cache) + self.stats.warm_storage_size = len(self.warm_storage) + self.stats.cold_storage_size = len(self.cold_storage) + self.stats.last_cleanup_time = current_time + + # 估算内存使用(粗略估计) + self.stats.total_memory_usage = ( + len(self.hot_cache) * 1024 + # 每个热流约1KB + len(self.warm_storage) * 512 + # 每个温流约512B + len(self.cold_storage) * 256 # 每个冷流约256B + ) + + if sum(cleanup_stats.values()) > 0: + logger.info( + f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, " + f"{cleanup_stats['warm_to_cold']}温→冷, " + f"{cleanup_stats['cold_removed']}冷删除" + ) + + def get_stats(self) -> StreamCacheStats: + """获取缓存统计信息""" + # 计算命中率 + total_requests = self.stats.cache_hits + self.stats.cache_misses + hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0 + + stats_copy = StreamCacheStats( + hot_cache_size=self.stats.hot_cache_size, + warm_storage_size=self.stats.warm_storage_size, + cold_storage_size=self.stats.cold_storage_size, + total_memory_usage=self.stats.total_memory_usage, + cache_hits=self.stats.cache_hits, + cache_misses=self.stats.cache_misses, + evictions=self.stats.evictions, + last_cleanup_time=self.stats.last_cleanup_time, + ) + + # 添加命中率信息 + stats_copy.hit_rate = hit_rate + + return stats_copy + + def clear_cache(self): + """清空所有缓存""" + self.hot_cache.clear() + self.warm_storage.clear() + self.cold_storage.clear() + + self.stats.hot_cache_size = 0 + self.stats.warm_storage_size = 0 + self.stats.cold_storage_size = 0 + self.stats.total_memory_usage = 0 + + logger.info("所有缓存已清空") + + async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]: + """获取流的快照(不修改缓存状态)""" + if stream_id in self.hot_cache: + return self.hot_cache[stream_id].create_snapshot() + elif stream_id in self.warm_storage: + return self.warm_storage[stream_id][0].create_snapshot() + elif stream_id in self.cold_storage: + return self.cold_storage[stream_id][0].create_snapshot() + return None + + def get_cached_stream_ids(self) -> Set[str]: + """获取所有缓存的流ID""" + return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys()) + + +# 全局缓存管理器实例 +_cache_manager: Optional[TieredStreamCache] = None + + +def get_stream_cache_manager() -> TieredStreamCache: + """获取流缓存管理器实例""" + global _cache_manager + if _cache_manager is None: + _cache_manager = TieredStreamCache() + return _cache_manager + + +async def init_stream_cache_manager(): + """初始化流缓存管理器""" + manager = get_stream_cache_manager() + await manager.start() + + +async def shutdown_stream_cache_manager(): + """关闭流缓存管理器""" + manager = get_stream_cache_manager() + await manager.stop() \ No newline at end of file diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 326620f75..33ac25604 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -464,7 +464,7 @@ class ChatManager: async def get_or_create_stream( self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None ) -> ChatStream: - """获取或创建聊天流 + """获取或创建聊天流 - 优化版本使用缓存管理器 Args: platform: 平台标识 @@ -478,6 +478,31 @@ class ChatManager: try: stream_id = self._generate_stream_id(platform, user_info, group_info) + # 优先使用缓存管理器(优化版本) + try: + from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager + cache_manager = get_stream_cache_manager() + + if cache_manager.is_running: + optimized_stream = await cache_manager.get_or_create_stream( + stream_id=stream_id, + platform=platform, + user_info=user_info, + group_info=group_info + ) + + # 设置消息上下文 + from .message import MessageRecv + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): + optimized_stream.set_context(self.last_messages[stream_id]) + + # 转换为原始ChatStream以保持兼容性 + return self._convert_to_original_stream(optimized_stream) + + except Exception as e: + logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}") + + # 回退到原始方法 # 检查内存中是否存在 if stream_id in self.streams: stream = self.streams[stream_id] @@ -634,12 +659,35 @@ class ChatManager: @staticmethod async def _save_stream(stream: ChatStream): - """保存聊天流到数据库""" + """保存聊天流到数据库 - 优化版本使用异步批量写入""" if stream.saved: return stream_data_dict = stream.to_dict() - # 尝试使用数据库批量调度器 + # 优先使用新的批量写入器 + try: + from src.chat.message_manager.batch_database_writer import get_batch_writer + + batch_writer = get_batch_writer() + if batch_writer.is_running: + success = await batch_writer.schedule_stream_update( + stream_id=stream_data_dict["stream_id"], + update_data=ChatManager._prepare_stream_data(stream_data_dict), + priority=1 # 流更新的优先级 + ) + if success: + stream.saved = True + logger.debug(f"聊天流 {stream.stream_id} 通过批量写入器调度成功") + return + else: + logger.warning(f"批量写入器队列已满,使用原始方法: {stream.stream_id}") + else: + logger.debug(f"批量写入器未运行,使用原始方法: {stream.stream_id}") + + except Exception as e: + logger.debug(f"批量写入器保存聊天流失败,使用原始方法: {e}") + + # 尝试使用数据库批量调度器(回退方案1) try: from src.common.database.db_batch_scheduler import batch_update, get_batch_session @@ -657,7 +705,7 @@ class ChatManager: except (ImportError, Exception) as e: logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}") - # 回退到原始方法 + # 回退到原始方法(最终方案) async def _db_save_stream_async(s_data_dict: dict): async with get_db_session() as session: user_info_d = s_data_dict.get("user_info") @@ -782,6 +830,46 @@ class ChatManager: chat_manager = None +def _convert_to_original_stream(self, optimized_stream) -> "ChatStream": + """将OptimizedChatStream转换为原始ChatStream以保持兼容性""" + try: + # 创建原始ChatStream实例 + original_stream = ChatStream( + stream_id=optimized_stream.stream_id, + platform=optimized_stream.platform, + user_info=optimized_stream._get_effective_user_info(), + group_info=optimized_stream._get_effective_group_info() + ) + + # 复制状态 + original_stream.create_time = optimized_stream.create_time + original_stream.last_active_time = optimized_stream.last_active_time + original_stream.sleep_pressure = optimized_stream.sleep_pressure + original_stream.base_interest_energy = optimized_stream.base_interest_energy + original_stream._focus_energy = optimized_stream._focus_energy + original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive + original_stream.saved = optimized_stream.saved + + # 复制上下文信息(如果存在) + if hasattr(optimized_stream, '_stream_context') and optimized_stream._stream_context: + original_stream.stream_context = optimized_stream._stream_context + + if hasattr(optimized_stream, '_context_manager') and optimized_stream._context_manager: + original_stream.context_manager = optimized_stream._context_manager + + return original_stream + + except Exception as e: + logger.error(f"转换OptimizedChatStream失败: {e}") + # 如果转换失败,创建一个新的原始流 + return ChatStream( + stream_id=optimized_stream.stream_id, + platform=optimized_stream.platform, + user_info=optimized_stream._get_effective_user_info(), + group_info=optimized_stream._get_effective_group_info() + ) + + def get_chat_manager(): global chat_manager if chat_manager is None: diff --git a/src/chat/message_receive/optimized_chat_stream.py b/src/chat/message_receive/optimized_chat_stream.py new file mode 100644 index 000000000..438f4e65c --- /dev/null +++ b/src/chat/message_receive/optimized_chat_stream.py @@ -0,0 +1,494 @@ +""" +优化版聊天流 - 实现写时复制机制 +避免不必要的深拷贝开销,提升多流并发性能 +""" + +import asyncio +import copy +import hashlib +import time +from typing import TYPE_CHECKING, Any, Dict, Optional + +from maim_message import GroupInfo, UserInfo +from rich.traceback import install + +from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.sqlalchemy_models import ChatStreams +from src.common.logger import get_logger +from src.config.config import global_config + +if TYPE_CHECKING: + from .message import MessageRecv + +install(extra_lines=3) + +logger = get_logger("optimized_chat_stream") + + +class SharedContext: + """共享上下文数据 - 只读数据结构""" + + def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None): + self.stream_id = stream_id + self.platform = platform + self.user_info = user_info + self.group_info = group_info + self.create_time = time.time() + self._frozen = True + + def __setattr__(self, name, value): + if hasattr(self, '_frozen') and self._frozen and name not in ['_frozen']: + raise AttributeError(f"SharedContext is frozen, cannot modify {name}") + super().__setattr__(name, value) + + +class LocalChanges: + """本地修改跟踪器""" + + def __init__(self): + self._changes: Dict[str, Any] = {} + self._dirty = False + + def set_change(self, key: str, value: Any): + """设置修改项""" + self._changes[key] = value + self._dirty = True + + def get_change(self, key: str, default: Any = None) -> Any: + """获取修改项""" + return self._changes.get(key, default) + + def has_changes(self) -> bool: + """是否有修改""" + return self._dirty + + def get_changes(self) -> Dict[str, Any]: + """获取所有修改""" + return self._changes.copy() + + def clear_changes(self): + """清除修改记录""" + self._changes.clear() + self._dirty = False + + +class OptimizedChatStream: + """优化版聊天流 - 使用写时复制机制""" + + def __init__( + self, + stream_id: str, + platform: str, + user_info: UserInfo, + group_info: Optional[GroupInfo] = None, + data: Optional[Dict] = None, + ): + # 共享的只读数据 + self._shared_context = SharedContext( + stream_id=stream_id, + platform=platform, + user_info=user_info, + group_info=group_info + ) + + # 本地修改数据 + self._local_changes = LocalChanges() + + # 写时复制标志 + self._copy_on_write = False + + # 基础参数 + self.base_interest_energy = data.get("base_interest_energy", 0.5) if data else 0.5 + self._focus_energy = data.get("focus_energy", 0.5) if data else 0.5 + self.no_reply_consecutive = 0 + + # 创建StreamContext(延迟创建) + self._stream_context = None + self._context_manager = None + + # 更新活跃时间 + self.update_active_time() + + # 保存标志 + self.saved = False + + @property + def stream_id(self) -> str: + return self._shared_context.stream_id + + @property + def platform(self) -> str: + return self._shared_context.platform + + @property + def user_info(self) -> UserInfo: + return self._shared_context.user_info + + @user_info.setter + def user_info(self, value: UserInfo): + """修改用户信息时触发写时复制""" + self._ensure_copy_on_write() + # 由于SharedContext是frozen的,我们需要在本地修改中记录 + self._local_changes.set_change('user_info', value) + + @property + def group_info(self) -> Optional[GroupInfo]: + if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes: + return self._local_changes.get_change('group_info') + return self._shared_context.group_info + + @group_info.setter + def group_info(self, value: Optional[GroupInfo]): + """修改群组信息时触发写时复制""" + self._ensure_copy_on_write() + self._local_changes.set_change('group_info', value) + + @property + def create_time(self) -> float: + if self._local_changes.has_changes() and 'create_time' in self._local_changes._changes: + return self._local_changes.get_change('create_time') + return self._shared_context.create_time + + @property + def last_active_time(self) -> float: + return self._local_changes.get_change('last_active_time', self.create_time) + + @last_active_time.setter + def last_active_time(self, value: float): + self._local_changes.set_change('last_active_time', value) + self.saved = False + + @property + def sleep_pressure(self) -> float: + return self._local_changes.get_change('sleep_pressure', 0.0) + + @sleep_pressure.setter + def sleep_pressure(self, value: float): + self._local_changes.set_change('sleep_pressure', value) + self.saved = False + + def _ensure_copy_on_write(self): + """确保写时复制机制生效""" + if not self._copy_on_write: + self._copy_on_write = True + # 深拷贝共享上下文到本地 + logger.debug(f"触发写时复制: {self.stream_id}") + + def _get_effective_user_info(self) -> UserInfo: + """获取有效的用户信息""" + if self._local_changes.has_changes() and 'user_info' in self._local_changes._changes: + return self._local_changes.get_change('user_info') + return self._shared_context.user_info + + def _get_effective_group_info(self) -> Optional[GroupInfo]: + """获取有效的群组信息""" + if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes: + return self._local_changes.get_change('group_info') + return self._shared_context.group_info + + def update_active_time(self): + """更新最后活跃时间""" + self.last_active_time = time.time() + + def set_context(self, message: "MessageRecv"): + """设置聊天消息上下文""" + # 确保stream_context存在 + if self._stream_context is None: + self._ensure_copy_on_write() + self._create_stream_context() + + # 将MessageRecv转换为DatabaseMessages并设置到stream_context + import json + from src.common.data_models.database_data_model import DatabaseMessages + + message_info = getattr(message, "message_info", {}) + user_info = getattr(message_info, "user_info", {}) + group_info = getattr(message_info, "group_info", {}) + + reply_to = None + if hasattr(message, "message_segment") and message.message_segment: + reply_to = self._extract_reply_from_segment(message.message_segment) + + db_message = DatabaseMessages( + message_id=getattr(message, "message_id", ""), + time=getattr(message, "time", time.time()), + chat_id=self._generate_chat_id(message_info), + reply_to=reply_to, + interest_value=getattr(message, "interest_value", 0.0), + key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False) + if getattr(message, "key_words", None) + else None, + key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False) + if getattr(message, "key_words_lite", None) + else None, + is_mentioned=getattr(message, "is_mentioned", None), + is_at=getattr(message, "is_at", False), + is_emoji=getattr(message, "is_emoji", False), + is_picid=getattr(message, "is_picid", False), + is_voice=getattr(message, "is_voice", False), + is_video=getattr(message, "is_video", False), + is_command=getattr(message, "is_command", False), + is_notify=getattr(message, "is_notify", False), + processed_plain_text=getattr(message, "processed_plain_text", ""), + display_message=getattr(message, "processed_plain_text", ""), + priority_mode=getattr(message, "priority_mode", None), + priority_info=json.dumps(getattr(message, "priority_info", None)) + if getattr(message, "priority_info", None) + else None, + additional_config=getattr(message_info, "additional_config", None), + user_id=str(getattr(user_info, "user_id", "")), + user_nickname=getattr(user_info, "user_nickname", ""), + user_cardname=getattr(user_info, "user_cardname", None), + user_platform=getattr(user_info, "platform", ""), + chat_info_group_id=getattr(group_info, "group_id", None), + chat_info_group_name=getattr(group_info, "group_name", None), + chat_info_group_platform=getattr(group_info, "platform", None), + chat_info_user_id=str(getattr(user_info, "user_id", "")), + chat_info_user_nickname=getattr(user_info, "user_nickname", ""), + chat_info_user_cardname=getattr(user_info, "user_cardname", None), + chat_info_user_platform=getattr(user_info, "platform", ""), + chat_info_stream_id=self.stream_id, + chat_info_platform=self.platform, + chat_info_create_time=self.create_time, + chat_info_last_active_time=self.last_active_time, + actions=self._safe_get_actions(message), + should_reply=getattr(message, "should_reply", False), + ) + + self._stream_context.set_current_message(db_message) + self._stream_context.priority_mode = getattr(message, "priority_mode", None) + self._stream_context.priority_info = getattr(message, "priority_info", None) + + logger.debug( + f"消息数据转移完成 - message_id: {db_message.message_id}, " + f"chat_id: {db_message.chat_id}, " + f"interest_value: {db_message.interest_value}" + ) + + def _create_stream_context(self): + """创建StreamContext""" + from src.common.data_models.message_manager_data_model import StreamContext + from src.plugin_system.base.component_types import ChatMode, ChatType + + self._stream_context = StreamContext( + stream_id=self.stream_id, + chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE, + chat_mode=ChatMode.NORMAL + ) + + # 创建单流上下文管理器 + from src.chat.message_manager.context_manager import SingleStreamContextManager + self._context_manager = SingleStreamContextManager( + stream_id=self.stream_id, context=self._stream_context + ) + + @property + def stream_context(self): + """获取StreamContext""" + if self._stream_context is None: + self._ensure_copy_on_write() + self._create_stream_context() + return self._stream_context + + @property + def context_manager(self): + """获取ContextManager""" + if self._context_manager is None: + self._ensure_copy_on_write() + self._create_stream_context() + return self._context_manager + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式 - 考虑本地修改""" + user_info = self._get_effective_user_info() + group_info = self._get_effective_group_info() + + return { + "stream_id": self.stream_id, + "platform": self.platform, + "user_info": user_info.to_dict() if user_info else None, + "group_info": group_info.to_dict() if group_info else None, + "create_time": self.create_time, + "last_active_time": self.last_active_time, + "sleep_pressure": self.sleep_pressure, + "focus_energy": self.focus_energy, + "base_interest_energy": self.base_interest_energy, + "stream_context_chat_type": self.stream_context.chat_type.value, + "stream_context_chat_mode": self.stream_context.chat_mode.value, + "interruption_count": self.stream_context.interruption_count, + } + + @classmethod + def from_dict(cls, data: Dict) -> "OptimizedChatStream": + """从字典创建实例""" + user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None + group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None + + instance = cls( + stream_id=data["stream_id"], + platform=data["platform"], + user_info=user_info, # type: ignore + group_info=group_info, + data=data, + ) + + # 恢复stream_context信息 + if "stream_context_chat_type" in data: + from src.plugin_system.base.component_types import ChatMode, ChatType + instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) + if "stream_context_chat_mode" in data: + from src.plugin_system.base.component_types import ChatMode, ChatType + instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) + + # 恢复interruption_count信息 + if "interruption_count" in data: + instance.stream_context.interruption_count = data["interruption_count"] + + return instance + + def _safe_get_actions(self, message: "MessageRecv") -> list | None: + """安全获取消息的actions字段""" + try: + actions = getattr(message, "actions", None) + if actions is None: + return None + + if isinstance(actions, str): + try: + import json + actions = json.loads(actions) + except json.JSONDecodeError: + logger.warning(f"无法解析actions JSON字符串: {actions}") + return None + + if isinstance(actions, list): + filtered_actions = [action for action in actions if action is not None and isinstance(action, str)] + return filtered_actions if filtered_actions else None + else: + logger.warning(f"actions字段类型不支持: {type(actions)}") + return None + + except Exception as e: + logger.warning(f"获取actions字段失败: {e}") + return None + + def _extract_reply_from_segment(self, segment) -> str | None: + """从消息段中提取reply_to信息""" + try: + if hasattr(segment, "type") and segment.type == "seglist": + if hasattr(segment, "data") and segment.data: + for seg in segment.data: + reply_id = self._extract_reply_from_segment(seg) + if reply_id: + return reply_id + elif hasattr(segment, "type") and segment.type == "reply": + return str(segment.data) if segment.data else None + except Exception as e: + logger.warning(f"提取reply_to信息失败: {e}") + return None + + def _generate_chat_id(self, message_info) -> str: + """生成chat_id,基于群组或用户信息""" + try: + group_info = getattr(message_info, "group_info", None) + user_info = getattr(message_info, "user_info", None) + + if group_info and hasattr(group_info, "group_id") and group_info.group_id: + return f"{self.platform}_{group_info.group_id}" + elif user_info and hasattr(user_info, "user_id") and user_info.user_id: + return f"{self.platform}_{user_info.user_id}_private" + else: + return self.stream_id + except Exception as e: + logger.warning(f"生成chat_id失败: {e}") + return self.stream_id + + @property + def focus_energy(self) -> float: + """获取缓存的focus_energy值""" + return self._focus_energy + + async def calculate_focus_energy(self) -> float: + """异步计算focus_energy""" + try: + all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size) + + user_id = None + effective_user_info = self._get_effective_user_info() + if effective_user_info and hasattr(effective_user_info, "user_id"): + user_id = str(effective_user_info.user_id) + + from src.chat.energy_system import energy_manager + + energy = await energy_manager.calculate_focus_energy( + stream_id=self.stream_id, messages=all_messages, user_id=user_id + ) + + self._focus_energy = energy + + logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}") + return energy + + except Exception as e: + logger.error(f"获取focus_energy失败: {e}", exc_info=True) + return self._focus_energy + + @focus_energy.setter + def focus_energy(self, value: float): + """设置focus_energy值""" + self._focus_energy = max(0.0, min(1.0, value)) + + async def _get_user_relationship_score(self) -> float: + """获取用户关系分""" + try: + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system + + effective_user_info = self._get_effective_user_info() + if effective_user_info and hasattr(effective_user_info, "user_id"): + user_id = str(effective_user_info.user_id) + relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id) + logger.debug(f"OptimizedChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}") + return max(0.0, min(1.0, relationship_score)) + + except Exception as e: + logger.warning(f"OptimizedChatStream {self.stream_id}: 插件内部关系分计算失败: {e}") + + return 0.3 + + def create_snapshot(self) -> "OptimizedChatStream": + """创建当前状态的快照(用于缓存)""" + # 创建一个新的实例,共享相同的上下文 + snapshot = OptimizedChatStream( + stream_id=self.stream_id, + platform=self.platform, + user_info=self._get_effective_user_info(), + group_info=self._get_effective_group_info() + ) + + # 复制本地修改(但不触发写时复制) + snapshot._local_changes._changes = self._local_changes.get_changes() + snapshot._local_changes._dirty = self._local_changes._dirty + snapshot._focus_energy = self._focus_energy + snapshot.base_interest_energy = self.base_interest_energy + snapshot.no_reply_consecutive = self.no_reply_consecutive + snapshot.saved = self.saved + + return snapshot + + +# 为了向后兼容,创建一个工厂函数 +def create_optimized_chat_stream( + stream_id: str, + platform: str, + user_info: UserInfo, + group_info: Optional[GroupInfo] = None, + data: Optional[Dict] = None, +) -> OptimizedChatStream: + """创建优化版聊天流实例""" + return OptimizedChatStream( + stream_id=stream_id, + platform=platform, + user_info=user_info, + group_info=group_info, + data=data + ) \ No newline at end of file