diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py deleted file mode 100644 index 4141b44e0..000000000 --- a/src/chat/memory_system/Hippocampus.py +++ /dev/null @@ -1,1765 +0,0 @@ -# -*- coding: utf-8 -*- -import datetime -import math -import random -import time -import asyncio -import re -import orjson -import jieba -import networkx as nx -import numpy as np - -from itertools import combinations -from typing import List, Tuple, Coroutine, Any, Set -from collections import Counter -from rich.traceback import install - -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from sqlalchemy import select, insert, update, delete -from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入 -from src.common.logger import get_logger -from src.common.database.sqlalchemy_database_api import get_db_session -from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 -from src.chat.utils.chat_message_builder import ( - get_raw_msg_by_timestamp, - build_readable_messages, - get_raw_msg_by_timestamp_with_chat, -) # 导入 build_readable_messages -from src.chat.utils.utils import translate_timestamp_to_human_readable - - -install(extra_lines=3) - - -def calculate_information_content(text): - """计算文本的信息量(熵)""" - char_count = Counter(text) - total_chars = len(text) - if total_chars == 0: - return 0 - entropy = 0 - for count in char_count.values(): - probability = count / total_chars - entropy -= probability * math.log2(probability) - - return entropy - - -def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else - """计算余弦相似度""" - dot_product = np.dot(v1, v2) - norm1 = np.linalg.norm(v1) - norm2 = np.linalg.norm(v2) - if norm1 == 0 or norm2 == 0: - return 0 - return dot_product / (norm1 * norm2) - - -logger = get_logger("memory") - - -class MemoryGraph: - def __init__(self): - self.G = nx.Graph() # 使用 networkx 的图结构 - - def connect_dot(self, concept1, concept2): - # 避免自连接 - if concept1 == concept2: - return - - current_time = datetime.datetime.now().timestamp() - - # 如果边已存在,增加 strength - if self.G.has_edge(concept1, concept2): - self.G[concept1][concept2]["strength"] = self.G[concept1][concept2].get("strength", 1) + 1 - # 更新最后修改时间 - self.G[concept1][concept2]["last_modified"] = current_time - else: - # 如果是新边,初始化 strength 为 1 - self.G.add_edge( - concept1, - concept2, - strength=1, - created_time=current_time, # 添加创建时间 - last_modified=current_time, - ) # 添加最后修改时间 - - def add_dot(self, concept, memory): - current_time = datetime.datetime.now().timestamp() - - if concept in self.G: - if "memory_items" in self.G.nodes[concept]: - if not isinstance(self.G.nodes[concept]["memory_items"], list): - self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]] - self.G.nodes[concept]["memory_items"].append(memory) - else: - self.G.nodes[concept]["memory_items"] = [memory] - # 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time - if "created_time" not in self.G.nodes[concept]: - self.G.nodes[concept]["created_time"] = current_time - # 更新最后修改时间 - self.G.nodes[concept]["last_modified"] = current_time - else: - # 如果是新节点,创建新的记忆列表 - self.G.add_node( - concept, - memory_items=[memory], - created_time=current_time, # 添加创建时间 - last_modified=current_time, - ) # 添加最后修改时间 - - def get_dot(self, concept): - # 检查节点是否存在于图中 - return (concept, self.G.nodes[concept]) if concept in self.G else None - - def get_related_item(self, topic, depth=1): - if topic not in self.G: - return [], [] - - first_layer_items = [] - second_layer_items = [] - - # 获取相邻节点 - neighbors = list(self.G.neighbors(topic)) - - # 获取当前节点的记忆项 - node_data = self.get_dot(topic) - if node_data: - concept, data = node_data - if "memory_items" in data: - memory_items = data["memory_items"] - if isinstance(memory_items, list): - first_layer_items.extend(memory_items) - else: - first_layer_items.append(memory_items) - - # 只在depth=2时获取第二层记忆 - if depth >= 2: - # 获取相邻节点的记忆项 - for neighbor in neighbors: - if node_data := self.get_dot(neighbor): - concept, data = node_data - if "memory_items" in data: - memory_items = data["memory_items"] - if isinstance(memory_items, list): - second_layer_items.extend(memory_items) - else: - second_layer_items.append(memory_items) - - return first_layer_items, second_layer_items - - @property - def dots(self): - # 返回所有节点对应的 Memory_dot 对象 - return [self.get_dot(node) for node in self.G.nodes()] - - def forget_topic(self, topic): - """随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点""" - if topic not in self.G: - return None - - # 获取话题节点数据 - node_data = self.G.nodes[topic] - - # 如果节点存在memory_items - if "memory_items" in node_data: - memory_items = node_data["memory_items"] - - # 确保memory_items是列表 - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - # 如果有记忆项可以删除 - if memory_items: - # 随机选择一个记忆项删除 - removed_item = random.choice(memory_items) - memory_items.remove(removed_item) - - # 更新节点的记忆项 - if memory_items: - self.G.nodes[topic]["memory_items"] = memory_items - else: - # 如果没有记忆项了,删除整个节点 - self.G.remove_node(topic) - - return removed_item - - return None - - -# 海马体 -class Hippocampus: - def __init__(self): - self.memory_graph = MemoryGraph() - self.model_small: LLMRequest = None # type: ignore - self.entorhinal_cortex: EntorhinalCortex = None # type: ignore - self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore - - def initialize(self): - # 初始化子组件 - self.entorhinal_cortex = EntorhinalCortex(self) - self.parahippocampal_gyrus = ParahippocampalGyrus(self) - # 从数据库加载记忆图 - # self.entorhinal_cortex.sync_memory_from_db() # 改为异步启动 - self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.small") - - def get_all_node_names(self) -> list: - """获取记忆图中所有节点的名字列表""" - return list(self.memory_graph.G.nodes()) - - @staticmethod - def calculate_node_hash(concept, memory_items) -> int: - """计算节点的特征值""" - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - # 使用集合来去重,避免排序 - unique_items = {str(item) for item in memory_items} - # 使用frozenset来保证顺序一致性 - content = f"{concept}:{frozenset(unique_items)}" - return hash(content) - - @staticmethod - def calculate_edge_hash(source, target) -> int: - """计算边的特征值""" - # 直接使用元组,保证顺序一致性 - return hash((source, target)) - - @staticmethod - def find_topic_llm(text: str, topic_num: int | list[int]): - # sourcery skip: inline-immediately-returned-variable - topic_num_str = "" - if isinstance(topic_num, list): - topic_num_str = f"{topic_num[0]}-{topic_num[1]}" - else: - topic_num_str = topic_num - - prompt = ( - f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num_str}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," - f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" - f"如果确定找不出主题或者没有明显主题,返回。" - ) - return prompt - - @staticmethod - def topic_what(text, topic): - # sourcery skip: inline-immediately-returned-variable - # 不再需要 time_info 参数 - prompt = ( - f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,' - f"要求包含对这个概念的定义,内容,知识,但是这些信息必须来自这段文字,不能添加信息。\n,请包含时间和人物。只输出这句话就好" - ) - return prompt - - @staticmethod - def calculate_topic_num(text, compress_rate): - """计算文本的话题数量""" - information_content = calculate_information_content(text) - topic_by_length = text.count("\n") * compress_rate - topic_by_information_content = max(1, min(5, int((information_content - 3) * 2))) - topic_num = int((topic_by_length + topic_by_information_content) / 2) - logger.debug( - f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, " - f"topic_num: {topic_num}" - ) - return topic_num - - def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: - """从关键词获取相关记忆。 - - Args: - keyword (str): 关键词 - max_depth (int, optional): 记忆检索深度,默认为2。1表示只获取直接相关的记忆,2表示获取间接相关的记忆。 - - Returns: - list: 记忆列表,每个元素是一个元组 (topic, memory_items, similarity) - - topic: str, 记忆主题 - - memory_items: list, 该主题下的记忆项列表 - - similarity: float, 与关键词的相似度 - """ - if not keyword: - return [] - - # 获取所有节点 - all_nodes = list(self.memory_graph.G.nodes()) - memories = [] - - # 计算关键词的词集合 - keyword_words = set(jieba.cut(keyword)) - - # 遍历所有节点,计算相似度 - for node in all_nodes: - node_words = set(jieba.cut(node)) - all_words = keyword_words | node_words - v1 = [1 if word in keyword_words else 0 for word in all_words] - v2 = [1 if word in node_words else 0 for word in all_words] - similarity = cosine_similarity(v1, v2) - - # 如果相似度超过阈值,获取该节点的记忆 - if similarity >= 0.3: # 可以调整这个阈值 - node_data = self.memory_graph.G.nodes[node] - memory_items = node_data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - memories.append((node, memory_items, similarity)) - - # 按相似度降序排序 - memories.sort(key=lambda x: x[2], reverse=True) - return memories - - async def get_keywords_from_text(self, text: str) -> list: - """从文本中提取关键词。 - - Args: - text (str): 输入文本 - fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 - 如果为True,使用jieba分词提取关键词,速度更快但可能不够准确。 - 如果为False,使用LLM提取关键词,速度较慢但更准确。 - """ - if not text: - return [] - - # 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算 - text_length = len(text) - topic_num: int | list[int] = 0 - if text_length <= 6: - words = jieba.cut(text) - keywords = [word for word in words if len(word) > 1] - keywords = list(set(keywords))[:3] # 限制最多3个关键词 - if keywords: - logger.debug(f"提取关键词: {keywords}") - return keywords - elif text_length <= 12: - topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本) - elif text_length <= 20: - topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本) - elif text_length <= 30: - topic_num = [3, 5] # 21-30字符: 3个关键词 (10.33%的文本) - elif text_length <= 50: - topic_num = [4, 5] # 31-50字符: 4个关键词 (9.79%的文本) - else: - topic_num = 5 # 51+字符: 5个关键词 (其余长文本) - - topics_response, _ = await self.model_small.generate_response_async(self.find_topic_llm(text, topic_num)) - - # 提取关键词 - keywords = re.findall(r"<([^>]+)>", topics_response) - if not keywords: - keywords = [] - else: - keywords = [ - keyword.strip() - for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - if keyword.strip() - ] - - if keywords: - logger.debug(f"提取关键词: {keywords}") - - return keywords - - async def get_memory_from_text( - self, - text: str, - max_memory_num: int = 3, - max_memory_length: int = 2, - max_depth: int = 3, - fast_retrieval: bool = False, - ) -> list: - """从文本中提取关键词并获取相关记忆。 - - Args: - text (str): 输入文本 - max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3,表示最多返回3条与输入文本相关度最高的记忆。 - max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2,表示每个主题最多返回2条相似度最高的记忆。 - max_depth (int, optional): 记忆检索深度。默认为3。值越大,检索范围越广,可以获取更多间接相关的记忆,但速度会变慢。 - fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 - 如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。 - 如果为False,使用LLM提取关键词,速度较慢但更准确。 - - Returns: - list: 记忆列表,每个元素是一个元组 (topic, memory_items, similarity) - - topic: str, 记忆主题 - - memory_items: list, 该主题下的记忆项列表 - - similarity: float, 与文本的相似度 - """ - keywords = await self.get_keywords_from_text(text) - - # 过滤掉不存在于记忆图中的关键词 - valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] - if not valid_keywords: - logger.debug("没有找到有效的关键词节点") - return [] - - logger.info(f"有效的关键词: {', '.join(valid_keywords)}") - - # 从每个关键词获取记忆 - activate_map = {} # 存储每个词的累计激活值 - - # 对每个关键词进行扩散式检索 - for keyword in valid_keywords: - logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):") - # 初始化激活值 - activation_values = {keyword: 1.0} - # 记录已访问的节点 - visited_nodes = {keyword} - # 待处理的节点队列,每个元素是(节点, 激活值, 当前深度) - nodes_to_process = [(keyword, 1.0, 0)] - - while nodes_to_process: - current_node, current_activation, current_depth = nodes_to_process.pop(0) - - # 如果激活值小于0或超过最大深度,停止扩散 - if current_activation <= 0 or current_depth >= max_depth: - continue - - # 获取当前节点的所有邻居 - neighbors = list(self.memory_graph.G.neighbors(current_node)) - - for neighbor in neighbors: - if neighbor in visited_nodes: - continue - - # 获取连接强度 - edge_data = self.memory_graph.G[current_node][neighbor] - strength = edge_data.get("strength", 1) - - # 计算新的激活值 - new_activation = current_activation - (1 / strength) - - if new_activation > 0: - activation_values[neighbor] = new_activation - visited_nodes.add(neighbor) - nodes_to_process.append((neighbor, new_activation, current_depth + 1)) - # logger.debug( - # f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})" - # ) # noqa: E501 - - # 更新激活映射 - for node, activation_value in activation_values.items(): - if activation_value > 0: - if node in activate_map: - activate_map[node] += activation_value - else: - activate_map[node] = activation_value - - # 输出激活映射 - # logger.info("激活映射统计:") - # for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True): - # logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}") - - # 基于激活值平方的独立概率选择 - remember_map = {} - # logger.info("基于激活值平方的归一化选择:") - - # 计算所有激活值的平方和 - total_squared_activation = sum(activation**2 for activation in activate_map.values()) - if total_squared_activation > 0: - # 计算归一化的激活值 - normalized_activations = { - node: (activation**2) / total_squared_activation for node, activation in activate_map.items() - } - - # 按归一化激活值排序并选择前max_memory_num个 - sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num] - - # 将选中的节点添加到remember_map - for node, normalized_activation in sorted_nodes: - remember_map[node] = activate_map[node] # 使用原始激活值 - logger.debug( - f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})" - ) - else: - logger.info("没有有效的激活值") - - # 从选中的节点中提取记忆 - all_memories = [] - # logger.info("开始从选中的节点中提取记忆:") - for node, activation in remember_map.items(): - logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):") - node_data = self.memory_graph.G.nodes[node] - memory_items = node_data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - if memory_items: - logger.debug(f"节点包含 {len(memory_items)} 条记忆") - # 计算每条记忆与输入文本的相似度 - memory_similarities = [] - for memory in memory_items: - # 计算与输入文本的相似度 - memory_words = set(jieba.cut(memory)) - text_words = set(jieba.cut(text)) - all_words = memory_words | text_words - v1 = [1 if word in memory_words else 0 for word in all_words] - v2 = [1 if word in text_words else 0 for word in all_words] - similarity = cosine_similarity(v1, v2) - memory_similarities.append((memory, similarity)) - - # 按相似度排序 - memory_similarities.sort(key=lambda x: x[1], reverse=True) - # 获取最匹配的记忆 - top_memories = memory_similarities[:max_memory_length] - - # 添加到结果中 - all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories) - else: - logger.info("节点没有记忆") - - # 去重(基于记忆内容) - logger.debug("开始记忆去重:") - seen_memories = set() - unique_memories = [] - for topic, memory_items, activation_value in all_memories: - memory = memory_items[0] # 因为每个topic只有一条记忆 - if memory not in seen_memories: - seen_memories.add(memory) - unique_memories.append((topic, memory_items, activation_value)) - logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})") - else: - logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})") - - # 转换为(关键词, 记忆)格式 - result = [] - for topic, memory_items, _ in unique_memories: - memory = memory_items[0] # 因为每个topic只有一条记忆 - result.append((topic, memory)) - logger.debug(f"选中记忆: {memory} (来自节点: {topic})") - - return result - - async def get_memory_from_topic( - self, - keywords: list[str], - max_memory_num: int = 3, - max_memory_length: int = 2, - max_depth: int = 3, - ) -> list: - """从关键词列表中获取相关记忆。 - - Args: - keywords (list): 输入关键词列表 - max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3,表示最多返回3条与输入关键词相关度最高的记忆。 - max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2,表示每个主题最多返回2条相似度最高的记忆。 - max_depth (int, optional): 记忆检索深度。默认为3。值越大,检索范围越广,可以获取更多间接相关的记忆,但速度会变慢。 - - Returns: - list: 记忆列表,每个元素是一个元组 (topic, memory_items, similarity) - - topic: str, 记忆主题 - - memory_items: list, 该主题下的记忆项列表 - - similarity: float, 与关键词的相似度 - """ - if not keywords: - return [] - - logger.info(f"提取的关键词: {', '.join(keywords)}") - - # 过滤掉不存在于记忆图中的关键词 - valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] - if not valid_keywords: - logger.debug("没有找到有效的关键词节点") - return [] - - logger.debug(f"有效的关键词: {', '.join(valid_keywords)}") - - # 从每个关键词获取记忆 - activate_map = {} # 存储每个词的累计激活值 - - # 对每个关键词进行扩散式检索 - for keyword in valid_keywords: - logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):") - # 初始化激活值 - activation_values = {keyword: 1.0} - # 记录已访问的节点 - visited_nodes = {keyword} - # 待处理的节点队列,每个元素是(节点, 激活值, 当前深度) - nodes_to_process = [(keyword, 1.0, 0)] - - while nodes_to_process: - current_node, current_activation, current_depth = nodes_to_process.pop(0) - - # 如果激活值小于0或超过最大深度,停止扩散 - if current_activation <= 0 or current_depth >= max_depth: - continue - - # 获取当前节点的所有邻居 - neighbors = list(self.memory_graph.G.neighbors(current_node)) - - for neighbor in neighbors: - if neighbor in visited_nodes: - continue - - # 获取连接强度 - edge_data = self.memory_graph.G[current_node][neighbor] - strength = edge_data.get("strength", 1) - - # 计算新的激活值 - new_activation = current_activation - (1 / strength) - - if new_activation > 0: - activation_values[neighbor] = new_activation - visited_nodes.add(neighbor) - nodes_to_process.append((neighbor, new_activation, current_depth + 1)) - # logger.debug( - # f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})" - # ) # noqa: E501 - - # 更新激活映射 - for node, activation_value in activation_values.items(): - if activation_value > 0: - if node in activate_map: - activate_map[node] += activation_value - else: - activate_map[node] = activation_value - - # 基于激活值平方的独立概率选择 - remember_map = {} - # logger.info("基于激活值平方的归一化选择:") - - # 计算所有激活值的平方和 - total_squared_activation = sum(activation**2 for activation in activate_map.values()) - if total_squared_activation > 0: - # 计算归一化的激活值 - normalized_activations = { - node: (activation**2) / total_squared_activation for node, activation in activate_map.items() - } - - # 按归一化激活值排序并选择前max_memory_num个 - sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num] - - # 将选中的节点添加到remember_map - for node, normalized_activation in sorted_nodes: - remember_map[node] = activate_map[node] # 使用原始激活值 - logger.debug( - f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})" - ) - else: - logger.info("没有有效的激活值") - - # 从选中的节点中提取记忆 - all_memories = [] - # logger.info("开始从选中的节点中提取记忆:") - for node, activation in remember_map.items(): - logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):") - node_data = self.memory_graph.G.nodes[node] - memory_items = node_data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - if memory_items: - logger.debug(f"节点包含 {len(memory_items)} 条记忆") - # 计算每条记忆与输入关键词的相似度 - memory_similarities = [] - for memory in memory_items: - # 计算与输入关键词的相似度 - memory_words = set(jieba.cut(memory)) - # 将所有关键词合并成一个字符串来计算相似度 - keywords_text = " ".join(valid_keywords) - keywords_words = set(jieba.cut(keywords_text)) - all_words = memory_words | keywords_words - v1 = [1 if word in memory_words else 0 for word in all_words] - v2 = [1 if word in keywords_words else 0 for word in all_words] - similarity = cosine_similarity(v1, v2) - memory_similarities.append((memory, similarity)) - - # 按相似度排序 - memory_similarities.sort(key=lambda x: x[1], reverse=True) - # 获取最匹配的记忆 - top_memories = memory_similarities[:max_memory_length] - - # 添加到结果中 - all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories) - else: - logger.info("节点没有记忆") - - # 去重(基于记忆内容) - logger.debug("开始记忆去重:") - seen_memories = set() - unique_memories = [] - for topic, memory_items, activation_value in all_memories: - memory = memory_items[0] # 因为每个topic只有一条记忆 - if memory not in seen_memories: - seen_memories.add(memory) - unique_memories.append((topic, memory_items, activation_value)) - logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})") - else: - logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})") - - # 转换为(关键词, 记忆)格式 - result = [] - for topic, memory_items, _ in unique_memories: - memory = memory_items[0] # 因为每个topic只有一条记忆 - result.append((topic, memory)) - logger.debug(f"选中记忆: {memory} (来自节点: {topic})") - - return result - - async def get_activate_from_text( - self, text: str, max_depth: int = 3, fast_retrieval: bool = False - ) -> tuple[float, list[str]]: - """从文本中提取关键词并获取相关记忆。 - - Args: - text (str): 输入文本 - max_depth (int, optional): 记忆检索深度。默认为2。 - fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 - 如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。 - 如果为False,使用LLM提取关键词,速度较慢但更准确。 - - Returns: - float: 激活节点数与总节点数的比值 - list[str]: 有效的关键词 - """ - keywords = await self.get_keywords_from_text(text) - - # 过滤掉不存在于记忆图中的关键词 - valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] - if not valid_keywords: - # logger.info("没有找到有效的关键词节点") - return 0, [] - - logger.debug(f"有效的关键词: {', '.join(valid_keywords)}") - - # 从每个关键词获取记忆 - activate_map = {} # 存储每个词的累计激活值 - - # 对每个关键词进行扩散式检索 - for keyword in valid_keywords: - logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):") - # 初始化激活值 - activation_values = {keyword: 1.5} - # 记录已访问的节点 - visited_nodes = {keyword} - # 待处理的节点队列,每个元素是(节点, 激活值, 当前深度) - nodes_to_process = [(keyword, 1.0, 0)] - - while nodes_to_process: - current_node, current_activation, current_depth = nodes_to_process.pop(0) - - # 如果激活值小于0或超过最大深度,停止扩散 - if current_activation <= 0 or current_depth >= max_depth: - continue - - # 获取当前节点的所有邻居 - neighbors = list(self.memory_graph.G.neighbors(current_node)) - - for neighbor in neighbors: - if neighbor in visited_nodes: - continue - - # 获取连接强度 - edge_data = self.memory_graph.G[current_node][neighbor] - strength = edge_data.get("strength", 1) - - # 计算新的激活值 - new_activation = current_activation - (1 / strength) - - if new_activation > 0: - activation_values[neighbor] = new_activation - visited_nodes.add(neighbor) - nodes_to_process.append((neighbor, new_activation, current_depth + 1)) - # logger.debug( - # f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})") # noqa: E501 - - # 更新激活映射 - for node, activation_value in activation_values.items(): - if activation_value > 0: - if node in activate_map: - activate_map[node] += activation_value - else: - activate_map[node] = activation_value - - # 计算激活节点数与总节点数的比值 - total_activation = sum(activate_map.values()) - # logger.debug(f"总激活值: {total_activation:.2f}") - total_nodes = len(self.memory_graph.G.nodes()) - # activated_nodes = len(activate_map) - activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0 - activation_ratio = activation_ratio * 50 - logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}") - - return activation_ratio, keywords - - -# 负责海马体与其他部分的交互 -class EntorhinalCortex: - def __init__(self, hippocampus: Hippocampus): - self.hippocampus = hippocampus - self.memory_graph = hippocampus.memory_graph - - async def get_memory_sample(self) -> tuple[list, list[str]]: - """从数据库获取记忆样本""" - # 硬编码:每条消息最大记忆次数 - max_memorized_time_per_msg = 2 - - # 创建双峰分布的记忆调度器 - sample_scheduler = MemoryBuildScheduler( - n_hours1=global_config.memory.memory_build_distribution[0], - std_hours1=global_config.memory.memory_build_distribution[1], - weight1=global_config.memory.memory_build_distribution[2], - n_hours2=global_config.memory.memory_build_distribution[3], - std_hours2=global_config.memory.memory_build_distribution[4], - weight2=global_config.memory.memory_build_distribution[5], - total_samples=global_config.memory.memory_build_sample_num, - ) - - timestamps = sample_scheduler.get_timestamp_array() - # 使用 translate_timestamp_to_human_readable 并指定 mode="normal" - readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps] - for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False): - logger.debug(f"回忆往事: {readable_timestamp}") - chat_samples = [] - all_message_ids_to_update = [] - for timestamp in timestamps: - if result := await self.random_get_msg_snippet( - timestamp, - global_config.memory.memory_build_sample_length, - max_memorized_time_per_msg, - ): - messages, message_ids_to_update = result - time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 - logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") - chat_samples.append(messages) - all_message_ids_to_update.extend(message_ids_to_update) - else: - logger.debug(f"时间戳 {timestamp} 的消息无需记忆") - - return chat_samples, all_message_ids_to_update - - @staticmethod - async def random_get_msg_snippet( - target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int - ) -> tuple[list, list[str]] | None: - # sourcery skip: invert-any-all, use-any, use-named-expression, use-next - """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" - time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟 - - for _ in range(3): - # 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds - timestamp_start = target_timestamp - timestamp_end = target_timestamp + time_window_seconds - - if chosen_message := await get_raw_msg_by_timestamp( - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - limit=1, - limit_mode="earliest", - ): - chat_id: str = chosen_message[0].get("chat_id") # type: ignore - - if messages := await get_raw_msg_by_timestamp_with_chat( - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - limit=chat_size, - limit_mode="earliest", - chat_id=chat_id, - ): - # 检查获取到的所有消息是否都未达到最大记忆次数 - all_valid = True - for message in messages: - if message.get("memorized_times", 0) >= max_memorized_time_per_msg: - all_valid = False - break - - # 如果所有消息都有效 - if all_valid: - # 返回消息和需要更新的message_id - message_ids_to_update = [msg["message_id"] for msg in messages] - return messages, message_ids_to_update - - target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试 - - # 三次尝试都失败,返回 None - return None - - async def sync_memory_to_db(self): - """将记忆图同步到数据库""" - start_time = time.time() - current_time = datetime.datetime.now().timestamp() - - # 获取数据库中所有节点和内存中所有节点 - async with get_db_session() as session: - result = await session.execute(select(GraphNodes)) - db_nodes = {node.concept: node for node in result.scalars()} - memory_nodes = list(self.memory_graph.G.nodes(data=True)) - - # 批量准备节点数据 - nodes_to_create = [] - nodes_to_update = [] - nodes_to_delete = set() - - # 处理节点 - for concept, data in memory_nodes: - if not concept or not isinstance(concept, str): - self.memory_graph.G.remove_node(concept) - continue - - memory_items = data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - if not memory_items: - self.memory_graph.G.remove_node(concept) - continue - - # 计算内存中节点的特征值 - memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items) - created_time = data.get("created_time", current_time) - last_modified = data.get("last_modified", current_time) - - # 将memory_items转换为JSON字符串 - try: - memory_items = [str(item) for item in memory_items] - memory_items_json = orjson.dumps(memory_items).decode("utf-8") - if not memory_items_json: - continue - except Exception: - self.memory_graph.G.remove_node(concept) - continue - - if concept not in db_nodes: - nodes_to_create.append( - { - "concept": concept, - "memory_items": memory_items_json, - "hash": memory_hash, - "weight": 1.0, # 默认权重为1.0 - "created_time": created_time, - "last_modified": last_modified, - } - ) - else: - db_node = db_nodes[concept] - if db_node.hash != memory_hash: - nodes_to_update.append( - { - "concept": concept, - "memory_items": memory_items_json, - "hash": memory_hash, - "last_modified": last_modified, - } - ) - - # 计算需要删除的节点 - memory_concepts = {concept for concept, _ in memory_nodes} - nodes_to_delete = set(db_nodes.keys()) - memory_concepts - - # 批量处理节点 - if nodes_to_create: - # 在插入前进行去重检查 - unique_nodes_to_create = [] - seen_concepts = set(db_nodes.keys()) - for node_data in nodes_to_create: - concept = node_data["concept"] - if concept not in seen_concepts: - unique_nodes_to_create.append(node_data) - seen_concepts.add(concept) - - if unique_nodes_to_create: - batch_size = 100 - for i in range(0, len(unique_nodes_to_create), batch_size): - batch = unique_nodes_to_create[i : i + batch_size] - await session.execute(insert(GraphNodes), batch) - - if nodes_to_update: - batch_size = 100 - for i in range(0, len(nodes_to_update), batch_size): - batch = nodes_to_update[i : i + batch_size] - for node_data in batch: - await session.execute( - update(GraphNodes) - .where(GraphNodes.concept == node_data["concept"]) - .values(**{k: v for k, v in node_data.items() if k != "concept"}) - ) - - if nodes_to_delete: - await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) - - # 处理边的信息 - result = await session.execute(select(GraphEdges)) - db_edges = list(result.scalars()) - memory_edges = list(self.memory_graph.G.edges(data=True)) - - # 创建边的哈希值字典 - db_edge_dict = {} - for edge in db_edges: - edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target) - db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength} - - # 批量准备边数据 - edges_to_create = [] - edges_to_update = [] - - # 处理边 - for source, target, data in memory_edges: - edge_hash = self.hippocampus.calculate_edge_hash(source, target) - edge_key = (source, target) - strength = data.get("strength", 1) - created_time = data.get("created_time", current_time) - last_modified = data.get("last_modified", current_time) - - if edge_key not in db_edge_dict: - edges_to_create.append( - { - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "created_time": created_time, - "last_modified": last_modified, - } - ) - elif db_edge_dict[edge_key]["hash"] != edge_hash: - edges_to_update.append( - { - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "last_modified": last_modified, - } - ) - - # 计算需要删除的边 - memory_edge_keys = {(source, target) for source, target, _ in memory_edges} - edges_to_delete = set(db_edge_dict.keys()) - memory_edge_keys - - # 批量处理边 - if edges_to_create: - batch_size = 100 - for i in range(0, len(edges_to_create), batch_size): - batch = edges_to_create[i : i + batch_size] - await session.execute(insert(GraphEdges), batch) - - if edges_to_update: - batch_size = 100 - for i in range(0, len(edges_to_update), batch_size): - batch = edges_to_update[i : i + batch_size] - for edge_data in batch: - await session.execute( - update(GraphEdges) - .where( - (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) - ) - .values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}) - ) - - if edges_to_delete: - for source, target in edges_to_delete: - await session.execute( - delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target)) - ) - - # 提交事务 - await session.commit() - - end_time = time.time() - logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}秒") - logger.info(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") - - async def resync_memory_to_db(self): - """清空数据库并重新同步所有记忆数据""" - start_time = time.time() - logger.info("[数据库] 开始重新同步所有记忆数据...") - - # 清空数据库 - async with get_db_session() as session: - clear_start = time.time() - await session.execute(delete(GraphNodes)) - await session.execute(delete(GraphEdges)) - - clear_end = time.time() - logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") - - # 获取所有节点和边 - memory_nodes = list(self.memory_graph.G.nodes(data=True)) - memory_edges = list(self.memory_graph.G.edges(data=True)) - current_time = datetime.datetime.now().timestamp() - - # 批量准备节点数据 - nodes_data = [] - for concept, data in memory_nodes: - memory_items = data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - try: - memory_items = [str(item) for item in memory_items] - if memory_items_json := orjson.dumps(memory_items).decode("utf-8"): - nodes_data.append( - { - "concept": concept, - "memory_items": memory_items_json, - "hash": self.hippocampus.calculate_node_hash(concept, memory_items), - "weight": 1.0, # 默认权重为1.0 - "created_time": data.get("created_time", current_time), - "last_modified": data.get("last_modified", current_time), - } - ) - - except Exception as e: - logger.error(f"准备节点 {concept} 数据时发生错误: {e}") - continue - - # 批量准备边数据 - edges_data = [] - for source, target, data in memory_edges: - try: - edges_data.append( - { - "source": source, - "target": target, - "strength": data.get("strength", 1), - "hash": self.hippocampus.calculate_edge_hash(source, target), - "created_time": data.get("created_time", current_time), - "last_modified": data.get("last_modified", current_time), - } - ) - except Exception as e: - logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}") - continue - - # 批量写入节点 - node_start = time.time() - if nodes_data: - batch_size = 500 # 增加批量大小 - for i in range(0, len(nodes_data), batch_size): - batch = nodes_data[i : i + batch_size] - await session.execute(insert(GraphNodes), batch) - - node_end = time.time() - logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒") - - # 批量写入边 - edge_start = time.time() - if edges_data: - batch_size = 500 # 增加批量大小 - for i in range(0, len(edges_data), batch_size): - batch = edges_data[i : i + batch_size] - await session.execute(insert(GraphEdges), batch) - await session.commit() - - edge_end = time.time() - logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒") - - end_time = time.time() - logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒") - logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边") - - async def sync_memory_from_db(self): - """从数据库同步数据到内存中的图结构""" - current_time = datetime.datetime.now().timestamp() - need_update = False - - # 清空当前图 - self.memory_graph.G.clear() - - # 从数据库加载所有节点 - async with get_db_session() as session: - result = await session.execute(select(GraphNodes)) - nodes = list(result.scalars()) - for node in nodes: - concept = node.concept - try: - memory_items = orjson.loads(node.memory_items) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - # 检查时间字段是否存在 - if not node.created_time or not node.last_modified: - need_update = True - # 更新数据库中的节点 - update_data = {} - if not node.created_time: - update_data["created_time"] = current_time - if not node.last_modified: - update_data["last_modified"] = current_time - - await session.execute( - update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data) - ) - - # 获取时间信息(如果不存在则使用当前时间) - created_time = node.created_time or current_time - last_modified = node.last_modified or current_time - - # 添加节点到图中 - self.memory_graph.G.add_node( - concept, memory_items=memory_items, created_time=created_time, last_modified=last_modified - ) - except Exception as e: - logger.error(f"加载节点 {concept} 时发生错误: {e}") - continue - - # 从数据库加载所有边 - result = await session.execute(select(GraphEdges)) - edges = list(result.scalars()) - for edge in edges: - source = edge.source - target = edge.target - strength = edge.strength - - # 检查时间字段是否存在 - if not edge.created_time or not edge.last_modified: - need_update = True - # 更新数据库中的边 - update_data = {} - if not edge.created_time: - update_data["created_time"] = current_time - if not edge.last_modified: - update_data["last_modified"] = current_time - - await session.execute( - update(GraphEdges) - .where((GraphEdges.source == source) & (GraphEdges.target == target)) - .values(**update_data) - ) - - # 获取时间信息(如果不存在则使用当前时间) - created_time = edge.created_time or current_time - last_modified = edge.last_modified or current_time - - # 只有当源节点和目标节点都存在时才添加边 - if source in self.memory_graph.G and target in self.memory_graph.G: - self.memory_graph.G.add_edge( - source, target, strength=strength, created_time=created_time, last_modified=last_modified - ) - await session.commit() - - if need_update: - logger.info("[数据库] 已为缺失的时间字段进行补充") - - -# 负责整合,遗忘,合并记忆 -class ParahippocampalGyrus: - def __init__(self, hippocampus: Hippocampus): - self.hippocampus = hippocampus - self.memory_graph = hippocampus.memory_graph - - self.memory_modify_model = LLMRequest( - model_set=model_config.model_task_config.utils, request_type="memory.modify" - ) - - async def memory_compress(self, messages: list, compress_rate=0.1): - """压缩和总结消息内容,生成记忆主题和摘要。 - - Args: - messages (list): 消息列表,每个消息是一个字典,包含数据库消息结构。 - compress_rate (float, optional): 压缩率,用于控制生成的主题数量。默认为0.1。 - - Returns: - tuple: (compressed_memory, similar_topics_dict) - - compressed_memory: set, 压缩后的记忆集合,每个元素是一个元组 (topic, summary) - - similar_topics_dict: dict, 相似主题字典 - - Process: - 1. 使用 build_readable_messages 生成包含时间、人物信息的格式化文本。 - 2. 使用LLM提取关键主题。 - 3. 过滤掉包含禁用关键词的主题。 - 4. 为每个主题生成摘要。 - 5. 查找与现有记忆中的相似主题。 - """ - if not messages: - return set(), {} - - # 1. 使用 build_readable_messages 生成格式化文本 - # build_readable_messages 只返回一个字符串,不需要解包 - input_text = await build_readable_messages( - messages, - merge_messages=True, # 合并连续消息 - timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式 - replace_bot_name=False, # 保留原始用户名 - ) - - # 如果生成的可读文本为空(例如所有消息都无效),则直接返回 - if not input_text: - logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。") - return set(), {} - - current_date = f"当前日期: {datetime.datetime.now().isoformat()}" - input_text = f"{current_date}\n{input_text}" - - logger.debug(f"记忆来源:\n{input_text}") - - # 2. 使用LLM提取关键主题 - topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) - topics_response, _ = await self.memory_modify_model.generate_response_async( - self.hippocampus.find_topic_llm(input_text, topic_num) - ) - - # 提取<>中的内容 - topics = re.findall(r"<([^>]+)>", topics_response) - - if not topics: - topics = ["none"] - else: - topics = [ - topic.strip() - for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - if topic.strip() - ] - - # 3. 过滤掉包含禁用关键词的topic - filtered_topics = [ - topic for topic in topics if all(keyword not in topic for keyword in global_config.memory.memory_ban_words) - ] - - logger.debug(f"过滤后话题: {filtered_topics}") - - # 4. 创建所有话题的摘要生成任务 - tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List | None]]]]] = [] - for topic in filtered_topics: - # 调用修改后的 topic_what,不再需要 time_info - topic_what_prompt = self.hippocampus.topic_what(input_text, topic) - try: - task = self.memory_modify_model.generate_response_async(topic_what_prompt) - tasks.append((topic.strip(), task)) - except Exception as e: - logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}") - continue - - # 等待所有任务完成 - compressed_memory: Set[Tuple[str, str]] = set() - similar_topics_dict = {} - - for topic, task in tasks: - response = await task - if response: - compressed_memory.add((topic, response[0])) - - existing_topics = list(self.memory_graph.G.nodes()) - similar_topics = [] - - for existing_topic in existing_topics: - topic_words = set(jieba.cut(topic)) - existing_words = set(jieba.cut(existing_topic)) - - all_words = topic_words | existing_words - v1 = [1 if word in topic_words else 0 for word in all_words] - v2 = [1 if word in existing_words else 0 for word in all_words] - similarity = cosine_similarity(v1, v2) - - if similarity >= 0.7: - similar_topics.append((existing_topic, similarity)) - - similar_topics.sort(key=lambda x: x[1], reverse=True) - similar_topics = similar_topics[:3] - similar_topics_dict[topic] = similar_topics - - return compressed_memory, similar_topics_dict - - async def operation_build_memory(self): - # sourcery skip: merge-list-appends-into-extend - logger.info("------------------------------------开始构建记忆--------------------------------------") - start_time = time.time() - memory_samples, all_message_ids_to_update = await self.hippocampus.entorhinal_cortex.get_memory_sample() - all_added_nodes = [] - all_connected_nodes = [] - all_added_edges = [] - for i, messages in enumerate(memory_samples, 1): - all_topics = [] - compress_rate = global_config.memory.memory_compress_rate - try: - compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) - except Exception as e: - logger.error(f"压缩记忆时发生错误: {e}") - continue - for topic, memory in compressed_memory: - logger.info(f"取得记忆: {topic} - {memory}") - for topic, similar_topics in similar_topics_dict.items(): - logger.debug(f"相似话题: {topic} - {similar_topics}") - - current_time = datetime.datetime.now().timestamp() - logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}") - all_added_nodes.extend(topic for topic, _ in compressed_memory) - - for topic, memory in compressed_memory: - self.memory_graph.add_dot(topic, memory) - all_topics.append(topic) - - if topic in similar_topics_dict: - similar_topics = similar_topics_dict[topic] - for similar_topic, similarity in similar_topics: - if topic != similar_topic: - strength = int(similarity * 10) - - logger.debug(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})") - all_added_edges.append(f"{topic}-{similar_topic}") - - all_connected_nodes.append(topic) - all_connected_nodes.append(similar_topic) - - self.memory_graph.G.add_edge( - topic, - similar_topic, - strength=strength, - created_time=current_time, - last_modified=current_time, - ) - - for topic1, topic2 in combinations(all_topics, 2): - logger.debug(f"连接同批次节点: {topic1} 和 {topic2}") - all_added_edges.append(f"{topic1}-{topic2}") - self.memory_graph.connect_dot(topic1, topic2) - - progress = (i / len(memory_samples)) * 100 - bar_length = 30 - filled_length = int(bar_length * i // len(memory_samples)) - bar = "█" * filled_length + "-" * (bar_length - filled_length) - logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") - - if all_added_nodes: - logger.info(f"更新记忆: {', '.join(all_added_nodes)}") - if all_added_edges: - logger.debug(f"强化连接: {', '.join(all_added_edges)}") - if all_connected_nodes: - logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") - - # 先同步记忆图 - await self.hippocampus.entorhinal_cortex.sync_memory_to_db() - - # 最后批量更新消息的记忆次数 - if all_message_ids_to_update: - async with get_db_session() as session: - # 使用 in_ 操作符进行批量更新 - await session.execute( - update(Messages) - .where(Messages.message_id.in_(all_message_ids_to_update)) - .values(memorized_times=Messages.memorized_times + 1) - ) - await session.commit() - logger.info(f"批量更新了 {len(all_message_ids_to_update)} 条消息的记忆次数") - - end_time = time.time() - logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------") - - async def operation_forget_topic(self, percentage=0.005): - start_time = time.time() - logger.info("[遗忘] 开始检查数据库...") - - # 验证百分比参数 - if not 0 <= percentage <= 1: - logger.warning(f"[遗忘] 无效的遗忘百分比: {percentage}, 使用默认值 0.005") - percentage = 0.005 - - all_nodes = list(self.memory_graph.G.nodes()) - all_edges = list(self.memory_graph.G.edges()) - - if not all_nodes and not all_edges: - logger.info("[遗忘] 记忆图为空,无需进行遗忘操作") - return - - # 确保至少检查1个节点和边,且不超过总数 - check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage))) - check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage))) - - # 只有在有足够的节点和边时进行采样 - if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count: - try: - nodes_to_check = random.sample(all_nodes, check_nodes_count) - edges_to_check = random.sample(all_edges, check_edges_count) - except ValueError as e: - logger.error(f"[遗忘] 采样错误: {str(e)}") - return - else: - logger.info("[遗忘] 没有足够的节点或边进行遗忘操作") - return - - # 使用列表存储变化信息 - edge_changes = { - "weakened": [], # 存储减弱的边 - "removed": [], # 存储移除的边 - } - node_changes = { - "reduced": [], # 存储减少记忆的节点 - "removed": [], # 存储移除的节点 - } - - current_time = datetime.datetime.now().timestamp() - - logger.info("[遗忘] 开始检查连接...") - edge_check_start = time.time() - for source, target in edges_to_check: - edge_data = self.memory_graph.G[source][target] - last_modified = edge_data.get("last_modified", current_time) - - if current_time - last_modified > 3600 * global_config.memory.memory_forget_time: - current_strength = edge_data.get("strength", 1) - new_strength = current_strength - 1 - - if new_strength <= 0: - self.memory_graph.G.remove_edge(source, target) - edge_changes["removed"].append(f"{source} -> {target}") - else: - edge_data["strength"] = new_strength - edge_data["last_modified"] = current_time - edge_changes["weakened"].append(f"{source}-{target} (强度: {current_strength} -> {new_strength})") - edge_check_end = time.time() - logger.info(f"[遗忘] 连接检查耗时: {edge_check_end - edge_check_start:.2f}秒") - - logger.info("[遗忘] 开始检查节点...") - node_check_start = time.time() - - # 初始化整合相关变量 - merged_count = 0 - nodes_modified = set() - - for node in nodes_to_check: - # 检查节点是否存在,以防在迭代中被移除(例如边移除导致) - if node not in self.memory_graph.G: - continue - - node_data = self.memory_graph.G.nodes[node] - - # 首先获取记忆项 - memory_items = node_data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - # 新增:检查节点是否为空 - if not memory_items: - try: - self.memory_graph.G.remove_node(node) - node_changes["removed"].append(f"{node}(空节点)") # 标记为空节点移除 - logger.debug(f"[遗忘] 移除了空的节点: {node}") - except nx.NetworkXError as e: - logger.warning(f"[遗忘] 移除空节点 {node} 时发生错误(可能已被移除): {e}") - continue # 处理下一个节点 - - # 检查节点的最后修改时间,如果太旧则尝试遗忘 - last_modified = node_data.get("last_modified", current_time) - if current_time - last_modified > 3600 * global_config.memory.memory_forget_time: - # 随机遗忘一条记忆 - if len(memory_items) > 1: - removed_item = self.memory_graph.forget_topic(node) - if removed_item: - node_changes["reduced"].append(f"{node} (移除: {removed_item[:50]}...)") - elif len(memory_items) == 1: - # 如果只有一条记忆,检查是否应该完全移除节点 - try: - self.memory_graph.G.remove_node(node) - node_changes["removed"].append(f"{node} (最后记忆)") - except nx.NetworkXError as e: - logger.warning(f"[遗忘] 移除节点 {node} 时发生错误: {e}") - - # 检查节点内是否有相似的记忆项需要整合 - if len(memory_items) > 1: - items_to_remove = [] - - for i in range(len(memory_items)): - for j in range(i + 1, len(memory_items)): - similarity = self._calculate_item_similarity(memory_items[i], memory_items[j]) - if similarity > 0.8: # 相似度阈值 - # 合并相似记忆项 - longer_item = ( - memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j] - ) - shorter_item = ( - memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i] - ) - - # 保留更长的记忆项,标记短的用于删除 - if shorter_item not in items_to_remove: - items_to_remove.append(shorter_item) - merged_count += 1 - logger.debug( - f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}..." - ) - - # 移除被合并的记忆项 - if items_to_remove: - for item in items_to_remove: - if item in memory_items: - memory_items.remove(item) - nodes_modified.add(node) - # 更新节点的记忆项 - self.memory_graph.G.nodes[node]["memory_items"] = memory_items - self.memory_graph.G.nodes[node]["last_modified"] = current_time - - node_check_end = time.time() - logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒") - - # 输出变化统计 - if edge_changes["weakened"]: - logger.info(f"[遗忘] 减弱了 {len(edge_changes['weakened'])} 个连接") - if edge_changes["removed"]: - logger.info(f"[遗忘] 移除了 {len(edge_changes['removed'])} 个连接") - if node_changes["reduced"]: - logger.info(f"[遗忘] 减少了 {len(node_changes['reduced'])} 个节点的记忆") - if node_changes["removed"]: - logger.info(f"[遗忘] 移除了 {len(node_changes['removed'])} 个节点") - - # 检查是否有变化需要同步到数据库 - has_changes = ( - edge_changes["weakened"] - or edge_changes["removed"] - or node_changes["reduced"] - or node_changes["removed"] - or merged_count > 0 - ) - - if has_changes: - logger.info("[遗忘] 开始将变更同步到数据库...") - sync_start = time.time() - await self.hippocampus.entorhinal_cortex.sync_memory_to_db() - sync_end = time.time() - logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}秒") - - if merged_count > 0: - logger.info(f"[整合] 共合并了 {merged_count} 对相似记忆项,分布在 {len(nodes_modified)} 个节点中。") - sync_start = time.time() - logger.info("[整合] 开始将变更同步到数据库...") - # 使用 resync 更安全地处理删除和添加 - await self.hippocampus.entorhinal_cortex.resync_memory_to_db() - sync_end = time.time() - logger.info(f"[整合] 数据库同步耗时: {sync_end - sync_start:.2f}秒") - else: - logger.info("[整合] 本次检查未发现需要合并的记忆项。") - - end_time = time.time() - logger.info(f"[整合] 整合检查完成,总耗时: {end_time - start_time:.2f}秒") - - @staticmethod - def _calculate_item_similarity(item1: str, item2: str) -> float: - """计算两条记忆项文本的余弦相似度""" - words1 = set(jieba.cut(item1)) - words2 = set(jieba.cut(item2)) - all_words = words1 | words2 - if not all_words: - return 0.0 - v1 = [1 if word in words1 else 0 for word in all_words] - v2 = [1 if word in words2 else 0 for word in all_words] - return cosine_similarity(v1, v2) - - -class HippocampusManager: - def __init__(self): - self._hippocampus: Hippocampus = None # type: ignore - self._initialized = False - self._db_lock = asyncio.Lock() - - def initialize(self): - """初始化海马体实例""" - if self._initialized: - return self._hippocampus - - self._hippocampus = Hippocampus() - # self._hippocampus.initialize() # 改为异步启动 - self._initialized = True - - # 输出记忆图统计信息 - memory_graph = self._hippocampus.memory_graph.G - node_count = len(memory_graph.nodes()) - edge_count = len(memory_graph.edges()) - - logger.info(f""" - -------------------------------- - 记忆系统参数配置: - 构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate} - 记忆构建分布: {global_config.memory.memory_build_distribution} - 遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后 - 记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count} - --------------------------------""") # noqa: E501 - - return self._hippocampus - - async def initialize_async(self): - """异步初始化海马体实例""" - if not self._initialized: - self.initialize() # 先进行同步部分的初始化 - self._hippocampus.initialize() - await self._hippocampus.entorhinal_cortex.sync_memory_from_db() - - def get_hippocampus(self): - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - return self._hippocampus - - async def build_memory(self): - """构建记忆的公共接口""" - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - return await self._hippocampus.parahippocampal_gyrus.operation_build_memory() - - async def forget_memory(self, percentage: float = 0.005): - """遗忘记忆的公共接口""" - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - async with self._db_lock: - return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage) - - async def consolidate_memory(self): - """整合记忆的公共接口""" - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - # 使用 operation_build_memory 方法来整合记忆 - async with self._db_lock: - return await self._hippocampus.parahippocampal_gyrus.operation_build_memory() - - async def get_memory_from_text( - self, - text: str, - max_memory_num: int = 3, - max_memory_length: int = 2, - max_depth: int = 3, - fast_retrieval: bool = False, - ) -> list: - """从文本中获取相关记忆的公共接口""" - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - try: - response = await self._hippocampus.get_memory_from_text( - text, max_memory_num, max_memory_length, max_depth, fast_retrieval - ) - except Exception as e: - logger.error(f"文本激活记忆失败: {e}") - response = [] - return response - - async def get_memory_from_topic( - self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 - ) -> list: - """从文本中获取相关记忆的公共接口""" - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - try: - response = await self._hippocampus.get_memory_from_topic( - valid_keywords, max_memory_num, max_memory_length, max_depth - ) - except Exception as e: - logger.error(f"文本激活记忆失败: {e}") - response = [] - return response - - async def get_activate_from_text( - self, text: str, max_depth: int = 3, fast_retrieval: bool = False - ) -> tuple[float, list[str]]: - """从文本中获取激活值的公共接口""" - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - try: - response, keywords = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval) - except Exception as e: - logger.error(f"文本产生激活值失败: {e}") - response = 0.0 - keywords = [] # 初始化 keywords 为空列表 - return response, keywords - - def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: - """从关键词获取相关记忆的公共接口""" - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - return self._hippocampus.get_memory_from_keyword(keyword, max_depth) - - def get_all_node_names(self) -> list: - """获取所有节点名称的公共接口""" - if not self._initialized: - raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - return self._hippocampus.get_all_node_names() - - -# 创建全局实例 -hippocampus_manager = HippocampusManager() diff --git a/src/chat/memory_system/async_instant_memory_wrapper.py b/src/chat/memory_system/async_instant_memory_wrapper.py deleted file mode 100644 index 9b387c535..000000000 --- a/src/chat/memory_system/async_instant_memory_wrapper.py +++ /dev/null @@ -1,248 +0,0 @@ -# -*- coding: utf-8 -*- -""" -异步瞬时记忆包装器 -提供对现有瞬时记忆系统的异步包装,支持超时控制和回退机制 -""" - -import asyncio -import time -from typing import Optional, Dict, Any -from src.common.logger import get_logger -from src.config.config import global_config - -logger = get_logger("async_instant_memory_wrapper") - - -class AsyncInstantMemoryWrapper: - """异步瞬时记忆包装器""" - - def __init__(self, chat_id: str): - self.chat_id = chat_id - self.llm_memory = None - self.vector_memory = None - self.cache: Dict[str, tuple[Any, float]] = {} # 缓存:(结果, 时间戳) - self.cache_ttl = 300 # 缓存5分钟 - self.default_timeout = 3.0 # 默认超时3秒 - - # 从配置中读取是否启用各种记忆系统 - self.llm_memory_enabled = global_config.memory.enable_llm_instant_memory - self.vector_memory_enabled = global_config.memory.enable_vector_instant_memory - - async def _ensure_llm_memory(self): - """确保LLM记忆系统已初始化""" - if self.llm_memory is None and self.llm_memory_enabled: - try: - from src.chat.memory_system.instant_memory import InstantMemory - - self.llm_memory = InstantMemory(self.chat_id) - logger.info(f"LLM瞬时记忆系统已初始化: {self.chat_id}") - except Exception as e: - logger.warning(f"LLM瞬时记忆系统初始化失败: {e}") - self.llm_memory_enabled = False # 初始化失败则禁用 - - async def _ensure_vector_memory(self): - """确保向量记忆系统已初始化""" - if self.vector_memory is None and self.vector_memory_enabled: - try: - from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2 - - self.vector_memory = VectorInstantMemoryV2(self.chat_id) - logger.info(f"向量瞬时记忆系统已初始化: {self.chat_id}") - except Exception as e: - logger.warning(f"向量瞬时记忆系统初始化失败: {e}") - self.vector_memory_enabled = False # 初始化失败则禁用 - - def _get_cache_key(self, operation: str, content: str) -> str: - """生成缓存键""" - return f"{operation}_{self.chat_id}_{hash(content)}" - - def _is_cache_valid(self, cache_key: str) -> bool: - """检查缓存是否有效""" - if cache_key not in self.cache: - return False - - _, timestamp = self.cache[cache_key] - return time.time() - timestamp < self.cache_ttl - - def _get_cached_result(self, cache_key: str) -> Optional[Any]: - """获取缓存结果""" - if self._is_cache_valid(cache_key): - result, _ = self.cache[cache_key] - return result - return None - - def _cache_result(self, cache_key: str, result: Any): - """缓存结果""" - self.cache[cache_key] = (result, time.time()) - - async def store_memory_async(self, content: str, timeout: Optional[float] = None) -> bool: - """异步存储记忆(带超时控制)""" - if timeout is None: - timeout = self.default_timeout - - success_count = 0 - - # 异步存储到LLM记忆系统 - await self._ensure_llm_memory() - if self.llm_memory: - try: - await asyncio.wait_for(self.llm_memory.create_and_store_memory(content), timeout=timeout) - success_count += 1 - logger.debug(f"LLM记忆存储成功: {content[:50]}...") - except asyncio.TimeoutError: - logger.warning(f"LLM记忆存储超时: {content[:50]}...") - except Exception as e: - logger.error(f"LLM记忆存储失败: {e}") - - # 异步存储到向量记忆系统 - await self._ensure_vector_memory() - if self.vector_memory: - try: - await asyncio.wait_for(self.vector_memory.store_message(content), timeout=timeout) - success_count += 1 - logger.debug(f"向量记忆存储成功: {content[:50]}...") - except asyncio.TimeoutError: - logger.warning(f"向量记忆存储超时: {content[:50]}...") - except Exception as e: - logger.error(f"向量记忆存储失败: {e}") - - return success_count > 0 - - async def retrieve_memory_async( - self, query: str, timeout: Optional[float] = None, use_cache: bool = True - ) -> Optional[Any]: - """异步检索记忆(带缓存和超时控制)""" - if timeout is None: - timeout = self.default_timeout - - # 检查缓存 - if use_cache: - cache_key = self._get_cache_key("retrieve", query) - cached_result = self._get_cached_result(cache_key) - if cached_result is not None: - logger.debug(f"记忆检索命中缓存: {query[:30]}...") - return cached_result - - # 尝试多种记忆系统 - results = [] - - # 从向量记忆系统检索(优先,速度快) - await self._ensure_vector_memory() - if self.vector_memory: - try: - vector_result = await asyncio.wait_for( - self.vector_memory.get_memory_for_context(query), - timeout=timeout * 0.6, # 给向量系统60%的时间 - ) - if vector_result: - results.append(vector_result) - logger.debug(f"向量记忆检索成功: {query[:30]}...") - except asyncio.TimeoutError: - logger.warning(f"向量记忆检索超时: {query[:30]}...") - except Exception as e: - logger.error(f"向量记忆检索失败: {e}") - - # 从LLM记忆系统检索(备用,更准确但较慢) - await self._ensure_llm_memory() - if self.llm_memory and len(results) == 0: # 只有向量检索失败时才使用LLM - try: - llm_result = await asyncio.wait_for( - self.llm_memory.get_memory(query), - timeout=timeout * 0.4, # 给LLM系统40%的时间 - ) - if llm_result: - results.extend(llm_result) - logger.debug(f"LLM记忆检索成功: {query[:30]}...") - except asyncio.TimeoutError: - logger.warning(f"LLM记忆检索超时: {query[:30]}...") - except Exception as e: - logger.error(f"LLM记忆检索失败: {e}") - - # 合并结果 - final_result = None - if results: - if len(results) == 1: - final_result = results[0] - else: - # 合并多个结果 - if isinstance(results[0], str): - final_result = "\n".join(str(r) for r in results) - elif isinstance(results[0], list): - final_result = [] - for r in results: - if isinstance(r, list): - final_result.extend(r) - else: - final_result.append(r) - else: - final_result = results[0] # 使用第一个结果 - - # 缓存结果 - if use_cache and final_result is not None: - cache_key = self._get_cache_key("retrieve", query) - self._cache_result(cache_key, final_result) - - return final_result - - async def get_memory_with_fallback(self, query: str, max_timeout: float = 2.0) -> str: - """获取记忆的回退方法,保证不会长时间阻塞""" - try: - # 首先尝试快速检索 - result = await self.retrieve_memory_async(query, timeout=max_timeout) - - if result: - if isinstance(result, list): - return "\n".join(str(item) for item in result) - return str(result) - - return "" - - except Exception as e: - logger.error(f"记忆检索完全失败: {e}") - return "" - - def store_memory_background(self, content: str): - """在后台存储记忆(发后即忘模式)""" - - async def background_store(): - try: - await self.store_memory_async(content, timeout=10.0) # 后台任务可以用更长超时 - except Exception as e: - logger.error(f"后台记忆存储失败: {e}") - - # 创建后台任务 - asyncio.create_task(background_store()) - - def get_status(self) -> Dict[str, Any]: - """获取包装器状态""" - return { - "chat_id": self.chat_id, - "llm_memory_available": self.llm_memory is not None, - "vector_memory_available": self.vector_memory is not None, - "cache_entries": len(self.cache), - "cache_ttl": self.cache_ttl, - "default_timeout": self.default_timeout, - } - - def clear_cache(self): - """清理缓存""" - self.cache.clear() - logger.info(f"记忆缓存已清理: {self.chat_id}") - - -# 缓存包装器实例,避免重复创建 -_wrapper_cache: Dict[str, AsyncInstantMemoryWrapper] = {} - - -def get_async_instant_memory(chat_id: str) -> AsyncInstantMemoryWrapper: - """获取异步瞬时记忆包装器实例""" - if chat_id not in _wrapper_cache: - _wrapper_cache[chat_id] = AsyncInstantMemoryWrapper(chat_id) - return _wrapper_cache[chat_id] - - -def clear_wrapper_cache(): - """清理包装器缓存""" - global _wrapper_cache - _wrapper_cache.clear() - logger.info("异步瞬时记忆包装器缓存已清理") diff --git a/src/chat/memory_system/async_memory_optimizer.py b/src/chat/memory_system/async_memory_optimizer.py deleted file mode 100644 index 1fcacb32d..000000000 --- a/src/chat/memory_system/async_memory_optimizer.py +++ /dev/null @@ -1,358 +0,0 @@ -# -*- coding: utf-8 -*- -""" -异步记忆系统优化器 -解决记忆系统阻塞主程序的问题,将同步操作改为异步非阻塞操作 -""" - -import asyncio -import time -from typing import Dict, List, Optional, Any, Callable -from dataclasses import dataclass -from concurrent.futures import ThreadPoolExecutor -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory - -logger = get_logger("async_memory_optimizer") - - -@dataclass -class MemoryTask: - """记忆任务数据结构""" - - task_id: str - task_type: str # "store", "retrieve", "build" - chat_id: str - content: str - priority: int = 1 # 1=低优先级, 2=中优先级, 3=高优先级 - callback: Optional[Callable] = None - created_at: float = None - - def __post_init__(self): - if self.created_at is None: - self.created_at = time.time() - - -class AsyncMemoryQueue: - """异步记忆任务队列管理器""" - - def __init__(self, max_workers: int = 3): - self.max_workers = max_workers - self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.task_queue = asyncio.Queue() - self.running_tasks: Dict[str, asyncio.Task] = {} - self.completed_tasks: Dict[str, Any] = {} - self.failed_tasks: Dict[str, str] = {} - self.is_running = False - self.worker_tasks: List[asyncio.Task] = [] - - async def start(self): - """启动异步队列处理器""" - if self.is_running: - return - - self.is_running = True - # 启动多个工作协程 - for i in range(self.max_workers): - worker = asyncio.create_task(self._worker(f"worker-{i}")) - self.worker_tasks.append(worker) - - logger.info(f"异步记忆队列已启动,工作线程数: {self.max_workers}") - - async def stop(self): - """停止队列处理器""" - self.is_running = False - - # 等待所有工作任务完成 - for task in self.worker_tasks: - task.cancel() - - await asyncio.gather(*self.worker_tasks, return_exceptions=True) - self.executor.shutdown(wait=True) - logger.info("异步记忆队列已停止") - - async def _worker(self, worker_name: str): - """工作协程,处理队列中的任务""" - logger.info(f"记忆处理工作线程 {worker_name} 启动") - - while self.is_running: - try: - # 等待任务,超时1秒避免永久阻塞 - task = await asyncio.wait_for(self.task_queue.get(), timeout=1.0) - - # 执行任务 - await self._execute_task(task, worker_name) - - except asyncio.TimeoutError: - # 超时正常,继续下一次循环 - continue - except Exception as e: - logger.error(f"工作线程 {worker_name} 处理任务时出错: {e}") - - async def _execute_task(self, task: MemoryTask, worker_name: str): - """执行具体的记忆任务""" - try: - logger.debug(f"[{worker_name}] 开始处理任务: {task.task_type} - {task.task_id}") - start_time = time.time() - - # 根据任务类型执行不同的处理逻辑 - result = None - if task.task_type == "store": - result = await self._handle_store_task(task) - elif task.task_type == "retrieve": - result = await self._handle_retrieve_task(task) - elif task.task_type == "build": - result = await self._handle_build_task(task) - else: - raise ValueError(f"未知的任务类型: {task.task_type}") - - # 记录完成的任务 - self.completed_tasks[task.task_id] = result - execution_time = time.time() - start_time - - logger.debug(f"[{worker_name}] 任务完成: {task.task_id} (耗时: {execution_time:.2f}s)") - - # 执行回调函数 - if task.callback: - try: - if asyncio.iscoroutinefunction(task.callback): - await task.callback(result) - else: - task.callback(result) - except Exception as e: - logger.error(f"任务回调执行失败: {e}") - - except Exception as e: - error_msg = f"任务执行失败: {e}" - logger.error(f"[{worker_name}] {error_msg}") - self.failed_tasks[task.task_id] = error_msg - - # 执行错误回调 - if task.callback: - try: - if asyncio.iscoroutinefunction(task.callback): - await task.callback(None) - else: - task.callback(None) - except Exception: - pass - - @staticmethod - async def _handle_store_task(task: MemoryTask) -> Any: - """处理记忆存储任务""" - # 这里需要根据具体的记忆系统来实现 - # 为了避免循环导入,这里使用延迟导入 - try: - # 获取包装器实例 - memory_wrapper = get_async_instant_memory(task.chat_id) - - # 使用包装器中的llm_memory实例 - if memory_wrapper and memory_wrapper.llm_memory: - await memory_wrapper.llm_memory.create_and_store_memory(task.content) - return True - else: - logger.warning(f"无法获取记忆系统实例,存储任务失败: chat_id={task.chat_id}") - return False - except Exception as e: - logger.error(f"记忆存储失败: {e}") - return False - - @staticmethod - async def _handle_retrieve_task(task: MemoryTask) -> Any: - """处理记忆检索任务""" - try: - # 获取包装器实例 - memory_wrapper = get_async_instant_memory(task.chat_id) - - # 使用包装器中的llm_memory实例 - if memory_wrapper and memory_wrapper.llm_memory: - memories = await memory_wrapper.llm_memory.get_memory(task.content) - return memories or [] - else: - logger.warning(f"无法获取记忆系统实例,检索任务失败: chat_id={task.chat_id}") - return [] - except Exception as e: - logger.error(f"记忆检索失败: {e}") - return [] - - @staticmethod - async def _handle_build_task(task: MemoryTask) -> Any: - """处理记忆构建任务(海马体系统)""" - try: - # 延迟导入避免循环依赖 - if global_config.memory.enable_memory: - from src.chat.memory_system.Hippocampus import hippocampus_manager - - if hippocampus_manager._initialized: - # 确保海马体对象已正确初始化 - if not hippocampus_manager._hippocampus.parahippocampal_gyrus: - logger.warning("海马体对象未完全初始化,进行同步初始化") - hippocampus_manager._hippocampus.initialize() - - await hippocampus_manager.build_memory() - return True - return False - except Exception as e: - logger.error(f"记忆构建失败: {e}") - return False - - async def add_task(self, task: MemoryTask) -> str: - """添加任务到队列""" - await self.task_queue.put(task) - self.running_tasks[task.task_id] = task - logger.debug(f"任务已加入队列: {task.task_type} - {task.task_id}") - return task.task_id - - def get_task_result(self, task_id: str) -> Optional[Any]: - """获取任务结果(非阻塞)""" - return self.completed_tasks.get(task_id) - - def is_task_completed(self, task_id: str) -> bool: - """检查任务是否完成""" - return task_id in self.completed_tasks or task_id in self.failed_tasks - - def get_queue_status(self) -> Dict[str, Any]: - """获取队列状态""" - return { - "is_running": self.is_running, - "queue_size": self.task_queue.qsize(), - "running_tasks": len(self.running_tasks), - "completed_tasks": len(self.completed_tasks), - "failed_tasks": len(self.failed_tasks), - "worker_count": len(self.worker_tasks), - } - - -class NonBlockingMemoryManager: - """非阻塞记忆管理器""" - - def __init__(self): - self.queue = AsyncMemoryQueue(max_workers=3) - self.cache: Dict[str, Any] = {} - self.cache_ttl: Dict[str, float] = {} - self.cache_timeout = 300 # 缓存5分钟 - - async def initialize(self): - """初始化管理器""" - await self.queue.start() - logger.info("非阻塞记忆管理器已初始化") - - async def shutdown(self): - """关闭管理器""" - await self.queue.stop() - logger.info("非阻塞记忆管理器已关闭") - - async def store_memory_async(self, chat_id: str, content: str, callback: Optional[Callable] = None) -> str: - """异步存储记忆(非阻塞)""" - task = MemoryTask( - task_id=f"store_{chat_id}_{int(time.time() * 1000)}", - task_type="store", - chat_id=chat_id, - content=content, - priority=1, # 存储优先级较低 - callback=callback, - ) - - return await self.queue.add_task(task) - - async def retrieve_memory_async(self, chat_id: str, query: str, callback: Optional[Callable] = None) -> str: - """异步检索记忆(非阻塞)""" - # 先检查缓存 - cache_key = f"retrieve_{chat_id}_{hash(query)}" - if self._is_cache_valid(cache_key): - result = self.cache[cache_key] - if callback: - if asyncio.iscoroutinefunction(callback): - await callback(result) - else: - callback(result) - return "cache_hit" - - task = MemoryTask( - task_id=f"retrieve_{chat_id}_{int(time.time() * 1000)}", - task_type="retrieve", - chat_id=chat_id, - content=query, - priority=2, # 检索优先级中等 - callback=self._create_cache_callback(cache_key, callback), - ) - - return await self.queue.add_task(task) - - async def build_memory_async(self, callback: Optional[Callable] = None) -> str: - """异步构建记忆(非阻塞)""" - task = MemoryTask( - task_id=f"build_memory_{int(time.time() * 1000)}", - task_type="build", - chat_id="system", - content="", - priority=1, # 构建优先级较低,避免影响用户体验 - callback=callback, - ) - - return await self.queue.add_task(task) - - def _is_cache_valid(self, cache_key: str) -> bool: - """检查缓存是否有效""" - if cache_key not in self.cache: - return False - - return time.time() - self.cache_ttl.get(cache_key, 0) < self.cache_timeout - - def _create_cache_callback(self, cache_key: str, original_callback: Optional[Callable]): - """创建带缓存的回调函数""" - - async def cache_callback(result): - # 存储到缓存 - if result is not None: - self.cache[cache_key] = result - self.cache_ttl[cache_key] = time.time() - - # 执行原始回调 - if original_callback: - if asyncio.iscoroutinefunction(original_callback): - await original_callback(result) - else: - original_callback(result) - - return cache_callback - - def get_cached_memory(self, chat_id: str, query: str) -> Optional[Any]: - """获取缓存的记忆(同步,立即返回)""" - cache_key = f"retrieve_{chat_id}_{hash(query)}" - if self._is_cache_valid(cache_key): - return self.cache[cache_key] - return None - - def get_status(self) -> Dict[str, Any]: - """获取管理器状态""" - status = self.queue.get_queue_status() - status.update({"cache_entries": len(self.cache), "cache_timeout": self.cache_timeout}) - return status - - -# 全局实例 -async_memory_manager = NonBlockingMemoryManager() - - -# 便捷函数 -async def store_memory_nonblocking(chat_id: str, content: str) -> str: - """非阻塞存储记忆的便捷函数""" - return await async_memory_manager.store_memory_async(chat_id, content) - - -async def retrieve_memory_nonblocking(chat_id: str, query: str) -> Optional[Any]: - """非阻塞检索记忆的便捷函数,支持缓存""" - # 先尝试从缓存获取 - cached_result = async_memory_manager.get_cached_memory(chat_id, query) - if cached_result is not None: - return cached_result - - # 缓存未命中,启动异步检索 - await async_memory_manager.retrieve_memory_async(chat_id, query) - return None # 返回None表示需要异步获取 - - -async def build_memory_nonblocking() -> str: - """非阻塞构建记忆的便捷函数""" - return await async_memory_manager.build_memory_async() diff --git a/src/chat/memory_system/async_optimization_guide.md b/src/chat/memory_system/async_optimization_guide.md deleted file mode 100644 index 613dbe439..000000000 --- a/src/chat/memory_system/async_optimization_guide.md +++ /dev/null @@ -1,196 +0,0 @@ -# 记忆系统异步优化说明 - -## 🎯 优化目标 - -解决MaiBot-Plus记忆系统阻塞主程序的问题,将原本的线性同步调用改为异步非阻塞运行。 - -## ⚠️ 问题分析 - -### 原有问题 -1. **瞬时记忆阻塞**:每次用户发消息时,`await self.instant_memory.get_memory_for_context(target)` 会阻塞等待LLM响应 -2. **定时记忆构建阻塞**:每600秒执行的 `build_memory_task()` 会完全阻塞主程序数十秒 -3. **LLM调用链阻塞**:记忆存储和检索都需要调用LLM,延迟较高 - -### 卡顿表现 -- 用户发消息后,程序响应延迟明显增加 -- 定时记忆构建时,整个程序无响应 -- 高并发时,记忆系统成为性能瓶颈 - -## 🚀 优化方案 - -### 1. 异步记忆队列系统 (`async_memory_optimizer.py`) - -**核心思想**:将记忆操作放入异步队列,后台处理,不阻塞主程序。 - -**关键特性**: -- 任务队列管理:支持存储、检索、构建三种任务类型 -- 优先级调度:高优先级任务(用户查询)优先处理 -- 线程池执行:避免阻塞事件循环 -- 结果缓存:减少重复计算 -- 失败重试:提高系统可靠性 - -```python -# 使用示例 -from src.chat.memory_system.async_memory_optimizer import ( - store_memory_nonblocking, - retrieve_memory_nonblocking, - build_memory_nonblocking -) - -# 非阻塞存储记忆 -task_id = await store_memory_nonblocking(chat_id, content) - -# 非阻塞检索记忆(支持缓存) -memories = await retrieve_memory_nonblocking(chat_id, query) - -# 非阻塞构建记忆 -task_id = await build_memory_nonblocking() -``` - -### 2. 异步瞬时记忆包装器 (`async_instant_memory_wrapper.py`) - -**核心思想**:为现有瞬时记忆系统提供异步包装,支持超时控制和多层回退。 - -**关键特性**: -- 超时控制:防止长时间阻塞 -- 缓存机制:热点查询快速响应 -- 多系统融合:LLM记忆 + 向量记忆 -- 回退策略:保证系统稳定性 -- 后台存储:存储操作完全非阻塞 - -```python -# 使用示例 -from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory - -async_memory = get_async_instant_memory(chat_id) - -# 后台存储(发后即忘) -async_memory.store_memory_background(content) - -# 快速检索(带超时) -result = await async_memory.get_memory_with_fallback(query, max_timeout=2.0) -``` - -### 3. 主程序优化 - -**记忆构建任务异步化**: -- 原来:`await self.hippocampus_manager.build_memory()` 阻塞主程序 -- 现在:使用异步队列或线程池,后台执行 - -**消息处理优化**: -- 原来:同步等待记忆检索完成 -- 现在:最大2秒超时,保证用户体验 - -## 📊 性能提升预期 - -### 响应速度 -- **用户消息响应**:从原来的3-10秒减少到0.5-2秒 -- **记忆检索**:缓存命中时几乎即时响应 -- **记忆存储**:从同步阻塞改为后台处理 - -### 并发能力 -- **多用户同时使用**:不再因记忆系统相互阻塞 -- **高峰期稳定性**:记忆任务排队处理,不会崩溃 - -### 资源使用 -- **CPU使用**:异步处理,更好的CPU利用率 -- **内存优化**:缓存机制,减少重复计算 -- **网络延迟**:LLM调用并行化,减少等待时间 - -## 🔧 部署和配置 - -### 1. 自动部署 -新的异步系统已经集成到现有代码中,支持自动回退: - -```python -# 优先级回退机制 -1. 异步瞬时记忆包装器 (最优) -2. 异步记忆管理器 (次优) -3. 带超时的同步模式 (保底) -``` - -### 2. 配置参数 - -在 `config.toml` 中可以调整相关参数: - -```toml -[memory] -enable_memory = true -enable_instant_memory = true -memory_build_interval = 600 # 记忆构建间隔(秒) -``` - -### 3. 监控和调试 - -```python -# 查看异步队列状态 -from src.chat.memory_system.async_memory_optimizer import async_memory_manager -status = async_memory_manager.get_status() -print(status) - -# 查看包装器状态 -from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory -wrapper = get_async_instant_memory(chat_id) -status = wrapper.get_status() -print(status) -``` - -## 🧪 验证方法 - -### 1. 性能测试 -```bash -# 测试用户消息响应时间 -time curl -X POST "http://localhost:8080/api/message" -d '{"message": "你还记得我们昨天聊的内容吗?"}' - -# 观察内存构建时的程序响应 -# 构建期间发送消息,观察是否还有阻塞 -``` - -### 2. 并发测试 -```python -import asyncio -import time - -async def test_concurrent_messages(): - """测试并发消息处理""" - tasks = [] - for i in range(10): - task = asyncio.create_task(send_message(f"测试消息 {i}")) - tasks.append(task) - - start_time = time.time() - results = await asyncio.gather(*tasks) - end_time = time.time() - - print(f"10条并发消息处理完成,耗时: {end_time - start_time:.2f}秒") -``` - -### 3. 日志监控 -关注以下日志输出: -- `"异步瞬时记忆:"` - 确认使用了异步系统 -- `"记忆构建任务已提交"` - 确认构建任务非阻塞 -- `"瞬时记忆检索超时"` - 监控超时情况 - -## 🔄 回退机制 - -系统设计了多层回退机制,确保即使新系统出现问题,也能维持基本功能: - -1. **异步包装器失败** → 使用异步队列管理器 -2. **异步队列失败** → 使用带超时的同步模式 -3. **超时保护** → 最长等待时间不超过2秒 -4. **完全失败** → 跳过记忆功能,保证基本对话 - -## 📝 注意事项 - -1. **首次启动**:异步系统需要初始化时间,可能前几次记忆调用延迟稍高 -2. **缓存预热**:系统运行一段时间后,缓存效果会显著提升响应速度 -3. **内存使用**:缓存会增加内存使用,但相对于性能提升是值得的 -4. **兼容性**:如果发现异步系统有问题,可以临时禁用相关导入,自动回退到原系统 - -## 🎉 预期效果 - -- ✅ **消息响应速度提升60%+** -- ✅ **记忆构建不再阻塞主程序** -- ✅ **支持更高的并发用户数** -- ✅ **系统整体稳定性提升** -- ✅ **保持原有记忆功能完整性** diff --git a/src/chat/memory_system/enhanced_memory_activator.py b/src/chat/memory_system/enhanced_memory_activator.py new file mode 100644 index 000000000..64bbdd64e --- /dev/null +++ b/src/chat/memory_system/enhanced_memory_activator.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- +""" +增强记忆激活器 +替代原有的 MemoryActivator,使用增强记忆系统 +""" + +import difflib +import orjson +import time +from typing import List, Dict, Optional +from datetime import datetime + +from json_repair import repair_json +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config, model_config +from src.common.logger import get_logger +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager, EnhancedMemoryResult + +logger = get_logger("enhanced_memory_activator") + + +def get_keywords_from_json(json_str) -> List: + """ + 从JSON字符串中提取关键词列表 + + Args: + json_str: JSON格式的字符串 + + Returns: + List[str]: 关键词列表 + """ + try: + # 使用repair_json修复JSON格式 + fixed_json = repair_json(json_str) + + # 如果repair_json返回的是字符串,需要解析为Python对象 + result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json + return result.get("keywords", []) + except Exception as e: + logger.error(f"解析关键词JSON失败: {e}") + return [] + + +def init_prompt(): + # --- Enhanced Memory Activator Prompt --- + enhanced_memory_activator_prompt = """ + 你是一个增强记忆分析器,你需要根据以下信息来进行记忆检索 + + 以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆检索的触发词 + + 聊天记录: + {obs_info_text} + + 用户想要回复的消息: + {target_message} + + 历史关键词(请避免重复提取这些关键词): + {cached_keywords} + + 请输出一个json格式,包含以下字段: + {{ + "keywords": ["关键词1", "关键词2", "关键词3",......] + }} + + 不要输出其他多余内容,只输出json格式就好 + """ + + Prompt(enhanced_memory_activator_prompt, "enhanced_memory_activator_prompt") + + +class EnhancedMemoryActivator: + """增强记忆激活器 - 替代原有的 MemoryActivator""" + + def __init__(self): + self.key_words_model = LLMRequest( + model_set=model_config.model_task_config.utils_small, + request_type="enhanced_memory.activator", + ) + + self.running_memory = [] + self.cached_keywords = set() # 用于缓存历史关键词 + self.last_enhanced_query_time = 0 # 上次查询增强记忆的时间 + + async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]: + """ + 激活增强记忆 + """ + # 如果记忆系统被禁用,直接返回空列表 + if not global_config.memory.enable_memory: + return [] + + # 将缓存的关键词转换为字符串,用于prompt + cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词" + + prompt = await global_prompt_manager.format_prompt( + "enhanced_memory_activator_prompt", + obs_info_text=chat_history_prompt, + target_message=target_message, + cached_keywords=cached_keywords_str, + ) + + # 生成关键词 + response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async( + prompt, temperature=0.5 + ) + + keywords = list(get_keywords_from_json(response)) + + # 更新关键词缓存 + if keywords: + # 限制缓存大小,最多保留10个关键词 + if len(self.cached_keywords) > 10: + # 转换为列表,移除最早的关键词 + cached_list = list(self.cached_keywords) + self.cached_keywords = set(cached_list[-8:]) + + # 添加新的关键词到缓存 + self.cached_keywords.update(keywords) + + logger.debug(f"增强记忆关键词: {self.cached_keywords}") + + # 使用增强记忆系统获取相关记忆 + enhanced_results = await self._query_enhanced_memory(keywords, target_message) + + # 处理和增强记忆结果 + if enhanced_results: + for result in enhanced_results: + # 检查是否已存在相似内容的记忆 + exists = any( + m["content"] == result.content or + difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7 + for m in self.running_memory + ) + if not exists: + memory_entry = { + "topic": result.memory_type, + "content": result.content, + "timestamp": datetime.fromtimestamp(result.timestamp).isoformat(), + "duration": 1, + "confidence": result.confidence, + "importance": result.importance, + "source": result.source + } + self.running_memory.append(memory_entry) + logger.debug(f"添加新增强记忆: {result.memory_type} - {result.content}") + + # 激活时,所有已有记忆的duration+1,达到3则移除 + for m in self.running_memory[:]: + m["duration"] = m.get("duration", 1) + 1 + self.running_memory = [m for m in self.running_memory if m["duration"] < 3] + + # 限制同时加载的记忆条数,最多保留最后5条(增强记忆可以处理更多) + if len(self.running_memory) > 5: + self.running_memory = self.running_memory[-5:] + + return self.running_memory + + async def _query_enhanced_memory(self, keywords: List[str], query_text: str) -> List[EnhancedMemoryResult]: + """查询增强记忆系统""" + try: + # 确保增强记忆管理器已初始化 + if not enhanced_memory_manager.is_initialized: + await enhanced_memory_manager.initialize() + + # 构建查询上下文 + context = { + "keywords": keywords, + "query_intent": "conversation_response", + "expected_memory_types": [ + "personal_fact", "event", "preference", "opinion" + ] + } + + # 查询增强记忆 + enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context( + query_text=query_text, + user_id="default_user", # 可以根据实际用户ID调整 + context=context, + limit=5 + ) + + logger.debug(f"增强记忆查询返回 {len(enhanced_results)} 条结果") + return enhanced_results + + except Exception as e: + logger.error(f"查询增强记忆失败: {e}") + return [] + + async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]: + """ + 获取即时记忆 - 兼容原有接口 + """ + try: + # 使用增强记忆系统获取相关记忆 + if not enhanced_memory_manager.is_initialized: + await enhanced_memory_manager.initialize() + + context = { + "query_intent": "instant_response", + "chat_id": chat_id, + "expected_memory_types": ["preference", "opinion", "personal_fact"] + } + + enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context( + query_text=target_message, + user_id="default_user", + context=context, + limit=1 + ) + + if enhanced_results: + # 返回最相关的记忆内容 + return enhanced_results[0].content + + return None + + except Exception as e: + logger.error(f"获取即时记忆失败: {e}") + return None + + def clear_cache(self): + """清除缓存""" + self.cached_keywords.clear() + self.running_memory.clear() + logger.debug("增强记忆激活器缓存已清除") + + +# 创建全局实例 +enhanced_memory_activator = EnhancedMemoryActivator() + + +# 为了兼容性,保留原有名称 +MemoryActivator = EnhancedMemoryActivator + + +init_prompt() \ No newline at end of file diff --git a/src/chat/memory_system/enhanced_memory_adapter.py b/src/chat/memory_system/enhanced_memory_adapter.py new file mode 100644 index 000000000..0d73b11f5 --- /dev/null +++ b/src/chat/memory_system/enhanced_memory_adapter.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- +""" +增强记忆系统适配器 +将增强记忆系统集成到现有MoFox Bot架构中 +""" + +import asyncio +import time +from typing import Dict, List, Optional, Any, Tuple +from dataclasses import dataclass + +from src.common.logger import get_logger +from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode +from src.chat.memory_system.memory_chunk import MemoryChunk +from src.llm_models.utils_model import LLMRequest + +logger = get_logger(__name__) + + +@dataclass +class AdapterConfig: + """适配器配置""" + enable_enhanced_memory: bool = True + integration_mode: str = "enhanced_only" # replace, enhanced_only + auto_migration: bool = True + memory_value_threshold: float = 0.6 + fusion_threshold: float = 0.85 + max_retrieval_results: int = 10 + + +class EnhancedMemoryAdapter: + """增强记忆系统适配器""" + + def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None): + self.llm_model = llm_model + self.config = config or AdapterConfig() + self.integration_layer: Optional[MemoryIntegrationLayer] = None + self._initialized = False + + # 统计信息 + self.adapter_stats = { + "total_processed": 0, + "enhanced_used": 0, + "legacy_used": 0, + "hybrid_used": 0, + "memories_created": 0, + "memories_retrieved": 0, + "average_processing_time": 0.0 + } + + async def initialize(self): + """初始化适配器""" + if self._initialized: + return + + try: + logger.info("🚀 初始化增强记忆系统适配器...") + + # 转换配置格式 + integration_config = IntegrationConfig( + mode=IntegrationMode(self.config.integration_mode), + enable_enhanced_memory=self.config.enable_enhanced_memory, + memory_value_threshold=self.config.memory_value_threshold, + fusion_threshold=self.config.fusion_threshold, + max_retrieval_results=self.config.max_retrieval_results, + enable_learning=True # 启用学习功能 + ) + + # 创建集成层 + self.integration_layer = MemoryIntegrationLayer( + llm_model=self.llm_model, + config=integration_config + ) + + # 初始化集成层 + await self.integration_layer.initialize() + + self._initialized = True + logger.info("✅ 增强记忆系统适配器初始化完成") + + except Exception as e: + logger.error(f"❌ 增强记忆系统适配器初始化失败: {e}", exc_info=True) + # 如果初始化失败,禁用增强记忆功能 + self.config.enable_enhanced_memory = False + + async def process_conversation_memory( + self, + conversation_text: str, + context: Dict[str, Any], + user_id: str, + timestamp: Optional[float] = None + ) -> Dict[str, Any]: + """处理对话记忆""" + if not self._initialized or not self.config.enable_enhanced_memory: + return {"success": False, "error": "Enhanced memory not available"} + + start_time = time.time() + self.adapter_stats["total_processed"] += 1 + + try: + # 使用集成层处理对话 + result = await self.integration_layer.process_conversation( + conversation_text, context, user_id, timestamp + ) + + # 更新统计 + processing_time = time.time() - start_time + self._update_processing_stats(processing_time) + + if result["success"]: + created_count = len(result.get("created_memories", [])) + self.adapter_stats["memories_created"] += created_count + logger.debug(f"对话记忆处理完成,创建 {created_count} 条记忆") + + return result + + except Exception as e: + logger.error(f"处理对话记忆失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def retrieve_relevant_memories( + self, + query: str, + user_id: str, + context: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None + ) -> List[MemoryChunk]: + """检索相关记忆""" + if not self._initialized or not self.config.enable_enhanced_memory: + return [] + + try: + limit = limit or self.config.max_retrieval_results + memories = await self.integration_layer.retrieve_relevant_memories( + query, user_id, context, limit + ) + + self.adapter_stats["memories_retrieved"] += len(memories) + logger.debug(f"检索到 {len(memories)} 条相关记忆") + + return memories + + except Exception as e: + logger.error(f"检索相关记忆失败: {e}", exc_info=True) + return [] + + async def get_memory_context_for_prompt( + self, + query: str, + user_id: str, + context: Optional[Dict[str, Any]] = None, + max_memories: int = 5 + ) -> str: + """获取用于提示词的记忆上下文""" + memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories) + + if not memories: + return "" + + # 格式化记忆为提示词友好的格式 + memory_context_parts = [] + for memory in memories: + memory_context_parts.append(f"- {memory.text_content}") + + return "\n".join(memory_context_parts) + + async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]: + """获取增强记忆系统摘要""" + if not self._initialized or not self.config.enable_enhanced_memory: + return {"available": False, "reason": "Not initialized or disabled"} + + try: + # 获取系统状态 + status = await self.integration_layer.get_system_status() + + # 获取适配器统计 + adapter_stats = self.adapter_stats.copy() + + # 获取集成统计 + integration_stats = self.integration_layer.get_integration_stats() + + return { + "available": True, + "system_status": status, + "adapter_stats": adapter_stats, + "integration_stats": integration_stats, + "total_memories_created": adapter_stats["memories_created"], + "total_memories_retrieved": adapter_stats["memories_retrieved"] + } + + except Exception as e: + logger.error(f"获取增强记忆摘要失败: {e}", exc_info=True) + return {"available": False, "error": str(e)} + + def _update_processing_stats(self, processing_time: float): + """更新处理统计""" + total_processed = self.adapter_stats["total_processed"] + if total_processed > 0: + current_avg = self.adapter_stats["average_processing_time"] + new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed + self.adapter_stats["average_processing_time"] = new_avg + + def get_adapter_stats(self) -> Dict[str, Any]: + """获取适配器统计信息""" + return self.adapter_stats.copy() + + async def maintenance(self): + """维护操作""" + if not self._initialized: + return + + try: + logger.info("🔧 增强记忆系统适配器维护...") + await self.integration_layer.maintenance() + logger.info("✅ 增强记忆系统适配器维护完成") + except Exception as e: + logger.error(f"❌ 增强记忆系统适配器维护失败: {e}", exc_info=True) + + async def shutdown(self): + """关闭适配器""" + if not self._initialized: + return + + try: + logger.info("🔄 关闭增强记忆系统适配器...") + await self.integration_layer.shutdown() + self._initialized = False + logger.info("✅ 增强记忆系统适配器已关闭") + except Exception as e: + logger.error(f"❌ 关闭增强记忆系统适配器失败: {e}", exc_info=True) + + +# 全局适配器实例 +_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None + + +async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter: + """获取全局增强记忆适配器实例""" + global _enhanced_memory_adapter + + if _enhanced_memory_adapter is None: + # 从配置中获取适配器配置 + from src.config.config import global_config + + adapter_config = AdapterConfig( + enable_enhanced_memory=getattr(global_config.memory, 'enable_enhanced_memory', True), + integration_mode=getattr(global_config.memory, 'enhanced_memory_mode', 'enhanced_only'), + auto_migration=getattr(global_config.memory, 'enable_memory_migration', True), + memory_value_threshold=getattr(global_config.memory, 'memory_value_threshold', 0.6), + fusion_threshold=getattr(global_config.memory, 'fusion_threshold', 0.85), + max_retrieval_results=getattr(global_config.memory, 'max_retrieval_results', 10) + ) + + _enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config) + await _enhanced_memory_adapter.initialize() + + return _enhanced_memory_adapter + + +async def initialize_enhanced_memory_system(llm_model: LLMRequest): + """初始化增强记忆系统""" + try: + logger.info("🚀 初始化增强记忆系统...") + adapter = await get_enhanced_memory_adapter(llm_model) + logger.info("✅ 增强记忆系统初始化完成") + return adapter + except Exception as e: + logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True) + return None + + +async def process_conversation_with_enhanced_memory( + conversation_text: str, + context: Dict[str, Any], + user_id: str, + timestamp: Optional[float] = None, + llm_model: Optional[LLMRequest] = None +) -> Dict[str, Any]: + """使用增强记忆系统处理对话""" + if not llm_model: + # 获取默认的LLM模型 + from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() + + try: + adapter = await get_enhanced_memory_adapter(llm_model) + return await adapter.process_conversation_memory(conversation_text, context, user_id, timestamp) + except Exception as e: + logger.error(f"使用增强记忆系统处理对话失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + +async def retrieve_memories_with_enhanced_system( + query: str, + user_id: str, + context: Optional[Dict[str, Any]] = None, + limit: int = 10, + llm_model: Optional[LLMRequest] = None +) -> List[MemoryChunk]: + """使用增强记忆系统检索记忆""" + if not llm_model: + # 获取默认的LLM模型 + from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() + + try: + adapter = await get_enhanced_memory_adapter(llm_model) + return await adapter.retrieve_relevant_memories(query, user_id, context, limit) + except Exception as e: + logger.error(f"使用增强记忆系统检索记忆失败: {e}", exc_info=True) + return [] + + +async def get_memory_context_for_prompt( + query: str, + user_id: str, + context: Optional[Dict[str, Any]] = None, + max_memories: int = 5, + llm_model: Optional[LLMRequest] = None +) -> str: + """获取用于提示词的记忆上下文""" + if not llm_model: + # 获取默认的LLM模型 + from src.llm_models.utils_model import get_global_llm_model + llm_model = get_global_llm_model() + + try: + adapter = await get_enhanced_memory_adapter(llm_model) + return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories) + except Exception as e: + logger.error(f"获取记忆上下文失败: {e}", exc_info=True) + return "" \ No newline at end of file diff --git a/src/chat/memory_system/enhanced_memory_core.py b/src/chat/memory_system/enhanced_memory_core.py new file mode 100644 index 000000000..ee33ab783 --- /dev/null +++ b/src/chat/memory_system/enhanced_memory_core.py @@ -0,0 +1,753 @@ +# -*- coding: utf-8 -*- +""" +增强型精准记忆系统核心模块 +基于文档设计的高效记忆构建、存储与召回优化系统 +""" + +import asyncio +import time +import orjson +import re +from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING +from datetime import datetime, timedelta +from dataclasses import dataclass, asdict +from enum import Enum +import numpy as np + +from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest +from src.config.config import model_config, global_config +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.chat.memory_system.memory_builder import MemoryBuilder +from src.chat.memory_system.memory_fusion import MemoryFusionEngine +from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig +from src.chat.memory_system.metadata_index import MetadataIndexManager +from src.chat.memory_system.multi_stage_retrieval import MultiStageRetrieval, RetrievalConfig + +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + +logger = get_logger(__name__) + + +class MemorySystemStatus(Enum): + """记忆系统状态""" + INITIALIZING = "initializing" + READY = "ready" + BUILDING = "building" + RETRIEVING = "retrieving" + ERROR = "error" + + +@dataclass +class MemorySystemConfig: + """记忆系统配置""" + # 记忆构建配置 + min_memory_length: int = 10 + max_memory_length: int = 500 + memory_value_threshold: float = 0.7 + + # 向量存储配置 + vector_dimension: int = 768 + similarity_threshold: float = 0.8 + + # 召回配置 + coarse_recall_limit: int = 50 + fine_recall_limit: int = 10 + final_recall_limit: int = 5 + + # 融合配置 + fusion_similarity_threshold: float = 0.85 + deduplication_window: timedelta = timedelta(hours=24) + + @classmethod + def from_global_config(cls): + """从全局配置创建配置实例""" + from src.config.config import global_config + + return cls( + # 记忆构建配置 + min_memory_length=global_config.memory.min_memory_length, + max_memory_length=global_config.memory.max_memory_length, + memory_value_threshold=global_config.memory.memory_value_threshold, + + # 向量存储配置 + vector_dimension=global_config.memory.vector_dimension, + similarity_threshold=global_config.memory.vector_similarity_threshold, + + # 召回配置 + coarse_recall_limit=global_config.memory.metadata_filter_limit, + fine_recall_limit=global_config.memory.final_result_limit, + final_recall_limit=global_config.memory.final_result_limit, + + # 融合配置 + fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold, + deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours) + ) + + +class EnhancedMemorySystem: + """增强型精准记忆系统核心类""" + + def __init__( + self, + llm_model: Optional[LLMRequest] = None, + config: Optional[MemorySystemConfig] = None + ): + self.config = config or MemorySystemConfig.from_global_config() + self.llm_model = llm_model + self.status = MemorySystemStatus.INITIALIZING + + # 核心组件 + self.memory_builder: MemoryBuilder = None + self.fusion_engine: MemoryFusionEngine = None + self.vector_storage: VectorStorageManager = None + self.metadata_index: MetadataIndexManager = None + self.retrieval_system: MultiStageRetrieval = None + + # LLM模型 + self.value_assessment_model: LLMRequest = None + self.memory_extraction_model: LLMRequest = None + + # 统计信息 + self.total_memories = 0 + self.last_build_time = None + self.last_retrieval_time = None + + logger.info("EnhancedMemorySystem 初始化开始") + + async def initialize(self): + """异步初始化记忆系统""" + try: + logger.info("正在初始化增强型记忆系统...") + + # 初始化LLM模型 + task_config = ( + self.llm_model.model_for_task + if self.llm_model is not None + else model_config.model_task_config.utils + ) + + self.value_assessment_model = LLMRequest( + model_set=task_config, + request_type="memory.value_assessment" + ) + + self.memory_extraction_model = LLMRequest( + model_set=task_config, + request_type="memory.extraction" + ) + + # 初始化核心组件 + self.memory_builder = MemoryBuilder(self.memory_extraction_model) + self.fusion_engine = MemoryFusionEngine(self.config.fusion_similarity_threshold) + # 创建向量存储配置 + vector_config = VectorStorageConfig( + dimension=self.config.vector_dimension, + similarity_threshold=self.config.similarity_threshold + ) + self.vector_storage = VectorStorageManager(vector_config) + self.metadata_index = MetadataIndexManager() + # 创建检索配置 + retrieval_config = RetrievalConfig( + metadata_filter_limit=self.config.coarse_recall_limit, + vector_search_limit=self.config.fine_recall_limit, + final_result_limit=self.config.final_recall_limit + ) + self.retrieval_system = MultiStageRetrieval(retrieval_config) + + # 加载持久化数据 + await self.vector_storage.load_storage() + await self.metadata_index.load_index() + + self.status = MemorySystemStatus.READY + logger.info("✅ 增强型记忆系统初始化完成") + + except Exception as e: + self.status = MemorySystemStatus.ERROR + logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True) + raise + + async def build_memory_from_conversation( + self, + conversation_text: str, + context: Dict[str, Any], + user_id: str, + timestamp: Optional[float] = None + ) -> List[MemoryChunk]: + """从对话中构建记忆 + + Args: + conversation_text: 对话文本 + context: 上下文信息(包括用户信息、群组信息等) + user_id: 用户ID + timestamp: 时间戳,默认为当前时间 + + Returns: + 构建的记忆块列表 + """ + if self.status != MemorySystemStatus.READY: + raise RuntimeError("记忆系统未就绪") + + self.status = MemorySystemStatus.BUILDING + start_time = time.time() + + try: + normalized_context = self._normalize_context(context, user_id, timestamp) + conversation_text = self._resolve_conversation_context(conversation_text, normalized_context) + + logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}") + + # 1. 信息价值评估 + value_score = await self._assess_information_value(conversation_text, normalized_context) + + if value_score < self.config.memory_value_threshold: + logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建") + self.status = MemorySystemStatus.READY + return [] + + # 2. 构建记忆块 + memory_chunks = await self.memory_builder.build_memories( + conversation_text, + normalized_context, + user_id, + timestamp or time.time() + ) + + if not memory_chunks: + logger.debug("未提取到有效记忆块") + self.status = MemorySystemStatus.READY + return [] + + # 3. 记忆融合与去重 + fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks) + + # 4. 存储记忆 + await self._store_memories(fused_chunks) + + # 5. 更新统计 + self.total_memories += len(fused_chunks) + self.last_build_time = time.time() + + build_time = time.time() - start_time + logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒") + + self.status = MemorySystemStatus.READY + return fused_chunks + + except Exception as e: + self.status = MemorySystemStatus.ERROR + logger.error(f"❌ 记忆构建失败: {e}", exc_info=True) + raise + + async def process_conversation_memory( + self, + conversation_text: str, + context: Dict[str, Any], + user_id: str, + timestamp: Optional[float] = None + ) -> Dict[str, Any]: + """对外暴露的对话记忆处理接口,兼容旧调用方式""" + start_time = time.time() + + try: + normalized_context = self._normalize_context(context, user_id, timestamp) + + memories = await self.build_memory_from_conversation( + conversation_text=conversation_text, + context=normalized_context, + user_id=user_id, + timestamp=timestamp + ) + + processing_time = time.time() - start_time + memory_count = len(memories) + + return { + "success": True, + "created_memories": memories, + "memory_count": memory_count, + "processing_time": processing_time, + "status": self.status.value + } + + except Exception as e: + processing_time = time.time() - start_time + logger.error(f"对话记忆处理失败: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "processing_time": processing_time, + "status": self.status.value + } + + async def retrieve_relevant_memories( + self, + query_text: Optional[str] = None, + user_id: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + limit: int = 5, + **kwargs + ) -> List[MemoryChunk]: + """检索相关记忆,兼容 query/query_text 参数形式""" + if self.status != MemorySystemStatus.READY: + raise RuntimeError("记忆系统未就绪") + + query_text = query_text or kwargs.get("query") + if not query_text: + raise ValueError("query_text 或 query 参数不能为空") + + context = context or {} + user_id = user_id or kwargs.get("user_id") + + self.status = MemorySystemStatus.RETRIEVING + start_time = time.time() + + try: + normalized_context = self._normalize_context(context, user_id, None) + + candidate_memories = list(self.vector_storage.memory_cache.values()) + if user_id: + candidate_memories = [m for m in candidate_memories if m.user_id == user_id] + + if not candidate_memories: + self.status = MemorySystemStatus.READY + self.last_retrieval_time = time.time() + logger.debug(f"未找到用户 {user_id} 的候选记忆") + return [] + + scored_memories = [] + for memory in candidate_memories: + score = self._compute_memory_score(query_text, memory, normalized_context) + if score > 0: + scored_memories.append((memory, score)) + + if not scored_memories: + # 如果所有分数为0,返回最近的记忆作为降级策略 + candidate_memories.sort(key=lambda m: m.metadata.last_accessed, reverse=True) + scored_memories = [(memory, 0.0) for memory in candidate_memories[:limit]] + else: + scored_memories.sort(key=lambda item: item[1], reverse=True) + + top_memories = [memory for memory, _ in scored_memories[:limit]] + + # 更新访问信息和缓存 + for memory, score in scored_memories[:limit]: + memory.update_access() + memory.update_relevance(score) + + cache_entry = self.metadata_index.memory_metadata_cache.get(memory.memory_id) + if cache_entry is not None: + cache_entry["last_accessed"] = memory.metadata.last_accessed + cache_entry["access_count"] = memory.metadata.access_count + cache_entry["relevance_score"] = memory.metadata.relevance_score + + retrieval_time = time.time() - start_time + logger.info( + f"✅ 为用户 {user_id or 'unknown'} 检索到 {len(top_memories)} 条相关记忆,耗时 {retrieval_time:.3f}秒" + ) + + self.last_retrieval_time = time.time() + self.status = MemorySystemStatus.READY + + return top_memories + + except Exception as e: + self.status = MemorySystemStatus.ERROR + logger.error(f"❌ 记忆检索失败: {e}", exc_info=True) + raise + + @staticmethod + def _extract_json_payload(response: str) -> Optional[str]: + """从模型响应中提取JSON部分,兼容Markdown代码块等格式""" + if not response: + return None + + stripped = response.strip() + + # 优先处理Markdown代码块格式 ```json ... ``` + code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL) + if code_block_match: + candidate = code_block_match.group(1).strip() + if candidate: + return candidate + + # 回退到查找第一个 JSON 对象的大括号范围 + start = stripped.find("{") + end = stripped.rfind("}") + if start != -1 and end != -1 and end > start: + return stripped[start:end + 1].strip() + + return stripped if stripped.startswith("{") and stripped.endswith("}") else None + + def _normalize_context( + self, + raw_context: Optional[Dict[str, Any]], + user_id: Optional[str], + timestamp: Optional[float] + ) -> Dict[str, Any]: + """标准化上下文,确保必备字段存在且格式正确""" + context: Dict[str, Any] = {} + if raw_context: + try: + context = dict(raw_context) + except Exception: + context = dict(raw_context or {}) + + # 基础字段 + context["user_id"] = context.get("user_id") or user_id or "unknown" + context["timestamp"] = context.get("timestamp") or timestamp or time.time() + context["message_type"] = context.get("message_type") or "normal" + context["platform"] = context.get("platform") or context.get("source_platform") or "unknown" + + # 标准化关键词类型 + keywords = context.get("keywords") + if keywords is None: + context["keywords"] = [] + elif isinstance(keywords, tuple): + context["keywords"] = list(keywords) + elif not isinstance(keywords, list): + context["keywords"] = [str(keywords)] if keywords else [] + + # 统一 stream_id + stream_id = context.get("stream_id") or context.get("stram_id") + if not stream_id: + potential = context.get("chat_id") or context.get("session_id") + if isinstance(potential, str) and potential: + stream_id = potential + if stream_id: + context["stream_id"] = stream_id + + # chat_id 兜底 + context["chat_id"] = context.get("chat_id") or context.get("stream_id") or f"session_{context['user_id']}" + + # 历史窗口配置 + window_candidate = ( + context.get("history_limit") + or context.get("history_window") + or context.get("memory_history_limit") + ) + if window_candidate is not None: + try: + context["history_limit"] = int(window_candidate) + except (TypeError, ValueError): + context.pop("history_limit", None) + + return context + + def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str: + """使用 stream_id 历史消息充实对话文本,默认回退到传入文本""" + if not context: + return fallback_text + + stream_id = context.get("stream_id") or context.get("stram_id") + if not stream_id: + return fallback_text + + try: + from src.chat.message_receive.chat_stream import get_chat_manager + + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + if not chat_stream or not hasattr(chat_stream, "context_manager"): + logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器") + return fallback_text + + history_limit = self._determine_history_limit(context) + messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True) + if not messages: + logger.debug(f"stream_id={stream_id} 未获取到历史消息") + return fallback_text + + transcript = self._format_history_messages(messages) + if not transcript: + return fallback_text + + cleaned_fallback = (fallback_text or "").strip() + if cleaned_fallback and cleaned_fallback not in transcript: + transcript = f"{transcript}\n[当前消息] {cleaned_fallback}" + + logger.debug( + "使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d", + stream_id, + len(messages), + history_limit, + ) + return transcript + + except Exception as exc: + logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True) + return fallback_text + + def _determine_history_limit(self, context: Dict[str, Any]) -> int: + """确定历史消息获取数量,限制在30-50之间""" + default_limit = 40 + candidate = ( + context.get("history_limit") + or context.get("history_window") + or context.get("memory_history_limit") + ) + + if isinstance(candidate, str): + try: + candidate = int(candidate) + except ValueError: + candidate = None + + if isinstance(candidate, int): + history_limit = max(30, min(50, candidate)) + else: + history_limit = default_limit + + return history_limit + + def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]: + """将历史消息格式化为可供LLM处理的多轮对话文本""" + if not messages: + return None + + lines: List[str] = [] + for msg in messages: + try: + content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None) + if not content: + continue + + content = re.sub(r"\s+", " ", str(content).strip()) + if not content: + continue + + speaker = None + if hasattr(msg, "user_info") and msg.user_info: + speaker = ( + getattr(msg.user_info, "user_nickname", None) + or getattr(msg.user_info, "user_cardname", None) + or getattr(msg.user_info, "user_id", None) + ) + speaker = speaker or getattr(msg, "user_nickname", None) or getattr(msg, "user_id", None) or "用户" + + timestamp_value = getattr(msg, "time", None) or 0.0 + try: + timestamp_dt = datetime.fromtimestamp(float(timestamp_value)) if timestamp_value else datetime.now() + except (TypeError, ValueError, OSError): + timestamp_dt = datetime.now() + + timestamp_str = timestamp_dt.strftime("%Y-%m-%d %H:%M:%S") + lines.append(f"[{timestamp_str}] {speaker}: {content}") + + except Exception as message_exc: + logger.debug(f"格式化历史消息失败: {message_exc}") + continue + + return "\n".join(lines) if lines else None + + async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float: + """评估信息价值 + + Args: + text: 文本内容 + context: 上下文信息 + + Returns: + 价值评分 (0.0-1.0) + """ + try: + # 构建评估提示 + prompt = f""" +请评估以下对话内容的信息价值,重点识别包含个人事实、事件、偏好、观点等重要信息的内容。 + +## 🎯 价值评估重点标准: + +### 高价值信息 (0.7-1.0分): +1. **个人事实** (personal_fact):包含姓名、年龄、职业、联系方式、住址、健康状况、家庭情况等个人信息 +2. **重要事件** (event):约会、会议、旅行、考试、面试、搬家等重要活动或经历 +3. **明确偏好** (preference):表达喜欢/不喜欢的食物、电影、音乐、品牌、生活习惯等偏好信息 +4. **观点态度** (opinion):对事物的评价、看法、建议、态度等主观观点 +5. **核心关系** (relationship):重要的朋友、家人、同事等人际关系信息 + +### 中等价值信息 (0.4-0.7分): +1. **情感表达**:当前情绪状态、心情变化 +2. **日常活动**:常规的工作、学习、生活安排 +3. **一般兴趣**:兴趣爱好、休闲活动 +4. **短期计划**:即将进行的安排和计划 + +### 低价值信息 (0.0-0.4分): +1. **寒暄问候**:简单的打招呼、礼貌用语 +2. **重复信息**:已经多次提到的相同内容 +3. **临时状态**:短暂的情绪波动、临时想法 +4. **无关内容**:与用户画像建立无关的信息 + +对话内容: +{text} + +上下文信息: +- 用户ID: {context.get('user_id', 'unknown')} +- 消息类型: {context.get('message_type', 'unknown')} +- 时间: {datetime.fromtimestamp(context.get('timestamp', time.time()))} + +## 📋 评估要求: + +### 积极识别原则: +- **宁可高估,不可低估** - 对于可能的个人信息给予较高评估 +- **重点关注** - 特别注意包含 personal_fact、event、preference、opinion 的内容 +- **细节丰富** - 具体的细节信息比笼统的描述更有价值 +- **建立画像** - 有助于建立完整用户画像的信息更有价值 + +### 评分指导: +- **0.9-1.0**:核心个人信息(姓名、联系方式、重要偏好) +- **0.7-0.8**:重要的个人事实、观点、事件经历 +- **0.5-0.6**:一般性偏好、日常活动、情感表达 +- **0.3-0.4**:简单的兴趣表达、临时状态 +- **0.0-0.2**:寒暄问候、重复内容、无关信息 + +请以JSON格式输出评估结果: +{{ + "value_score": 0.0到1.0之间的数值, + "reasoning": "评估理由,包含具体识别到的信息类型", + "key_factors": ["关键因素1", "关键因素2"], + "detected_types": ["personal_fact", "preference", "opinion", "event", "relationship", "emotion", "goal"] +}} +""" + + response, _ = await self.value_assessment_model.generate_response_async( + prompt, temperature=0.3 + ) + + # 解析响应 + try: + payload = self._extract_json_payload(response) + if not payload: + raise ValueError("未在响应中找到有效的JSON负载") + + result = orjson.loads(payload) + value_score = float(result.get("value_score", 0.0)) + reasoning = result.get("reasoning", "") + key_factors = result.get("key_factors", []) + + logger.info(f"信息价值评估: {value_score:.2f}, 理由: {reasoning}") + if key_factors: + logger.info(f"关键因素: {', '.join(key_factors)}") + + return max(0.0, min(1.0, value_score)) + + except (orjson.JSONDecodeError, ValueError) as e: + preview = response[:200].replace('\n', ' ') + logger.warning(f"解析价值评估响应失败: {e}, 响应片段: {preview}") + return 0.5 # 默认中等价值 + + except Exception as e: + logger.error(f"信息价值评估失败: {e}", exc_info=True) + return 0.5 # 默认中等价值 + + async def _store_memories(self, memory_chunks: List[MemoryChunk]): + """存储记忆块到各个存储系统""" + if not memory_chunks: + return + + # 并行存储到向量数据库和元数据索引 + storage_tasks = [] + + # 向量存储 + storage_tasks.append(self.vector_storage.store_memories(memory_chunks)) + + # 元数据索引 + storage_tasks.append(self.metadata_index.index_memories(memory_chunks)) + + # 等待所有存储任务完成 + await asyncio.gather(*storage_tasks, return_exceptions=True) + + logger.debug(f"成功存储 {len(memory_chunks)} 条记忆到各个存储系统") + + def get_system_stats(self) -> Dict[str, Any]: + """获取系统统计信息""" + return { + "status": self.status.value, + "total_memories": self.total_memories, + "last_build_time": self.last_build_time, + "last_retrieval_time": self.last_retrieval_time, + "config": asdict(self.config) + } + + def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + """根据查询和上下文为记忆计算匹配分数""" + tokens_query = self._tokenize_text(query_text) + tokens_memory = self._tokenize_text(memory.text_content) + + if tokens_query and tokens_memory: + base_score = len(tokens_query & tokens_memory) / len(tokens_query | tokens_memory) + else: + base_score = 0.0 + + context_keywords = context.get("keywords") or [] + keyword_overlap = 0.0 + if context_keywords: + memory_keywords = set(k.lower() for k in memory.keywords) + keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(len(context_keywords), 1) + + importance_boost = (memory.metadata.importance.value - 1) / 3 * 0.1 + confidence_boost = (memory.metadata.confidence.value - 1) / 3 * 0.05 + + final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost + return max(0.0, min(1.0, final_score)) + + def _tokenize_text(self, text: str) -> Set[str]: + """简单分词,兼容中英文""" + if not text: + return set() + + tokens = re.findall(r"[\w\u4e00-\u9fa5]+", text.lower()) + return {token for token in tokens if len(token) > 1} + + async def maintenance(self): + """系统维护操作""" + try: + logger.info("开始记忆系统维护...") + + # 向量存储优化 + await self.vector_storage.optimize_storage() + + # 元数据索引优化 + await self.metadata_index.optimize_index() + + # 记忆融合引擎维护 + await self.fusion_engine.maintenance() + + logger.info("✅ 记忆系统维护完成") + + except Exception as e: + logger.error(f"❌ 记忆系统维护失败: {e}", exc_info=True) + + async def shutdown(self): + """关闭系统""" + try: + logger.info("正在关闭增强型记忆系统...") + + # 保存持久化数据 + await self.vector_storage.save_storage() + await self.metadata_index.save_index() + + logger.info("✅ 增强型记忆系统已关闭") + + except Exception as e: + logger.error(f"❌ 记忆系统关闭失败: {e}", exc_info=True) + + +# 全局记忆系统实例 +enhanced_memory_system: EnhancedMemorySystem = None + + +def get_enhanced_memory_system() -> EnhancedMemorySystem: + """获取全局记忆系统实例""" + global enhanced_memory_system + if enhanced_memory_system is None: + enhanced_memory_system = EnhancedMemorySystem() + return enhanced_memory_system + + +async def initialize_enhanced_memory_system(): + """初始化全局记忆系统""" + global enhanced_memory_system + if enhanced_memory_system is None: + enhanced_memory_system = EnhancedMemorySystem() + await enhanced_memory_system.initialize() + return enhanced_memory_system \ No newline at end of file diff --git a/src/chat/memory_system/enhanced_memory_hooks.py b/src/chat/memory_system/enhanced_memory_hooks.py new file mode 100644 index 000000000..d28154e91 --- /dev/null +++ b/src/chat/memory_system/enhanced_memory_hooks.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +""" +增强记忆系统钩子 +用于在消息处理过程中自动构建和检索记忆 +""" + +import asyncio +from typing import Dict, List, Any, Optional +from datetime import datetime + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager + +logger = get_logger(__name__) + + +class EnhancedMemoryHooks: + """增强记忆系统钩子 - 自动处理消息的记忆构建和检索""" + + def __init__(self): + self.enabled = (global_config.memory.enable_memory and + global_config.memory.enable_enhanced_memory) + self.processed_messages = set() # 避免重复处理 + + async def process_message_for_memory( + self, + message_content: str, + user_id: str, + chat_id: str, + message_id: str, + context: Optional[Dict[str, Any]] = None + ) -> bool: + """ + 处理消息并构建记忆 + + Args: + message_content: 消息内容 + user_id: 用户ID + chat_id: 聊天ID + message_id: 消息ID + context: 上下文信息 + + Returns: + bool: 是否成功处理 + """ + if not self.enabled: + return False + + if message_id in self.processed_messages: + return False + + try: + # 确保增强记忆管理器已初始化 + if not enhanced_memory_manager.is_initialized: + await enhanced_memory_manager.initialize() + + # 构建上下文 + memory_context = { + "chat_id": chat_id, + "message_id": message_id, + "timestamp": datetime.now().timestamp(), + "message_type": "user_message", + **(context or {}) + } + + # 处理对话并构建记忆 + memory_chunks = await enhanced_memory_manager.process_conversation( + conversation_text=message_content, + context=memory_context, + user_id=user_id, + timestamp=memory_context["timestamp"] + ) + + # 标记消息已处理 + self.processed_messages.add(message_id) + + # 限制处理历史大小 + if len(self.processed_messages) > 1000: + # 移除最旧的500个记录 + self.processed_messages = set(list(self.processed_messages)[-500:]) + + logger.debug(f"为消息 {message_id} 构建了 {len(memory_chunks)} 条记忆") + return len(memory_chunks) > 0 + + except Exception as e: + logger.error(f"处理消息记忆失败: {e}") + return False + + async def get_memory_for_response( + self, + query_text: str, + user_id: str, + chat_id: str, + limit: int = 5 + ) -> List[Dict[str, Any]]: + """ + 为回复获取相关记忆 + + Args: + query_text: 查询文本 + user_id: 用户ID + chat_id: 聊天ID + limit: 返回记忆数量限制 + + Returns: + List[Dict]: 相关记忆列表 + """ + if not self.enabled: + return [] + + try: + # 确保增强记忆管理器已初始化 + if not enhanced_memory_manager.is_initialized: + await enhanced_memory_manager.initialize() + + # 构建查询上下文 + context = { + "chat_id": chat_id, + "query_intent": "response_generation", + "expected_memory_types": [ + "personal_fact", "event", "preference", "opinion" + ] + } + + # 获取相关记忆 + enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context( + query_text=query_text, + user_id=user_id, + context=context, + limit=limit + ) + + # 转换为字典格式 + results = [] + for result in enhanced_results: + memory_dict = { + "content": result.content, + "type": result.memory_type, + "confidence": result.confidence, + "importance": result.importance, + "timestamp": result.timestamp, + "source": result.source + } + results.append(memory_dict) + + logger.debug(f"为回复查询到 {len(results)} 条相关记忆") + return results + + except Exception as e: + logger.error(f"获取回复记忆失败: {e}") + return [] + + async def cleanup_old_memories(self): + """清理旧记忆""" + try: + if enhanced_memory_manager.is_initialized: + # 调用增强记忆系统的维护功能 + await enhanced_memory_manager.enhanced_system.maintenance() + logger.debug("增强记忆系统维护完成") + except Exception as e: + logger.error(f"清理旧记忆失败: {e}") + + def clear_processed_cache(self): + """清除已处理消息的缓存""" + self.processed_messages.clear() + logger.debug("已清除消息处理缓存") + + def enable(self): + """启用记忆钩子""" + self.enabled = True + logger.info("增强记忆钩子已启用") + + def disable(self): + """禁用记忆钩子""" + self.enabled = False + logger.info("增强记忆钩子已禁用") + + +# 创建全局实例 +enhanced_memory_hooks = EnhancedMemoryHooks() \ No newline at end of file diff --git a/src/chat/memory_system/enhanced_memory_integration.py b/src/chat/memory_system/enhanced_memory_integration.py new file mode 100644 index 000000000..fea62db6a --- /dev/null +++ b/src/chat/memory_system/enhanced_memory_integration.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- +""" +增强记忆系统集成脚本 +用于在现有系统中无缝集成增强记忆功能 +""" + +import asyncio +from typing import Dict, Any, Optional + +from src.common.logger import get_logger +from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks + +logger = get_logger(__name__) + + +async def process_user_message_memory( + message_content: str, + user_id: str, + chat_id: str, + message_id: str, + context: Optional[Dict[str, Any]] = None +) -> bool: + """ + 处理用户消息并构建记忆 + + Args: + message_content: 消息内容 + user_id: 用户ID + chat_id: 聊天ID + message_id: 消息ID + context: 额外的上下文信息 + + Returns: + bool: 是否成功构建记忆 + """ + try: + success = await enhanced_memory_hooks.process_message_for_memory( + message_content=message_content, + user_id=user_id, + chat_id=chat_id, + message_id=message_id, + context=context + ) + + if success: + logger.debug(f"成功为消息 {message_id} 构建记忆") + + return success + + except Exception as e: + logger.error(f"处理用户消息记忆失败: {e}") + return False + + +async def get_relevant_memories_for_response( + query_text: str, + user_id: str, + chat_id: str, + limit: int = 5 +) -> Dict[str, Any]: + """ + 为回复获取相关记忆 + + Args: + query_text: 查询文本(通常是用户的当前消息) + user_id: 用户ID + chat_id: 聊天ID + limit: 返回记忆数量限制 + + Returns: + Dict: 包含记忆信息的字典 + """ + try: + memories = await enhanced_memory_hooks.get_memory_for_response( + query_text=query_text, + user_id=user_id, + chat_id=chat_id, + limit=limit + ) + + result = { + "has_memories": len(memories) > 0, + "memories": memories, + "memory_count": len(memories) + } + + logger.debug(f"为回复获取到 {len(memories)} 条相关记忆") + return result + + except Exception as e: + logger.error(f"获取回复记忆失败: {e}") + return { + "has_memories": False, + "memories": [], + "memory_count": 0 + } + + +def format_memories_for_prompt(memories: Dict[str, Any]) -> str: + """ + 格式化记忆信息用于Prompt + + Args: + memories: 记忆信息字典 + + Returns: + str: 格式化后的记忆文本 + """ + if not memories["has_memories"]: + return "" + + memory_lines = ["以下是相关的记忆信息:"] + + for memory in memories["memories"]: + content = memory["content"] + memory_type = memory["type"] + confidence = memory["confidence"] + importance = memory["importance"] + + # 根据重要性添加不同的标记 + importance_marker = "🔥" if importance >= 3 else "⭐" if importance >= 2 else "📝" + confidence_marker = "✅" if confidence >= 3 else "⚠️" if confidence >= 2 else "💭" + + memory_line = f"{importance_marker} {content} ({memory_type}, {confidence_marker}置信度)" + memory_lines.append(memory_line) + + return "\n".join(memory_lines) + + +async def cleanup_memory_system(): + """清理记忆系统""" + try: + await enhanced_memory_hooks.cleanup_old_memories() + logger.info("记忆系统清理完成") + except Exception as e: + logger.error(f"记忆系统清理失败: {e}") + + +def get_memory_system_status() -> Dict[str, Any]: + """ + 获取记忆系统状态 + + Returns: + Dict: 系统状态信息 + """ + from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager + + return { + "enabled": enhanced_memory_hooks.enabled, + "enhanced_system_initialized": enhanced_memory_manager.is_initialized, + "processed_messages_count": len(enhanced_memory_hooks.processed_messages), + "system_type": "enhanced_memory_system" + } + + +# 便捷函数 +async def remember_message( + message: str, + user_id: str = "default_user", + chat_id: str = "default_chat" +) -> bool: + """ + 便捷的记忆构建函数 + + Args: + message: 要记住的消息 + user_id: 用户ID + chat_id: 聊天ID + + Returns: + bool: 是否成功 + """ + import uuid + message_id = str(uuid.uuid4()) + return await process_user_message_memory( + message_content=message, + user_id=user_id, + chat_id=chat_id, + message_id=message_id + ) + + +async def recall_memories( + query: str, + user_id: str = "default_user", + chat_id: str = "default_chat", + limit: int = 5 +) -> Dict[str, Any]: + """ + 便捷的记忆检索函数 + + Args: + query: 查询文本 + user_id: 用户ID + chat_id: 聊天ID + limit: 返回数量限制 + + Returns: + Dict: 记忆信息 + """ + return await get_relevant_memories_for_response( + query_text=query, + user_id=user_id, + chat_id=chat_id, + limit=limit + ) \ No newline at end of file diff --git a/src/chat/memory_system/enhanced_memory_manager.py b/src/chat/memory_system/enhanced_memory_manager.py new file mode 100644 index 000000000..ee267ceda --- /dev/null +++ b/src/chat/memory_system/enhanced_memory_manager.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +""" +增强记忆系统管理器 +替代原有的 Hippocampus 和 instant_memory 系统 +""" + +import asyncio +import time +from typing import Dict, List, Optional, Any, Tuple +from datetime import datetime +from dataclasses import dataclass + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType +from src.chat.memory_system.enhanced_memory_adapter import ( + initialize_enhanced_memory_system +) + +logger = get_logger(__name__) + + +@dataclass +class EnhancedMemoryResult: + """增强记忆查询结果""" + content: str + memory_type: str + confidence: float + importance: float + timestamp: float + source: str = "enhanced_memory" + relevance_score: float = 0.0 + + +class EnhancedMemoryManager: + """增强记忆系统管理器 - 替代原有的 HippocampusManager""" + + def __init__(self): + self.enhanced_system: Optional[EnhancedMemorySystem] = None + self.is_initialized = False + self.user_cache = {} # 用户记忆缓存 + + async def initialize(self): + """初始化增强记忆系统""" + if self.is_initialized: + return + + try: + from src.config.config import global_config + + # 检查是否启用增强记忆系统 + if not global_config.memory.enable_enhanced_memory: + logger.info("增强记忆系统已禁用,跳过初始化") + self.is_initialized = True + return + + logger.info("正在初始化增强记忆系统...") + + # 获取LLM模型 + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config + llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory") + + # 初始化增强记忆系统 + self.enhanced_system = await initialize_enhanced_memory_system(llm_model) + + # 设置全局实例 + global_enhanced_manager = self.enhanced_system + + self.is_initialized = True + logger.info("✅ 增强记忆系统初始化完成") + + except Exception as e: + logger.error(f"❌ 增强记忆系统初始化失败: {e}") + # 如果增强系统初始化失败,创建一个空的管理器避免系统崩溃 + self.enhanced_system = None + self.is_initialized = True # 标记为已初始化但系统不可用 + + def get_hippocampus(self): + """兼容原有接口 - 返回空""" + logger.debug("get_hippocampus 调用 - 增强记忆系统不使用此方法") + return {} + + async def build_memory(self): + """兼容原有接口 - 构建记忆""" + if not self.is_initialized or not self.enhanced_system: + return + + try: + # 增强记忆系统使用实时构建,不需要定时构建 + logger.debug("build_memory 调用 - 增强记忆系统使用实时构建") + except Exception as e: + logger.error(f"build_memory 失败: {e}") + + async def forget_memory(self, percentage: float = 0.005): + """兼容原有接口 - 遗忘机制""" + if not self.is_initialized or not self.enhanced_system: + return + + try: + # 增强记忆系统有内置的遗忘机制 + logger.debug(f"forget_memory 调用 - 参数: {percentage}") + # 可以在这里调用增强系统的维护功能 + await self.enhanced_system.maintenance() + except Exception as e: + logger.error(f"forget_memory 失败: {e}") + + async def consolidate_memory(self): + """兼容原有接口 - 记忆巩固""" + if not self.is_initialized or not self.enhanced_system: + return + + try: + # 增强记忆系统自动处理记忆巩固 + logger.debug("consolidate_memory 调用 - 增强记忆系统自动处理") + except Exception as e: + logger.error(f"consolidate_memory 失败: {e}") + + async def get_memory_from_text( + self, + text: str, + chat_id: str, + user_id: str, + max_memory_num: int = 3, + max_memory_length: int = 2, + time_weight: float = 1.0, + keyword_weight: float = 1.0 + ) -> List[Tuple[str, str]]: + """从文本获取相关记忆 - 兼容原有接口""" + if not self.is_initialized or not self.enhanced_system: + return [] + + try: + # 使用增强记忆系统检索 + context = { + "chat_id": chat_id, + "expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE] + } + + relevant_memories = await self.enhanced_system.retrieve_relevant_memories( + query=text, + user_id=user_id, + context=context, + limit=max_memory_num + ) + + # 转换为原有格式 (topic, content) + results = [] + for memory in relevant_memories: + topic = memory.memory_type.value + content = memory.text_content + results.append((topic, content)) + + logger.debug(f"从文本检索到 {len(results)} 条相关记忆") + return results + + except Exception as e: + logger.error(f"get_memory_from_text 失败: {e}") + return [] + + async def get_memory_from_topic( + self, + valid_keywords: List[str], + max_memory_num: int = 3, + max_memory_length: int = 2, + max_depth: int = 3 + ) -> List[Tuple[str, str]]: + """从关键词获取记忆 - 兼容原有接口""" + if not self.is_initialized or not self.enhanced_system: + return [] + + try: + # 将关键词转换为查询文本 + query_text = " ".join(valid_keywords) + + # 使用增强记忆系统检索 + context = { + "keywords": valid_keywords, + "expected_memory_types": [ + MemoryType.PERSONAL_FACT, + MemoryType.EVENT, + MemoryType.PREFERENCE, + MemoryType.OPINION + ] + } + + relevant_memories = await self.enhanced_system.retrieve_relevant_memories( + query_text=query_text, + user_id="default_user", # 可以根据实际需要传递 + context=context, + limit=max_memory_num + ) + + # 转换为原有格式 (topic, content) + results = [] + for memory in relevant_memories: + topic = memory.memory_type.value + content = memory.text_content + results.append((topic, content)) + + logger.debug(f"从关键词 {valid_keywords} 检索到 {len(results)} 条相关记忆") + return results + + except Exception as e: + logger.error(f"get_memory_from_topic 失败: {e}") + return [] + + def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: + """从单个关键词获取记忆 - 兼容原有接口""" + if not self.is_initialized or not self.enhanced_system: + return [] + + try: + # 同步方法,返回空列表 + logger.debug(f"get_memory_from_keyword 调用 - 关键词: {keyword}") + return [] + except Exception as e: + logger.error(f"get_memory_from_keyword 失败: {e}") + return [] + + async def process_conversation( + self, + conversation_text: str, + context: Dict[str, Any], + user_id: str, + timestamp: Optional[float] = None + ) -> List[MemoryChunk]: + """处理对话并构建记忆 - 新增功能""" + if not self.is_initialized or not self.enhanced_system: + return [] + + try: + result = await self.enhanced_system.process_conversation_memory( + conversation_text=conversation_text, + context=context, + user_id=user_id, + timestamp=timestamp + ) + + # 从结果中提取记忆块 + memory_chunks = [] + if result.get("success"): + memory_chunks = result.get("created_memories", []) + + logger.info(f"从对话构建了 {len(memory_chunks)} 条记忆") + return memory_chunks + + except Exception as e: + logger.error(f"process_conversation 失败: {e}") + return [] + + async def get_enhanced_memory_context( + self, + query_text: str, + user_id: str, + context: Optional[Dict[str, Any]] = None, + limit: int = 5 + ) -> List[EnhancedMemoryResult]: + """获取增强记忆上下文 - 新增功能""" + if not self.is_initialized or not self.enhanced_system: + return [] + + try: + relevant_memories = await self.enhanced_system.retrieve_relevant_memories( + query=query_text, + user_id=user_id, + context=context or {}, + limit=limit + ) + + results = [] + for memory in relevant_memories: + result = EnhancedMemoryResult( + content=memory.text_content, + memory_type=memory.memory_type.value, + confidence=memory.metadata.confidence.value, + importance=memory.metadata.importance.value, + timestamp=memory.metadata.created_at, + source="enhanced_memory", + relevance_score=memory.metadata.relevance_score + ) + results.append(result) + + return results + + except Exception as e: + logger.error(f"get_enhanced_memory_context 失败: {e}") + return [] + + async def shutdown(self): + """关闭增强记忆系统""" + if not self.is_initialized: + return + + try: + if self.enhanced_system: + await self.enhanced_system.shutdown() + logger.info("✅ 增强记忆系统已关闭") + except Exception as e: + logger.error(f"关闭增强记忆系统失败: {e}") + + +# 全局增强记忆管理器实例 +enhanced_memory_manager = EnhancedMemoryManager() \ No newline at end of file diff --git a/src/chat/memory_system/hybrid_memory_design.md b/src/chat/memory_system/hybrid_memory_design.md deleted file mode 100644 index f47d3cacc..000000000 --- a/src/chat/memory_system/hybrid_memory_design.md +++ /dev/null @@ -1,168 +0,0 @@ -# 混合瞬时记忆系统设计 - -## 系统概述 - -融合 `instant_memory.py`(LLM系统)和 `vector_instant_memory.py`(向量系统)的混合记忆系统,智能选择最优策略,无需配置文件控制。 - -## 融合架构 - -``` -聊天输入 → 智能调度器 → 选择策略 → 双重存储 → 融合检索 → 统一输出 -``` - -## 核心组件设计 - -### 1. HybridInstantMemory (主类) - -**职责**: 统一接口,智能调度两套记忆系统 - -**关键方法**: -- `__init__(chat_id)` - 初始化两套子系统 -- `create_and_store_memory(text)` - 智能存储记忆 -- `get_memory(target)` - 融合检索记忆 -- `get_stats()` - 统计信息 - -### 2. MemoryStrategy (策略判断器) - -**职责**: 判断使用哪种记忆策略 - -**判断规则**: -- 文本长度 < 30字符 → 优先向量系统(快速) -- 包含情感词汇/重要信息 → 使用LLM系统(准确) -- 复杂场景 → 双重验证 - -**实现方法**: -```python -def decide_strategy(self, text: str) -> MemoryMode: - # 长度判断 - if len(text) < 30: - return MemoryMode.VECTOR_ONLY - - # 情感关键词检测 - if self._contains_emotional_content(text): - return MemoryMode.LLM_PREFERRED - - # 默认混合模式 - return MemoryMode.HYBRID -``` - -### 3. MemorySync (同步器) - -**职责**: 处理两套系统间的记忆同步和去重 - -**同步策略**: -- 向量系统存储的记忆 → 异步同步到LLM系统 -- LLM系统生成的高质量记忆 → 生成向量存储 -- 定期去重,避免重复记忆 - -### 4. HybridRetriever (检索器) - -**职责**: 融合两种检索方式,提供最优结果 - -**检索策略**: -1. 并行查询向量系统和LLM系统 -2. 按相似度/相关性排序 -3. 去重合并,返回最相关的记忆 - -## 智能调度逻辑 - -### 快速路径 (Vector Path) -- 适用: 短文本、常规对话、快速查询 -- 优势: 响应速度快,资源消耗低 -- 时机: 文本简单、无特殊情感内容 - -### 准确路径 (LLM Path) -- 适用: 重要信息、情感表达、复杂语义 -- 优势: 语义理解深度,记忆质量高 -- 时机: 检测到重要性标志 - -### 混合路径 (Hybrid Path) -- 适用: 中等复杂度内容 -- 策略: 向量快速筛选 + LLM精确处理 -- 平衡: 速度与准确性 - -## 记忆存储策略 - -### 双重备份机制 -1. **主存储**: 根据策略选择主要存储方式 -2. **备份存储**: 异步备份到另一系统 -3. **同步检查**: 定期校验两边数据一致性 - -### 存储优化 -- 向量系统: 立即存储,快速可用 -- LLM系统: 批量处理,高质量整理 -- 重复检测: 跨系统去重 - -## 检索融合策略 - -### 并行检索 -```python -async def get_memory(self, target: str): - # 并行查询两个系统 - vector_task = self.vector_memory.get_memory(target) - llm_task = self.llm_memory.get_memory(target) - - vector_results, llm_results = await asyncio.gather( - vector_task, llm_task, return_exceptions=True - ) - - # 融合结果 - return self._merge_results(vector_results, llm_results) -``` - -### 结果融合 -1. **相似度评分**: 统一两种系统的相似度计算 -2. **权重调整**: 根据查询类型调整系统权重 -3. **去重合并**: 移除重复内容,保留最相关的 - -## 性能优化 - -### 异步处理 -- 向量检索: 同步快速响应 -- LLM处理: 异步后台处理 -- 批量操作: 减少系统调用开销 - -### 缓存策略 -- 热点记忆缓存 -- 查询结果缓存 -- 向量计算缓存 - -### 降级机制 -- 向量系统故障 → 只使用LLM系统 -- LLM系统故障 → 只使用向量系统 -- 全部故障 → 返回空结果,记录错误 - -## 实现计划 - -1. **基础框架**: 创建HybridInstantMemory主类 -2. **策略判断**: 实现智能调度逻辑 -3. **存储融合**: 实现双重存储机制 -4. **检索融合**: 实现并行检索和结果合并 -5. **同步机制**: 实现跨系统数据同步 -6. **性能优化**: 异步处理和缓存优化 -7. **错误处理**: 降级机制和异常处理 - -## 使用接口 - -```python -# 初始化混合记忆系统 -hybrid_memory = HybridInstantMemory(chat_id="user_123") - -# 智能存储记忆 -await hybrid_memory.create_and_store_memory("今天天气真好,我去公园散步了") - -# 融合检索记忆 -memories = await hybrid_memory.get_memory("天气") - -# 获取系统状态 -stats = hybrid_memory.get_stats() -print(f"向量记忆: {stats['vector_count']} 条") -print(f"LLM记忆: {stats['llm_count']} 条") -``` - -## 预期效果 - -- **响应速度**: 比纯LLM系统快60%+ -- **记忆质量**: 比纯向量系统准确30%+ -- **资源使用**: 智能调度,按需使用资源 -- **可靠性**: 双系统备份,单点故障不影响服务 \ No newline at end of file diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py deleted file mode 100644 index 3c3c33514..000000000 --- a/src/chat/memory_system/instant_memory.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- -import time -import re -import orjson -import traceback - -from json_repair import repair_json -from datetime import datetime, timedelta - -from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入 -from src.common.database.sqlalchemy_database_api import get_db_session -from src.config.config import model_config - -from sqlalchemy import select - -logger = get_logger(__name__) - - -class MemoryItem: - def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]): - self.memory_id = memory_id - self.chat_id = chat_id - self.memory_text: str = memory_text - self.keywords: list[str] = keywords - self.create_time: float = time.time() - self.last_view_time: float = time.time() - - -class InstantMemory: - def __init__(self, chat_id): - self.chat_id = chat_id - self.last_view_time = time.time() - self.summary_model = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="memory.summary", - ) - - async def if_need_build(self, text): - prompt = f""" -请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0 -{text} -请只输出1或0就好 - """ - - try: - response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) - print(prompt) - print(response) - - return "1" in response - except Exception as e: - logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}") - return False - - async def build_memory(self, text): - prompt = f""" - 以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出 - {text} - 请以json格式输出一段概括的记忆内容和关键词 - {{ - "memory_text": "记忆内容", - "keywords": "关键词,用/划分" - }} - """ - try: - response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) - # print(prompt) - # print(response) - if not response: - return None - try: - repaired = repair_json(response) - result = orjson.loads(repaired) - memory_text = result.get("memory_text", "") - keywords = result.get("keywords", "") - if isinstance(keywords, str): - keywords_list = [k.strip() for k in keywords.split("/") if k.strip()] - elif isinstance(keywords, list): - keywords_list = keywords - else: - keywords_list = [] - return {"memory_text": memory_text, "keywords": keywords_list} - except Exception as parse_e: - logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}") - return None - except Exception as e: - logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}") - return None - - async def create_and_store_memory(self, text): - if_need = await self.if_need_build(text) - if if_need: - logger.info(f"需要记忆:{text}") - memory = await self.build_memory(text) - if memory and memory.get("memory_text"): - memory_id = f"{self.chat_id}_{time.time()}" - memory_item = MemoryItem( - memory_id=memory_id, - chat_id=self.chat_id, - memory_text=memory["memory_text"], - keywords=memory.get("keywords", []), - ) - await self.store_memory(memory_item) - else: - logger.info(f"不需要记忆:{text}") - - @staticmethod - async def store_memory(memory_item: MemoryItem): - async with get_db_session() as session: - memory = Memory( - memory_id=memory_item.memory_id, - chat_id=memory_item.chat_id, - memory_text=memory_item.memory_text, - keywords=orjson.dumps(memory_item.keywords).decode("utf-8"), - create_time=memory_item.create_time, - last_view_time=memory_item.last_view_time, - ) - session.add(memory) - await session.commit() - - async def get_memory(self, target: str): - from json_repair import repair_json - - prompt = f""" - 请根据以下发言内容,判断是否需要提取记忆 - {target} - 请用json格式输出,包含以下字段: - 其中,time的要求是: - 可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD - 可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前 - 可以选择留空进行模糊搜索 - {{ - "need_memory": 1, - "keywords": "希望获取的记忆关键词,用/划分", - "time": "希望获取的记忆大致时间" - }} - 请只输出json格式,不要输出其他多余内容 - """ - try: - response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) - print(prompt) - print(response) - if not response: - return None - try: - repaired = repair_json(response) - result = orjson.loads(repaired) - # 解析keywords - keywords = result.get("keywords", "") - if isinstance(keywords, str): - keywords_list = [k.strip() for k in keywords.split("/") if k.strip()] - elif isinstance(keywords, list): - keywords_list = keywords - else: - keywords_list = [] - # 解析time为时间段 - time_str = result.get("time", "").strip() - start_time, end_time = self._parse_time_range(time_str) - logger.info(f"start_time: {start_time}, end_time: {end_time}") - # 检索包含关键词的记忆 - memories_set = set() - async with get_db_session() as session: - if start_time and end_time: - start_ts = start_time.timestamp() - end_ts = end_time.timestamp() - - query = (await session.execute( - select(Memory).where( - (Memory.chat_id == self.chat_id) - & (Memory.create_time >= start_ts) - & (Memory.create_time < end_ts) - ) - )).scalars() - else: - result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id)) - query = result.scalars() - for mem in query: - # 对每条记忆 - mem_keywords_str = mem.keywords or "[]" - try: - mem_keywords = orjson.loads(mem_keywords_str) - except orjson.JSONDecodeError: - mem_keywords = [] - # logger.info(f"mem_keywords: {mem_keywords}") - # logger.info(f"keywords_list: {keywords_list}") - for kw in keywords_list: - # logger.info(f"kw: {kw}") - # logger.info(f"kw in mem_keywords: {kw in mem_keywords}") - if kw in mem_keywords: - # logger.info(f"mem.memory_text: {mem.memory_text}") - memories_set.add(mem.memory_text) - break - return list(memories_set) - except Exception as parse_e: - logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}") - return None - except Exception as e: - logger.error(f"获取记忆出现错误:{str(e)} {traceback.format_exc()}") - return None - - @staticmethod - def _parse_time_range(time_str): - # sourcery skip: extract-duplicate-method, use-contextlib-suppress - """ - 支持解析如下格式: - - 具体日期时间:YYYY-MM-DD HH:MM:SS - - 具体日期:YYYY-MM-DD - - 相对时间:今天,昨天,前天,N天前,N个月前 - - 空字符串:返回(None, None) - """ - now = datetime.now() - if not time_str: - return 0, now - time_str = time_str.strip() - # 具体日期时间 - try: - dt = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") - return dt, dt + timedelta(hours=1) - except Exception: - ... - # 具体日期 - try: - dt = datetime.strptime(time_str, "%Y-%m-%d") - return dt, dt + timedelta(days=1) - except Exception: - ... - # 相对时间 - if time_str == "今天": - start = now.replace(hour=0, minute=0, second=0, microsecond=0) - end = start + timedelta(days=1) - return start, end - if time_str == "昨天": - start = (now - timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0) - end = start + timedelta(days=1) - return start, end - if time_str == "前天": - start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0) - end = start + timedelta(days=1) - return start, end - if m := re.match(r"(\d+)天前", time_str): - days = int(m.group(1)) - start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0) - end = start + timedelta(days=1) - return start, end - if m := re.match(r"(\d+)个月前", time_str): - months = int(m.group(1)) - # 近似每月30天 - start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0) - end = start + timedelta(days=1) - return start, end - # 其他无法解析 - return 0, now diff --git a/src/chat/memory_system/integration_layer.py b/src/chat/memory_system/integration_layer.py new file mode 100644 index 000000000..db3be9d6d --- /dev/null +++ b/src/chat/memory_system/integration_layer.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +""" +增强记忆系统集成层 +现在只管理新的增强记忆系统,旧系统已被完全移除 +""" + +import time +import asyncio +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from enum import Enum + +from src.common.logger import get_logger +from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel +from src.llm_models.utils_model import LLMRequest + +logger = get_logger(__name__) + + +class IntegrationMode(Enum): + """集成模式""" + REPLACE = "replace" # 完全替换现有记忆系统 + ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统 + + +@dataclass +class IntegrationConfig: + """集成配置""" + mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY + enable_enhanced_memory: bool = True + memory_value_threshold: float = 0.6 + fusion_threshold: float = 0.85 + max_retrieval_results: int = 10 + enable_learning: bool = True + + +class MemoryIntegrationLayer: + """记忆系统集成层 - 现在只管理增强记忆系统""" + + def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None): + self.llm_model = llm_model + self.config = config or IntegrationConfig() + + # 只初始化增强记忆系统 + self.enhanced_memory: Optional[EnhancedMemorySystem] = None + + # 集成统计 + self.integration_stats = { + "total_queries": 0, + "enhanced_queries": 0, + "memory_creations": 0, + "average_response_time": 0.0, + "success_rate": 0.0 + } + + # 初始化锁 + self._initialization_lock = asyncio.Lock() + self._initialized = False + + async def initialize(self): + """初始化集成层""" + if self._initialized: + return + + async with self._initialization_lock: + if self._initialized: + return + + logger.info("🚀 开始初始化增强记忆系统集成层...") + + try: + # 初始化增强记忆系统 + if self.config.enable_enhanced_memory: + await self._initialize_enhanced_memory() + + self._initialized = True + logger.info("✅ 增强记忆系统集成层初始化完成") + + except Exception as e: + logger.error(f"❌ 集成层初始化失败: {e}", exc_info=True) + raise + + async def _initialize_enhanced_memory(self): + """初始化增强记忆系统""" + try: + logger.debug("初始化增强记忆系统...") + + # 创建增强记忆系统配置 + from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig + memory_config = MemorySystemConfig.from_global_config() + + # 使用集成配置覆盖部分值 + memory_config.memory_value_threshold = self.config.memory_value_threshold + memory_config.fusion_similarity_threshold = self.config.fusion_threshold + memory_config.final_recall_limit = self.config.max_retrieval_results + + # 创建增强记忆系统 + self.enhanced_memory = EnhancedMemorySystem( + config=memory_config + ) + + # 如果外部提供了LLM模型,注入到系统中 + if self.llm_model is not None: + self.enhanced_memory.llm_model = self.llm_model + + # 初始化系统 + await self.enhanced_memory.initialize() + logger.info("✅ 增强记忆系统初始化完成") + + except Exception as e: + logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True) + raise + + async def process_conversation( + self, + conversation_text: str, + context: Dict[str, Any], + user_id: str, + timestamp: Optional[float] = None + ) -> Dict[str, Any]: + """处理对话记忆""" + if not self._initialized or not self.enhanced_memory: + return {"success": False, "error": "Memory system not available"} + + start_time = time.time() + self.integration_stats["total_queries"] += 1 + self.integration_stats["enhanced_queries"] += 1 + + try: + # 直接使用增强记忆系统处理 + result = await self.enhanced_memory.process_conversation_memory( + conversation_text=conversation_text, + context=context, + user_id=user_id, + timestamp=timestamp + ) + + # 更新统计 + processing_time = time.time() - start_time + self._update_response_stats(processing_time, result.get("success", False)) + + if result.get("success"): + created_count = len(result.get("created_memories", [])) + self.integration_stats["memory_creations"] += created_count + logger.debug(f"对话处理完成,创建 {created_count} 条记忆") + + return result + + except Exception as e: + processing_time = time.time() - start_time + self._update_response_stats(processing_time, False) + logger.error(f"处理对话记忆失败: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def retrieve_relevant_memories( + self, + query: str, + user_id: str, + context: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None + ) -> List[MemoryChunk]: + """检索相关记忆""" + if not self._initialized or not self.enhanced_memory: + return [] + + try: + limit = limit or self.config.max_retrieval_results + memories = await self.enhanced_memory.retrieve_relevant_memories( + query=query, + user_id=user_id, + context=context or {}, + limit=limit + ) + + memory_count = len(memories) + logger.debug(f"检索到 {memory_count} 条相关记忆") + return memories + + except Exception as e: + logger.error(f"检索相关记忆失败: {e}", exc_info=True) + return [] + + async def get_system_status(self) -> Dict[str, Any]: + """获取系统状态""" + if not self._initialized: + return {"status": "not_initialized"} + + try: + enhanced_status = {} + if self.enhanced_memory: + enhanced_status = await self.enhanced_memory.get_system_status() + + return { + "status": "initialized", + "mode": self.config.mode.value, + "enhanced_memory": enhanced_status, + "integration_stats": self.integration_stats.copy() + } + + except Exception as e: + logger.error(f"获取系统状态失败: {e}", exc_info=True) + return {"status": "error", "error": str(e)} + + def get_integration_stats(self) -> Dict[str, Any]: + """获取集成统计信息""" + return self.integration_stats.copy() + + def _update_response_stats(self, processing_time: float, success: bool): + """更新响应统计""" + total_queries = self.integration_stats["total_queries"] + if total_queries > 0: + # 更新平均响应时间 + current_avg = self.integration_stats["average_response_time"] + new_avg = (current_avg * (total_queries - 1) + processing_time) / total_queries + self.integration_stats["average_response_time"] = new_avg + + # 更新成功率 + if success: + current_success_rate = self.integration_stats["success_rate"] + new_success_rate = (current_success_rate * (total_queries - 1) + 1) / total_queries + self.integration_stats["success_rate"] = new_success_rate + + async def maintenance(self): + """执行维护操作""" + if not self._initialized: + return + + try: + logger.info("🔧 执行记忆系统集成层维护...") + + if self.enhanced_memory: + await self.enhanced_memory.maintenance() + + logger.info("✅ 记忆系统集成层维护完成") + + except Exception as e: + logger.error(f"❌ 集成层维护失败: {e}", exc_info=True) + + async def shutdown(self): + """关闭集成层""" + if not self._initialized: + return + + try: + logger.info("🔄 关闭记忆系统集成层...") + + if self.enhanced_memory: + await self.enhanced_memory.shutdown() + + self._initialized = False + logger.info("✅ 记忆系统集成层已关闭") + + except Exception as e: + logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True) \ No newline at end of file diff --git a/src/chat/memory_system/memory_activator.py b/src/chat/memory_system/memory_activator.py deleted file mode 100644 index 33d22a5dd..000000000 --- a/src/chat/memory_system/memory_activator.py +++ /dev/null @@ -1,144 +0,0 @@ -import difflib -import orjson - -from json_repair import repair_json -from typing import List, Dict -from datetime import datetime - -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger -from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.chat.memory_system.Hippocampus import hippocampus_manager - - -logger = get_logger("memory_activator") - - -def get_keywords_from_json(json_str) -> List: - """ - 从JSON字符串中提取关键词列表 - - Args: - json_str: JSON格式的字符串 - - Returns: - List[str]: 关键词列表 - """ - try: - # 使用repair_json修复JSON格式 - fixed_json = repair_json(json_str) - - # 如果repair_json返回的是字符串,需要解析为Python对象 - result = orjson.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json - return result.get("keywords", []) - except Exception as e: - logger.error(f"解析关键词JSON失败: {e}") - return [] - - -def init_prompt(): - # --- Group Chat Prompt --- - memory_activator_prompt = """ - 你是一个记忆分析器,你需要根据以下信息来进行回忆 - 以下是一段聊天记录,请根据这些信息,总结出几个关键词作为记忆回忆的触发词 - - 聊天记录: - {obs_info_text} - 你想要回复的消息: - {target_message} - - 历史关键词(请避免重复提取这些关键词): - {cached_keywords} - - 请输出一个json格式,包含以下字段: - {{ - "keywords": ["关键词1", "关键词2", "关键词3",......] - }} - 不要输出其他多余内容,只输出json格式就好 - """ - - Prompt(memory_activator_prompt, "memory_activator_prompt") - - -class MemoryActivator: - def __init__(self): - self.key_words_model = LLMRequest( - model_set=model_config.model_task_config.utils_small, - request_type="memory.activator", - ) - - self.running_memory = [] - self.cached_keywords = set() # 用于缓存历史关键词 - - async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]: - """ - 激活记忆 - """ - # 如果记忆系统被禁用,直接返回空列表 - if not global_config.memory.enable_memory: - return [] - - # 将缓存的关键词转换为字符串,用于prompt - cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词" - - prompt = await global_prompt_manager.format_prompt( - "memory_activator_prompt", - obs_info_text=chat_history_prompt, - target_message=target_message, - cached_keywords=cached_keywords_str, - ) - - # logger.debug(f"prompt: {prompt}") - - response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async( - prompt, temperature=0.5 - ) - - keywords = list(get_keywords_from_json(response)) - - # 更新关键词缓存 - if keywords: - # 限制缓存大小,最多保留10个关键词 - if len(self.cached_keywords) > 10: - # 转换为列表,移除最早的关键词 - cached_list = list(self.cached_keywords) - self.cached_keywords = set(cached_list[-8:]) - - # 添加新的关键词到缓存 - self.cached_keywords.update(keywords) - - # 调用记忆系统获取相关记忆 - related_memory = await hippocampus_manager.get_memory_from_topic( - valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3 - ) - - logger.debug(f"当前记忆关键词: {self.cached_keywords} ") - logger.debug(f"获取到的记忆: {related_memory}") - - # 激活时,所有已有记忆的duration+1,达到3则移除 - for m in self.running_memory[:]: - m["duration"] = m.get("duration", 1) + 1 - self.running_memory = [m for m in self.running_memory if m["duration"] < 3] - - if related_memory: - for topic, memory in related_memory: - # 检查是否已存在相同topic或相似内容(相似度>=0.7)的记忆 - exists = any( - m["topic"] == topic or difflib.SequenceMatcher(None, m["content"], memory).ratio() >= 0.7 - for m in self.running_memory - ) - if not exists: - self.running_memory.append( - {"topic": topic, "content": memory, "timestamp": datetime.now().isoformat(), "duration": 1} - ) - logger.debug(f"添加新记忆: {topic} - {memory}") - - # 限制同时加载的记忆条数,最多保留最后3条 - if len(self.running_memory) > 3: - self.running_memory = self.running_memory[-3:] - - return self.running_memory - - -init_prompt() diff --git a/src/chat/memory_system/memory_builder.py b/src/chat/memory_system/memory_builder.py new file mode 100644 index 000000000..6a48a4dba --- /dev/null +++ b/src/chat/memory_system/memory_builder.py @@ -0,0 +1,602 @@ +# -*- coding: utf-8 -*- +""" +记忆构建模块 +从对话流中提取高质量、结构化记忆单元 +""" + +import re +import time +import orjson +from typing import Dict, List, Optional, Tuple, Any, Set +from datetime import datetime +from dataclasses import dataclass +from enum import Enum + +from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest +from src.chat.memory_system.memory_chunk import ( + MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel, + ContentStructure, MemoryMetadata, create_memory_chunk +) + +logger = get_logger(__name__) + + +class ExtractionStrategy(Enum): + """提取策略""" + LLM_BASED = "llm_based" # 基于LLM的智能提取 + RULE_BASED = "rule_based" # 基于规则的提取 + HYBRID = "hybrid" # 混合策略 + + +@dataclass +class ExtractionResult: + """提取结果""" + memories: List[MemoryChunk] + confidence_scores: List[float] + extraction_time: float + strategy_used: ExtractionStrategy + + +class MemoryBuilder: + """记忆构建器""" + + def __init__(self, llm_model: LLMRequest): + self.llm_model = llm_model + self.extraction_stats = { + "total_extractions": 0, + "successful_extractions": 0, + "failed_extractions": 0, + "average_confidence": 0.0 + } + + async def build_memories( + self, + conversation_text: str, + context: Dict[str, Any], + user_id: str, + timestamp: float + ) -> List[MemoryChunk]: + """从对话中构建记忆""" + start_time = time.time() + + try: + logger.debug(f"开始从对话构建记忆,文本长度: {len(conversation_text)}") + + # 预处理文本 + processed_text = self._preprocess_text(conversation_text) + + # 确定提取策略 + strategy = self._determine_extraction_strategy(processed_text, context) + + # 根据策略提取记忆 + if strategy == ExtractionStrategy.LLM_BASED: + memories = await self._extract_with_llm(processed_text, context, user_id, timestamp) + elif strategy == ExtractionStrategy.RULE_BASED: + memories = self._extract_with_rules(processed_text, context, user_id, timestamp) + else: # HYBRID + memories = await self._extract_with_hybrid(processed_text, context, user_id, timestamp) + + # 后处理和验证 + validated_memories = self._validate_and_enhance_memories(memories, context) + + # 更新统计 + extraction_time = time.time() - start_time + self._update_extraction_stats(len(validated_memories), extraction_time) + + logger.info(f"✅ 成功构建 {len(validated_memories)} 条记忆,耗时 {extraction_time:.2f}秒") + return validated_memories + + except Exception as e: + logger.error(f"❌ 记忆构建失败: {e}", exc_info=True) + self.extraction_stats["failed_extractions"] += 1 + return [] + + def _preprocess_text(self, text: str) -> str: + """预处理文本""" + # 移除多余的空白字符 + text = re.sub(r'\s+', ' ', text.strip()) + + # 移除特殊字符,但保留基本标点 + text = re.sub(r'[^\w\s\u4e00-\u9fff,。!?、;:""''()【】]', '', text) + + # 截断过长的文本 + if len(text) > 2000: + text = text[:2000] + "..." + + return text + + def _determine_extraction_strategy(self, text: str, context: Dict[str, Any]) -> ExtractionStrategy: + """确定提取策略""" + text_length = len(text) + has_structured_data = any(key in context for key in ["structured_data", "entities", "keywords"]) + message_type = context.get("message_type", "normal") + + # 短文本使用规则提取 + if text_length < 50: + return ExtractionStrategy.RULE_BASED + + # 包含结构化数据使用混合策略 + if has_structured_data: + return ExtractionStrategy.HYBRID + + # 系统消息或命令使用规则提取 + if message_type in ["command", "system"]: + return ExtractionStrategy.RULE_BASED + + # 默认使用LLM提取 + return ExtractionStrategy.LLM_BASED + + async def _extract_with_llm( + self, + text: str, + context: Dict[str, Any], + user_id: str, + timestamp: float + ) -> List[MemoryChunk]: + """使用LLM提取记忆""" + try: + prompt = self._build_llm_extraction_prompt(text, context) + + response, _ = await self.llm_model.generate_response_async( + prompt, temperature=0.3 + ) + + # 解析LLM响应 + memories = self._parse_llm_response(response, user_id, timestamp, context) + + return memories + + except Exception as e: + logger.error(f"LLM提取失败: {e}") + return [] + + def _extract_with_rules( + self, + text: str, + context: Dict[str, Any], + user_id: str, + timestamp: float + ) -> List[MemoryChunk]: + """使用规则提取记忆""" + memories = [] + + # 规则1: 检测个人信息 + personal_info = self._extract_personal_info(text, user_id, timestamp, context) + memories.extend(personal_info) + + # 规则2: 检测偏好信息 + preferences = self._extract_preferences(text, user_id, timestamp, context) + memories.extend(preferences) + + # 规则3: 检测事件信息 + events = self._extract_events(text, user_id, timestamp, context) + memories.extend(events) + + return memories + + async def _extract_with_hybrid( + self, + text: str, + context: Dict[str, Any], + user_id: str, + timestamp: float + ) -> List[MemoryChunk]: + """混合策略提取记忆""" + all_memories = [] + + # 首先使用规则提取 + rule_memories = self._extract_with_rules(text, context, user_id, timestamp) + all_memories.extend(rule_memories) + + # 然后使用LLM提取 + llm_memories = await self._extract_with_llm(text, context, user_id, timestamp) + + # 合并和去重 + final_memories = self._merge_hybrid_results(all_memories, llm_memories) + + return final_memories + + def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str: + """构建LLM提取提示""" + current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + chat_id = context.get("chat_id", "unknown") + message_type = context.get("message_type", "normal") + + prompt = f""" +你是一个专业的记忆提取专家。请从以下对话中主动识别并提取所有可能重要的信息,特别是包含个人事实、事件、偏好、观点等要素的内容。 + +当前时间: {current_date} +聊天ID: {chat_id} +消息类型: {message_type} + +对话内容: +{text} + +## 🎯 重点记忆类型识别指南 + +### 1. **个人事实** (personal_fact) - 高优先级记忆 +**包括但不限于:** +- 基本信息:姓名、年龄、职业、学校、专业、工作地点 +- 生活状况:住址、电话、邮箱、社交账号 +- 身份特征:生日、星座、血型、国籍、语言能力 +- 健康信息:身体状况、疾病史、药物过敏、运动习惯 +- 家庭情况:家庭成员、婚姻状况、子女信息、宠物信息 + +**判断标准:** 涉及个人身份和生活的重要信息,都应该记忆 + +### 2. **事件** (event) - 高优先级记忆 +**包括但不限于:** +- 重要时刻:生日聚会、毕业典礼、婚礼、旅行 +- 日常活动:上班、上学、约会、看电影、吃饭 +- 特殊经历:考试、面试、会议、搬家、购物 +- 计划安排:约会、会议、旅行、活动 + +**判断标准:** 涉及时间地点的具体活动和经历,都应该记忆 + +### 3. **偏好** (preference) - 高优先级记忆 +**包括但不限于:** +- 饮食偏好:喜欢的食物、餐厅、口味、禁忌 +- 娱乐喜好:喜欢的电影、音乐、游戏、书籍 +- 生活习惯:作息时间、运动方式、购物习惯 +- 消费偏好:品牌喜好、价格敏感度、购物场所 +- 风格偏好:服装风格、装修风格、颜色喜好 + +**判断标准:** 任何表达"喜欢"、"不喜欢"、"习惯"、"经常"等偏好的内容,都应该记忆 + +### 4. **观点** (opinion) - 高优先级记忆 +**包括但不限于:** +- 评价看法:对事物的评价、意见、建议 +- 价值判断:认为什么重要、什么不重要 +- 态度立场:支持、反对、中立的态度 +- 感受反馈:对经历的感受、反馈 + +**判断标准:** 任何表达主观看法和态度的内容,都应该记忆 + +### 5. **关系** (relationship) - 中等优先级记忆 +**包括但不限于:** +- 人际关系:朋友、同事、家人、恋人的关系状态 +- 社交互动:与他人的互动、交流、合作 +- 群体归属:所属团队、组织、社群 + +### 6. **情感** (emotion) - 中等优先级记忆 +**包括但不限于:** +- 情绪状态:开心、难过、生气、焦虑、兴奋 +- 情感变化:情绪的转变、原因和结果 + +### 7. **目标** (goal) - 中等优先级记忆 +**包括但不限于:** +- 计划安排:短期计划、长期目标 +- 愿望期待:想要实现的事情、期望的结果 + +## 📝 记忆提取原则 + +### ✅ 积极提取原则: +1. **宁可错记,不可遗漏** - 对于可能的个人信息优先记忆 +2. **持续追踪** - 相同信息的多次提及要强化记忆 +3. **上下文关联** - 结合对话背景理解信息重要性 +4. **细节丰富** - 记录具体的细节和描述 + +### 🎯 重要性等级标准: +- **4分 (关键)**:个人核心信息(姓名、联系方式、重要日期) +- **3分 (高)**:重要偏好、观点、经历事件 +- **2分 (一般)**:一般性信息、日常活动、感受表达 +- **1分 (低)**:琐碎细节、重复信息、临时状态 + +### 🔍 置信度标准: +- **4分 (已验证)**:用户明确确认的信息 +- **3分 (高)**:用户直接表达的清晰信息 +- **2分 (中等)**:需要推理或上下文判断的信息 +- **1分 (低)**:模糊或不完整的信息 + +输出格式要求: +{{ + "memories": [ + {{ + "type": "记忆类型", + "subject": "主语(通常是用户)", + "predicate": "谓语(动作/状态)", + "object": "宾语(对象/属性)", + "keywords": ["关键词1", "关键词2"], + "importance": "重要性等级(1-4)", + "confidence": "置信度(1-4)", + "reasoning": "提取理由" + }} + ] +}} + +注意: +1. 只提取确实值得记忆的信息,不要提取琐碎内容 +2. 确保提取的信息准确、具体、有价值 +3. 使用主谓宾结构确保信息清晰 +4. 重要性等级: 1=低, 2=一般, 3=高, 4=关键 +5. 置信度: 1=低, 2=中等, 3=高, 4=已验证 +""" + + return prompt + + def _parse_llm_response( + self, + response: str, + user_id: str, + timestamp: float, + context: Dict[str, Any] + ) -> List[MemoryChunk]: + """解析LLM响应""" + memories = [] + + try: + data = orjson.loads(response) + memory_list = data.get("memories", []) + + for mem_data in memory_list: + try: + # 创建记忆块 + memory = create_memory_chunk( + user_id=user_id, + subject=mem_data.get("subject", user_id), + predicate=mem_data.get("predicate", ""), + obj=mem_data.get("object", ""), + memory_type=MemoryType(mem_data.get("type", "contextual")), + chat_id=context.get("chat_id"), + source_context=mem_data.get("reasoning", ""), + importance=ImportanceLevel(mem_data.get("importance", 2)), + confidence=ConfidenceLevel(mem_data.get("confidence", 2)) + ) + + # 添加关键词 + keywords = mem_data.get("keywords", []) + for keyword in keywords: + memory.add_keyword(keyword) + + memories.append(memory) + + except Exception as e: + logger.warning(f"解析单个记忆失败: {e}, 数据: {mem_data}") + continue + + except Exception as e: + logger.error(f"解析LLM响应失败: {e}, 响应: {response}") + + return memories + + def _extract_personal_info( + self, + text: str, + user_id: str, + timestamp: float, + context: Dict[str, Any] + ) -> List[MemoryChunk]: + """提取个人信息""" + memories = [] + + # 常见个人信息模式 + patterns = { + r"我叫(\w+)": ("is_named", {"name": "$1"}), + r"我今年(\d+)岁": ("is_age", {"age": "$1"}), + r"我是(\w+)": ("is_profession", {"profession": "$1"}), + r"我住在(\w+)": ("lives_in", {"location": "$1"}), + r"我的电话是(\d+)": ("has_phone", {"phone": "$1"}), + r"我的邮箱是(\w+@\w+\.\w+)": ("has_email", {"email": "$1"}), + } + + for pattern, (predicate, obj_template) in patterns.items(): + match = re.search(pattern, text) + if match: + obj = obj_template + for i, group in enumerate(match.groups(), 1): + obj = {k: v.replace(f"${i}", group) for k, v in obj.items()} + + memory = create_memory_chunk( + user_id=user_id, + subject=user_id, + predicate=predicate, + obj=obj, + memory_type=MemoryType.PERSONAL_FACT, + chat_id=context.get("chat_id"), + importance=ImportanceLevel.HIGH, + confidence=ConfidenceLevel.HIGH + ) + + memories.append(memory) + + return memories + + def _extract_preferences( + self, + text: str, + user_id: str, + timestamp: float, + context: Dict[str, Any] + ) -> List[MemoryChunk]: + """提取偏好信息""" + memories = [] + + # 偏好模式 + preference_patterns = [ + (r"我喜欢(.+)", "likes"), + (r"我不喜欢(.+)", "dislikes"), + (r"我爱吃(.+)", "likes_food"), + (r"我讨厌(.+)", "hates"), + (r"我最喜欢的(.+)", "favorite_is"), + ] + + for pattern, predicate in preference_patterns: + match = re.search(pattern, text) + if match: + memory = create_memory_chunk( + user_id=user_id, + subject=user_id, + predicate=predicate, + obj=match.group(1), + memory_type=MemoryType.PREFERENCE, + chat_id=context.get("chat_id"), + importance=ImportanceLevel.NORMAL, + confidence=ConfidenceLevel.MEDIUM + ) + + memories.append(memory) + + return memories + + def _extract_events( + self, + text: str, + user_id: str, + timestamp: float, + context: Dict[str, Any] + ) -> List[MemoryChunk]: + """提取事件信息""" + memories = [] + + # 事件关键词 + event_keywords = ["明天", "今天", "昨天", "上周", "下周", "约会", "会议", "活动", "旅行", "生日"] + + if any(keyword in text for keyword in event_keywords): + memory = create_memory_chunk( + user_id=user_id, + subject=user_id, + predicate="mentioned_event", + obj={"event_text": text, "timestamp": timestamp}, + memory_type=MemoryType.EVENT, + chat_id=context.get("chat_id"), + importance=ImportanceLevel.NORMAL, + confidence=ConfidenceLevel.MEDIUM + ) + + memories.append(memory) + + return memories + + def _merge_hybrid_results( + self, + rule_memories: List[MemoryChunk], + llm_memories: List[MemoryChunk] + ) -> List[MemoryChunk]: + """合并混合策略结果""" + all_memories = rule_memories.copy() + + # 添加LLM记忆,避免重复 + for llm_memory in llm_memories: + is_duplicate = False + for rule_memory in rule_memories: + if llm_memory.is_similar_to(rule_memory, threshold=0.7): + is_duplicate = True + # 合并置信度 + rule_memory.metadata.confidence = ConfidenceLevel( + max(rule_memory.metadata.confidence.value, llm_memory.metadata.confidence.value) + ) + break + + if not is_duplicate: + all_memories.append(llm_memory) + + return all_memories + + def _validate_and_enhance_memories( + self, + memories: List[MemoryChunk], + context: Dict[str, Any] + ) -> List[MemoryChunk]: + """验证和增强记忆""" + validated_memories = [] + + for memory in memories: + # 基本验证 + if not self._validate_memory(memory): + continue + + # 增强记忆 + enhanced_memory = self._enhance_memory(memory, context) + validated_memories.append(enhanced_memory) + + return validated_memories + + def _validate_memory(self, memory: MemoryChunk) -> bool: + """验证记忆块""" + # 检查基本字段 + if not memory.content.subject or not memory.content.predicate: + logger.debug(f"记忆块缺少主语或谓语: {memory.memory_id}") + return False + + # 检查内容长度 + content_length = len(memory.text_content) + if content_length < 5 or content_length > 500: + logger.debug(f"记忆块内容长度异常: {content_length}") + return False + + # 检查置信度 + if memory.metadata.confidence == ConfidenceLevel.LOW: + logger.debug(f"记忆块置信度过低: {memory.memory_id}") + return False + + return True + + def _enhance_memory( + self, + memory: MemoryChunk, + context: Dict[str, Any] + ) -> MemoryChunk: + """增强记忆块""" + # 添加时间上下文 + if not memory.temporal_context: + memory.temporal_context = { + "timestamp": memory.metadata.created_at, + "timezone": context.get("timezone", "UTC"), + "day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A") + } + + # 添加情感上下文(如果有) + if context.get("sentiment"): + memory.metadata.emotional_context = context["sentiment"] + + # 自动添加标签 + self._auto_tag_memory(memory) + + return memory + + def _auto_tag_memory(self, memory: MemoryChunk): + """自动为记忆添加标签""" + # 基于记忆类型的自动标签 + type_tags = { + MemoryType.PERSONAL_FACT: ["个人信息", "基本资料"], + MemoryType.EVENT: ["事件", "日程"], + MemoryType.PREFERENCE: ["偏好", "喜好"], + MemoryType.OPINION: ["观点", "态度"], + MemoryType.RELATIONSHIP: ["关系", "社交"], + MemoryType.EMOTION: ["情感", "情绪"], + MemoryType.KNOWLEDGE: ["知识", "信息"], + MemoryType.SKILL: ["技能", "能力"], + MemoryType.GOAL: ["目标", "计划"], + MemoryType.EXPERIENCE: ["经验", "经历"], + } + + tags = type_tags.get(memory.memory_type, []) + for tag in tags: + memory.add_tag(tag) + + def _update_extraction_stats(self, success_count: int, extraction_time: float): + """更新提取统计""" + self.extraction_stats["total_extractions"] += 1 + self.extraction_stats["successful_extractions"] += success_count + self.extraction_stats["failed_extractions"] += max(0, 1 - success_count) + + # 更新平均置信度 + if self.extraction_stats["successful_extractions"] > 0: + total_confidence = self.extraction_stats["average_confidence"] * (self.extraction_stats["successful_extractions"] - success_count) + # 假设新记忆的平均置信度为0.8 + total_confidence += 0.8 * success_count + self.extraction_stats["average_confidence"] = total_confidence / self.extraction_stats["successful_extractions"] + + def get_extraction_stats(self) -> Dict[str, Any]: + """获取提取统计信息""" + return self.extraction_stats.copy() + + def reset_stats(self): + """重置统计信息""" + self.extraction_stats = { + "total_extractions": 0, + "successful_extractions": 0, + "failed_extractions": 0, + "average_confidence": 0.0 + } \ No newline at end of file diff --git a/src/chat/memory_system/memory_chunk.py b/src/chat/memory_system/memory_chunk.py new file mode 100644 index 000000000..0b9da0180 --- /dev/null +++ b/src/chat/memory_system/memory_chunk.py @@ -0,0 +1,463 @@ +# -*- coding: utf-8 -*- +""" +结构化记忆单元设计 +实现高质量、结构化的记忆单元,符合文档设计规范 +""" + +import time +import uuid +import orjson +from typing import Dict, List, Optional, Any, Union +from dataclasses import dataclass, field, asdict +from datetime import datetime +from enum import Enum +import hashlib + +import numpy as np +from src.common.logger import get_logger + +logger = get_logger(__name__) + + +class MemoryType(Enum): + """记忆类型分类""" + PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等) + EVENT = "event" # 事件(重要经历、约会等) + PREFERENCE = "preference" # 偏好(喜好、习惯等) + OPINION = "opinion" # 观点(对事物的看法) + RELATIONSHIP = "relationship" # 关系(与他人的关系) + EMOTION = "emotion" # 情感状态 + KNOWLEDGE = "knowledge" # 知识信息 + SKILL = "skill" # 技能能力 + GOAL = "goal" # 目标计划 + EXPERIENCE = "experience" # 经验教训 + CONTEXTUAL = "contextual" # 上下文信息 + + +class ConfidenceLevel(Enum): + """置信度等级""" + LOW = 1 # 低置信度,可能不准确 + MEDIUM = 2 # 中等置信度,有一定依据 + HIGH = 3 # 高置信度,有明确来源 + VERIFIED = 4 # 已验证,非常可靠 + + +class ImportanceLevel(Enum): + """重要性等级""" + LOW = 1 # 低重要性,普通信息 + NORMAL = 2 # 一般重要性,日常信息 + HIGH = 3 # 高重要性,重要信息 + CRITICAL = 4 # 关键重要性,核心信息 + + +@dataclass +class ContentStructure: + """主谓宾三元组结构""" + subject: str # 主语(通常为用户) + predicate: str # 谓语(动作、状态、关系) + object: Union[str, Dict] # 宾语(对象、属性、值) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式""" + return { + "subject": self.subject, + "predicate": self.predicate, + "object": self.object + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure": + """从字典创建实例""" + return cls( + subject=data.get("subject", ""), + predicate=data.get("predicate", ""), + object=data.get("object", "") + ) + + def __str__(self) -> str: + """字符串表示""" + if isinstance(self.object, dict): + object_str = str(self.object) + else: + object_str = str(self.object) + return f"{self.subject} {self.predicate} {object_str}" + + +@dataclass +class MemoryMetadata: + """记忆元数据""" + # 基础信息 + memory_id: str # 唯一标识符 + user_id: str # 用户ID + chat_id: Optional[str] = None # 聊天ID(群聊或私聊) + + # 时间信息 + created_at: float = 0.0 # 创建时间戳 + last_accessed: float = 0.0 # 最后访问时间 + last_modified: float = 0.0 # 最后修改时间 + + # 统计信息 + access_count: int = 0 # 访问次数 + relevance_score: float = 0.0 # 相关度评分 + + # 信心和重要性 + confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM + importance: ImportanceLevel = ImportanceLevel.NORMAL + + # 情感和关系 + emotional_context: Optional[str] = None # 情感上下文 + relationship_score: float = 0.0 # 关系分(0-1) + + # 来源和验证 + source_context: Optional[str] = None # 来源上下文片段 + verification_status: bool = False # 验证状态 + + def __post_init__(self): + """后初始化处理""" + if not self.memory_id: + self.memory_id = str(uuid.uuid4()) + + if self.created_at == 0: + self.created_at = time.time() + + if self.last_accessed == 0: + self.last_accessed = self.created_at + + if self.last_modified == 0: + self.last_modified = self.created_at + + def update_access(self): + """更新访问信息""" + current_time = time.time() + self.last_accessed = current_time + self.access_count += 1 + + def update_relevance(self, new_score: float): + """更新相关度评分""" + self.relevance_score = max(0.0, min(1.0, new_score)) + self.last_modified = time.time() + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式""" + return { + "memory_id": self.memory_id, + "user_id": self.user_id, + "chat_id": self.chat_id, + "created_at": self.created_at, + "last_accessed": self.last_accessed, + "last_modified": self.last_modified, + "access_count": self.access_count, + "relevance_score": self.relevance_score, + "confidence": self.confidence.value, + "importance": self.importance.value, + "emotional_context": self.emotional_context, + "relationship_score": self.relationship_score, + "source_context": self.source_context, + "verification_status": self.verification_status + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata": + """从字典创建实例""" + return cls( + memory_id=data.get("memory_id", ""), + user_id=data.get("user_id", ""), + chat_id=data.get("chat_id"), + created_at=data.get("created_at", 0), + last_accessed=data.get("last_accessed", 0), + last_modified=data.get("last_modified", 0), + access_count=data.get("access_count", 0), + relevance_score=data.get("relevance_score", 0.0), + confidence=ConfidenceLevel(data.get("confidence", ConfidenceLevel.MEDIUM.value)), + importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)), + emotional_context=data.get("emotional_context"), + relationship_score=data.get("relationship_score", 0.0), + source_context=data.get("source_context"), + verification_status=data.get("verification_status", False) + ) + + +@dataclass +class MemoryChunk: + """结构化记忆单元 - 核心数据结构""" + + # 元数据 + metadata: MemoryMetadata + + # 内容结构 + content: ContentStructure # 主谓宾结构 + memory_type: MemoryType # 记忆类型 + + # 扩展信息 + keywords: List[str] = field(default_factory=list) # 关键词列表 + tags: List[str] = field(default_factory=list) # 标签列表 + categories: List[str] = field(default_factory=list) # 分类列表 + + # 语义信息 + embedding: Optional[List[float]] = None # 语义向量 + semantic_hash: Optional[str] = None # 语义哈希值 + + # 关联信息 + related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表 + temporal_context: Optional[Dict[str, Any]] = None # 时间上下文 + + def __post_init__(self): + """后初始化处理""" + if self.embedding and len(self.embedding) > 0: + self._generate_semantic_hash() + + def _generate_semantic_hash(self): + """生成语义哈希值""" + if not self.embedding: + return + + try: + # 使用向量和内容生成稳定的哈希 + content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}" + embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding])) + + hash_input = f"{content_str}|{embedding_str}" + hash_object = hashlib.sha256(hash_input.encode('utf-8')) + self.semantic_hash = hash_object.hexdigest()[:16] + + except Exception as e: + logger.warning(f"生成语义哈希失败: {e}") + self.semantic_hash = str(uuid.uuid4())[:16] + + @property + def memory_id(self) -> str: + """获取记忆ID""" + return self.metadata.memory_id + + @property + def user_id(self) -> str: + """获取用户ID""" + return self.metadata.user_id + + @property + def text_content(self) -> str: + """获取文本内容""" + return str(self.content) + + def update_access(self): + """更新访问信息""" + self.metadata.update_access() + + def update_relevance(self, new_score: float): + """更新相关度评分""" + self.metadata.update_relevance(new_score) + + def add_keyword(self, keyword: str): + """添加关键词""" + if keyword and keyword not in self.keywords: + self.keywords.append(keyword.strip()) + + def add_tag(self, tag: str): + """添加标签""" + if tag and tag not in self.tags: + self.tags.append(tag.strip()) + + def add_category(self, category: str): + """添加分类""" + if category and category not in self.categories: + self.categories.append(category.strip()) + + def add_related_memory(self, memory_id: str): + """添加关联记忆""" + if memory_id and memory_id not in self.related_memories: + self.related_memories.append(memory_id) + + def set_embedding(self, embedding: List[float]): + """设置语义向量""" + self.embedding = embedding + self._generate_semantic_hash() + + def calculate_similarity(self, other: "MemoryChunk") -> float: + """计算与另一个记忆块的相似度""" + if not self.embedding or not other.embedding: + return 0.0 + + try: + # 计算余弦相似度 + v1 = np.array(self.embedding) + v2 = np.array(other.embedding) + + dot_product = np.dot(v1, v2) + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + similarity = dot_product / (norm1 * norm2) + return max(0.0, min(1.0, similarity)) + + except Exception as e: + logger.warning(f"计算记忆相似度失败: {e}") + return 0.0 + + def to_dict(self) -> Dict[str, Any]: + """转换为完整的字典格式""" + return { + "metadata": self.metadata.to_dict(), + "content": self.content.to_dict(), + "memory_type": self.memory_type.value, + "keywords": self.keywords, + "tags": self.tags, + "categories": self.categories, + "embedding": self.embedding, + "semantic_hash": self.semantic_hash, + "related_memories": self.related_memories, + "temporal_context": self.temporal_context + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk": + """从字典创建实例""" + metadata = MemoryMetadata.from_dict(data.get("metadata", {})) + content = ContentStructure.from_dict(data.get("content", {})) + + chunk = cls( + metadata=metadata, + content=content, + memory_type=MemoryType(data.get("memory_type", MemoryType.CONTEXTUAL.value)), + keywords=data.get("keywords", []), + tags=data.get("tags", []), + categories=data.get("categories", []), + embedding=data.get("embedding"), + semantic_hash=data.get("semantic_hash"), + related_memories=data.get("related_memories", []), + temporal_context=data.get("temporal_context") + ) + + return chunk + + def to_json(self) -> str: + """转换为JSON字符串""" + return orjson.dumps(self.to_dict(), ensure_ascii=False).decode('utf-8') + + @classmethod + def from_json(cls, json_str: str) -> "MemoryChunk": + """从JSON字符串创建实例""" + try: + data = orjson.loads(json_str) + return cls.from_dict(data) + except Exception as e: + logger.error(f"从JSON创建记忆块失败: {e}") + raise + + def is_similar_to(self, other: "MemoryChunk", threshold: float = 0.8) -> bool: + """判断是否与另一个记忆块相似""" + if self.semantic_hash and other.semantic_hash: + return self.semantic_hash == other.semantic_hash + + return self.calculate_similarity(other) >= threshold + + def merge_with(self, other: "MemoryChunk") -> bool: + """与另一个记忆块合并(如果相似)""" + if not self.is_similar_to(other): + return False + + try: + # 合并关键词 + for keyword in other.keywords: + self.add_keyword(keyword) + + # 合并标签 + for tag in other.tags: + self.add_tag(tag) + + # 合并分类 + for category in other.categories: + self.add_category(category) + + # 合并关联记忆 + for memory_id in other.related_memories: + self.add_related_memory(memory_id) + + # 更新元数据 + self.metadata.last_modified = time.time() + self.metadata.access_count += other.metadata.access_count + self.metadata.relevance_score = max(self.metadata.relevance_score, other.metadata.relevance_score) + + # 更新置信度 + if other.metadata.confidence.value > self.metadata.confidence.value: + self.metadata.confidence = other.metadata.confidence + + # 更新重要性 + if other.metadata.importance.value > self.metadata.importance.value: + self.metadata.importance = other.metadata.importance + + logger.debug(f"记忆块 {self.memory_id} 合并了记忆块 {other.memory_id}") + return True + + except Exception as e: + logger.error(f"合并记忆块失败: {e}") + return False + + def __str__(self) -> str: + """字符串表示""" + type_emoji = { + MemoryType.PERSONAL_FACT: "👤", + MemoryType.EVENT: "📅", + MemoryType.PREFERENCE: "❤️", + MemoryType.OPINION: "💭", + MemoryType.RELATIONSHIP: "👥", + MemoryType.EMOTION: "😊", + MemoryType.KNOWLEDGE: "📚", + MemoryType.SKILL: "🛠️", + MemoryType.GOAL: "🎯", + MemoryType.EXPERIENCE: "💡", + MemoryType.CONTEXTUAL: "📝" + } + + emoji = type_emoji.get(self.memory_type, "📝") + confidence_icon = "●" * self.metadata.confidence.value + importance_icon = "★" * self.metadata.importance.value + + return f"{emoji} [{self.memory_type.value}] {self.text_content} {confidence_icon} {importance_icon}" + + def __repr__(self) -> str: + """调试表示""" + return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})" + + +def create_memory_chunk( + user_id: str, + subject: str, + predicate: str, + obj: Union[str, Dict], + memory_type: MemoryType, + chat_id: Optional[str] = None, + source_context: Optional[str] = None, + importance: ImportanceLevel = ImportanceLevel.NORMAL, + confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM, + **kwargs +) -> MemoryChunk: + """便捷的内存块创建函数""" + metadata = MemoryMetadata( + memory_id="", + user_id=user_id, + chat_id=chat_id, + created_at=time.time(), + last_accessed=0, + last_modified=0, + confidence=confidence, + importance=importance, + source_context=source_context + ) + + content = ContentStructure( + subject=subject, + predicate=predicate, + object=obj + ) + + chunk = MemoryChunk( + metadata=metadata, + content=content, + memory_type=memory_type, + **kwargs + ) + + return chunk \ No newline at end of file diff --git a/src/chat/memory_system/memory_fusion.py b/src/chat/memory_system/memory_fusion.py new file mode 100644 index 000000000..0ff33b0f9 --- /dev/null +++ b/src/chat/memory_system/memory_fusion.py @@ -0,0 +1,522 @@ +# -*- coding: utf-8 -*- +""" +记忆融合与去重机制 +避免记忆碎片化,确保长期记忆库的高质量 +""" + +import time +import hashlib +from typing import Dict, List, Optional, Tuple, Set, Any +from datetime import datetime, timedelta +from dataclasses import dataclass +from collections import defaultdict +import asyncio + +from src.common.logger import get_logger +from src.chat.memory_system.memory_chunk import ( + MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel +) + +logger = get_logger(__name__) + + +@dataclass +class FusionResult: + """融合结果""" + original_count: int + fused_count: int + removed_duplicates: int + merged_memories: List[MemoryChunk] + fusion_time: float + details: List[str] + + +@dataclass +class DuplicateGroup: + """重复记忆组""" + group_id: str + memories: List[MemoryChunk] + similarity_matrix: List[List[float]] + representative_memory: Optional[MemoryChunk] = None + + +class MemoryFusionEngine: + """记忆融合引擎""" + + def __init__(self, similarity_threshold: float = 0.85): + self.similarity_threshold = similarity_threshold + self.fusion_stats = { + "total_fusions": 0, + "memories_fused": 0, + "duplicates_removed": 0, + "average_similarity": 0.0 + } + + # 融合策略配置 + self.fusion_strategies = { + "semantic_similarity": True, # 语义相似性融合 + "temporal_proximity": True, # 时间接近性融合 + "logical_consistency": True, # 逻辑一致性融合 + "confidence_boosting": True, # 置信度提升 + "importance_preservation": True # 重要性保持 + } + + async def fuse_memories( + self, + new_memories: List[MemoryChunk], + existing_memories: Optional[List[MemoryChunk]] = None + ) -> List[MemoryChunk]: + """融合记忆列表""" + start_time = time.time() + + try: + if not new_memories: + return [] + + logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}") + + # 1. 检测重复记忆组 + duplicate_groups = await self._detect_duplicate_groups( + new_memories, existing_memories or [] + ) + + # 2. 对每个重复组进行融合 + fused_memories = [] + removed_count = 0 + + for group in duplicate_groups: + if len(group.memories) == 1: + # 单个记忆,直接添加 + fused_memories.append(group.memories[0]) + else: + # 多个记忆,进行融合 + fused_memory = await self._fuse_memory_group(group) + if fused_memory: + fused_memories.append(fused_memory) + removed_count += len(group.memories) - 1 + + # 3. 更新统计 + fusion_time = time.time() - start_time + self._update_fusion_stats(len(new_memories), removed_count, fusion_time) + + logger.info(f"✅ 记忆融合完成: {len(fused_memories)} 条记忆,移除 {removed_count} 条重复") + return fused_memories + + except Exception as e: + logger.error(f"❌ 记忆融合失败: {e}", exc_info=True) + return new_memories # 失败时返回原始记忆 + + async def _detect_duplicate_groups( + self, + new_memories: List[MemoryChunk], + existing_memories: List[MemoryChunk] + ) -> List[DuplicateGroup]: + """检测重复记忆组""" + all_memories = new_memories + existing_memories + groups = [] + processed_ids = set() + + for i, memory1 in enumerate(all_memories): + if memory1.memory_id in processed_ids: + continue + + # 创建新的重复组 + group = DuplicateGroup( + group_id=f"group_{len(groups)}", + memories=[memory1], + similarity_matrix=[[1.0]] + ) + + processed_ids.add(memory1.memory_id) + + # 寻找相似记忆 + for j, memory2 in enumerate(all_memories[i+1:], i+1): + if memory2.memory_id in processed_ids: + continue + + similarity = self._calculate_comprehensive_similarity(memory1, memory2) + + if similarity >= self.similarity_threshold: + group.memories.append(memory2) + processed_ids.add(memory2.memory_id) + + # 更新相似度矩阵 + self._update_similarity_matrix(group, memory2, similarity) + + if len(group.memories) > 1: + # 选择代表性记忆 + group.representative_memory = self._select_representative_memory(group) + groups.append(group) + + logger.debug(f"检测到 {len(groups)} 个重复记忆组") + return groups + + def _calculate_comprehensive_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float: + """计算综合相似度""" + similarity_scores = [] + + # 1. 语义向量相似度 + if self.fusion_strategies["semantic_similarity"]: + semantic_sim = mem1.calculate_similarity(mem2) + similarity_scores.append(("semantic", semantic_sim)) + + # 2. 文本相似度 + text_sim = self._calculate_text_similarity(mem1.text_content, mem2.text_content) + similarity_scores.append(("text", text_sim)) + + # 3. 关键词重叠度 + keyword_sim = self._calculate_keyword_similarity(mem1.keywords, mem2.keywords) + similarity_scores.append(("keyword", keyword_sim)) + + # 4. 类型一致性 + type_consistency = 1.0 if mem1.memory_type == mem2.memory_type else 0.0 + similarity_scores.append(("type", type_consistency)) + + # 5. 时间接近性 + if self.fusion_strategies["temporal_proximity"]: + temporal_sim = self._calculate_temporal_similarity( + mem1.metadata.created_at, mem2.metadata.created_at + ) + similarity_scores.append(("temporal", temporal_sim)) + + # 6. 逻辑一致性 + if self.fusion_strategies["logical_consistency"]: + logical_sim = self._calculate_logical_similarity(mem1, mem2) + similarity_scores.append(("logical", logical_sim)) + + # 计算加权平均相似度 + weights = { + "semantic": 0.35, + "text": 0.25, + "keyword": 0.15, + "type": 0.10, + "temporal": 0.10, + "logical": 0.05 + } + + weighted_sum = 0.0 + total_weight = 0.0 + + for score_type, score in similarity_scores: + weight = weights.get(score_type, 0.1) + weighted_sum += weight * score + total_weight += weight + + final_similarity = weighted_sum / total_weight if total_weight > 0 else 0.0 + + logger.debug(f"综合相似度计算: {final_similarity:.3f} - {[(t, f'{s:.3f}') for t, s in similarity_scores]}") + + return final_similarity + + def _calculate_text_similarity(self, text1: str, text2: str) -> float: + """计算文本相似度""" + # 简单的词汇重叠度计算 + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 or not words2: + return 0.0 + + intersection = words1 & words2 + union = words1 | words2 + + jaccard_similarity = len(intersection) / len(union) + return jaccard_similarity + + def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float: + """计算关键词相似度""" + if not keywords1 or not keywords2: + return 0.0 + + set1 = set(k.lower() for k in keywords1) + set2 = set(k.lower() for k in keywords2) + + intersection = set1 & set2 + union = set1 | set2 + + return len(intersection) / len(union) if union else 0.0 + + def _calculate_temporal_similarity(self, time1: float, time2: float) -> float: + """计算时间相似度""" + time_diff = abs(time1 - time2) + hours_diff = time_diff / 3600 + + # 24小时内相似度较高 + if hours_diff <= 24: + return 1.0 - (hours_diff / 24) + elif hours_diff <= 168: # 一周内 + return 0.7 - ((hours_diff - 24) / 168) * 0.5 + else: + return 0.2 + + def _calculate_logical_similarity(self, mem1: MemoryChunk, mem2: MemoryChunk) -> float: + """计算逻辑一致性""" + # 检查主谓宾结构的逻辑一致性 + consistency_score = 0.0 + + # 主语一致性 + if mem1.content.subject == mem2.content.subject: + consistency_score += 0.4 + + # 谓语相似性 + predicate_sim = self._calculate_text_similarity(mem1.content.predicate, mem2.content.predicate) + consistency_score += predicate_sim * 0.3 + + # 宾语相似性 + if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str): + object_sim = self._calculate_text_similarity( + str(mem1.content.object), str(mem2.content.object) + ) + consistency_score += object_sim * 0.3 + + return consistency_score + + def _update_similarity_matrix(self, group: DuplicateGroup, new_memory: MemoryChunk, similarity: float): + """更新组的相似度矩阵""" + # 为新记忆添加行和列 + for i in range(len(group.similarity_matrix)): + group.similarity_matrix[i].append(similarity) + + # 添加新行 + new_row = [similarity] + [1.0] * len(group.similarity_matrix) + group.similarity_matrix.append(new_row) + + def _select_representative_memory(self, group: DuplicateGroup) -> MemoryChunk: + """选择代表性记忆""" + if not group.memories: + return None + + # 评分标准 + best_memory = None + best_score = -1.0 + + for memory in group.memories: + score = 0.0 + + # 置信度权重 + score += memory.metadata.confidence.value * 0.3 + + # 重要性权重 + score += memory.metadata.importance.value * 0.3 + + # 访问次数权重 + score += min(memory.metadata.access_count * 0.1, 0.2) + + # 相关度权重 + score += memory.metadata.relevance_score * 0.2 + + if score > best_score: + best_score = score + best_memory = memory + + return best_memory + + async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]: + """融合记忆组""" + if not group.memories: + return None + + if len(group.memories) == 1: + return group.memories[0] + + try: + # 选择基础记忆(通常是代表性记忆) + base_memory = group.representative_memory or group.memories[0] + + # 融合其他记忆的属性 + fused_memory = await self._merge_memory_attributes(base_memory, group.memories) + + # 更新元数据 + self._update_fused_metadata(fused_memory, group) + + logger.debug(f"成功融合记忆组,包含 {len(group.memories)} 条原始记忆") + return fused_memory + + except Exception as e: + logger.error(f"融合记忆组失败: {e}") + # 返回置信度最高的记忆 + return max(group.memories, key=lambda m: m.metadata.confidence.value) + + async def _merge_memory_attributes( + self, + base_memory: MemoryChunk, + memories: List[MemoryChunk] + ) -> MemoryChunk: + """合并记忆属性""" + # 创建基础记忆的深拷贝 + fused_memory = MemoryChunk.from_dict(base_memory.to_dict()) + + # 合并关键词 + all_keywords = set() + for memory in memories: + all_keywords.update(memory.keywords) + fused_memory.keywords = sorted(all_keywords) + + # 合并标签 + all_tags = set() + for memory in memories: + all_tags.update(memory.tags) + fused_memory.tags = sorted(all_tags) + + # 合并分类 + all_categories = set() + for memory in memories: + all_categories.update(memory.categories) + fused_memory.categories = sorted(all_categories) + + # 合并关联记忆 + all_related = set() + for memory in memories: + all_related.update(memory.related_memories) + # 移除对自身和组内记忆的引用 + all_related = {rid for rid in all_related if rid not in [m.memory_id for m in memories]} + fused_memory.related_memories = sorted(all_related) + + # 合并时间上下文 + if self.fusion_strategies["temporal_proximity"]: + fused_memory.temporal_context = self._merge_temporal_context(memories) + + return fused_memory + + def _update_fused_metadata(self, fused_memory: MemoryChunk, group: DuplicateGroup): + """更新融合记忆的元数据""" + # 更新修改时间 + fused_memory.metadata.last_modified = time.time() + + # 计算平均访问次数 + total_access = sum(m.metadata.access_count for m in group.memories) + fused_memory.metadata.access_count = total_access + + # 提升置信度(如果有多个来源支持) + if self.fusion_strategies["confidence_boosting"] and len(group.memories) > 1: + max_confidence = max(m.metadata.confidence.value for m in group.memories) + if max_confidence < ConfidenceLevel.VERIFIED.value: + fused_memory.metadata.confidence = ConfidenceLevel( + min(max_confidence + 1, ConfidenceLevel.VERIFIED.value) + ) + + # 保持最高重要性 + if self.fusion_strategies["importance_preservation"]: + max_importance = max(m.metadata.importance.value for m in group.memories) + fused_memory.metadata.importance = ImportanceLevel(max_importance) + + # 计算平均相关度 + avg_relevance = sum(m.metadata.relevance_score for m in group.memories) / len(group.memories) + fused_memory.metadata.relevance_score = min(avg_relevance * 1.1, 1.0) # 稍微提升相关度 + + # 设置来源信息 + source_ids = [m.memory_id[:8] for m in group.memories] + fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}" + + def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]: + """合并时间上下文""" + contexts = [m.temporal_context for m in memories if m.temporal_context] + + if not contexts: + return {} + + # 计算时间范围 + timestamps = [m.metadata.created_at for m in memories] + earliest_time = min(timestamps) + latest_time = max(timestamps) + + merged_context = { + "earliest_timestamp": earliest_time, + "latest_timestamp": latest_time, + "time_span_hours": (latest_time - earliest_time) / 3600, + "source_memories": len(memories) + } + + # 合并其他上下文信息 + for context in contexts: + for key, value in context.items(): + if key not in ["timestamp", "earliest_timestamp", "latest_timestamp"]: + if key not in merged_context: + merged_context[key] = value + elif merged_context[key] != value: + merged_context[key] = f"multiple: {value}" + + return merged_context + + async def incremental_fusion( + self, + new_memory: MemoryChunk, + existing_memories: List[MemoryChunk] + ) -> Tuple[MemoryChunk, List[MemoryChunk]]: + """增量融合(单个新记忆与现有记忆融合)""" + # 寻找相似记忆 + similar_memories = [] + + for existing in existing_memories: + similarity = self._calculate_comprehensive_similarity(new_memory, existing) + if similarity >= self.similarity_threshold: + similar_memories.append((existing, similarity)) + + if not similar_memories: + # 没有相似记忆,直接返回 + return new_memory, existing_memories + + # 按相似度排序 + similar_memories.sort(key=lambda x: x[1], reverse=True) + + # 与最相似的记忆融合 + best_match, similarity = similar_memories[0] + + # 创建融合组 + group = DuplicateGroup( + group_id=f"incremental_{int(time.time())}", + memories=[new_memory, best_match], + similarity_matrix=[[1.0, similarity], [similarity, 1.0]] + ) + + # 执行融合 + fused_memory = await self._fuse_memory_group(group) + + # 从现有记忆中移除被融合的记忆 + updated_existing = [m for m in existing_memories if m.memory_id != best_match.memory_id] + updated_existing.append(fused_memory) + + logger.debug(f"增量融合完成,相似度: {similarity:.3f}") + + return fused_memory, updated_existing + + def _update_fusion_stats(self, original_count: int, removed_count: int, fusion_time: float): + """更新融合统计""" + self.fusion_stats["total_fusions"] += 1 + self.fusion_stats["memories_fused"] += original_count + self.fusion_stats["duplicates_removed"] += removed_count + + # 更新平均相似度(估算) + if removed_count > 0: + avg_similarity = 0.9 # 假设平均相似度较高 + total_similarity = self.fusion_stats["average_similarity"] * (self.fusion_stats["total_fusions"] - 1) + total_similarity += avg_similarity + self.fusion_stats["average_similarity"] = total_similarity / self.fusion_stats["total_fusions"] + + async def maintenance(self): + """维护操作""" + try: + logger.info("开始记忆融合引擎维护...") + + # 可以在这里添加定期维护任务,如: + # - 重新评估低置信度记忆 + # - 清理孤立记忆引用 + # - 优化融合策略参数 + + logger.info("✅ 记忆融合引擎维护完成") + + except Exception as e: + logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True) + + def get_fusion_stats(self) -> Dict[str, Any]: + """获取融合统计信息""" + return self.fusion_stats.copy() + + def reset_stats(self): + """重置统计信息""" + self.fusion_stats = { + "total_fusions": 0, + "memories_fused": 0, + "duplicates_removed": 0, + "average_similarity": 0.0 + } \ No newline at end of file diff --git a/src/chat/memory_system/memory_integration_hooks.py b/src/chat/memory_system/memory_integration_hooks.py new file mode 100644 index 000000000..6728613e4 --- /dev/null +++ b/src/chat/memory_system/memory_integration_hooks.py @@ -0,0 +1,542 @@ +# -*- coding: utf-8 -*- +""" +记忆系统集成钩子 +提供与现有MoFox Bot系统的无缝集成点 +""" + +import asyncio +import time +from typing import Dict, List, Optional, Any, Callable +from dataclasses import dataclass + +from src.common.logger import get_logger +from src.chat.memory_system.enhanced_memory_adapter import ( + get_enhanced_memory_adapter, + process_conversation_with_enhanced_memory, + retrieve_memories_with_enhanced_system, + get_memory_context_for_prompt +) + +logger = get_logger(__name__) + + +@dataclass +class HookResult: + """钩子执行结果""" + success: bool + data: Any = None + error: Optional[str] = None + processing_time: float = 0.0 + + +class MemoryIntegrationHooks: + """记忆系统集成钩子""" + + def __init__(self): + self.hooks_registered = False + self.hook_stats = { + "message_processing_hooks": 0, + "memory_retrieval_hooks": 0, + "prompt_enhancement_hooks": 0, + "total_hook_executions": 0, + "average_hook_time": 0.0 + } + + async def register_hooks(self): + """注册所有集成钩子""" + if self.hooks_registered: + return + + try: + logger.info("🔗 注册记忆系统集成钩子...") + + # 注册消息处理钩子 + await self._register_message_processing_hooks() + + # 注册记忆检索钩子 + await self._register_memory_retrieval_hooks() + + # 注册提示词增强钩子 + await self._register_prompt_enhancement_hooks() + + # 注册系统维护钩子 + await self._register_maintenance_hooks() + + self.hooks_registered = True + logger.info("✅ 记忆系统集成钩子注册完成") + + except Exception as e: + logger.error(f"❌ 注册记忆系统集成钩子失败: {e}", exc_info=True) + + async def _register_message_processing_hooks(self): + """注册消息处理钩子""" + try: + # 钩子1: 在消息处理后创建记忆 + await self._register_post_message_hook() + + # 钩子2: 在聊天流保存时处理记忆 + await self._register_chat_stream_hook() + + logger.debug("消息处理钩子注册完成") + + except Exception as e: + logger.error(f"注册消息处理钩子失败: {e}") + + async def _register_memory_retrieval_hooks(self): + """注册记忆检索钩子""" + try: + # 钩子1: 在生成回复前检索相关记忆 + await self._register_pre_response_hook() + + # 钩子2: 在知识库查询前增强上下文 + await self._register_knowledge_query_hook() + + logger.debug("记忆检索钩子注册完成") + + except Exception as e: + logger.error(f"注册记忆检索钩子失败: {e}") + + async def _register_prompt_enhancement_hooks(self): + """注册提示词增强钩子""" + try: + # 钩子1: 增强提示词构建 + await self._register_prompt_building_hook() + + logger.debug("提示词增强钩子注册完成") + + except Exception as e: + logger.error(f"注册提示词增强钩子失败: {e}") + + async def _register_maintenance_hooks(self): + """注册系统维护钩子""" + try: + # 钩子1: 系统维护时的记忆系统维护 + await self._register_system_maintenance_hook() + + logger.debug("系统维护钩子注册完成") + + except Exception as e: + logger.error(f"注册系统维护钩子失败: {e}") + + async def _register_post_message_hook(self): + """注册消息后处理钩子""" + try: + # 这里需要根据实际的系统架构来注册钩子 + # 以下是一个示例实现,需要根据实际的插件系统或事件系统来调整 + + # 尝试注册到事件系统 + try: + from src.plugin_system.core.event_manager import event_manager + from src.plugin_system.base.component_types import EventType + + # 注册消息后处理事件 + event_manager.subscribe( + EventType.MESSAGE_PROCESSED, + self._on_message_processed_handler + ) + logger.debug("已注册到事件系统的消息处理钩子") + + except ImportError: + logger.debug("事件系统不可用,跳过事件钩子注册") + + # 尝试注册到消息管理器 + try: + from src.chat.message_manager import message_manager + + # 如果消息管理器支持钩子注册 + if hasattr(message_manager, 'register_post_process_hook'): + message_manager.register_post_process_hook( + self._on_message_processed_hook + ) + logger.debug("已注册到消息管理器的处理钩子") + + except ImportError: + logger.debug("消息管理器不可用,跳过消息管理器钩子注册") + + except Exception as e: + logger.error(f"注册消息后处理钩子失败: {e}") + + async def _register_chat_stream_hook(self): + """注册聊天流钩子""" + try: + # 尝试注册到聊天流管理器 + try: + from src.chat.message_receive.chat_stream import get_chat_manager + + chat_manager = get_chat_manager() + if hasattr(chat_manager, 'register_save_hook'): + chat_manager.register_save_hook( + self._on_chat_stream_save_hook + ) + logger.debug("已注册到聊天流管理器的保存钩子") + + except ImportError: + logger.debug("聊天流管理器不可用,跳过聊天流钩子注册") + + except Exception as e: + logger.error(f"注册聊天流钩子失败: {e}") + + async def _register_pre_response_hook(self): + """注册回复前钩子""" + try: + # 尝试注册到回复生成器 + try: + from src.chat.replyer.default_generator import default_generator + + if hasattr(default_generator, 'register_pre_generation_hook'): + default_generator.register_pre_generation_hook( + self._on_pre_response_hook + ) + logger.debug("已注册到回复生成器的前置钩子") + + except ImportError: + logger.debug("回复生成器不可用,跳过回复前钩子注册") + + except Exception as e: + logger.error(f"注册回复前钩子失败: {e}") + + async def _register_knowledge_query_hook(self): + """注册知识库查询钩子""" + try: + # 尝试注册到知识库系统 + try: + from src.chat.knowledge.knowledge_lib import knowledge_manager + + if hasattr(knowledge_manager, 'register_query_enhancer'): + knowledge_manager.register_query_enhancer( + self._on_knowledge_query_hook + ) + logger.debug("已注册到知识库的查询增强钩子") + + except ImportError: + logger.debug("知识库系统不可用,跳过知识库钩子注册") + + except Exception as e: + logger.error(f"注册知识库查询钩子失败: {e}") + + async def _register_prompt_building_hook(self): + """注册提示词构建钩子""" + try: + # 尝试注册到提示词系统 + try: + from src.chat.utils.prompt import prompt_manager + + if hasattr(prompt_manager, 'register_enhancer'): + prompt_manager.register_enhancer( + self._on_prompt_building_hook + ) + logger.debug("已注册到提示词管理器的增强钩子") + + except ImportError: + logger.debug("提示词系统不可用,跳过提示词钩子注册") + + except Exception as e: + logger.error(f"注册提示词构建钩子失败: {e}") + + async def _register_system_maintenance_hook(self): + """注册系统维护钩子""" + try: + # 尝试注册到系统维护器 + try: + from src.manager.async_task_manager import async_task_manager + + # 注册定期维护任务 + async_task_manager.add_task(MemoryMaintenanceTask()) + logger.debug("已注册到系统维护器的定期任务") + + except ImportError: + logger.debug("异步任务管理器不可用,跳过系统维护钩子注册") + + except Exception as e: + logger.error(f"注册系统维护钩子失败: {e}") + + # 钩子处理器方法 + + async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult: + """事件系统的消息处理处理器""" + return await self._on_message_processed_hook(event_data) + + async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult: + """消息后处理钩子""" + start_time = time.time() + + try: + self.hook_stats["message_processing_hooks"] += 1 + + # 提取必要的信息 + message_info = message_data.get("message_info", {}) + user_info = message_info.get("user_info", {}) + conversation_text = message_data.get("processed_plain_text", "") + + if not conversation_text: + return HookResult(success=True, data="No conversation text") + + user_id = str(user_info.get("user_id", "unknown")) + context = { + "chat_id": message_data.get("chat_id"), + "message_type": message_data.get("message_type", "normal"), + "platform": message_info.get("platform", "unknown"), + "interest_value": message_data.get("interest_value", 0.0), + "keywords": message_data.get("key_words", []), + "timestamp": message_data.get("time", time.time()) + } + + # 使用增强记忆系统处理对话 + result = await process_conversation_with_enhanced_memory( + conversation_text, context, user_id + ) + + processing_time = time.time() - start_time + self._update_hook_stats(processing_time) + + if result["success"]: + logger.debug(f"消息处理钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆") + return HookResult(success=True, data=result, processing_time=processing_time) + else: + logger.warning(f"消息处理钩子执行失败: {result.get('error')}") + return HookResult(success=False, error=result.get('error'), processing_time=processing_time) + + except Exception as e: + processing_time = time.time() - start_time + logger.error(f"消息处理钩子执行异常: {e}", exc_info=True) + return HookResult(success=False, error=str(e), processing_time=processing_time) + + async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult: + """聊天流保存钩子""" + start_time = time.time() + + try: + self.hook_stats["message_processing_hooks"] += 1 + + # 从聊天流数据中提取对话信息 + stream_context = chat_stream_data.get("stream_context", {}) + user_id = stream_context.get("user_id", "unknown") + messages = stream_context.get("messages", []) + + if not messages: + return HookResult(success=True, data="No messages to process") + + # 构建对话文本 + conversation_parts = [] + for msg in messages[-10:]: # 只处理最近10条消息 + text = msg.get("processed_plain_text", "") + if text: + conversation_parts.append(f"{msg.get('user_nickname', 'User')}: {text}") + + conversation_text = "\n".join(conversation_parts) + if not conversation_text: + return HookResult(success=True, data="No conversation text") + + context = { + "chat_id": chat_stream_data.get("chat_id"), + "stream_id": chat_stream_data.get("stream_id"), + "platform": chat_stream_data.get("platform", "unknown"), + "message_count": len(messages), + "timestamp": time.time() + } + + # 使用增强记忆系统处理对话 + result = await process_conversation_with_enhanced_memory( + conversation_text, context, user_id + ) + + processing_time = time.time() - start_time + self._update_hook_stats(processing_time) + + if result["success"]: + logger.debug(f"聊天流保存钩子执行成功,创建 {len(result.get('created_memories', []))} 条记忆") + return HookResult(success=True, data=result, processing_time=processing_time) + else: + logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}") + return HookResult(success=False, error=result.get('error'), processing_time=processing_time) + + except Exception as e: + processing_time = time.time() - start_time + logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True) + return HookResult(success=False, error=str(e), processing_time=processing_time) + + async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult: + """回复前钩子""" + start_time = time.time() + + try: + self.hook_stats["memory_retrieval_hooks"] += 1 + + # 提取查询信息 + query = response_data.get("query", "") + user_id = response_data.get("user_id", "unknown") + context = response_data.get("context", {}) + + if not query: + return HookResult(success=True, data="No query provided") + + # 检索相关记忆 + memories = await retrieve_memories_with_enhanced_system( + query, user_id, context, limit=5 + ) + + processing_time = time.time() - start_time + self._update_hook_stats(processing_time) + + # 将记忆添加到响应数据中 + response_data["enhanced_memories"] = memories + response_data["enhanced_memory_context"] = await get_memory_context_for_prompt( + query, user_id, context, max_memories=5 + ) + + logger.debug(f"回复前钩子执行成功,检索到 {len(memories)} 条记忆") + return HookResult(success=True, data=memories, processing_time=processing_time) + + except Exception as e: + processing_time = time.time() - start_time + logger.error(f"回复前钩子执行异常: {e}", exc_info=True) + return HookResult(success=False, error=str(e), processing_time=processing_time) + + async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult: + """知识库查询钩子""" + start_time = time.time() + + try: + self.hook_stats["memory_retrieval_hooks"] += 1 + + query = query_data.get("query", "") + user_id = query_data.get("user_id", "unknown") + context = query_data.get("context", {}) + + if not query: + return HookResult(success=True, data="No query provided") + + # 获取记忆上下文并增强查询 + memory_context = await get_memory_context_for_prompt( + query, user_id, context, max_memories=3 + ) + + processing_time = time.time() - start_time + self._update_hook_stats(processing_time) + + # 将记忆上下文添加到查询数据中 + query_data["enhanced_memory_context"] = memory_context + + logger.debug("知识库查询钩子执行成功") + return HookResult(success=True, data=memory_context, processing_time=processing_time) + + except Exception as e: + processing_time = time.time() - start_time + logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True) + return HookResult(success=False, error=str(e), processing_time=processing_time) + + async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult: + """提示词构建钩子""" + start_time = time.time() + + try: + self.hook_stats["prompt_enhancement_hooks"] += 1 + + query = prompt_data.get("query", "") + user_id = prompt_data.get("user_id", "unknown") + context = prompt_data.get("context", {}) + base_prompt = prompt_data.get("base_prompt", "") + + if not query: + return HookResult(success=True, data="No query provided") + + # 获取记忆上下文 + memory_context = await get_memory_context_for_prompt( + query, user_id, context, max_memories=5 + ) + + processing_time = time.time() - start_time + self._update_hook_stats(processing_time) + + # 构建增强的提示词 + enhanced_prompt = base_prompt + if memory_context: + enhanced_prompt += f"\n\n### 相关记忆上下文 ###\n{memory_context}\n" + + # 将增强的提示词添加到数据中 + prompt_data["enhanced_prompt"] = enhanced_prompt + prompt_data["memory_context"] = memory_context + + logger.debug("提示词构建钩子执行成功") + return HookResult(success=True, data=enhanced_prompt, processing_time=processing_time) + + except Exception as e: + processing_time = time.time() - start_time + logger.error(f"提示词构建钩子执行异常: {e}", exc_info=True) + return HookResult(success=False, error=str(e), processing_time=processing_time) + + def _update_hook_stats(self, processing_time: float): + """更新钩子统计""" + self.hook_stats["total_hook_executions"] += 1 + + total_executions = self.hook_stats["total_hook_executions"] + if total_executions > 0: + current_avg = self.hook_stats["average_hook_time"] + new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions + self.hook_stats["average_hook_time"] = new_avg + + def get_hook_stats(self) -> Dict[str, Any]: + """获取钩子统计信息""" + return self.hook_stats.copy() + + +class MemoryMaintenanceTask: + """记忆系统维护任务""" + + def __init__(self): + self.task_name = "enhanced_memory_maintenance" + self.interval = 3600 # 1小时执行一次 + + async def execute(self): + """执行维护任务""" + try: + logger.info("🔧 执行增强记忆系统维护任务...") + + # 获取适配器实例 + try: + from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter + if _enhanced_memory_adapter: + await _enhanced_memory_adapter.maintenance() + logger.info("✅ 增强记忆系统维护任务完成") + else: + logger.debug("增强记忆适配器未初始化,跳过维护") + except Exception as e: + logger.error(f"增强记忆系统维护失败: {e}") + + except Exception as e: + logger.error(f"执行维护任务时发生异常: {e}", exc_info=True) + + def get_interval(self) -> int: + """获取执行间隔""" + return self.interval + + def get_task_name(self) -> str: + """获取任务名称""" + return self.task_name + + +# 全局钩子实例 +_memory_hooks: Optional[MemoryIntegrationHooks] = None + + +async def get_memory_integration_hooks() -> MemoryIntegrationHooks: + """获取全局记忆集成钩子实例""" + global _memory_hooks + + if _memory_hooks is None: + _memory_hooks = MemoryIntegrationHooks() + await _memory_hooks.register_hooks() + + return _memory_hooks + + +async def initialize_memory_integration_hooks(): + """初始化记忆集成钩子""" + try: + logger.info("🚀 初始化记忆集成钩子...") + hooks = await get_memory_integration_hooks() + logger.info("✅ 记忆集成钩子初始化完成") + return hooks + except Exception as e: + logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/src/chat/memory_system/metadata_index.py b/src/chat/memory_system/metadata_index.py new file mode 100644 index 000000000..77aa8c995 --- /dev/null +++ b/src/chat/memory_system/metadata_index.py @@ -0,0 +1,832 @@ +# -*- coding: utf-8 -*- +""" +元数据索引系统 +为记忆系统提供多维度的精准过滤和查询能力 +""" + +import os +import time +import orjson +from typing import Dict, List, Optional, Tuple, Set, Any, Union +from datetime import datetime, timedelta +from dataclasses import dataclass, field +from enum import Enum +import threading +from collections import defaultdict +from pathlib import Path + +from src.common.logger import get_logger +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel + +logger = get_logger(__name__) + + +class IndexType(Enum): + """索引类型""" + MEMORY_TYPE = "memory_type" # 记忆类型索引 + USER_ID = "user_id" # 用户ID索引 + KEYWORD = "keyword" # 关键词索引 + TAG = "tag" # 标签索引 + CATEGORY = "category" # 分类索引 + TIMESTAMP = "timestamp" # 时间索引 + CONFIDENCE = "confidence" # 置信度索引 + IMPORTANCE = "importance" # 重要性索引 + RELATIONSHIP_SCORE = "relationship_score" # 关系分索引 + ACCESS_FREQUENCY = "access_frequency" # 访问频率索引 + SEMANTIC_HASH = "semantic_hash" # 语义哈希索引 + + +@dataclass +class IndexQuery: + """索引查询条件""" + user_ids: Optional[List[str]] = None + memory_types: Optional[List[MemoryType]] = None + keywords: Optional[List[str]] = None + tags: Optional[List[str]] = None + categories: Optional[List[str]] = None + time_range: Optional[Tuple[float, float]] = None + confidence_levels: Optional[List[ConfidenceLevel]] = None + importance_levels: Optional[List[ImportanceLevel]] = None + min_relationship_score: Optional[float] = None + max_relationship_score: Optional[float] = None + min_access_count: Optional[int] = None + semantic_hashes: Optional[List[str]] = None + limit: Optional[int] = None + sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score" + sort_order: str = "desc" # "asc", "desc" + + +@dataclass +class IndexResult: + """索引结果""" + memory_ids: List[str] + total_count: int + query_time: float + filtered_by: List[str] + + +class MetadataIndexManager: + """元数据索引管理器""" + + def __init__(self, index_path: str = "data/memory_metadata"): + self.index_path = Path(index_path) + self.index_path.mkdir(parents=True, exist_ok=True) + + # 各类索引 + self.indices = { + IndexType.MEMORY_TYPE: defaultdict(set), + IndexType.USER_ID: defaultdict(set), + IndexType.KEYWORD: defaultdict(set), + IndexType.TAG: defaultdict(set), + IndexType.CATEGORY: defaultdict(set), + IndexType.CONFIDENCE: defaultdict(set), + IndexType.IMPORTANCE: defaultdict(set), + IndexType.SEMANTIC_HASH: defaultdict(set), + } + + # 时间索引(特殊处理) + self.time_index = [] # [(timestamp, memory_id), ...] + self.relationship_index = [] # [(relationship_score, memory_id), ...] + self.access_frequency_index = [] # [(access_count, memory_id), ...] + + # 内存缓存 + self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {} + + # 统计信息 + self.index_stats = { + "total_memories": 0, + "index_build_time": 0.0, + "average_query_time": 0.0, + "total_queries": 0, + "cache_hit_rate": 0.0, + "cache_hits": 0 + } + + # 线程锁 + self._lock = threading.RLock() + self._dirty = False # 标记索引是否有未保存的更改 + + # 自动保存配置 + self.auto_save_interval = 500 # 每500次操作自动保存 + self._operation_count = 0 + + async def index_memories(self, memories: List[MemoryChunk]): + """为记忆建立索引""" + if not memories: + return + + start_time = time.time() + + try: + with self._lock: + for memory in memories: + self._index_single_memory(memory) + + # 标记为需要保存 + self._dirty = True + self._operation_count += len(memories) + + # 自动保存检查 + if self._operation_count >= self.auto_save_interval: + await self.save_index() + self._operation_count = 0 + + index_time = time.time() - start_time + self.index_stats["index_build_time"] = ( + (self.index_stats["index_build_time"] * (len(memories) - 1) + index_time) / + len(memories) + ) + + logger.debug(f"元数据索引完成,{len(memories)} 条记忆,耗时 {index_time:.3f}秒") + + except Exception as e: + logger.error(f"❌ 元数据索引失败: {e}", exc_info=True) + + def _index_single_memory(self, memory: MemoryChunk): + """为单个记忆建立索引""" + memory_id = memory.memory_id + + # 更新内存缓存 + self.memory_metadata_cache[memory_id] = { + "user_id": memory.user_id, + "memory_type": memory.memory_type, + "created_at": memory.metadata.created_at, + "last_accessed": memory.metadata.last_accessed, + "access_count": memory.metadata.access_count, + "confidence": memory.metadata.confidence, + "importance": memory.metadata.importance, + "relationship_score": memory.metadata.relationship_score, + "relevance_score": memory.metadata.relevance_score, + "semantic_hash": memory.semantic_hash + } + + # 记忆类型索引 + self.indices[IndexType.MEMORY_TYPE][memory.memory_type].add(memory_id) + + # 用户ID索引 + self.indices[IndexType.USER_ID][memory.user_id].add(memory_id) + + # 关键词索引 + for keyword in memory.keywords: + self.indices[IndexType.KEYWORD][keyword.lower()].add(memory_id) + + # 标签索引 + for tag in memory.tags: + self.indices[IndexType.TAG][tag.lower()].add(memory_id) + + # 分类索引 + for category in memory.categories: + self.indices[IndexType.CATEGORY][category.lower()].add(memory_id) + + # 置信度索引 + self.indices[IndexType.CONFIDENCE][memory.metadata.confidence].add(memory_id) + + # 重要性索引 + self.indices[IndexType.IMPORTANCE][memory.metadata.importance].add(memory_id) + + # 语义哈希索引 + if memory.semantic_hash: + self.indices[IndexType.SEMANTIC_HASH][memory.semantic_hash].add(memory_id) + + # 时间索引(插入排序保持有序) + self._insert_into_time_index(memory.metadata.created_at, memory_id) + + # 关系分索引(插入排序保持有序) + self._insert_into_relationship_index(memory.metadata.relationship_score, memory_id) + + # 访问频率索引(插入排序保持有序) + self._insert_into_access_frequency_index(memory.metadata.access_count, memory_id) + + # 更新统计 + self.index_stats["total_memories"] += 1 + + def _insert_into_time_index(self, timestamp: float, memory_id: str): + """插入时间索引(保持降序)""" + insert_pos = len(self.time_index) + for i, (ts, _) in enumerate(self.time_index): + if timestamp >= ts: + insert_pos = i + break + + self.time_index.insert(insert_pos, (timestamp, memory_id)) + + def _insert_into_relationship_index(self, relationship_score: float, memory_id: str): + """插入关系分索引(保持降序)""" + insert_pos = len(self.relationship_index) + for i, (score, _) in enumerate(self.relationship_index): + if relationship_score >= score: + insert_pos = i + break + + self.relationship_index.insert(insert_pos, (relationship_score, memory_id)) + + def _insert_into_access_frequency_index(self, access_count: int, memory_id: str): + """插入访问频率索引(保持降序)""" + insert_pos = len(self.access_frequency_index) + for i, (count, _) in enumerate(self.access_frequency_index): + if access_count >= count: + insert_pos = i + break + + self.access_frequency_index.insert(insert_pos, (access_count, memory_id)) + + async def query_memories(self, query: IndexQuery) -> IndexResult: + """查询记忆""" + start_time = time.time() + + try: + with self._lock: + # 获取候选记忆ID集合 + candidate_ids = self._get_candidate_memories(query) + + # 应用过滤条件 + filtered_ids = self._apply_filters(candidate_ids, query) + + # 排序 + if query.sort_by: + filtered_ids = self._sort_memories(filtered_ids, query.sort_by, query.sort_order) + + # 限制数量 + if query.limit and len(filtered_ids) > query.limit: + filtered_ids = filtered_ids[:query.limit] + + # 记录查询统计 + query_time = time.time() - start_time + self.index_stats["total_queries"] += 1 + self.index_stats["average_query_time"] = ( + (self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time) / + self.index_stats["total_queries"] + ) + + return IndexResult( + memory_ids=filtered_ids, + total_count=len(filtered_ids), + query_time=query_time, + filtered_by=self._get_applied_filters(query) + ) + + except Exception as e: + logger.error(f"❌ 元数据查询失败: {e}", exc_info=True) + return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[]) + + def _get_candidate_memories(self, query: IndexQuery) -> Set[str]: + """获取候选记忆ID集合""" + candidate_ids = set() + + # 获取所有记忆ID作为起点 + all_memory_ids = set(self.memory_metadata_cache.keys()) + + if not all_memory_ids: + return candidate_ids + + # 应用最严格的过滤条件 + applied_filters = [] + + if query.user_ids: + user_ids_set = set() + for user_id in query.user_ids: + user_ids_set.update(self.indices[IndexType.USER_ID].get(user_id, set())) + candidate_ids.update(user_ids_set) + applied_filters.append("user_ids") + + if query.memory_types: + memory_types_set = set() + for memory_type in query.memory_types: + memory_types_set.update(self.indices[IndexType.MEMORY_TYPE].get(memory_type, set())) + if applied_filters: + candidate_ids &= memory_types_set + else: + candidate_ids.update(memory_types_set) + applied_filters.append("memory_types") + + if query.keywords: + keywords_set = set() + for keyword in query.keywords: + keywords_set.update(self.indices[IndexType.KEYWORD].get(keyword.lower(), set())) + if applied_filters: + candidate_ids &= keywords_set + else: + candidate_ids.update(keywords_set) + applied_filters.append("keywords") + + if query.tags: + tags_set = set() + for tag in query.tags: + tags_set.update(self.indices[IndexType.TAG].get(tag.lower(), set())) + if applied_filters: + candidate_ids &= tags_set + else: + candidate_ids.update(tags_set) + applied_filters.append("tags") + + if query.categories: + categories_set = set() + for category in query.categories: + categories_set.update(self.indices[IndexType.CATEGORY].get(category.lower(), set())) + if applied_filters: + candidate_ids &= categories_set + else: + candidate_ids.update(categories_set) + applied_filters.append("categories") + + # 如果没有应用任何过滤条件,返回所有记忆 + if not applied_filters: + return all_memory_ids + + return candidate_ids + + def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]: + """应用过滤条件""" + filtered_ids = list(candidate_ids) + + # 时间范围过滤 + if query.time_range: + start_time, end_time = query.time_range + filtered_ids = [ + memory_id for memory_id in filtered_ids + if self._is_in_time_range(memory_id, start_time, end_time) + ] + + # 置信度过滤 + if query.confidence_levels: + confidence_set = set(query.confidence_levels) + filtered_ids = [ + memory_id for memory_id in filtered_ids + if self.memory_metadata_cache[memory_id]["confidence"] in confidence_set + ] + + # 重要性过滤 + if query.importance_levels: + importance_set = set(query.importance_levels) + filtered_ids = [ + memory_id for memory_id in filtered_ids + if self.memory_metadata_cache[memory_id]["importance"] in importance_set + ] + + # 关系分范围过滤 + if query.min_relationship_score is not None: + filtered_ids = [ + memory_id for memory_id in filtered_ids + if self.memory_metadata_cache[memory_id]["relationship_score"] >= query.min_relationship_score + ] + + if query.max_relationship_score is not None: + filtered_ids = [ + memory_id for memory_id in filtered_ids + if self.memory_metadata_cache[memory_id]["relationship_score"] <= query.max_relationship_score + ] + + # 最小访问次数过滤 + if query.min_access_count is not None: + filtered_ids = [ + memory_id for memory_id in filtered_ids + if self.memory_metadata_cache[memory_id]["access_count"] >= query.min_access_count + ] + + # 语义哈希过滤 + if query.semantic_hashes: + hash_set = set(query.semantic_hashes) + filtered_ids = [ + memory_id for memory_id in filtered_ids + if self.memory_metadata_cache[memory_id]["semantic_hash"] in hash_set + ] + + return filtered_ids + + def _is_in_time_range(self, memory_id: str, start_time: float, end_time: float) -> bool: + """检查记忆是否在时间范围内""" + created_at = self.memory_metadata_cache[memory_id]["created_at"] + return start_time <= created_at <= end_time + + def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]: + """对记忆进行排序""" + if sort_by == "created_at": + # 使用时间索引(已经有序) + if sort_order == "desc": + return memory_ids # 时间索引已经是降序 + else: + return memory_ids[::-1] # 反转为升序 + + elif sort_by == "access_count": + # 使用访问频率索引(已经有序) + if sort_order == "desc": + return memory_ids # 访问频率索引已经是降序 + else: + return memory_ids[::-1] # 反转为升序 + + elif sort_by == "relevance_score": + # 按相关度排序 + memory_ids.sort( + key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"], + reverse=(sort_order == "desc") + ) + + elif sort_by == "relationship_score": + # 使用关系分索引(已经有序) + if sort_order == "desc": + return memory_ids # 关系分索引已经是降序 + else: + return memory_ids[::-1] # 反转为升序 + + elif sort_by == "last_accessed": + # 按最后访问时间排序 + memory_ids.sort( + key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"], + reverse=(sort_order == "desc") + ) + + return memory_ids + + def _get_applied_filters(self, query: IndexQuery) -> List[str]: + """获取应用的过滤器列表""" + filters = [] + if query.user_ids: + filters.append("user_ids") + if query.memory_types: + filters.append("memory_types") + if query.keywords: + filters.append("keywords") + if query.tags: + filters.append("tags") + if query.categories: + filters.append("categories") + if query.time_range: + filters.append("time_range") + if query.confidence_levels: + filters.append("confidence_levels") + if query.importance_levels: + filters.append("importance_levels") + if query.min_relationship_score is not None or query.max_relationship_score is not None: + filters.append("relationship_score_range") + if query.min_access_count is not None: + filters.append("min_access_count") + if query.semantic_hashes: + filters.append("semantic_hashes") + return filters + + async def update_memory_index(self, memory: MemoryChunk): + """更新记忆索引""" + with self._lock: + try: + memory_id = memory.memory_id + + # 如果记忆已存在,先删除旧索引 + if memory_id in self.memory_metadata_cache: + await self.remove_memory_index(memory_id) + + # 重新建立索引 + self._index_single_memory(memory) + self._dirty = True + self._operation_count += 1 + + # 自动保存检查 + if self._operation_count >= self.auto_save_interval: + await self.save_index() + self._operation_count = 0 + + logger.debug(f"更新记忆索引完成: {memory_id}") + + except Exception as e: + logger.error(f"❌ 更新记忆索引失败: {e}") + + async def remove_memory_index(self, memory_id: str): + """移除记忆索引""" + with self._lock: + try: + if memory_id not in self.memory_metadata_cache: + return + + # 获取记忆元数据 + metadata = self.memory_metadata_cache[memory_id] + + # 从各类索引中移除 + self.indices[IndexType.MEMORY_TYPE][metadata["memory_type"]].discard(memory_id) + self.indices[IndexType.USER_ID][metadata["user_id"]].discard(memory_id) + + # 从时间索引中移除 + self.time_index = [(ts, mid) for ts, mid in self.time_index if mid != memory_id] + + # 从关系分索引中移除 + self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid != memory_id] + + # 从访问频率索引中移除 + self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid != memory_id] + + # 注意:关键词、标签、分类索引需要从原始记忆中获取,这里简化处理 + # 实际实现中可能需要重新加载记忆或维护反向索引 + + # 从缓存中移除 + del self.memory_metadata_cache[memory_id] + + # 更新统计 + self.index_stats["total_memories"] = max(0, self.index_stats["total_memories"] - 1) + self._dirty = True + + logger.debug(f"移除记忆索引完成: {memory_id}") + + except Exception as e: + logger.error(f"❌ 移除记忆索引失败: {e}") + + async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]: + """获取记忆元数据""" + return self.memory_metadata_cache.get(memory_id) + + async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]: + """获取用户的所有记忆ID""" + user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set())) + + if limit and len(user_memory_ids) > limit: + user_memory_ids = user_memory_ids[:limit] + + return user_memory_ids + + async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]: + """获取记忆统计信息""" + stats = { + "total_memories": self.index_stats["total_memories"], + "memory_types": {}, + "average_confidence": 0.0, + "average_importance": 0.0, + "average_relationship_score": 0.0, + "top_keywords": [], + "top_tags": [] + } + + if user_id: + # 限定用户统计 + user_memory_ids = self.indices[IndexType.USER_ID].get(user_id, set()) + stats["user_total_memories"] = len(user_memory_ids) + + if not user_memory_ids: + return stats + + # 用户记忆类型分布 + user_types = {} + for memory_type, memory_ids in self.indices[IndexType.MEMORY_TYPE].items(): + user_count = len(user_memory_ids & memory_ids) + if user_count > 0: + user_types[memory_type.value] = user_count + stats["memory_types"] = user_types + + # 计算用户平均值 + user_confidences = [] + user_importances = [] + user_relationship_scores = [] + + for memory_id in user_memory_ids: + metadata = self.memory_metadata_cache.get(memory_id, {}) + if metadata: + user_confidences.append(metadata["confidence"].value) + user_importances.append(metadata["importance"].value) + user_relationship_scores.append(metadata["relationship_score"]) + + if user_confidences: + stats["average_confidence"] = sum(user_confidences) / len(user_confidences) + if user_importances: + stats["average_importance"] = sum(user_importances) / len(user_importances) + if user_relationship_scores: + stats["average_relationship_score"] = sum(user_relationship_scores) / len(user_relationship_scores) + + else: + # 全局统计 + for memory_type, memory_ids in self.indices[IndexType.MEMORY_TYPE].items(): + stats["memory_types"][memory_type.value] = len(memory_ids) + + # 计算全局平均值 + if self.memory_metadata_cache: + all_confidences = [m["confidence"].value for m in self.memory_metadata_cache.values()] + all_importances = [m["importance"].value for m in self.memory_metadata_cache.values()] + all_relationship_scores = [m["relationship_score"] for m in self.memory_metadata_cache.values()] + + if all_confidences: + stats["average_confidence"] = sum(all_confidences) / len(all_confidences) + if all_importances: + stats["average_importance"] = sum(all_importances) / len(all_importances) + if all_relationship_scores: + stats["average_relationship_score"] = sum(all_relationship_scores) / len(all_relationship_scores) + + # 统计热门关键词和标签 + keyword_counts = [(keyword, len(memory_ids)) for keyword, memory_ids in self.indices[IndexType.KEYWORD].items()] + keyword_counts.sort(key=lambda x: x[1], reverse=True) + stats["top_keywords"] = keyword_counts[:10] + + tag_counts = [(tag, len(memory_ids)) for tag, memory_ids in self.indices[IndexType.TAG].items()] + tag_counts.sort(key=lambda x: x[1], reverse=True) + stats["top_tags"] = tag_counts[:10] + + return stats + + async def save_index(self): + """保存索引到文件""" + if not self._dirty: + return + + try: + logger.info("正在保存元数据索引...") + + # 保存各类索引 + indices_data = {} + for index_type, index_data in self.indices.items(): + indices_data[index_type.value] = { + key: list(values) for key, values in index_data.items() + } + + indices_file = self.index_path / "indices.json" + with open(indices_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + + # 保存时间索引 + time_index_file = self.index_path / "time_index.json" + with open(time_index_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + + # 保存关系分索引 + relationship_index_file = self.index_path / "relationship_index.json" + with open(relationship_index_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + + # 保存访问频率索引 + access_frequency_index_file = self.index_path / "access_frequency_index.json" + with open(access_frequency_index_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode('utf-8')) + + # 保存元数据缓存 + metadata_cache_file = self.index_path / "metadata_cache.json" + with open(metadata_cache_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(self.memory_metadata_cache, option=orjson.OPT_INDENT_2).decode('utf-8')) + + # 保存统计信息 + stats_file = self.index_path / "index_stats.json" + with open(stats_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode('utf-8')) + + self._dirty = False + logger.info("✅ 元数据索引保存完成") + + except Exception as e: + logger.error(f"❌ 保存元数据索引失败: {e}") + + async def load_index(self): + """从文件加载索引""" + try: + logger.info("正在加载元数据索引...") + + # 加载各类索引 + indices_file = self.index_path / "indices.json" + if indices_file.exists(): + with open(indices_file, 'r', encoding='utf-8') as f: + indices_data = orjson.loads(f.read()) + + for index_type_value, index_data in indices_data.items(): + index_type = IndexType(index_type_value) + self.indices[index_type] = { + key: set(values) for key, values in index_data.items() + } + + # 加载时间索引 + time_index_file = self.index_path / "time_index.json" + if time_index_file.exists(): + with open(time_index_file, 'r', encoding='utf-8') as f: + self.time_index = orjson.loads(f.read()) + + # 加载关系分索引 + relationship_index_file = self.index_path / "relationship_index.json" + if relationship_index_file.exists(): + with open(relationship_index_file, 'r', encoding='utf-8') as f: + self.relationship_index = orjson.loads(f.read()) + + # 加载访问频率索引 + access_frequency_index_file = self.index_path / "access_frequency_index.json" + if access_frequency_index_file.exists(): + with open(access_frequency_index_file, 'r', encoding='utf-8') as f: + self.access_frequency_index = orjson.loads(f.read()) + + # 加载元数据缓存 + metadata_cache_file = self.index_path / "metadata_cache.json" + if metadata_cache_file.exists(): + with open(metadata_cache_file, 'r', encoding='utf-8') as f: + cache_data = orjson.loads(f.read()) + + # 转换置信度和重要性为枚举类型 + for memory_id, metadata in cache_data.items(): + if isinstance(metadata["confidence"], str): + metadata["confidence"] = ConfidenceLevel(metadata["confidence"]) + if isinstance(metadata["importance"], str): + metadata["importance"] = ImportanceLevel(metadata["importance"]) + + self.memory_metadata_cache = cache_data + + # 加载统计信息 + stats_file = self.index_path / "index_stats.json" + if stats_file.exists(): + with open(stats_file, 'r', encoding='utf-8') as f: + self.index_stats = orjson.loads(f.read()) + + # 更新记忆计数 + self.index_stats["total_memories"] = len(self.memory_metadata_cache) + + logger.info(f"✅ 元数据索引加载完成,{self.index_stats['total_memories']} 个记忆") + + except Exception as e: + logger.error(f"❌ 加载元数据索引失败: {e}") + + async def optimize_index(self): + """优化索引""" + try: + logger.info("开始元数据索引优化...") + + # 清理无效引用 + self._cleanup_invalid_references() + + # 重建有序索引 + self._rebuild_ordered_indices() + + # 清理低频关键词和标签 + self._cleanup_low_frequency_terms() + + # 更新统计信息 + if self.index_stats["total_queries"] > 0: + self.index_stats["cache_hit_rate"] = ( + self.index_stats["cache_hits"] / self.index_stats["total_queries"] + ) + + logger.info("✅ 元数据索引优化完成") + + except Exception as e: + logger.error(f"❌ 元数据索引优化失败: {e}") + + def _cleanup_invalid_references(self): + """清理无效引用""" + valid_memory_ids = set(self.memory_metadata_cache.keys()) + + # 清理各类索引中的无效引用 + for index_type in self.indices: + for key in list(self.indices[index_type].keys()): + valid_ids = self.indices[index_type][key] & valid_memory_ids + self.indices[index_type][key] = valid_ids + + # 如果某类别下没有记忆了,删除该类别 + if not valid_ids: + del self.indices[index_type][key] + + # 清理时间索引中的无效引用 + self.time_index = [(ts, mid) for ts, mid in self.time_index if mid in valid_memory_ids] + + # 清理关系分索引中的无效引用 + self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid in valid_memory_ids] + + # 清理访问频率索引中的无效引用 + self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids] + + # 更新总记忆数 + self.index_stats["total_memories"] = len(valid_memory_ids) + + def _rebuild_ordered_indices(self): + """重建有序索引""" + # 重建时间索引 + self.time_index.sort(key=lambda x: x[0], reverse=True) + + # 重建关系分索引 + self.relationship_index.sort(key=lambda x: x[0], reverse=True) + + # 重建访问频率索引 + self.access_frequency_index.sort(key=lambda x: x[0], reverse=True) + + def _cleanup_low_frequency_terms(self, min_frequency: int = 2): + """清理低频术语""" + # 清理低频关键词 + for keyword in list(self.indices[IndexType.KEYWORD].keys()): + if len(self.indices[IndexType.KEYWORD][keyword]) < min_frequency: + del self.indices[IndexType.KEYWORD][keyword] + + # 清理低频标签 + for tag in list(self.indices[IndexType.TAG].keys()): + if len(self.indices[IndexType.TAG][tag]) < min_frequency: + del self.indices[IndexType.TAG][tag] + + # 清理低频分类 + for category in list(self.indices[IndexType.CATEGORY].keys()): + if len(self.indices[IndexType.CATEGORY][category]) < min_frequency: + del self.indices[IndexType.CATEGORY][category] + + def get_index_stats(self) -> Dict[str, Any]: + """获取索引统计信息""" + stats = self.index_stats.copy() + if stats["total_queries"] > 0: + stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_queries"] + else: + stats["cache_hit_rate"] = 0.0 + + # 添加索引详细信息 + stats["index_details"] = { + "memory_types": len(self.indices[IndexType.MEMORY_TYPE]), + "user_ids": len(self.indices[IndexType.USER_ID]), + "keywords": len(self.indices[IndexType.KEYWORD]), + "tags": len(self.indices[IndexType.TAG]), + "categories": len(self.indices[IndexType.CATEGORY]), + "confidence_levels": len(self.indices[IndexType.CONFIDENCE]), + "importance_levels": len(self.indices[IndexType.IMPORTANCE]), + "semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH]) + } + + return stats \ No newline at end of file diff --git a/src/chat/memory_system/multi_stage_retrieval.py b/src/chat/memory_system/multi_stage_retrieval.py new file mode 100644 index 000000000..d8e7afe7b --- /dev/null +++ b/src/chat/memory_system/multi_stage_retrieval.py @@ -0,0 +1,595 @@ +# -*- coding: utf-8 -*- +""" +多阶段召回机制 +实现粗粒度到细粒度的记忆检索优化 +""" + +import time +import asyncio +from typing import Dict, List, Optional, Tuple, Set, Any +from dataclasses import dataclass +from enum import Enum +import numpy as np + +from src.common.logger import get_logger +from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel + +logger = get_logger(__name__) + + +class RetrievalStage(Enum): + """检索阶段""" + METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段 + VECTOR_SEARCH = "vector_search" # 向量搜索阶段 + SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段 + CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段 + + +@dataclass +class RetrievalConfig: + """检索配置""" + # 各阶段配置 + metadata_filter_limit: int = 100 # 元数据过滤阶段返回数量 + vector_search_limit: int = 50 # 向量搜索阶段返回数量 + semantic_rerank_limit: int = 20 # 语义重排序阶段返回数量 + final_result_limit: int = 10 # 最终结果数量 + + # 相似度阈值 + vector_similarity_threshold: float = 0.7 # 向量相似度阈值 + semantic_similarity_threshold: float = 0.6 # 语义相似度阈值 + + # 权重配置 + vector_weight: float = 0.4 # 向量相似度权重 + semantic_weight: float = 0.3 # 语义相似度权重 + context_weight: float = 0.2 # 上下文权重 + recency_weight: float = 0.1 # 时效性权重 + + @classmethod + def from_global_config(cls): + """从全局配置创建配置实例""" + from src.config.config import global_config + + return cls( + # 各阶段配置 + metadata_filter_limit=global_config.memory.metadata_filter_limit, + vector_search_limit=global_config.memory.vector_search_limit, + semantic_rerank_limit=global_config.memory.semantic_rerank_limit, + final_result_limit=global_config.memory.final_result_limit, + + # 相似度阈值 + vector_similarity_threshold=global_config.memory.vector_similarity_threshold, + semantic_similarity_threshold=0.6, # 保持默认值 + + # 权重配置 + vector_weight=global_config.memory.vector_weight, + semantic_weight=global_config.memory.semantic_weight, + context_weight=global_config.memory.context_weight, + recency_weight=global_config.memory.recency_weight + ) + + +@dataclass +class StageResult: + """阶段结果""" + stage: RetrievalStage + memory_ids: List[str] + processing_time: float + filtered_count: int + score_threshold: float + + +@dataclass +class RetrievalResult: + """检索结果""" + query: str + user_id: str + final_memories: List[MemoryChunk] + stage_results: List[StageResult] + total_processing_time: float + total_filtered: int + retrieval_stats: Dict[str, Any] + + +class MultiStageRetrieval: + """多阶段召回系统""" + + def __init__(self, config: Optional[RetrievalConfig] = None): + self.config = config or RetrievalConfig.from_global_config() + self.retrieval_stats = { + "total_queries": 0, + "average_retrieval_time": 0.0, + "stage_stats": { + "metadata_filtering": {"calls": 0, "avg_time": 0.0}, + "vector_search": {"calls": 0, "avg_time": 0.0}, + "semantic_reranking": {"calls": 0, "avg_time": 0.0}, + "contextual_filtering": {"calls": 0, "avg_time": 0.0} + } + } + + async def retrieve_memories( + self, + query: str, + user_id: str, + context: Dict[str, Any], + metadata_index, + vector_storage, + all_memories_cache: Dict[str, MemoryChunk], + limit: Optional[int] = None + ) -> RetrievalResult: + """多阶段记忆检索""" + start_time = time.time() + limit = limit or self.config.final_result_limit + + stage_results = [] + current_memory_ids = set() + + try: + logger.debug(f"开始多阶段检索:query='{query}', user_id='{user_id}'") + + # 阶段1:元数据过滤 + stage1_result = await self._metadata_filtering_stage( + query, user_id, context, metadata_index, all_memories_cache + ) + stage_results.append(stage1_result) + current_memory_ids.update(stage1_result.memory_ids) + + # 阶段2:向量搜索 + stage2_result = await self._vector_search_stage( + query, user_id, context, vector_storage, current_memory_ids, all_memories_cache + ) + stage_results.append(stage2_result) + current_memory_ids.update(stage2_result.memory_ids) + + # 阶段3:语义重排序 + stage3_result = await self._semantic_reranking_stage( + query, user_id, context, current_memory_ids, all_memories_cache + ) + stage_results.append(stage3_result) + + # 阶段4:上下文过滤 + stage4_result = await self._contextual_filtering_stage( + query, user_id, context, stage3_result.memory_ids, all_memories_cache, limit + ) + stage_results.append(stage4_result) + + # 获取最终记忆对象 + final_memories = [] + for memory_id in stage4_result.memory_ids: + if memory_id in all_memories_cache: + final_memories.append(all_memories_cache[memory_id]) + + # 更新统计 + total_time = time.time() - start_time + self._update_retrieval_stats(total_time, stage_results) + + total_filtered = sum(result.filtered_count for result in stage_results) + + logger.debug(f"多阶段检索完成:返回 {len(final_memories)} 条记忆,耗时 {total_time:.3f}s") + + return RetrievalResult( + query=query, + user_id=user_id, + final_memories=final_memories, + stage_results=stage_results, + total_processing_time=total_time, + total_filtered=total_filtered, + retrieval_stats=self.retrieval_stats.copy() + ) + + except Exception as e: + logger.error(f"多阶段检索失败: {e}", exc_info=True) + # 返回空结果 + return RetrievalResult( + query=query, + user_id=user_id, + final_memories=[], + stage_results=stage_results, + total_processing_time=time.time() - start_time, + total_filtered=0, + retrieval_stats=self.retrieval_stats.copy() + ) + + async def _metadata_filtering_stage( + self, + query: str, + user_id: str, + context: Dict[str, Any], + metadata_index, + all_memories_cache: Dict[str, MemoryChunk] + ) -> StageResult: + """阶段1:元数据过滤""" + start_time = time.time() + + try: + from .metadata_index import IndexQuery + + # 构建索引查询 + index_query = IndexQuery( + user_ids=[user_id], + memory_types=self._extract_memory_types_from_context(context), + keywords=self._extract_keywords_from_query(query), + limit=self.config.metadata_filter_limit, + sort_by="last_accessed", + sort_order="desc" + ) + + # 执行查询 + result = await metadata_index.query_memories(index_query) + filtered_count = result.total_count - len(result.memory_ids) + + logger.debug(f"元数据过滤:{result.total_count} -> {len(result.memory_ids)} 条记忆") + + return StageResult( + stage=RetrievalStage.METADATA_FILTERING, + memory_ids=result.memory_ids, + processing_time=time.time() - start_time, + filtered_count=filtered_count, + score_threshold=0.0 + ) + + except Exception as e: + logger.error(f"元数据过滤阶段失败: {e}") + return StageResult( + stage=RetrievalStage.METADATA_FILTERING, + memory_ids=[], + processing_time=time.time() - start_time, + filtered_count=0, + score_threshold=0.0 + ) + + async def _vector_search_stage( + self, + query: str, + user_id: str, + context: Dict[str, Any], + vector_storage, + candidate_ids: Set[str], + all_memories_cache: Dict[str, MemoryChunk] + ) -> StageResult: + """阶段2:向量搜索""" + start_time = time.time() + + try: + # 生成查询向量 + query_embedding = await self._generate_query_embedding(query, context) + + if not query_embedding: + return StageResult( + stage=RetrievalStage.VECTOR_SEARCH, + memory_ids=[], + processing_time=time.time() - start_time, + filtered_count=0, + score_threshold=self.config.vector_similarity_threshold + ) + + # 执行向量搜索 + search_result = await vector_storage.search_similar( + query_embedding, + limit=self.config.vector_search_limit + ) + + # 过滤候选记忆 + filtered_memories = [] + for memory_id, similarity in search_result: + if memory_id in candidate_ids and similarity >= self.config.vector_similarity_threshold: + filtered_memories.append((memory_id, similarity)) + + # 按相似度排序 + filtered_memories.sort(key=lambda x: x[1], reverse=True) + result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]] + + filtered_count = len(candidate_ids) - len(result_ids) + + logger.debug(f"向量搜索:{len(candidate_ids)} -> {len(result_ids)} 条记忆") + + return StageResult( + stage=RetrievalStage.VECTOR_SEARCH, + memory_ids=result_ids, + processing_time=time.time() - start_time, + filtered_count=filtered_count, + score_threshold=self.config.vector_similarity_threshold + ) + + except Exception as e: + logger.error(f"向量搜索阶段失败: {e}") + return StageResult( + stage=RetrievalStage.VECTOR_SEARCH, + memory_ids=[], + processing_time=time.time() - start_time, + filtered_count=0, + score_threshold=self.config.vector_similarity_threshold + ) + + async def _semantic_reranking_stage( + self, + query: str, + user_id: str, + context: Dict[str, Any], + candidate_ids: Set[str], + all_memories_cache: Dict[str, MemoryChunk] + ) -> StageResult: + """阶段3:语义重排序""" + start_time = time.time() + + try: + reranked_memories = [] + + for memory_id in candidate_ids: + if memory_id not in all_memories_cache: + continue + + memory = all_memories_cache[memory_id] + + # 计算综合语义相似度 + semantic_score = await self._calculate_semantic_similarity(query, memory, context) + + if semantic_score >= self.config.semantic_similarity_threshold: + reranked_memories.append((memory_id, semantic_score)) + + # 按语义相似度排序 + reranked_memories.sort(key=lambda x: x[1], reverse=True) + result_ids = [memory_id for memory_id, _ in reranked_memories[:self.config.semantic_rerank_limit]] + + filtered_count = len(candidate_ids) - len(result_ids) + + logger.debug(f"语义重排序:{len(candidate_ids)} -> {len(result_ids)} 条记忆") + + return StageResult( + stage=RetrievalStage.SEMANTIC_RERANKING, + memory_ids=result_ids, + processing_time=time.time() - start_time, + filtered_count=filtered_count, + score_threshold=self.config.semantic_similarity_threshold + ) + + except Exception as e: + logger.error(f"语义重排序阶段失败: {e}") + return StageResult( + stage=RetrievalStage.SEMANTIC_RERANKING, + memory_ids=list(candidate_ids), # 失败时返回原候选集 + processing_time=time.time() - start_time, + filtered_count=0, + score_threshold=self.config.semantic_similarity_threshold + ) + + async def _contextual_filtering_stage( + self, + query: str, + user_id: str, + context: Dict[str, Any], + candidate_ids: List[str], + all_memories_cache: Dict[str, MemoryChunk], + limit: int + ) -> StageResult: + """阶段4:上下文过滤""" + start_time = time.time() + + try: + final_memories = [] + + for memory_id in candidate_ids: + if memory_id not in all_memories_cache: + continue + + memory = all_memories_cache[memory_id] + + # 计算上下文相关度评分 + context_score = await self._calculate_context_relevance(query, memory, context) + + # 结合多因子评分 + final_score = await self._calculate_final_score(query, memory, context, context_score) + + final_memories.append((memory_id, final_score)) + + # 按最终评分排序 + final_memories.sort(key=lambda x: x[1], reverse=True) + result_ids = [memory_id for memory_id, _ in final_memories[:limit]] + + filtered_count = len(candidate_ids) - len(result_ids) + + logger.debug(f"上下文过滤:{len(candidate_ids)} -> {len(result_ids)} 条记忆") + + return StageResult( + stage=RetrievalStage.CONTEXTUAL_FILTERING, + memory_ids=result_ids, + processing_time=time.time() - start_time, + filtered_count=filtered_count, + score_threshold=0.0 # 动态阈值 + ) + + except Exception as e: + logger.error(f"上下文过滤阶段失败: {e}") + return StageResult( + stage=RetrievalStage.CONTEXTUAL_FILTERING, + memory_ids=candidate_ids[:limit], # 失败时返回前limit个 + processing_time=time.time() - start_time, + filtered_count=0, + score_threshold=0.0 + ) + + async def _generate_query_embedding(self, query: str, context: Dict[str, Any]) -> Optional[List[float]]: + """生成查询向量""" + try: + # 这里应该调用embedding模型 + # 由于我们可能没有直接的embedding模型,返回None或使用简单的方法 + # 在实际实现中,这里应该调用与记忆存储相同的embedding模型 + return None + except Exception as e: + logger.warning(f"生成查询向量失败: {e}") + return None + + async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + """计算语义相似度""" + try: + # 简单的文本相似度计算 + query_words = set(query.lower().split()) + memory_words = set(memory.text_content.lower().split()) + + if not query_words or not memory_words: + return 0.0 + + intersection = query_words & memory_words + union = query_words | memory_words + + jaccard_similarity = len(intersection) / len(union) + return jaccard_similarity + + except Exception as e: + logger.warning(f"计算语义相似度失败: {e}") + return 0.0 + + async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float: + """计算上下文相关度""" + try: + score = 0.0 + + # 检查记忆类型是否匹配上下文 + if context.get("expected_memory_types"): + if memory.memory_type in context["expected_memory_types"]: + score += 0.3 + + # 检查关键词匹配 + if context.get("keywords"): + memory_keywords = set(memory.keywords) + context_keywords = set(context["keywords"]) + overlap = memory_keywords & context_keywords + if overlap: + score += len(overlap) / max(len(context_keywords), 1) * 0.4 + + # 检查时效性 + if context.get("recent_only", False): + memory_age = time.time() - memory.metadata.created_at + if memory_age < 7 * 24 * 3600: # 7天内 + score += 0.3 + + return min(score, 1.0) + + except Exception as e: + logger.warning(f"计算上下文相关度失败: {e}") + return 0.0 + + async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float: + """计算最终评分""" + try: + # 语义相似度 + semantic_score = await self._calculate_semantic_similarity(query, memory, context) + + # 向量相似度(如果有) + vector_score = 0.0 + if memory.embedding: + # 这里应该有向量相似度计算,简化处理 + vector_score = 0.5 + + # 时效性评分 + recency_score = self._calculate_recency_score(memory.metadata.created_at) + + # 权重组合 + final_score = ( + semantic_score * self.config.semantic_weight + + vector_score * self.config.vector_weight + + context_score * self.config.context_weight + + recency_score * self.config.recency_weight + ) + + # 加入记忆重要性权重 + importance_weight = memory.metadata.importance.value / 4.0 # 标准化到0-1 + final_score = final_score * (0.7 + importance_weight * 0.3) # 重要性影响30% + + return final_score + + except Exception as e: + logger.warning(f"计算最终评分失败: {e}") + return 0.0 + + def _calculate_recency_score(self, timestamp: float) -> float: + """计算时效性评分""" + try: + age = time.time() - timestamp + age_days = age / (24 * 3600) + + if age_days < 1: + return 1.0 + elif age_days < 7: + return 0.8 + elif age_days < 30: + return 0.6 + elif age_days < 90: + return 0.4 + else: + return 0.2 + + except Exception: + return 0.5 + + def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]: + """从上下文中提取记忆类型""" + try: + if "expected_memory_types" in context: + return context["expected_memory_types"] + + # 根据上下文推断记忆类型 + if "message_type" in context: + message_type = context["message_type"] + if message_type in ["personal_info", "fact"]: + return [MemoryType.PERSONAL_FACT] + elif message_type in ["event", "activity"]: + return [MemoryType.EVENT] + elif message_type in ["preference", "like"]: + return [MemoryType.PREFERENCE] + elif message_type in ["opinion", "view"]: + return [MemoryType.OPINION] + + return [] + + except Exception: + return [] + + def _extract_keywords_from_query(self, query: str) -> List[str]: + """从查询中提取关键词""" + try: + # 简单的关键词提取 + words = query.lower().split() + # 过滤停用词 + stopwords = {"的", "是", "在", "有", "我", "你", "他", "她", "它", "这", "那", "了", "吗", "呢"} + keywords = [word for word in words if len(word) > 1 and word not in stopwords] + return keywords[:10] # 最多返回10个关键词 + except Exception: + return [] + + def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]): + """更新检索统计""" + self.retrieval_stats["total_queries"] += 1 + + # 更新平均检索时间 + current_avg = self.retrieval_stats["average_retrieval_time"] + total_queries = self.retrieval_stats["total_queries"] + new_avg = (current_avg * (total_queries - 1) + total_time) / total_queries + self.retrieval_stats["average_retrieval_time"] = new_avg + + # 更新各阶段统计 + for result in stage_results: + stage_name = result.stage.value + if stage_name in self.retrieval_stats["stage_stats"]: + stage_stat = self.retrieval_stats["stage_stats"][stage_name] + stage_stat["calls"] += 1 + + current_stage_avg = stage_stat["avg_time"] + new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat["calls"] + stage_stat["avg_time"] = new_stage_avg + + def get_retrieval_stats(self) -> Dict[str, Any]: + """获取检索统计信息""" + return self.retrieval_stats.copy() + + def reset_stats(self): + """重置统计信息""" + self.retrieval_stats = { + "total_queries": 0, + "average_retrieval_time": 0.0, + "stage_stats": { + "metadata_filtering": {"calls": 0, "avg_time": 0.0}, + "vector_search": {"calls": 0, "avg_time": 0.0}, + "semantic_reranking": {"calls": 0, "avg_time": 0.0}, + "contextual_filtering": {"calls": 0, "avg_time": 0.0} + } + } \ No newline at end of file diff --git a/src/chat/memory_system/sample_distribution.py b/src/chat/memory_system/sample_distribution.py deleted file mode 100644 index d1dc3a22d..000000000 --- a/src/chat/memory_system/sample_distribution.py +++ /dev/null @@ -1,126 +0,0 @@ -import numpy as np -from datetime import datetime, timedelta -from rich.traceback import install - -install(extra_lines=3) - - -class MemoryBuildScheduler: - def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50): - """ - 初始化记忆构建调度器 - - 参数: - n_hours1 (float): 第一个分布的均值(距离现在的小时数) - std_hours1 (float): 第一个分布的标准差(小时) - weight1 (float): 第一个分布的权重 - n_hours2 (float): 第二个分布的均值(距离现在的小时数) - std_hours2 (float): 第二个分布的标准差(小时) - weight2 (float): 第二个分布的权重 - total_samples (int): 要生成的总时间点数量 - """ - # 验证参数 - if total_samples <= 0: - raise ValueError("total_samples 必须大于0") - if weight1 < 0 or weight2 < 0: - raise ValueError("权重必须为非负数") - if std_hours1 < 0 or std_hours2 < 0: - raise ValueError("标准差必须为非负数") - - # 归一化权重 - total_weight = weight1 + weight2 - if total_weight == 0: - raise ValueError("权重总和不能为0") - self.weight1 = weight1 / total_weight - self.weight2 = weight2 / total_weight - - self.n_hours1 = n_hours1 - self.std_hours1 = std_hours1 - self.n_hours2 = n_hours2 - self.std_hours2 = std_hours2 - self.total_samples = total_samples - self.base_time = datetime.now() - - def generate_time_samples(self): - """生成混合分布的时间采样点""" - # 根据权重计算每个分布的样本数 - samples1 = max(1, int(self.total_samples * self.weight1)) - samples2 = max(1, self.total_samples - samples1) # 确保 samples2 至少为1 - - # 生成两个正态分布的小时偏移 - hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1) - hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2) - - # 合并两个分布的偏移 - hours_offset = np.concatenate([hours_offset1, hours_offset2]) - - # 将偏移转换为实际时间戳(使用绝对值确保时间点在过去) - timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset] - - # 按时间排序(从最早到最近) - return sorted(timestamps) - - def get_timestamp_array(self): - """返回时间戳数组""" - timestamps = self.generate_time_samples() - return [int(t.timestamp()) for t in timestamps] - - -# def print_time_samples(timestamps, show_distribution=True): -# """打印时间样本和分布信息""" -# print(f"\n生成的{len(timestamps)}个时间点分布:") -# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)") -# print("-" * 50) - -# now = datetime.now() -# time_diffs = [] - -# for i, timestamp in enumerate(timestamps, 1): -# hours_diff = (now - timestamp).total_seconds() / 3600 -# time_diffs.append(hours_diff) -# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}") - -# # 打印统计信息 -# print("\n统计信息:") -# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时") -# print(f"标准差:{np.std(time_diffs):.2f}小时") -# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)") -# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)") - -# if show_distribution: -# # 计算时间分布的直方图 -# hist, bins = np.histogram(time_diffs, bins=40) -# print("\n时间分布(每个*代表一个时间点):") -# for i in range(len(hist)): -# if hist[i] > 0: -# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}") - - -# # 使用示例 -# if __name__ == "__main__": -# # 创建一个双峰分布的记忆调度器 -# scheduler = MemoryBuildScheduler( -# n_hours1=12, # 第一个分布均值(12小时前) -# std_hours1=8, # 第一个分布标准差 -# weight1=0.7, # 第一个分布权重 70% -# n_hours2=36, # 第二个分布均值(36小时前) -# std_hours2=24, # 第二个分布标准差 -# weight2=0.3, # 第二个分布权重 30% -# total_samples=50, # 总共生成50个时间点 -# ) - -# # 生成时间分布 -# timestamps = scheduler.generate_time_samples() - -# # 打印结果,包含分布可视化 -# print_time_samples(timestamps, show_distribution=True) - -# # 打印时间戳数组 -# timestamp_array = scheduler.get_timestamp_array() -# print("\n时间戳数组(Unix时间戳):") -# print("[", end="") -# for i, ts in enumerate(timestamp_array): -# if i > 0: -# print(", ", end="") -# print(ts, end="") -# print("]") diff --git a/src/chat/memory_system/vector_instant_memory.py b/src/chat/memory_system/vector_instant_memory.py deleted file mode 100644 index 12d9622e0..000000000 --- a/src/chat/memory_system/vector_instant_memory.py +++ /dev/null @@ -1,359 +0,0 @@ -import asyncio -import time -from typing import List, Dict, Any -from dataclasses import dataclass -import threading - -from src.common.logger import get_logger -from src.chat.utils.utils import get_embedding -from src.common.vector_db import vector_db_service - - -logger = get_logger("vector_instant_memory_v2") - - -@dataclass -class ChatMessage: - """聊天消息数据结构""" - - message_id: str - chat_id: str - content: str - timestamp: float - sender: str = "unknown" - message_type: str = "text" - - -class VectorInstantMemoryV2: - """重构的向量瞬时记忆系统 V2 - - 新设计理念: - 1. 全量存储 - 所有聊天记录都存储为向量 - 2. 定时清理 - 定期清理过期记录 - 3. 实时匹配 - 新消息与历史记录做向量相似度匹配 - """ - - def __init__(self, chat_id: str, retention_hours: int = 24, cleanup_interval: int = 3600): - """ - 初始化向量瞬时记忆系统 - - Args: - chat_id: 聊天ID - retention_hours: 记忆保留时长(小时) - cleanup_interval: 清理间隔(秒) - """ - self.chat_id = chat_id - self.retention_hours = retention_hours - self.cleanup_interval = cleanup_interval - self.collection_name = "instant_memory" - - # 清理任务相关 - self.cleanup_task = None - self.is_running = True - - # 初始化系统 - self._init_chroma() - self._start_cleanup_task() - - logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)") - - def _init_chroma(self): - """使用全局服务初始化向量数据库集合""" - try: - # 现在我们只获取集合,而不是创建新的客户端 - vector_db_service.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"}) - logger.info(f"向量记忆集合 '{self.collection_name}' 已准备就绪") - except Exception as e: - logger.error(f"获取向量记忆集合失败: {e}") - - def _start_cleanup_task(self): - """启动定时清理任务""" - - def cleanup_worker(): - while self.is_running: - try: - self._cleanup_expired_messages() - time.sleep(self.cleanup_interval) - except Exception as e: - logger.error(f"清理任务异常: {e}") - time.sleep(60) # 异常时等待1分钟再继续 - - self.cleanup_task = threading.Thread(target=cleanup_worker, daemon=True) - self.cleanup_task.start() - logger.info(f"定时清理任务已启动,间隔{self.cleanup_interval}秒") - - def _cleanup_expired_messages(self): - """清理过期的聊天记录""" - try: - expire_time = time.time() - (self.retention_hours * 3600) - - # 采用 get -> filter -> delete 模式,避免复杂的 where 查询 - # 1. 获取当前 chat_id 的所有文档 - results = vector_db_service.get( - collection_name=self.collection_name, where={"chat_id": self.chat_id}, include=["metadatas"] - ) - - if not results or not results.get("ids"): - logger.info(f"chat_id '{self.chat_id}' 没有找到任何记录,无需清理") - return - - # 2. 在内存中过滤出过期的文档 - expired_ids = [] - metadatas = results.get("metadatas", []) - ids = results.get("ids", []) - - for i, metadata in enumerate(metadatas): - if metadata and metadata.get("timestamp", float("inf")) < expire_time: - expired_ids.append(ids[i]) - - # 3. 如果有过期文档,根据 ID 进行删除 - if expired_ids: - vector_db_service.delete(collection_name=self.collection_name, ids=expired_ids) - logger.info(f"为 chat_id '{self.chat_id}' 清理了 {len(expired_ids)} 条过期记录") - else: - logger.info(f"chat_id '{self.chat_id}' 没有需要清理的过期记录") - - except Exception as e: - logger.error(f"清理过期记录失败: {e}") - - async def store_message(self, content: str, sender: str = "user") -> bool: - """ - 存储聊天消息到向量库 - - Args: - content: 消息内容 - sender: 发送者 - - Returns: - bool: 是否存储成功 - """ - if not content.strip(): - return False - - try: - # 生成消息向量 - message_vector = await get_embedding(content) - if not message_vector: - logger.warning(f"消息向量生成失败: {content[:50]}...") - return False - - message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}" - - message = ChatMessage( - message_id=message_id, chat_id=self.chat_id, content=content, timestamp=time.time(), sender=sender - ) - - # 使用新的服务存储 - vector_db_service.add( - collection_name=self.collection_name, - embeddings=[message_vector], - documents=[content], - metadatas=[ - { - "message_id": message.message_id, - "chat_id": message.chat_id, - "timestamp": message.timestamp, - "sender": message.sender, - "message_type": message.message_type, - } - ], - ids=[message_id], - ) - - logger.debug(f"消息已存储: {content[:50]}...") - return True - - except Exception as e: - logger.error(f"存储消息失败: {e}") - return False - - async def find_similar_messages( - self, query: str, top_k: int = 5, similarity_threshold: float = 0.7 - ) -> List[Dict[str, Any]]: - """ - 查找与查询相似的历史消息 - - Args: - query: 查询内容 - top_k: 返回的最相似消息数量 - similarity_threshold: 相似度阈值 - - Returns: - List[Dict]: 相似消息列表,包含content、similarity、timestamp等信息 - """ - if not query.strip(): - return [] - - try: - query_vector = await get_embedding(query) - if not query_vector: - return [] - - # 使用新的服务进行查询 - results = vector_db_service.query( - collection_name=self.collection_name, - query_embeddings=[query_vector], - n_results=top_k, - where={"chat_id": self.chat_id}, - ) - - if not results.get("documents") or not results["documents"][0]: - return [] - - # 处理搜索结果 - similar_messages = [] - documents = results["documents"][0] - distances = results["distances"][0] if results["distances"] else [] - metadatas = results["metadatas"][0] if results["metadatas"] else [] - - for i, doc in enumerate(documents): - # 计算相似度(ChromaDB返回距离,需转换) - distance = distances[i] if i < len(distances) else 1.0 - similarity = 1 - distance - - # 过滤低相似度结果 - if similarity < similarity_threshold: - continue - - # 获取元数据 - metadata = metadatas[i] if i < len(metadatas) else {} - - # 安全获取timestamp - timestamp = metadata.get("timestamp", 0) if isinstance(metadata, dict) else 0 - timestamp = float(timestamp) if isinstance(timestamp, (int, float)) else 0.0 - - similar_messages.append( - { - "content": doc, - "similarity": similarity, - "timestamp": timestamp, - "sender": metadata.get("sender", "unknown") if isinstance(metadata, dict) else "unknown", - "message_id": metadata.get("message_id", "") if isinstance(metadata, dict) else "", - "time_ago": self._format_time_ago(timestamp), - } - ) - - # 按相似度排序 - similar_messages.sort(key=lambda x: x["similarity"], reverse=True) - - logger.debug(f"找到 {len(similar_messages)} 条相似消息 (查询: {query[:30]}...)") - return similar_messages - - except Exception as e: - logger.error(f"查找相似消息失败: {e}") - return [] - - @staticmethod - def _format_time_ago(timestamp: float) -> str: - """格式化时间差显示""" - if timestamp <= 0: - return "未知时间" - - try: - now = time.time() - diff = now - timestamp - - if diff < 60: - return f"{int(diff)}秒前" - elif diff < 3600: - return f"{int(diff / 60)}分钟前" - elif diff < 86400: - return f"{int(diff / 3600)}小时前" - else: - return f"{int(diff / 86400)}天前" - except Exception: - return "时间格式错误" - - async def get_memory_for_context(self, current_message: str, context_size: int = 3) -> str: - """ - 获取与当前消息相关的记忆上下文 - - Args: - current_message: 当前消息 - context_size: 上下文消息数量 - - Returns: - str: 格式化的记忆上下文 - """ - similar_messages = await self.find_similar_messages( - current_message, - top_k=context_size, - similarity_threshold=0.6, # 降低阈值以获得更多上下文 - ) - - if not similar_messages: - return "" - - # 格式化上下文 - context_lines = [] - for msg in similar_messages: - context_lines.append( - f"[{msg['time_ago']}] {msg['sender']}: {msg['content']} (相似度: {msg['similarity']:.2f})" - ) - - return "相关的历史记忆:\n" + "\n".join(context_lines) - - def get_stats(self) -> Dict[str, Any]: - """获取记忆系统统计信息""" - stats = { - "chat_id": self.chat_id, - "retention_hours": self.retention_hours, - "cleanup_interval": self.cleanup_interval, - "system_status": "running" if self.is_running else "stopped", - "total_messages": 0, - "db_status": "connected", - } - - try: - # 注意:count() 现在没有 chat_id 过滤,返回的是整个集合的数量 - # 若要精确计数,需要 get(where={"chat_id": ...}) 然后 len(results['ids']) - # 这里为了简化,暂时显示集合总数 - result = vector_db_service.count(collection_name=self.collection_name) - stats["total_messages"] = result - except Exception: - stats["total_messages"] = "查询失败" - stats["db_status"] = "disconnected" - - return stats - - def stop(self): - """停止记忆系统""" - self.is_running = False - if self.cleanup_task and self.cleanup_task.is_alive(): - logger.info("正在停止定时清理任务...") - logger.info(f"向量瞬时记忆系统已停止: {self.chat_id}") - - -# 为了兼容现有代码,提供工厂函数 -def create_vector_memory_v2(chat_id: str, retention_hours: int = 24) -> VectorInstantMemoryV2: - """创建向量瞬时记忆系统V2实例""" - return VectorInstantMemoryV2(chat_id, retention_hours) - - -# 使用示例 -async def demo(): - """使用演示""" - memory = VectorInstantMemoryV2("demo_chat") - - # 存储一些测试消息 - await memory.store_message("今天天气不错,出去散步了", "用户") - await memory.store_message("刚才买了个冰淇淋,很好吃", "用户") - await memory.store_message("明天要开会,有点紧张", "用户") - - # 查找相似消息 - similar = await memory.find_similar_messages("天气怎么样") - print("相似消息:", similar) - - # 获取上下文 - context = await memory.get_memory_for_context("今天心情如何") - print("记忆上下文:", context) - - # 查看统计信息 - stats = memory.get_stats() - print("系统状态:", stats) - - memory.stop() - - -if __name__ == "__main__": - asyncio.run(demo()) diff --git a/src/chat/memory_system/vector_storage.py b/src/chat/memory_system/vector_storage.py new file mode 100644 index 000000000..4ad8d8271 --- /dev/null +++ b/src/chat/memory_system/vector_storage.py @@ -0,0 +1,723 @@ +# -*- coding: utf-8 -*- +""" +向量数据库存储接口 +为记忆系统提供高效的向量存储和语义搜索能力 +""" + +import os +import time +import orjson +import asyncio +from typing import Dict, List, Optional, Tuple, Set, Any +from dataclasses import dataclass +from datetime import datetime +import threading +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import pandas as pd +from pathlib import Path + +from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest +from src.config.config import model_config, global_config +from src.chat.memory_system.memory_chunk import MemoryChunk + +logger = get_logger(__name__) + +# 尝试导入FAISS,如果不可用则使用简单替代 +try: + import faiss + FAISS_AVAILABLE = True +except ImportError: + FAISS_AVAILABLE = False + logger.warning("FAISS not available, using simple vector storage") + + +@dataclass +class VectorStorageConfig: + """向量存储配置""" + dimension: int = 768 + similarity_threshold: float = 0.8 + index_type: str = "flat" # flat, ivf, hnsw + max_index_size: int = 100000 + storage_path: str = "data/memory_vectors" + auto_save_interval: int = 100 # 每N次操作自动保存 + enable_compression: bool = True + + +class VectorStorageManager: + """向量存储管理器""" + + def __init__(self, config: Optional[VectorStorageConfig] = None): + self.config = config or VectorStorageConfig() + self.storage_path = Path(self.config.storage_path) + self.storage_path.mkdir(parents=True, exist_ok=True) + + # 向量索引 + self.vector_index = None + self.memory_id_to_index = {} # memory_id -> vector index + self.index_to_memory_id = {} # vector index -> memory_id + + # 内存缓存 + self.memory_cache: Dict[str, MemoryChunk] = {} + self.vector_cache: Dict[str, List[float]] = {} + + # 统计信息 + self.storage_stats = { + "total_vectors": 0, + "index_build_time": 0.0, + "average_search_time": 0.0, + "cache_hit_rate": 0.0, + "total_searches": 0, + "cache_hits": 0 + } + + # 线程锁 + self._lock = threading.RLock() + self._operation_count = 0 + + # 初始化索引 + self._initialize_index() + + # 嵌入模型 + self.embedding_model: LLMRequest = None + + def _initialize_index(self): + """初始化向量索引""" + try: + if FAISS_AVAILABLE: + if self.config.index_type == "flat": + self.vector_index = faiss.IndexFlatIP(self.config.dimension) + elif self.config.index_type == "ivf": + quantizer = faiss.IndexFlatIP(self.config.dimension) + nlist = min(100, max(1, self.config.max_index_size // 1000)) + self.vector_index = faiss.IndexIVFFlat(quantizer, self.config.dimension, nlist) + elif self.config.index_type == "hnsw": + self.vector_index = faiss.IndexHNSWFlat(self.config.dimension, 32) + self.vector_index.hnsw.efConstruction = 40 + else: + self.vector_index = faiss.IndexFlatIP(self.config.dimension) + else: + # 简单的向量存储实现 + self.vector_index = SimpleVectorIndex(self.config.dimension) + + logger.info(f"✅ 向量索引初始化完成,类型: {self.config.index_type}") + + except Exception as e: + logger.error(f"❌ 向量索引初始化失败: {e}") + # 回退到简单实现 + self.vector_index = SimpleVectorIndex(self.config.dimension) + + async def initialize_embedding_model(self): + """初始化嵌入模型""" + if self.embedding_model is None: + self.embedding_model = LLMRequest( + model_set=model_config.model_task_config.embedding, + request_type="memory.embedding" + ) + logger.info("✅ 嵌入模型初始化完成") + + async def store_memories(self, memories: List[MemoryChunk]): + """存储记忆向量""" + if not memories: + return + + start_time = time.time() + + try: + # 确保嵌入模型已初始化 + await self.initialize_embedding_model() + + # 批量获取嵌入向量 + embedding_tasks = [] + memory_texts = [] + + for memory in memories: + if memory.embedding is None: + # 如果没有嵌入向量,需要生成 + text = self._prepare_embedding_text(memory) + memory_texts.append((memory.memory_id, text)) + else: + # 已有嵌入向量,直接使用 + await self._add_single_memory(memory, memory.embedding) + + # 批量生成缺失的嵌入向量 + if memory_texts: + await self._batch_generate_and_store_embeddings(memory_texts) + + # 自动保存检查 + self._operation_count += len(memories) + if self._operation_count >= self.config.auto_save_interval: + await self.save_storage() + self._operation_count = 0 + + storage_time = time.time() - start_time + logger.debug(f"向量存储完成,{len(memories)} 条记忆,耗时 {storage_time:.3f}秒") + + except Exception as e: + logger.error(f"❌ 向量存储失败: {e}", exc_info=True) + + def _prepare_embedding_text(self, memory: MemoryChunk) -> str: + """准备用于嵌入的文本""" + # 构建包含丰富信息的文本 + text_parts = [ + memory.text_content, + f"类型: {memory.memory_type.value}", + f"关键词: {', '.join(memory.keywords)}", + f"标签: {', '.join(memory.tags)}" + ] + + if memory.metadata.emotional_context: + text_parts.append(f"情感: {memory.metadata.emotional_context}") + + return " | ".join(text_parts) + + async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]): + """批量生成和存储嵌入向量""" + if not memory_texts: + return + + try: + texts = [text for _, text in memory_texts] + memory_ids = [memory_id for memory_id, _ in memory_texts] + + # 批量生成嵌入向量 + embeddings = await self._batch_generate_embeddings(texts) + + # 存储向量和记忆 + for memory_id, embedding in zip(memory_ids, embeddings): + if embedding and len(embedding) == self.config.dimension: + memory = self.memory_cache.get(memory_id) + if memory: + await self._add_single_memory(memory, embedding) + + except Exception as e: + logger.error(f"❌ 批量生成嵌入向量失败: {e}") + + async def _batch_generate_embeddings(self, texts: List[str]) -> List[List[float]]: + """批量生成嵌入向量""" + if not texts: + return [] + + try: + # 创建新的事件循环来运行异步操作 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # 使用线程池并行生成嵌入向量 + with ThreadPoolExecutor(max_workers=min(4, len(texts))) as executor: + tasks = [] + for text in texts: + task = loop.run_in_executor( + executor, + self._generate_single_embedding, + text + ) + tasks.append(task) + + embeddings = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理结果 + valid_embeddings = [] + for i, embedding in enumerate(embeddings): + if isinstance(embedding, Exception): + logger.warning(f"生成第 {i} 个文本的嵌入向量失败: {embedding}") + valid_embeddings.append([]) + elif embedding and len(embedding) == self.config.dimension: + valid_embeddings.append(embedding) + else: + logger.warning(f"第 {i} 个文本的嵌入向量格式异常") + valid_embeddings.append([]) + + return valid_embeddings + + finally: + loop.close() + + except Exception as e: + logger.error(f"❌ 批量生成嵌入向量失败: {e}") + return [[] for _ in texts] + + def _generate_single_embedding(self, text: str) -> List[float]: + """生成单个文本的嵌入向量""" + try: + # 创建新的事件循环 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # 使用模型生成嵌入向量 + embedding, _ = loop.run_until_complete( + self.embedding_model.get_embedding(text) + ) + + if embedding and len(embedding) == self.config.dimension: + return embedding + else: + logger.warning(f"嵌入向量维度不匹配: 期望 {self.config.dimension}, 实际 {len(embedding) if embedding else 0}") + return [] + + finally: + loop.close() + + except Exception as e: + logger.error(f"生成嵌入向量失败: {e}") + return [] + + async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]): + """添加单个记忆到向量存储""" + with self._lock: + try: + # 规范化向量 + if embedding: + embedding = self._normalize_vector(embedding) + + # 添加到缓存 + self.memory_cache[memory.memory_id] = memory + self.vector_cache[memory.memory_id] = embedding + + # 更新记忆的嵌入向量 + memory.set_embedding(embedding) + + # 添加到向量索引 + if hasattr(self.vector_index, 'add'): + # FAISS索引 + if isinstance(embedding, np.ndarray): + vector_array = embedding.reshape(1, -1).astype('float32') + else: + vector_array = np.array([embedding], dtype='float32') + + # 特殊处理IVF索引 + if self.config.index_type == "ivf" and self.vector_index.ntotal == 0: + # IVF索引需要先训练 + logger.debug("训练IVF索引...") + self.vector_index.train(vector_array) + + self.vector_index.add(vector_array) + index_id = self.vector_index.ntotal - 1 + + else: + # 简单索引 + index_id = self.vector_index.add_vector(embedding) + + # 更新映射关系 + self.memory_id_to_index[memory.memory_id] = index_id + self.index_to_memory_id[index_id] = memory.memory_id + + # 更新统计 + self.storage_stats["total_vectors"] += 1 + + except Exception as e: + logger.error(f"❌ 添加记忆到向量存储失败: {e}") + + def _normalize_vector(self, vector: List[float]) -> List[float]: + """L2归一化向量""" + if not vector: + return vector + + try: + vector_array = np.array(vector, dtype=np.float32) + norm = np.linalg.norm(vector_array) + if norm == 0: + return vector + + normalized = vector_array / norm + return normalized.tolist() + + except Exception as e: + logger.warning(f"向量归一化失败: {e}") + return vector + + async def search_similar_memories( + self, + query_vector: List[float], + limit: int = 10, + user_id: Optional[str] = None + ) -> List[Tuple[str, float]]: + """搜索相似记忆""" + start_time = time.time() + + try: + # 规范化查询向量 + query_vector = self._normalize_vector(query_vector) + + # 执行向量搜索 + with self._lock: + if hasattr(self.vector_index, 'search'): + # FAISS索引 + if isinstance(query_vector, np.ndarray): + query_array = query_vector.reshape(1, -1).astype('float32') + else: + query_array = np.array([query_vector], dtype='float32') + + if self.config.index_type == "ivf" and self.vector_index.ntotal > 0: + # 设置IVF搜索参数 + nprobe = min(self.vector_index.nlist, 10) + self.vector_index.nprobe = nprobe + + distances, indices = self.vector_index.search(query_array, min(limit, self.storage_stats["total_vectors"])) + distances = distances.flatten().tolist() + indices = indices.flatten().tolist() + else: + # 简单索引 + results = self.vector_index.search(query_vector, limit) + distances = [score for _, score in results] + indices = [idx for idx, _ in results] + + # 处理搜索结果 + results = [] + for distance, index in zip(distances, indices): + if index == -1: # FAISS的无效索引标记 + continue + + memory_id = self.index_to_memory_id.get(index) + if memory_id: + # 应用用户过滤 + if user_id: + memory = self.memory_cache.get(memory_id) + if memory and memory.user_id != user_id: + continue + + similarity = max(0.0, min(1.0, distance)) # 确保在0-1范围内 + results.append((memory_id, similarity)) + + # 更新统计 + search_time = time.time() - start_time + self.storage_stats["total_searches"] += 1 + self.storage_stats["average_search_time"] = ( + (self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time) / + self.storage_stats["total_searches"] + ) + + return results[:limit] + + except Exception as e: + logger.error(f"❌ 向量搜索失败: {e}") + return [] + + async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]: + """根据ID获取记忆""" + # 先检查缓存 + if memory_id in self.memory_cache: + self.storage_stats["cache_hits"] += 1 + return self.memory_cache[memory_id] + + self.storage_stats["total_searches"] += 1 + return None + + async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]): + """更新记忆的嵌入向量""" + with self._lock: + try: + if memory_id not in self.memory_id_to_index: + logger.warning(f"记忆 {memory_id} 不存在于向量索引中") + return + + # 获取旧索引 + old_index = self.memory_id_to_index[memory_id] + + # 删除旧向量(如果支持) + if hasattr(self.vector_index, 'remove_ids'): + try: + self.vector_index.remove_ids(np.array([old_index])) + except: + logger.warning("无法删除旧向量,将直接添加新向量") + + # 规范化新向量 + new_embedding = self._normalize_vector(new_embedding) + + # 添加新向量 + if hasattr(self.vector_index, 'add'): + if isinstance(new_embedding, np.ndarray): + vector_array = new_embedding.reshape(1, -1).astype('float32') + else: + vector_array = np.array([new_embedding], dtype='float32') + + self.vector_index.add(vector_array) + new_index = self.vector_index.ntotal - 1 + else: + new_index = self.vector_index.add_vector(new_embedding) + + # 更新映射关系 + self.memory_id_to_index[memory_id] = new_index + self.index_to_memory_id[new_index] = memory_id + + # 更新缓存 + self.vector_cache[memory_id] = new_embedding + + # 更新记忆对象 + memory = self.memory_cache.get(memory_id) + if memory: + memory.set_embedding(new_embedding) + + logger.debug(f"更新记忆 {memory_id} 的嵌入向量") + + except Exception as e: + logger.error(f"❌ 更新记忆嵌入向量失败: {e}") + + async def delete_memory(self, memory_id: str): + """删除记忆""" + with self._lock: + try: + if memory_id not in self.memory_id_to_index: + return + + # 获取索引 + index = self.memory_id_to_index[memory_id] + + # 从向量索引中删除(如果支持) + if hasattr(self.vector_index, 'remove_ids'): + try: + self.vector_index.remove_ids(np.array([index])) + except: + logger.warning("无法从向量索引中删除,仅从缓存中移除") + + # 删除映射关系 + del self.memory_id_to_index[memory_id] + if index in self.index_to_memory_id: + del self.index_to_memory_id[index] + + # 从缓存中删除 + self.memory_cache.pop(memory_id, None) + self.vector_cache.pop(memory_id, None) + + # 更新统计 + self.storage_stats["total_vectors"] = max(0, self.storage_stats["total_vectors"] - 1) + + logger.debug(f"删除记忆 {memory_id}") + + except Exception as e: + logger.error(f"❌ 删除记忆失败: {e}") + + async def save_storage(self): + """保存向量存储到文件""" + try: + logger.info("正在保存向量存储...") + + # 保存记忆缓存 + cache_data = { + memory_id: memory.to_dict() + for memory_id, memory in self.memory_cache.items() + } + + cache_file = self.storage_path / "memory_cache.json" + with open(cache_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + + # 保存向量缓存 + vector_cache_file = self.storage_path / "vector_cache.json" + with open(vector_cache_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode('utf-8')) + + # 保存映射关系 + mapping_file = self.storage_path / "id_mapping.json" + mapping_data = { + "memory_id_to_index": self.memory_id_to_index, + "index_to_memory_id": self.index_to_memory_id + } + with open(mapping_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8')) + + # 保存FAISS索引(如果可用) + if FAISS_AVAILABLE and hasattr(self.vector_index, 'save'): + index_file = self.storage_path / "vector_index.faiss" + faiss.write_index(self.vector_index, str(index_file)) + + # 保存统计信息 + stats_file = self.storage_path / "storage_stats.json" + with open(stats_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode('utf-8')) + + logger.info("✅ 向量存储保存完成") + + except Exception as e: + logger.error(f"❌ 保存向量存储失败: {e}") + + async def load_storage(self): + """从文件加载向量存储""" + try: + logger.info("正在加载向量存储...") + + # 加载记忆缓存 + cache_file = self.storage_path / "memory_cache.json" + if cache_file.exists(): + with open(cache_file, 'r', encoding='utf-8') as f: + cache_data = orjson.loads(f.read()) + + self.memory_cache = { + memory_id: MemoryChunk.from_dict(memory_data) + for memory_id, memory_data in cache_data.items() + } + + # 加载向量缓存 + vector_cache_file = self.storage_path / "vector_cache.json" + if vector_cache_file.exists(): + with open(vector_cache_file, 'r', encoding='utf-8') as f: + self.vector_cache = orjson.loads(f.read()) + + # 加载映射关系 + mapping_file = self.storage_path / "id_mapping.json" + if mapping_file.exists(): + with open(mapping_file, 'r', encoding='utf-8') as f: + mapping_data = orjson.loads(f.read()) + self.memory_id_to_index = mapping_data.get("memory_id_to_index", {}) + self.index_to_memory_id = mapping_data.get("index_to_memory_id", {}) + + # 加载FAISS索引(如果可用) + if FAISS_AVAILABLE: + index_file = self.storage_path / "vector_index.faiss" + if index_file.exists() and hasattr(self.vector_index, 'load'): + try: + loaded_index = faiss.read_index(str(index_file)) + # 如果索引类型匹配,则替换 + if type(loaded_index) == type(self.vector_index): + self.vector_index = loaded_index + logger.info("✅ FAISS索引加载完成") + else: + logger.warning("索引类型不匹配,重新构建索引") + await self._rebuild_index() + except Exception as e: + logger.warning(f"加载FAISS索引失败: {e},重新构建") + await self._rebuild_index() + + # 加载统计信息 + stats_file = self.storage_path / "storage_stats.json" + if stats_file.exists(): + with open(stats_file, 'r', encoding='utf-8') as f: + self.storage_stats = orjson.loads(f.read()) + + # 更新向量计数 + self.storage_stats["total_vectors"] = len(self.memory_id_to_index) + + logger.info(f"✅ 向量存储加载完成,{self.storage_stats['total_vectors']} 个向量") + + except Exception as e: + logger.error(f"❌ 加载向量存储失败: {e}") + + async def _rebuild_index(self): + """重建向量索引""" + try: + logger.info("正在重建向量索引...") + + # 重新初始化索引 + self._initialize_index() + + # 重新添加所有向量 + for memory_id, embedding in self.vector_cache.items(): + if embedding: + memory = self.memory_cache.get(memory_id) + if memory: + await self._add_single_memory(memory, embedding) + + logger.info("✅ 向量索引重建完成") + + except Exception as e: + logger.error(f"❌ 重建向量索引失败: {e}") + + async def optimize_storage(self): + """优化存储""" + try: + logger.info("开始向量存储优化...") + + # 清理无效引用 + self._cleanup_invalid_references() + + # 重新构建索引(如果碎片化严重) + if self.storage_stats["total_vectors"] > 1000: + await self._rebuild_index() + + # 更新缓存命中率 + if self.storage_stats["total_searches"] > 0: + self.storage_stats["cache_hit_rate"] = ( + self.storage_stats["cache_hits"] / self.storage_stats["total_searches"] + ) + + logger.info("✅ 向量存储优化完成") + + except Exception as e: + logger.error(f"❌ 向量存储优化失败: {e}") + + def _cleanup_invalid_references(self): + """清理无效引用""" + with self._lock: + # 清理无效的memory_id到index的映射 + valid_memory_ids = set(self.memory_cache.keys()) + invalid_memory_ids = set(self.memory_id_to_index.keys()) - valid_memory_ids + + for memory_id in invalid_memory_ids: + index = self.memory_id_to_index[memory_id] + del self.memory_id_to_index[memory_id] + if index in self.index_to_memory_id: + del self.index_to_memory_id[index] + + if invalid_memory_ids: + logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用") + + def get_storage_stats(self) -> Dict[str, Any]: + """获取存储统计信息""" + stats = self.storage_stats.copy() + if stats["total_searches"] > 0: + stats["cache_hit_rate"] = stats["cache_hits"] / stats["total_searches"] + else: + stats["cache_hit_rate"] = 0.0 + return stats + + +class SimpleVectorIndex: + """简单的向量索引实现(当FAISS不可用时的替代方案)""" + + def __init__(self, dimension: int): + self.dimension = dimension + self.vectors: List[List[float]] = [] + self.vector_ids: List[int] = [] + self.next_id = 0 + + def add_vector(self, vector: List[float]) -> int: + """添加向量""" + if len(vector) != self.dimension: + raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}") + + vector_id = self.next_id + self.vectors.append(vector.copy()) + self.vector_ids.append(vector_id) + self.next_id += 1 + + return vector_id + + def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]: + """搜索相似向量""" + if len(query_vector) != self.dimension: + raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}") + + results = [] + + for i, vector in enumerate(self.vectors): + similarity = self._calculate_cosine_similarity(query_vector, vector) + results.append((self.vector_ids[i], similarity)) + + # 按相似度排序 + results.sort(key=lambda x: x[1], reverse=True) + + return results[:limit] + + def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float: + """计算余弦相似度""" + try: + dot_product = sum(x * y for x, y in zip(v1, v2)) + norm1 = sum(x * x for x in v1) ** 0.5 + norm2 = sum(x * x for x in v2) ** 0.5 + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return dot_product / (norm1 * norm2) + + except Exception: + return 0.0 + + @property + def ntotal(self) -> int: + """向量总数""" + return len(self.vectors) \ No newline at end of file diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index c2a86f979..a6f875623 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -28,8 +28,8 @@ from src.chat.utils.chat_message_builder import ( replace_user_references_sync, ) from src.chat.express.expression_selector import expression_selector -from src.chat.memory_system.memory_activator import MemoryActivator -from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2 +# 旧记忆系统已被移除 +# 旧记忆系统已被移除 from src.mood.mood_manager import mood_manager from src.person_info.person_info import get_person_info_manager from src.plugin_system.base.component_types import ActionInfo, EventType @@ -231,9 +231,12 @@ class DefaultReplyer: self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) self.heart_fc_sender = HeartFCSender() - self.memory_activator = MemoryActivator() - # 使用纯向量瞬时记忆系统V2,支持自定义保留时间 - self.instant_memory = VectorInstantMemoryV2(chat_id=self.chat_stream.stream_id, retention_hours=1) + # 使用新的增强记忆系统 + # from src.chat.memory_system.enhanced_memory_activator import EnhancedMemoryActivator + # self.memory_activator = EnhancedMemoryActivator() + self.memory_activator = None # 暂时禁用记忆激活器 + # 旧的即时记忆系统已被移除,现在使用增强记忆系统 + # self.instant_memory = VectorInstantMemoryV2(chat_id=self.chat_stream.stream_id, retention_hours=1) from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 @@ -459,90 +462,65 @@ class DefaultReplyer: instant_memory = None - running_memories = await self.memory_activator.activate_memory_with_chat_history( - target_message=target, chat_history_prompt=chat_history - ) + # 使用新的增强记忆系统检索记忆 + running_memories = [] + instant_memory = None if global_config.memory.enable_instant_memory: - # 使用异步记忆包装器(最优化的非阻塞模式) try: - from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory + # 使用新的增强记忆系统 + from src.chat.memory_system.enhanced_memory_integration import recall_memories, remember_message - # 获取异步记忆包装器 - async_memory = get_async_instant_memory(self.chat_stream.stream_id) - - # 后台存储聊天历史(完全非阻塞) - async_memory.store_memory_background(chat_history) - - # 快速检索记忆,最大超时2秒 - instant_memory = await async_memory.get_memory_with_fallback(target, max_timeout=2.0) - - logger.info(f"异步瞬时记忆:{instant_memory}") - - except ImportError: - # 如果异步包装器不可用,尝试使用异步记忆管理器 - try: - from src.chat.memory_system.async_memory_optimizer import ( - retrieve_memory_nonblocking, - store_memory_nonblocking, + # 异步存储聊天历史(非阻塞) + asyncio.create_task( + remember_message( + message=chat_history, + user_id=str(self.chat_stream.stream_id), + chat_id=self.chat_stream.stream_id ) + ) - # 异步存储聊天历史(非阻塞) - asyncio.create_task( - store_memory_nonblocking(chat_id=self.chat_stream.stream_id, content=chat_history) - ) + # 检索相关记忆 + enhanced_memories = await recall_memories( + query=target, + user_id=str(self.chat_stream.stream_id), + chat_id=self.chat_stream.stream_id + ) - # 尝试从缓存获取瞬时记忆 - instant_memory = await retrieve_memory_nonblocking(chat_id=self.chat_stream.stream_id, query=target) + # 转换格式以兼容现有代码 + running_memories = [] + if enhanced_memories and enhanced_memories.get("has_memories"): + for memory in enhanced_memories.get("memories", []): + running_memories.append({ + 'content': memory.get("content", ""), + 'score': memory.get("confidence", 0.0), + 'memory_type': memory.get("type", "unknown") + }) - # 如果没有缓存结果,快速检索一次 - if instant_memory is None: - try: - instant_memory = await asyncio.wait_for( - self.instant_memory.get_memory_for_context(target), timeout=1.5 - ) - except asyncio.TimeoutError: - logger.warning("瞬时记忆检索超时,使用空结果") - instant_memory = "" + # 构建瞬时记忆字符串 + if enhanced_memories and enhanced_memories.get("has_memories"): + instant_memory = "\\n".join([ + f"{memory.get('content', '')} (相似度: {memory.get('confidence', 0.0):.2f})" + for memory in enhanced_memories.get("memories", [])[:3] # 取前3条 + ]) - logger.info(f"向量瞬时记忆:{instant_memory}") - - except ImportError: - # 最后的fallback:使用原有逻辑但加上超时控制 - logger.warning("异步记忆系统不可用,使用带超时的同步方式") - - # 异步存储聊天历史 - asyncio.create_task(self.instant_memory.store_message(chat_history)) - - # 带超时的记忆检索 - try: - instant_memory = await asyncio.wait_for( - self.instant_memory.get_memory_for_context(target), - timeout=1.0, # 最保守的1秒超时 - ) - except asyncio.TimeoutError: - logger.warning("瞬时记忆检索超时,跳过记忆获取") - instant_memory = "" - except Exception as e: - logger.error(f"瞬时记忆检索失败: {e}") - instant_memory = "" - - logger.info(f"同步瞬时记忆:{instant_memory}") + logger.info(f"增强记忆系统检索到 {len(running_memories)} 条记忆") except Exception as e: - logger.error(f"瞬时记忆系统异常: {e}") + logger.warning(f"增强记忆系统检索失败: {e}") + running_memories = [] instant_memory = "" # 构建记忆字符串,即使某种记忆为空也要继续 memory_str = "" has_any_memory = False - # 添加长期记忆 + # 添加长期记忆(来自增强记忆系统) if running_memories: if not memory_str: memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" for running_memory in running_memories: - memory_str += f"- {running_memory['content']}\n" + memory_str += f"- {running_memory['content']} (类型: {running_memory['memory_type']}, 相似度: {running_memory['score']:.2f})\n" has_any_memory = True # 添加瞬时记忆 diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index e17a71069..72543efdd 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -371,28 +371,35 @@ class Prompt: tasks.append(self._build_cross_context()) task_names.append("cross_context") - # 性能优化 - base_timeout = 10.0 - task_timeout = 2.0 - timeout_seconds = min( - max(base_timeout, len(tasks) * task_timeout), - 30.0, - ) + # 性能优化 - 为不同任务设置不同的超时时间 + task_timeouts = { + "memory_block": 5.0, # 记忆系统可能较慢,单独设置超时 + "tool_info": 3.0, # 工具信息中等速度 + "relation_info": 2.0, # 关系信息通常较快 + "knowledge_info": 3.0, # 知识库查询中等速度 + "cross_context": 2.0, # 上下文处理通常较快 + "expression_habits": 1.5, # 表达习惯处理很快 + } - max_concurrent_tasks = 5 - if len(tasks) > max_concurrent_tasks: - results = [] - for i in range(0, len(tasks), max_concurrent_tasks): - batch_tasks = tasks[i : i + max_concurrent_tasks] + # 分别处理每个任务,避免慢任务影响快任务 + results = [] + for i, task in enumerate(tasks): + task_name = task_names[i] if i < len(task_names) else f"task_{i}" + task_timeout = task_timeouts.get(task_name, 2.0) # 默认2秒 - batch_results = await asyncio.wait_for( - asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds - ) - results.extend(batch_results) - else: - results = await asyncio.wait_for( - asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds - ) + try: + result = await asyncio.wait_for(task, timeout=task_timeout) + results.append(result) + logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)") + except asyncio.TimeoutError: + logger.warning(f"构建任务{task_name}超时 ({task_timeout}s),使用默认值") + # 为超时任务提供默认值 + default_result = self._get_default_result_for_task(task_name) + results.append(default_result) + except Exception as e: + logger.error(f"构建任务{task_name}失败: {str(e)}") + default_result = self._get_default_result_for_task(task_name) + results.append(default_result) # 处理结果 context_data = {} @@ -528,8 +535,7 @@ class Prompt: return {"memory_block": ""} try: - from src.chat.memory_system.memory_activator import MemoryActivator - from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory + from src.chat.memory_system.enhanced_memory_activator import enhanced_memory_activator # 获取聊天历史 chat_history = "" @@ -539,15 +545,38 @@ class Prompt: recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - # 激活长期记忆 - memory_activator = MemoryActivator() - running_memories = await memory_activator.activate_memory_with_chat_history( - target_message=self.parameters.target, chat_history_prompt=chat_history - ) + # 并行执行记忆查询以提高性能 + import asyncio - # 获取即时记忆 - async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id) - instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target) + # 创建记忆查询任务 + memory_tasks = [ + enhanced_memory_activator.activate_memory_with_chat_history( + target_message=self.parameters.target, chat_history_prompt=chat_history + ), + enhanced_memory_activator.get_instant_memory( + target_message=self.parameters.target, chat_id=self.parameters.chat_id + ) + ] + + # 等待所有记忆查询完成(最多3秒) + try: + running_memories, instant_memory = await asyncio.wait_for( + asyncio.gather(*memory_tasks, return_exceptions=True), + timeout=3.0 + ) + + # 处理可能的异常结果 + if isinstance(running_memories, Exception): + logger.warning(f"长期记忆查询失败: {running_memories}") + running_memories = [] + if isinstance(instant_memory, Exception): + logger.warning(f"即时记忆查询失败: {instant_memory}") + instant_memory = None + + except asyncio.TimeoutError: + logger.warning("记忆查询超时,使用部分结果") + running_memories = [] + instant_memory = None # 构建记忆块 memory_parts = [] @@ -870,6 +899,32 @@ class Prompt: return await relationship_fetcher.build_relation_info(person_id, points_num=5) + def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]: + """ + 为超时的任务提供默认结果 + + Args: + task_name: 任务名称 + + Returns: + Dict: 默认结果 + """ + defaults = { + "memory_block": {"memory_block": ""}, + "tool_info": {"tool_info_block": ""}, + "relation_info": {"relation_info_block": ""}, + "knowledge_info": {"knowledge_prompt": ""}, + "cross_context": {"cross_context_block": ""}, + "expression_habits": {"expression_habits_block": ""}, + } + + if task_name in defaults: + logger.info(f"为超时任务 {task_name} 提供默认值") + return defaults[task_name] + else: + logger.warning(f"未知任务类型 {task_name},返回空结果") + return {} + @staticmethod async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str: """ diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 7afedfae7..1dd517834 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -459,6 +459,40 @@ class MemoryConfig(ValidatedConfigBase): enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆") enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆") + # 增强记忆系统配置 + enable_enhanced_memory: bool = Field(default=True, description="启用增强记忆系统") + enhanced_memory_auto_save: bool = Field(default=True, description="自动保存增强记忆") + + # 记忆构建配置 + min_memory_length: int = Field(default=10, description="最小记忆长度") + max_memory_length: int = Field(default=500, description="最大记忆长度") + memory_value_threshold: float = Field(default=0.7, description="记忆价值阈值") + + # 向量存储配置 + vector_dimension: int = Field(default=768, description="向量维度") + vector_similarity_threshold: float = Field(default=0.8, description="向量相似度阈值") + + # 多阶段检索配置 + metadata_filter_limit: int = Field(default=100, description="元数据过滤阶段返回数量") + vector_search_limit: int = Field(default=50, description="向量搜索阶段返回数量") + semantic_rerank_limit: int = Field(default=20, description="语义重排序阶段返回数量") + final_result_limit: int = Field(default=10, description="最终结果数量") + + # 检索权重配置 + vector_weight: float = Field(default=0.4, description="向量相似度权重") + semantic_weight: float = Field(default=0.3, description="语义相似度权重") + context_weight: float = Field(default=0.2, description="上下文权重") + recency_weight: float = Field(default=0.1, description="时效性权重") + + # 记忆融合配置 + fusion_similarity_threshold: float = Field(default=0.85, description="融合相似度阈值") + deduplication_window_hours: int = Field(default=24, description="去重时间窗口(小时)") + + # 缓存配置 + enable_memory_cache: bool = Field(default=True, description="启用记忆缓存") + cache_ttl_seconds: int = Field(default=300, description="缓存生存时间(秒)") + max_cache_size: int = Field(default=1000, description="最大缓存大小") + class MoodConfig(ValidatedConfigBase): """情绪配置类""" diff --git a/src/main.py b/src/main.py index 9fae02a28..dc1ed5289 100644 --- a/src/main.py +++ b/src/main.py @@ -34,54 +34,8 @@ from src.plugin_system.core.plugin_hot_reload import hot_reload_manager # 导入消息API和traceback模块 from src.common.message import get_global_api -from src.chat.memory_system.Hippocampus import hippocampus_manager - -if not global_config.memory.enable_memory: - import src.chat.memory_system.Hippocampus as hippocampus_module - - class MockHippocampusManager: - def initialize(self): - pass - - def get_hippocampus(self): - return None - - async def build_memory(self): - pass - - async def forget_memory(self, percentage: float = 0.005): - pass - - async def consolidate_memory(self): - pass - - async def get_memory_from_text( - self, - text: str, - max_memory_num: int = 3, - max_memory_length: int = 2, - max_depth: int = 3, - fast_retrieval: bool = False, - ) -> list: - return [] - - async def get_memory_from_topic( - self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 - ) -> list: - return [] - - async def get_activate_from_text( - self, text: str, max_depth: int = 3, fast_retrieval: bool = False - ) -> tuple[float, list[str]]: - return 0.0, [] - - def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: - return [] - - def get_all_node_names(self) -> list: - return [] - - hippocampus_module.hippocampus_manager = MockHippocampusManager() +# 导入增强记忆系统管理器 +from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager # 插件系统现在使用统一的插件加载器 @@ -106,7 +60,8 @@ def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float): class MainSystem: def __init__(self): - self.hippocampus_manager = hippocampus_manager + # 使用增强记忆系统 + self.enhanced_memory_manager = enhanced_memory_manager self.individuality: Individuality = get_individuality() @@ -169,19 +124,18 @@ class MainSystem: logger.error(f"停止热重载系统时出错: {e}") try: - # 停止异步记忆管理器 + # 停止增强记忆系统 if global_config.memory.enable_memory: - from src.chat.memory_system.async_memory_optimizer import async_memory_manager import asyncio loop = asyncio.get_event_loop() if loop.is_running(): - asyncio.create_task(async_memory_manager.shutdown()) + asyncio.create_task(self.enhanced_memory_manager.shutdown()) else: - loop.run_until_complete(async_memory_manager.shutdown()) - logger.info("🛑 记忆管理器已停止") + loop.run_until_complete(self.enhanced_memory_manager.shutdown()) + logger.info("🛑 增强记忆系统已停止") except Exception as e: - logger.error(f"停止记忆管理器时出错: {e}") + logger.error(f"停止增强记忆系统时出错: {e}") async def _message_process_wrapper(self, message_data: Dict[str, Any]): """并行处理消息的包装器""" @@ -304,9 +258,11 @@ MoFox_Bot(第三方修改版) logger.info("聊天管理器初始化成功") - # 初始化记忆系统 - self.hippocampus_manager.initialize() - logger.info("记忆系统初始化成功") + # 初始化增强记忆系统 + await self.enhanced_memory_manager.initialize() + logger.info("增强记忆系统初始化成功") + + # 老记忆系统已完全删除 # 初始化LPMM知识库 from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge @@ -314,14 +270,8 @@ MoFox_Bot(第三方修改版) initialize_lpmm_knowledge() logger.info("LPMM知识库初始化成功") - # 初始化异步记忆管理器 - try: - from src.chat.memory_system.async_memory_optimizer import async_memory_manager - - await async_memory_manager.initialize() - logger.info("记忆管理器初始化成功") - except Exception as e: - logger.error(f"记忆管理器初始化失败: {e}") + # 异步记忆管理器已禁用,增强记忆系统有内置的优化机制 + logger.info("异步记忆管理器已禁用 - 使用增强记忆系统内置优化") # await asyncio.sleep(0.5) #防止logger输出飞了 @@ -376,81 +326,12 @@ MoFox_Bot(第三方修改版) self.server.run(), ] - # 添加记忆系统相关任务 - tasks.extend( - [ - self.build_memory_task(), - self.forget_memory_task(), - self.consolidate_memory_task(), - ] - ) + # 增强记忆系统不需要定时任务,已禁用原有记忆系统的定时任务 + logger.info("原有记忆系统定时任务已禁用 - 使用增强记忆系统") await asyncio.gather(*tasks) - async def build_memory_task(self): - """记忆构建任务""" - while True: - await asyncio.sleep(global_config.memory.memory_build_interval) - - try: - # 使用异步记忆管理器进行非阻塞记忆构建 - from src.chat.memory_system.async_memory_optimizer import build_memory_nonblocking - - logger.info("正在启动记忆构建") - - # 定义构建完成的回调函数 - def build_completed(result): - if result: - logger.info("记忆构建完成") - else: - logger.warning("记忆构建失败") - - # 启动异步构建,不等待完成 - task_id = await build_memory_nonblocking() - logger.info(f"记忆构建任务已提交:{task_id}") - - except ImportError: - # 如果异步优化器不可用,使用原有的同步方式(但在单独的线程中运行) - logger.warning("记忆优化器不可用,使用线性运行执行记忆构建") - - def sync_build_memory(): - """在线程池中执行同步记忆构建""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - result = loop.run_until_complete(self.hippocampus_manager.build_memory()) - logger.info("记忆构建完成") - return result - except Exception as e: - logger.error(f"记忆构建失败: {e}") - return None - finally: - loop.close() - - # 在线程池中执行记忆构建 - asyncio.get_event_loop().run_in_executor(None, sync_build_memory) - - except Exception as e: - logger.error(f"记忆构建任务启动失败: {e}") - # fallback到原有的同步方式 - logger.info("正在进行记忆构建(同步模式)") - await self.hippocampus_manager.build_memory() # type: ignore - - async def forget_memory_task(self): - """记忆遗忘任务""" - while True: - await asyncio.sleep(global_config.memory.forget_memory_interval) - logger.info("[记忆遗忘] 开始遗忘记忆...") - await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore - logger.info("[记忆遗忘] 记忆遗忘完成") - - async def consolidate_memory_task(self): - """记忆整合任务""" - while True: - await asyncio.sleep(global_config.memory.consolidate_memory_interval) - logger.info("[记忆整合] 开始整合记忆...") - await self.hippocampus_manager.consolidate_memory() # type: ignore - logger.info("[记忆整合] 记忆整合完成") + # 老记忆系统的定时任务已删除 - 增强记忆系统使用内置的维护机制 async def main(): diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index e3682a450..4f07cdf7a 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -2,7 +2,8 @@ import asyncio import math from typing import Tuple -from src.chat.memory_system.Hippocampus import hippocampus_manager +# 旧的Hippocampus系统已被移除,现在使用增强记忆系统 +# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager from src.chat.message_receive.message import MessageRecv, MessageRecvS4U from maim_message.message_base import GroupInfo from src.chat.message_receive.storage import MessageStorage @@ -40,11 +41,31 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: if global_config.memory.enable_memory: with Timer("记忆激活"): - interested_rate, _ = await hippocampus_manager.get_activate_from_text( - message.processed_plain_text, - fast_retrieval=True, - ) - logger.debug(f"记忆激活率: {interested_rate:.2f}") + # 使用新的增强记忆系统计算兴趣度 + try: + from src.chat.memory_system.enhanced_memory_integration import recall_memories + + # 检索相关记忆来估算兴趣度 + enhanced_memories = await recall_memories( + query=message.processed_plain_text, + user_id=str(message.user_info.user_id), + chat_id=message.chat_id + ) + + # 基于检索结果计算兴趣度 + if enhanced_memories: + # 有相关记忆,兴趣度基于相似度计算 + max_score = max(score for _, score in enhanced_memories) + interested_rate = min(max_score, 1.0) # 限制在0-1之间 + else: + # 没有相关记忆,给予基础兴趣度 + interested_rate = 0.1 + + logger.debug(f"增强记忆系统兴趣度: {interested_rate:.2f}") + + except Exception as e: + logger.warning(f"增强记忆系统兴趣度计算失败: {e}") + interested_rate = 0.1 # 默认基础兴趣度 text_len = len(message.processed_plain_text) # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 336346e25..268af7e1f 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -4,7 +4,8 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat import time from src.chat.utils.utils import get_recent_group_speaker -from src.chat.memory_system.Hippocampus import hippocampus_manager +# 旧的Hippocampus系统已被移除,现在使用增强记忆系统 +# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager import random from datetime import datetime import asyncio @@ -171,16 +172,26 @@ class PromptBuilder: @staticmethod async def build_memory_block(text: str) -> str: - related_memory = await hippocampus_manager.get_memory_from_text( - text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False - ) + # 使用新的增强记忆系统检索记忆 + try: + from src.chat.memory_system.enhanced_memory_integration import recall_memories - related_memory_info = "" - if related_memory: - for memory in related_memory: - related_memory_info += memory[1] - return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info) - return "" + enhanced_memories = await recall_memories( + query=text, + user_id="system", # 系统查询 + chat_id="system" + ) + + related_memory_info = "" + if enhanced_memories and enhanced_memories.get("has_memories"): + for memory in enhanced_memories.get("memories", []): + related_memory_info += memory.get("content", "") + " " + return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info.strip()) + return "" + + except Exception as e: + logger.warning(f"增强记忆系统检索失败: {e}") + return "" @staticmethod async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U): diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index b21bd6b3e..ad6621b3a 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -98,7 +98,6 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa message_recv = MessageRecv(new_message_dict) logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}") - logger.info(message_recv) return message_recv diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index d75ffb574..0091f75f7 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -11,7 +11,8 @@ from typing import Any, Dict, List, Optional from json_repair import repair_json -from src.chat.memory_system.Hippocampus import hippocampus_manager +# 旧的Hippocampus系统已被移除,现在使用增强记忆系统 +# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager from src.chat.utils.chat_message_builder import ( build_readable_actions, build_readable_messages_with_id, @@ -602,14 +603,32 @@ class ChatterPlanFilter: else: keywords.append("晚上") - retrieved_memories = await hippocampus_manager.get_memory_from_topic( - valid_keywords=keywords, max_memory_num=5, max_memory_length=1 - ) + # 使用新的增强记忆系统检索记忆 + try: + from src.chat.memory_system.enhanced_memory_integration import recall_memories - if not retrieved_memories: + # 将关键词转换为查询字符串 + query = " ".join(keywords) + enhanced_memories = await recall_memories( + query=query, + user_id="system", # 系统查询 + chat_id="system" + ) + + if not enhanced_memories: + return "最近没有什么特别的记忆。" + + # 转换格式以兼容现有代码 + retrieved_memories = [] + if enhanced_memories and enhanced_memories.get("has_memories"): + for memory in enhanced_memories.get("memories", []): + retrieved_memories.append((memory.get("type", "unknown"), memory.get("content", ""))) + + memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories] + + except Exception as e: + logger.warning(f"增强记忆系统检索失败,使用默认回复: {e}") return "最近没有什么特别的记忆。" - - memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories] return " ".join(memory_statements) except Exception as e: logger.error(f"获取长期记忆时出错: {e}")