From 691172766e3a95d348409986da2b012070d6c083 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 3 Jun 2025 13:10:48 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9A=E4=BC=98=E5=8C=96=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E5=90=8C=E6=AD=A5=E7=AE=97=E6=B3=95=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E8=AE=B0=E5=BF=86=E6=9E=84=E5=BB=BA=E6=B2=A1=E6=9C=89?= =?UTF-8?q?chat=5Fid=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/memory_system/Hippocampus.py | 446 +++++++++++++++++++++---- src/chat/utils/chat_message_builder.py | 2 +- 2 files changed, 373 insertions(+), 75 deletions(-) diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index e63840f11..43cc1fef6 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -17,12 +17,14 @@ from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # from ..utils.chat_message_builder import ( get_raw_msg_by_timestamp, build_readable_messages, + get_raw_msg_by_timestamp_with_chat, ) # 导入 build_readable_messages from ..utils.utils import translate_timestamp_to_human_readable from rich.traceback import install from ...config.config import global_config from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 +from peewee import Case install(extra_lines=3) @@ -215,15 +217,18 @@ class Hippocampus: """计算节点的特征值""" if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] - sorted_items = sorted(memory_items) - content = f"{concept}:{'|'.join(sorted_items)}" + + # 使用集合来去重,避免排序 + unique_items = set(str(item) for item in memory_items) + # 使用frozenset来保证顺序一致性 + content = f"{concept}:{frozenset(unique_items)}" return hash(content) @staticmethod def calculate_edge_hash(source, target) -> int: """计算边的特征值""" - nodes = sorted([source, target]) - return hash(f"{nodes[0]}:{nodes[1]}") + # 直接使用元组,保证顺序一致性 + return hash((source, target)) @staticmethod def find_topic_llm(text, topic_num): @@ -811,7 +816,8 @@ class EntorhinalCortex: timestamps = sample_scheduler.get_timestamp_array() # 使用 translate_timestamp_to_human_readable 并指定 mode="normal" readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps] - logger.info(f"回忆往事: {readable_timestamps}") + for timestamp, readable_timestamp in zip(timestamps, readable_timestamps): + logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] for timestamp in timestamps: # 调用修改后的 random_get_msg_snippet @@ -820,10 +826,10 @@ class EntorhinalCortex: ) if messages: time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 - logger.debug(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") + logger.success(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") chat_samples.append(messages) else: - logger.debug(f"时间戳 {timestamp} 的消息样本抽取失败") + logger.debug(f"时间戳 {timestamp} 的消息无需记忆") return chat_samples @@ -837,32 +843,37 @@ class EntorhinalCortex: # 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds timestamp_start = target_timestamp timestamp_end = target_timestamp + time_window_seconds - - # 使用 chat_message_builder 的函数获取消息 - # limit_mode='earliest' 获取这个时间窗口内最早的 chat_size 条消息 - messages = get_raw_msg_by_timestamp( - timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest" + + chosen_message = get_raw_msg_by_timestamp( + timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest" ) + + if chosen_message: + chat_id = chosen_message[0].get("chat_id") - if messages: - # 检查获取到的所有消息是否都未达到最大记忆次数 - all_valid = True - for message in messages: - if message.get("memorized_times", 0) >= max_memorized_time_per_msg: - all_valid = False - break + messages = get_raw_msg_by_timestamp_with_chat( + timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id + ) - # 如果所有消息都有效 - if all_valid: - # 更新数据库中的记忆次数 + if messages: + # 检查获取到的所有消息是否都未达到最大记忆次数 + all_valid = True for message in messages: - # 确保在更新前获取最新的 memorized_times - current_memorized_times = message.get("memorized_times", 0) - # 使用 Peewee 更新记录 - Messages.update(memorized_times=current_memorized_times + 1).where( - Messages.message_id == message["message_id"] - ).execute() - return messages # 直接返回原始的消息列表 + if message.get("memorized_times", 0) >= max_memorized_time_per_msg: + all_valid = False + break + + # 如果所有消息都有效 + if all_valid: + # 更新数据库中的记忆次数 + for message in messages: + # 确保在更新前获取最新的 memorized_times + current_memorized_times = message.get("memorized_times", 0) + # 使用 Peewee 更新记录 + Messages.update(memorized_times=current_memorized_times + 1).where( + Messages.message_id == message["message_id"] + ).execute() + return messages # 直接返回原始的消息列表 # 如果获取失败或消息无效,增加尝试次数 try_count += 1 @@ -873,85 +884,361 @@ class EntorhinalCortex: async def sync_memory_to_db(self): """将记忆图同步到数据库""" + start_time = time.time() + # 获取数据库中所有节点和内存中所有节点 + db_load_start = time.time() db_nodes = {node.concept: node for node in GraphNodes.select()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) + db_load_end = time.time() + logger.info(f"[同步] 加载数据库耗时: {db_load_end - db_load_start:.2f}秒") + + # 批量准备节点数据 + nodes_to_create = [] + nodes_to_update = [] + current_time = datetime.datetime.now().timestamp() # 检查并更新节点 + node_process_start = time.time() for concept, data in memory_nodes: + # 检查概念是否有效 + if not concept or not isinstance(concept, str): + logger.warning(f"[同步] 发现无效概念,将移除节点: {concept}") + # 从图中移除节点(这会自动移除相关的边) + self.memory_graph.G.remove_node(concept) + continue + memory_items = data.get("memory_items", []) if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] + # 检查记忆项是否为空 + if not memory_items: + logger.warning(f"[同步] 发现空记忆节点,将移除节点: {concept}") + # 从图中移除节点(这会自动移除相关的边) + self.memory_graph.G.remove_node(concept) + continue + # 计算内存中节点的特征值 memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items) # 获取时间信息 - created_time = data.get("created_time", datetime.datetime.now().timestamp()) - last_modified = data.get("last_modified", datetime.datetime.now().timestamp()) + created_time = data.get("created_time", current_time) + last_modified = data.get("last_modified", current_time) # 将memory_items转换为JSON字符串 - memory_items_json = json.dumps(memory_items, ensure_ascii=False) + try: + # 确保memory_items中的每个项都是字符串 + memory_items = [str(item) for item in memory_items] + memory_items_json = json.dumps(memory_items, ensure_ascii=False) + if not memory_items_json: # 确保JSON字符串不为空 + raise ValueError("序列化后的JSON字符串为空") + # 验证JSON字符串是否有效 + json.loads(memory_items_json) + except Exception as e: + logger.error(f"[同步] 序列化记忆项失败,将移除节点: {concept}, 错误: {e}") + # 从图中移除节点(这会自动移除相关的边) + self.memory_graph.G.remove_node(concept) + continue if concept not in db_nodes: - # 数据库中缺少的节点,添加 - GraphNodes.create( - concept=concept, - memory_items=memory_items_json, - hash=memory_hash, - created_time=created_time, - last_modified=last_modified, - ) + # 数据库中缺少的节点,添加到创建列表 + nodes_to_create.append({ + 'concept': concept, + 'memory_items': memory_items_json, + 'hash': memory_hash, + 'created_time': created_time, + 'last_modified': last_modified + }) + logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}") else: # 获取数据库中节点的特征值 db_node = db_nodes[concept] db_hash = db_node.hash - # 如果特征值不同,则更新节点 + # 如果特征值不同,则添加到更新列表 if db_hash != memory_hash: - db_node.memory_items = memory_items_json - db_node.hash = memory_hash - db_node.last_modified = last_modified - db_node.save() + nodes_to_update.append({ + 'concept': concept, + 'memory_items': memory_items_json, + 'hash': memory_hash, + 'last_modified': last_modified + }) + + # 检查需要删除的节点 + memory_concepts = {concept for concept, _ in memory_nodes} + db_concepts = set(db_nodes.keys()) + nodes_to_delete = db_concepts - memory_concepts + + node_process_end = time.time() + logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}秒") + logger.info(f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点") + + # 异步批量创建新节点 + node_create_start = time.time() + if nodes_to_create: + try: + # 验证所有要创建的节点数据 + valid_nodes_to_create = [] + for node_data in nodes_to_create: + if not node_data.get('memory_items'): + logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空") + continue + try: + # 验证 JSON 字符串 + json.loads(node_data['memory_items']) + valid_nodes_to_create.append(node_data) + except json.JSONDecodeError: + logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + continue + + if valid_nodes_to_create: + # 使用异步批量插入 + batch_size = 100 + for i in range(0, len(valid_nodes_to_create), batch_size): + batch = valid_nodes_to_create[i:i + batch_size] + await self._async_batch_create_nodes(batch) + logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点") + else: + logger.warning("[同步] 没有有效的节点可以创建") + except Exception as e: + logger.error(f"[同步] 创建节点失败: {e}") + # 尝试逐个创建以找出问题节点 + for node_data in nodes_to_create: + try: + if not node_data.get('memory_items'): + logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空") + continue + try: + json.loads(node_data['memory_items']) + except json.JSONDecodeError: + logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + continue + await self._async_create_node(node_data) + except Exception as e: + logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}") + # 从图中移除问题节点 + self.memory_graph.G.remove_node(node_data['concept']) + node_create_end = time.time() + logger.info(f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)") + + # 异步批量更新节点 + node_update_start = time.time() + if nodes_to_update: + # 按批次更新节点,每批100个 + batch_size = 100 + for i in range(0, len(nodes_to_update), batch_size): + batch = nodes_to_update[i:i + batch_size] + try: + # 验证批次中的每个节点数据 + valid_batch = [] + for node_data in batch: + # 确保 memory_items 不为空且是有效的 JSON 字符串 + if not node_data.get('memory_items'): + logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空") + continue + try: + # 验证 JSON 字符串是否有效 + json.loads(node_data['memory_items']) + valid_batch.append(node_data) + except json.JSONDecodeError: + logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + continue + + if not valid_batch: + logger.warning(f"[同步] 批次 {i//batch_size + 1} 没有有效的节点可以更新") + continue + + # 异步批量更新节点 + await self._async_batch_update_nodes(valid_batch) + logger.debug(f"[同步] 成功更新批次 {i//batch_size + 1} 中的 {len(valid_batch)} 个节点") + except Exception as e: + logger.error(f"[同步] 批量更新节点失败: {e}") + # 如果批量更新失败,尝试逐个更新 + for node_data in valid_batch: + try: + await self._async_update_node(node_data) + except Exception as e: + logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}") + # 从图中移除问题节点 + self.memory_graph.G.remove_node(node_data['concept']) + + node_update_end = time.time() + logger.info(f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)") + + # 异步删除不存在的节点 + node_delete_start = time.time() + if nodes_to_delete: + await self._async_delete_nodes(nodes_to_delete) + node_delete_end = time.time() + logger.info(f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)") # 处理边的信息 + edge_load_start = time.time() db_edges = list(GraphEdges.select()) memory_edges = list(self.memory_graph.G.edges(data=True)) + edge_load_end = time.time() + logger.info(f"[同步] 加载边数据耗时: {edge_load_end - edge_load_start:.2f}秒") # 创建边的哈希值字典 + edge_dict_start = time.time() db_edge_dict = {} for edge in db_edges: edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target) db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength} + edge_dict_end = time.time() + logger.info(f"[同步] 创建边字典耗时: {edge_dict_end - edge_dict_start:.2f}秒") + + # 批量准备边数据 + edges_to_create = [] + edges_to_update = [] # 检查并更新边 + edge_process_start = time.time() for source, target, data in memory_edges: edge_hash = self.hippocampus.calculate_edge_hash(source, target) edge_key = (source, target) strength = data.get("strength", 1) # 获取边的时间信息 - created_time = data.get("created_time", datetime.datetime.now().timestamp()) - last_modified = data.get("last_modified", datetime.datetime.now().timestamp()) + created_time = data.get("created_time", current_time) + last_modified = data.get("last_modified", current_time) if edge_key not in db_edge_dict: - # 添加新边 - GraphEdges.create( - source=source, - target=target, - strength=strength, - hash=edge_hash, - created_time=created_time, - last_modified=last_modified, - ) + # 添加新边到创建列表 + edges_to_create.append({ + 'source': source, + 'target': target, + 'strength': strength, + 'hash': edge_hash, + 'created_time': created_time, + 'last_modified': last_modified + }) else: # 检查边的特征值是否变化 if db_edge_dict[edge_key]["hash"] != edge_hash: - edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target) - edge.hash = edge_hash - edge.strength = strength - edge.last_modified = last_modified - edge.save() + edges_to_update.append({ + 'source': source, + 'target': target, + 'strength': strength, + 'hash': edge_hash, + 'last_modified': last_modified + }) + edge_process_end = time.time() + logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}秒") + + # 异步批量创建新边 + edge_create_start = time.time() + if edges_to_create: + batch_size = 100 + for i in range(0, len(edges_to_create), batch_size): + batch = edges_to_create[i:i + batch_size] + await self._async_batch_create_edges(batch) + edge_create_end = time.time() + logger.info(f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)") + + # 异步批量更新边 + edge_update_start = time.time() + if edges_to_update: + batch_size = 100 + for i in range(0, len(edges_to_update), batch_size): + batch = edges_to_update[i:i + batch_size] + await self._async_batch_update_edges(batch) + edge_update_end = time.time() + logger.info(f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)") + + # 检查需要删除的边 + memory_edge_keys = {(source, target) for source, target, _ in memory_edges} + db_edge_keys = {(edge.source, edge.target) for edge in db_edges} + edges_to_delete = db_edge_keys - memory_edge_keys + + # 异步删除不存在的边 + edge_delete_start = time.time() + if edges_to_delete: + await self._async_delete_edges(edges_to_delete) + edge_delete_end = time.time() + logger.info(f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)") + + end_time = time.time() + logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒") + logger.success(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") + + async def _async_batch_create_nodes(self, nodes_data): + """异步批量创建节点""" + try: + GraphNodes.insert_many(nodes_data).execute() + except Exception as e: + logger.error(f"[同步] 批量创建节点失败: {e}") + raise + + async def _async_create_node(self, node_data): + """异步创建单个节点""" + try: + GraphNodes.create(**node_data) + except Exception as e: + logger.error(f"[同步] 创建节点失败: {e}") + raise + + async def _async_batch_update_nodes(self, nodes_data): + """异步批量更新节点""" + try: + for node_data in nodes_data: + GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where( + GraphNodes.concept == node_data['concept'] + ).execute() + except Exception as e: + logger.error(f"[同步] 批量更新节点失败: {e}") + raise + + async def _async_update_node(self, node_data): + """异步更新单个节点""" + try: + GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where( + GraphNodes.concept == node_data['concept'] + ).execute() + except Exception as e: + logger.error(f"[同步] 更新节点失败: {e}") + raise + + async def _async_delete_nodes(self, concepts): + """异步删除节点""" + try: + GraphNodes.delete().where(GraphNodes.concept.in_(concepts)).execute() + except Exception as e: + logger.error(f"[同步] 删除节点失败: {e}") + raise + + async def _async_batch_create_edges(self, edges_data): + """异步批量创建边""" + try: + GraphEdges.insert_many(edges_data).execute() + except Exception as e: + logger.error(f"[同步] 批量创建边失败: {e}") + raise + + async def _async_batch_update_edges(self, edges_data): + """异步批量更新边""" + try: + for edge_data in edges_data: + GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ['source', 'target']}).where( + (GraphEdges.source == edge_data['source']) & + (GraphEdges.target == edge_data['target']) + ).execute() + except Exception as e: + logger.error(f"[同步] 批量更新边失败: {e}") + raise + + async def _async_delete_edges(self, edge_keys): + """异步删除边""" + try: + for source, target in edge_keys: + GraphEdges.delete().where( + (GraphEdges.source == source) & + (GraphEdges.target == target) + ).execute() + except Exception as e: + logger.error(f"[同步] 删除边失败: {e}") + raise def sync_memory_from_db(self): """从数据库同步数据到内存中的图结构""" @@ -1111,7 +1398,7 @@ class ParahippocampalGyrus: input_text = await build_readable_messages( messages, merge_messages=True, # 合并连续消息 - timestamp_mode="normal", # 使用 'YYYY-MM-DD HH:MM:SS' 格式 + timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式 replace_bot_name=False, # 保留原始用户名 ) @@ -1119,8 +1406,12 @@ class ParahippocampalGyrus: if not input_text: logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。") return set(), {} + + current_YMD_time = datetime.datetime.now().strftime("%Y-%m-%d") + current_YMD_time_str = f"当前日期: {current_YMD_time}" + input_text = f"{current_YMD_time_str}\n{input_text}" - logger.debug(f"用于压缩的格式化文本:\n{input_text}") + logger.debug(f"记忆来源:\n{input_text}") # 2. 使用LLM提取关键主题 topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) @@ -1191,7 +1482,7 @@ class ParahippocampalGyrus: return compressed_memory, similar_topics_dict async def operation_build_memory(self): - logger.debug("------------------------------------开始构建记忆--------------------------------------") + logger.info("------------------------------------开始构建记忆--------------------------------------") start_time = time.time() memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample() all_added_nodes = [] @@ -1199,19 +1490,16 @@ class ParahippocampalGyrus: all_added_edges = [] for i, messages in enumerate(memory_samples, 1): all_topics = [] - progress = (i / len(memory_samples)) * 100 - bar_length = 30 - filled_length = int(bar_length * i // len(memory_samples)) - bar = "█" * filled_length + "-" * (bar_length - filled_length) - logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") - compress_rate = global_config.memory.memory_compress_rate try: compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) except Exception as e: logger.error(f"压缩记忆时发生错误: {e}") continue - logger.debug(f"压缩后记忆数量: {compressed_memory},似曾相识的话题: {similar_topics_dict}") + for topic, memory in compressed_memory: + logger.info(f"取得记忆: {topic} - {memory}") + for topic, similar_topics in similar_topics_dict.items(): + logger.debug(f"相似话题: {topic} - {similar_topics}") current_time = datetime.datetime.now().timestamp() logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}") @@ -1245,10 +1533,20 @@ class ParahippocampalGyrus: logger.debug(f"连接同批次节点: {topic1} 和 {topic2}") all_added_edges.append(f"{topic1}-{topic2}") self.memory_graph.connect_dot(topic1, topic2) + + + progress = (i / len(memory_samples)) * 100 + bar_length = 30 + filled_length = int(bar_length * i // len(memory_samples)) + bar = "█" * filled_length + "-" * (bar_length - filled_length) + logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") - logger.success(f"更新记忆: {', '.join(all_added_nodes)}") - logger.debug(f"强化连接: {', '.join(all_added_edges)}") - logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") + if all_added_nodes: + logger.success(f"更新记忆: {', '.join(all_added_nodes)}") + if all_added_edges: + logger.debug(f"强化连接: {', '.join(all_added_edges)}") + if all_connected_nodes: + logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") await self.hippocampus.entorhinal_cortex.sync_memory_to_db() diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index e896420aa..59cac2139 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -342,7 +342,7 @@ async def _build_readable_messages_internal( # 使用指定的 timestamp_mode 格式化时间 readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode) - header = f"{readable_time}{merged['name']} 说:" + header = f"{readable_time}, {merged['name']} :" output_lines.append(header) # 将内容合并,并添加缩进 for line in merged["content"]: