diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 3f47cd116..463d477ab 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -888,26 +888,20 @@ class EntorhinalCortex: async def sync_memory_to_db(self): """将记忆图同步到数据库""" start_time = time.time() + current_time = datetime.datetime.now().timestamp() # 获取数据库中所有节点和内存中所有节点 - db_load_start = time.time() db_nodes = {node.concept: node for node in GraphNodes.select()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) - db_load_end = time.time() - logger.info(f"[同步] 加载数据库耗时: {db_load_end - db_load_start:.2f}秒") - + # 批量准备节点数据 nodes_to_create = [] nodes_to_update = [] - current_time = datetime.datetime.now().timestamp() + nodes_to_delete = set() - # 检查并更新节点 - node_process_start = time.time() + # 处理节点 for concept, data in memory_nodes: - # 检查概念是否有效 if not concept or not isinstance(concept, str): - logger.warning(f"[同步] 发现无效概念,将移除节点: {concept}") - # 从图中移除节点(这会自动移除相关的边) self.memory_graph.G.remove_node(concept) continue @@ -915,357 +909,218 @@ class EntorhinalCortex: if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] - # 检查记忆项是否为空 if not memory_items: - logger.warning(f"[同步] 发现空记忆节点,将移除节点: {concept}") - # 从图中移除节点(这会自动移除相关的边) self.memory_graph.G.remove_node(concept) continue # 计算内存中节点的特征值 memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items) - - # 获取时间信息 created_time = data.get("created_time", current_time) last_modified = data.get("last_modified", current_time) # 将memory_items转换为JSON字符串 try: - # 确保memory_items中的每个项都是字符串 memory_items = [str(item) for item in memory_items] memory_items_json = json.dumps(memory_items, ensure_ascii=False) - if not memory_items_json: # 确保JSON字符串不为空 - raise ValueError("序列化后的JSON字符串为空") - # 验证JSON字符串是否有效 - json.loads(memory_items_json) - except Exception as e: - logger.error(f"[同步] 序列化记忆项失败,将移除节点: {concept}, 错误: {e}") - # 从图中移除节点(这会自动移除相关的边) + if not memory_items_json: + continue + except Exception: self.memory_graph.G.remove_node(concept) continue if concept not in db_nodes: - # 数据库中缺少的节点,添加到创建列表 - nodes_to_create.append( - { + nodes_to_create.append({ + "concept": concept, + "memory_items": memory_items_json, + "hash": memory_hash, + "created_time": created_time, + "last_modified": last_modified, + }) + else: + db_node = db_nodes[concept] + if db_node.hash != memory_hash: + nodes_to_update.append({ "concept": concept, "memory_items": memory_items_json, "hash": memory_hash, - "created_time": created_time, "last_modified": last_modified, - } - ) - logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}") - else: - # 获取数据库中节点的特征值 - db_node = db_nodes[concept] - db_hash = db_node.hash + }) - # 如果特征值不同,则添加到更新列表 - if db_hash != memory_hash: - nodes_to_update.append( - { - "concept": concept, - "memory_items": memory_items_json, - "hash": memory_hash, - "last_modified": last_modified, - } - ) - - # 检查需要删除的节点 + # 计算需要删除的节点 memory_concepts = {concept for concept, _ in memory_nodes} - db_concepts = set(db_nodes.keys()) - nodes_to_delete = db_concepts - memory_concepts + nodes_to_delete = set(db_nodes.keys()) - memory_concepts - node_process_end = time.time() - logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}秒") - logger.info( - f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点" - ) - - # 异步批量创建新节点 - node_create_start = time.time() + # 批量处理节点 if nodes_to_create: - try: - # 验证所有要创建的节点数据 - valid_nodes_to_create = [] - for node_data in nodes_to_create: - if not node_data.get("memory_items"): - logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空") - continue - try: - # 验证 JSON 字符串 - json.loads(node_data["memory_items"]) - valid_nodes_to_create.append(node_data) - except json.JSONDecodeError: - logger.warning( - f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串" - ) - continue + batch_size = 100 + for i in range(0, len(nodes_to_create), batch_size): + batch = nodes_to_create[i:i + batch_size] + GraphNodes.insert_many(batch).execute() - if valid_nodes_to_create: - # 使用异步批量插入 - batch_size = 100 - for i in range(0, len(valid_nodes_to_create), batch_size): - batch = valid_nodes_to_create[i : i + batch_size] - await self._async_batch_create_nodes(batch) - logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点") - else: - logger.warning("[同步] 没有有效的节点可以创建") - except Exception as e: - logger.error(f"[同步] 创建节点失败: {e}") - # 尝试逐个创建以找出问题节点 - for node_data in nodes_to_create: - try: - if not node_data.get("memory_items"): - logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空") - continue - try: - json.loads(node_data["memory_items"]) - except json.JSONDecodeError: - logger.warning( - f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串" - ) - continue - await self._async_create_node(node_data) - except Exception as e: - logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}") - # 从图中移除问题节点 - self.memory_graph.G.remove_node(node_data["concept"]) - node_create_end = time.time() - logger.info( - f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)" - ) - - # 异步批量更新节点 - node_update_start = time.time() if nodes_to_update: - # 按批次更新节点,每批100个 batch_size = 100 for i in range(0, len(nodes_to_update), batch_size): - batch = nodes_to_update[i : i + batch_size] - try: - # 验证批次中的每个节点数据 - valid_batch = [] - for node_data in batch: - # 确保 memory_items 不为空且是有效的 JSON 字符串 - if not node_data.get("memory_items"): - logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空") - continue - try: - # 验证 JSON 字符串是否有效 - json.loads(node_data["memory_items"]) - valid_batch.append(node_data) - except json.JSONDecodeError: - logger.warning( - f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串" - ) - continue + batch = nodes_to_update[i:i + batch_size] + for node_data in batch: + GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where( + GraphNodes.concept == node_data["concept"] + ).execute() - if not valid_batch: - logger.warning(f"[同步] 批次 {i // batch_size + 1} 没有有效的节点可以更新") - continue - - # 异步批量更新节点 - await self._async_batch_update_nodes(valid_batch) - logger.debug(f"[同步] 成功更新批次 {i // batch_size + 1} 中的 {len(valid_batch)} 个节点") - except Exception as e: - logger.error(f"[同步] 批量更新节点失败: {e}") - # 如果批量更新失败,尝试逐个更新 - for node_data in valid_batch: - try: - await self._async_update_node(node_data) - except Exception as e: - logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}") - # 从图中移除问题节点 - self.memory_graph.G.remove_node(node_data["concept"]) - - node_update_end = time.time() - logger.info( - f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)" - ) - - # 异步删除不存在的节点 - node_delete_start = time.time() if nodes_to_delete: - await self._async_delete_nodes(nodes_to_delete) - node_delete_end = time.time() - logger.info( - f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)" - ) + GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # 处理边的信息 - edge_load_start = time.time() db_edges = list(GraphEdges.select()) memory_edges = list(self.memory_graph.G.edges(data=True)) - edge_load_end = time.time() - logger.info(f"[同步] 加载边数据耗时: {edge_load_end - edge_load_start:.2f}秒") # 创建边的哈希值字典 - edge_dict_start = time.time() db_edge_dict = {} for edge in db_edges: edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target) db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength} - edge_dict_end = time.time() - logger.info(f"[同步] 创建边字典耗时: {edge_dict_end - edge_dict_start:.2f}秒") # 批量准备边数据 edges_to_create = [] edges_to_update = [] - # 检查并更新边 - edge_process_start = time.time() + # 处理边 for source, target, data in memory_edges: edge_hash = self.hippocampus.calculate_edge_hash(source, target) edge_key = (source, target) strength = data.get("strength", 1) - - # 获取边的时间信息 created_time = data.get("created_time", current_time) last_modified = data.get("last_modified", current_time) if edge_key not in db_edge_dict: - # 添加新边到创建列表 - edges_to_create.append( - { - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "created_time": created_time, - "last_modified": last_modified, - } - ) - else: - # 检查边的特征值是否变化 - if db_edge_dict[edge_key]["hash"] != edge_hash: - edges_to_update.append( - { - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "last_modified": last_modified, - } - ) - edge_process_end = time.time() - logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}秒") + edges_to_create.append({ + "source": source, + "target": target, + "strength": strength, + "hash": edge_hash, + "created_time": created_time, + "last_modified": last_modified, + }) + elif db_edge_dict[edge_key]["hash"] != edge_hash: + edges_to_update.append({ + "source": source, + "target": target, + "strength": strength, + "hash": edge_hash, + "last_modified": last_modified, + }) - # 异步批量创建新边 - edge_create_start = time.time() + # 计算需要删除的边 + memory_edge_keys = {(source, target) for source, target, _ in memory_edges} + edges_to_delete = set(db_edge_dict.keys()) - memory_edge_keys + + # 批量处理边 if edges_to_create: batch_size = 100 for i in range(0, len(edges_to_create), batch_size): - batch = edges_to_create[i : i + batch_size] - await self._async_batch_create_edges(batch) - edge_create_end = time.time() - logger.info( - f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)" - ) + batch = edges_to_create[i:i + batch_size] + GraphEdges.insert_many(batch).execute() - # 异步批量更新边 - edge_update_start = time.time() if edges_to_update: batch_size = 100 for i in range(0, len(edges_to_update), batch_size): - batch = edges_to_update[i : i + batch_size] - await self._async_batch_update_edges(batch) - edge_update_end = time.time() - logger.info( - f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)" - ) + batch = edges_to_update[i:i + batch_size] + for edge_data in batch: + GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where( + (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) + ).execute() - # 检查需要删除的边 - memory_edge_keys = {(source, target) for source, target, _ in memory_edges} - db_edge_keys = {(edge.source, edge.target) for edge in db_edges} - edges_to_delete = db_edge_keys - memory_edge_keys - - # 异步删除不存在的边 - edge_delete_start = time.time() if edges_to_delete: - await self._async_delete_edges(edges_to_delete) - edge_delete_end = time.time() - logger.info( - f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)" - ) + for source, target in edges_to_delete: + GraphEdges.delete().where( + (GraphEdges.source == source) & (GraphEdges.target == target) + ).execute() end_time = time.time() logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒") logger.success(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") - async def _async_batch_create_nodes(self, nodes_data): - """异步批量创建节点""" - try: - GraphNodes.insert_many(nodes_data).execute() - except Exception as e: - logger.error(f"[同步] 批量创建节点失败: {e}") - raise + async def resync_memory_to_db(self): + """清空数据库并重新同步所有记忆数据""" + start_time = time.time() + logger.info("[数据库] 开始重新同步所有记忆数据...") - async def _async_create_node(self, node_data): - """异步创建单个节点""" - try: - GraphNodes.create(**node_data) - except Exception as e: - logger.error(f"[同步] 创建节点失败: {e}") - raise + # 清空数据库 + clear_start = time.time() + GraphNodes.delete().execute() + GraphEdges.delete().execute() + clear_end = time.time() + logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") - async def _async_batch_update_nodes(self, nodes_data): - """异步批量更新节点""" - try: - for node_data in nodes_data: - GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where( - GraphNodes.concept == node_data["concept"] - ).execute() - except Exception as e: - logger.error(f"[同步] 批量更新节点失败: {e}") - raise + # 获取所有节点和边 + memory_nodes = list(self.memory_graph.G.nodes(data=True)) + memory_edges = list(self.memory_graph.G.edges(data=True)) + current_time = datetime.datetime.now().timestamp() - async def _async_update_node(self, node_data): - """异步更新单个节点""" - try: - GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where( - GraphNodes.concept == node_data["concept"] - ).execute() - except Exception as e: - logger.error(f"[同步] 更新节点失败: {e}") - raise + # 批量准备节点数据 + nodes_data = [] + for concept, data in memory_nodes: + memory_items = data.get("memory_items", []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] - async def _async_delete_nodes(self, concepts): - """异步删除节点""" - try: - GraphNodes.delete().where(GraphNodes.concept.in_(concepts)).execute() - except Exception as e: - logger.error(f"[同步] 删除节点失败: {e}") - raise + try: + memory_items = [str(item) for item in memory_items] + memory_items_json = json.dumps(memory_items, ensure_ascii=False) + if not memory_items_json: + continue - async def _async_batch_create_edges(self, edges_data): - """异步批量创建边""" - try: - GraphEdges.insert_many(edges_data).execute() - except Exception as e: - logger.error(f"[同步] 批量创建边失败: {e}") - raise + nodes_data.append({ + "concept": concept, + "memory_items": memory_items_json, + "hash": self.hippocampus.calculate_node_hash(concept, memory_items), + "created_time": data.get("created_time", current_time), + "last_modified": data.get("last_modified", current_time), + }) + except Exception as e: + logger.error(f"准备节点 {concept} 数据时发生错误: {e}") + continue - async def _async_batch_update_edges(self, edges_data): - """异步批量更新边""" - try: - for edge_data in edges_data: - GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where( - (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) - ).execute() - except Exception as e: - logger.error(f"[同步] 批量更新边失败: {e}") - raise + # 批量准备边数据 + edges_data = [] + for source, target, data in memory_edges: + try: + edges_data.append({ + "source": source, + "target": target, + "strength": data.get("strength", 1), + "hash": self.hippocampus.calculate_edge_hash(source, target), + "created_time": data.get("created_time", current_time), + "last_modified": data.get("last_modified", current_time), + }) + except Exception as e: + logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}") + continue - async def _async_delete_edges(self, edge_keys): - """异步删除边""" - try: - for source, target in edge_keys: - GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute() - except Exception as e: - logger.error(f"[同步] 删除边失败: {e}") - raise + # 使用事务批量写入节点 + node_start = time.time() + if nodes_data: + batch_size = 500 # 增加批量大小 + with GraphNodes._meta.database.atomic(): + for i in range(0, len(nodes_data), batch_size): + batch = nodes_data[i:i + batch_size] + GraphNodes.insert_many(batch).execute() + node_end = time.time() + logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒") + + # 使用事务批量写入边 + edge_start = time.time() + if edges_data: + batch_size = 500 # 增加批量大小 + with GraphEdges._meta.database.atomic(): + for i in range(0, len(edges_data), batch_size): + batch = edges_data[i:i + batch_size] + GraphEdges.insert_many(batch).execute() + edge_end = time.time() + logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒") + + end_time = time.time() + logger.success(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒") + logger.success(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边") def sync_memory_from_db(self): """从数据库同步数据到内存中的图结构""" @@ -1279,31 +1134,34 @@ class EntorhinalCortex: nodes = list(GraphNodes.select()) for node in nodes: concept = node.concept - memory_items = json.loads(node.memory_items) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] + try: + memory_items = json.loads(node.memory_items) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] - # 检查时间字段是否存在 - if not node.created_time or not node.last_modified: - need_update = True - # 更新数据库中的节点 - update_data = {} - if not node.created_time: - update_data["created_time"] = current_time - if not node.last_modified: - update_data["last_modified"] = current_time + # 检查时间字段是否存在 + if not node.created_time or not node.last_modified: + need_update = True + # 更新数据库中的节点 + update_data = {} + if not node.created_time: + update_data["created_time"] = current_time + if not node.last_modified: + update_data["last_modified"] = current_time - GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute() - logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") + GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute() - # 获取时间信息(如果不存在则使用当前时间) - created_time = node.created_time or current_time - last_modified = node.last_modified or current_time + # 获取时间信息(如果不存在则使用当前时间) + created_time = node.created_time or current_time + last_modified = node.last_modified or current_time - # 添加节点到图中 - self.memory_graph.G.add_node( - concept, memory_items=memory_items, created_time=created_time, last_modified=last_modified - ) + # 添加节点到图中 + self.memory_graph.G.add_node( + concept, memory_items=memory_items, created_time=created_time, last_modified=last_modified + ) + except Exception as e: + logger.error(f"加载节点 {concept} 时发生错误: {e}") + continue # 从数据库加载所有边 edges = list(GraphEdges.select()) @@ -1325,7 +1183,6 @@ class EntorhinalCortex: GraphEdges.update(**update_data).where( (GraphEdges.source == source) & (GraphEdges.target == target) ).execute() - logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") # 获取时间信息(如果不存在则使用当前时间) created_time = edge.created_time or current_time @@ -1340,57 +1197,6 @@ class EntorhinalCortex: if need_update: logger.success("[数据库] 已为缺失的时间字段进行补充") - async def resync_memory_to_db(self): - """清空数据库并重新同步所有记忆数据""" - start_time = time.time() - logger.info("[数据库] 开始重新同步所有记忆数据...") - - # 清空数据库 - clear_start = time.time() - GraphNodes.delete().execute() - GraphEdges.delete().execute() - clear_end = time.time() - logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") - - # 获取所有节点和边 - memory_nodes = list(self.memory_graph.G.nodes(data=True)) - memory_edges = list(self.memory_graph.G.edges(data=True)) - - # 重新写入节点 - node_start = time.time() - for concept, data in memory_nodes: - memory_items = data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - GraphNodes.create( - concept=concept, - memory_items=json.dumps(memory_items), - hash=self.hippocampus.calculate_node_hash(concept, memory_items), - created_time=data.get("created_time", datetime.datetime.now().timestamp()), - last_modified=data.get("last_modified", datetime.datetime.now().timestamp()), - ) - node_end = time.time() - logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒") - - # 重新写入边 - edge_start = time.time() - for source, target, data in memory_edges: - GraphEdges.create( - source=source, - target=target, - strength=data.get("strength", 1), - hash=self.hippocampus.calculate_edge_hash(source, target), - created_time=data.get("created_time", datetime.datetime.now().timestamp()), - last_modified=data.get("last_modified", datetime.datetime.now().timestamp()), - ) - edge_end = time.time() - logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒") - - end_time = time.time() - logger.success(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒") - logger.success(f"[数据库] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") - # 负责整合,遗忘,合并记忆 class ParahippocampalGyrus: