diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 67df2b817..23a296c8d 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -4,6 +4,7 @@ import math import random import time import re +import json from itertools import combinations import jieba @@ -22,7 +23,7 @@ from ..utils.utils import translate_timestamp_to_human_readable from rich.traceback import install from ...config.config import global_config -from src.common.database.database_model import Messages # Peewee Messages 模型导入 +from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 install(extra_lines=3) @@ -877,12 +878,9 @@ class EntorhinalCortex: async def sync_memory_to_db(self): """将记忆图同步到数据库""" # 获取数据库中所有节点和内存中所有节点 - db_nodes = list(db.graph_data.nodes.find()) + db_nodes = {node.concept: node for node in GraphNodes.select()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) - # 转换数据库节点为字典格式,方便查找 - db_nodes_dict = {node["concept"]: node for node in db_nodes} - # 检查并更新节点 for concept, data in memory_nodes: memory_items = data.get("memory_items", []) @@ -896,44 +894,39 @@ class EntorhinalCortex: created_time = data.get("created_time", datetime.datetime.now().timestamp()) last_modified = data.get("last_modified", datetime.datetime.now().timestamp()) - if concept not in db_nodes_dict: + # 将memory_items转换为JSON字符串 + memory_items_json = json.dumps(memory_items, ensure_ascii=False) + + if concept not in db_nodes: # 数据库中缺少的节点,添加 - node_data = { - "concept": concept, - "memory_items": memory_items, - "hash": memory_hash, - "created_time": created_time, - "last_modified": last_modified, - } - db.graph_data.nodes.insert_one(node_data) + GraphNodes.create( + concept=concept, + memory_items=memory_items_json, + hash=memory_hash, + created_time=created_time, + last_modified=last_modified, + ) else: # 获取数据库中节点的特征值 - db_node = db_nodes_dict[concept] - db_hash = db_node.get("hash", None) + db_node = db_nodes[concept] + db_hash = db_node.hash # 如果特征值不同,则更新节点 if db_hash != memory_hash: - db.graph_data.nodes.update_one( - {"concept": concept}, - { - "$set": { - "memory_items": memory_items, - "hash": memory_hash, - "created_time": created_time, - "last_modified": last_modified, - } - }, - ).execute() + db_node.memory_items = memory_items_json + db_node.hash = memory_hash + db_node.last_modified = last_modified + db_node.save() # 处理边的信息 - db_edges = list(db.graph_data.edges.find()) + db_edges = list(GraphEdges.select()) memory_edges = list(self.memory_graph.G.edges(data=True)) # 创建边的哈希值字典 db_edge_dict = {} for edge in db_edges: - edge_hash = self.hippocampus.calculate_edge_hash(edge["source"], edge["target"]) - db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)} + edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target) + db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength} # 检查并更新边 for source, target, data in memory_edges: @@ -947,29 +940,22 @@ class EntorhinalCortex: if edge_key not in db_edge_dict: # 添加新边 - edge_data = { - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "created_time": created_time, - "last_modified": last_modified, - } - db.graph_data.edges.insert_one(edge_data) + GraphEdges.create( + source=source, + target=target, + strength=strength, + hash=edge_hash, + created_time=created_time, + last_modified=last_modified, + ) else: # 检查边的特征值是否变化 if db_edge_dict[edge_key]["hash"] != edge_hash: - db.graph_data.edges.update_one( - {"source": source, "target": target}, - { - "$set": { - "hash": edge_hash, - "strength": strength, - "created_time": created_time, - "last_modified": last_modified, - } - }, - ).execute() + edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target) + edge.hash = edge_hash + edge.strength = strength + edge.last_modified = last_modified + edge.save() def sync_memory_from_db(self): """从数据库同步数据到内存中的图结构""" @@ -980,29 +966,31 @@ class EntorhinalCortex: self.memory_graph.G.clear() # 从数据库加载所有节点 - nodes = list(db.graph_data.nodes.find()) + nodes = list(GraphNodes.select()) for node in nodes: - concept = node["concept"] - memory_items = node.get("memory_items", []) + concept = node.concept + memory_items = json.loads(node.memory_items) if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] # 检查时间字段是否存在 - if "created_time" not in node or "last_modified" not in node: + if not node.created_time or not node.last_modified: need_update = True # 更新数据库中的节点 update_data = {} - if "created_time" not in node: + if not node.created_time: update_data["created_time"] = current_time - if "last_modified" not in node: + if not node.last_modified: update_data["last_modified"] = current_time - db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data}).execute() + GraphNodes.update( + **update_data + ).where(GraphNodes.concept == concept).execute() logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") # 获取时间信息(如果不存在则使用当前时间) - created_time = node.get("created_time", current_time) - last_modified = node.get("last_modified", current_time) + created_time = node.created_time or current_time + last_modified = node.last_modified or current_time # 添加节点到图中 self.memory_graph.G.add_node( @@ -1010,28 +998,32 @@ class EntorhinalCortex: ) # 从数据库加载所有边 - edges = list(db.graph_data.edges.find()) + edges = list(GraphEdges.select()) for edge in edges: - source = edge["source"] - target = edge["target"] - strength = edge.get("strength", 1) + source = edge.source + target = edge.target + strength = edge.strength # 检查时间字段是否存在 - if "created_time" not in edge or "last_modified" not in edge: + if not edge.created_time or not edge.last_modified: need_update = True # 更新数据库中的边 update_data = {} - if "created_time" not in edge: + if not edge.created_time: update_data["created_time"] = current_time - if "last_modified" not in edge: + if not edge.last_modified: update_data["last_modified"] = current_time - db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data}).execute() + GraphEdges.update( + **update_data + ).where( + (GraphEdges.source == source) & (GraphEdges.target == target) + ).execute() logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") # 获取时间信息(如果不存在则使用当前时间) - created_time = edge.get("created_time", current_time) - last_modified = edge.get("last_modified", current_time) + created_time = edge.created_time or current_time + last_modified = edge.last_modified or current_time # 只有当源节点和目标节点都存在时才添加边 if source in self.memory_graph.G and target in self.memory_graph.G: @@ -1049,8 +1041,8 @@ class EntorhinalCortex: # 清空数据库 clear_start = time.time() - db.graph_data.nodes.delete_many({}) - db.graph_data.edges.delete_many({}) + GraphNodes.delete().execute() + GraphEdges.delete().execute() clear_end = time.time() logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") @@ -1065,29 +1057,27 @@ class EntorhinalCortex: if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] - node_data = { - "concept": concept, - "memory_items": memory_items, - "hash": self.hippocampus.calculate_node_hash(concept, memory_items), - "created_time": data.get("created_time", datetime.datetime.now().timestamp()), - "last_modified": data.get("last_modified", datetime.datetime.now().timestamp()), - } - db.graph_data.nodes.insert_one(node_data) + GraphNodes.create( + concept=concept, + memory_items=json.dumps(memory_items), + hash=self.hippocampus.calculate_node_hash(concept, memory_items), + created_time=data.get("created_time", datetime.datetime.now().timestamp()), + last_modified=data.get("last_modified", datetime.datetime.now().timestamp()), + ) node_end = time.time() logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒") # 重新写入边 edge_start = time.time() for source, target, data in memory_edges: - edge_data = { - "source": source, - "target": target, - "strength": data.get("strength", 1), - "hash": self.hippocampus.calculate_edge_hash(source, target), - "created_time": data.get("created_time", datetime.datetime.now().timestamp()), - "last_modified": data.get("last_modified", datetime.datetime.now().timestamp()), - } - db.graph_data.edges.insert_one(edge_data) + GraphEdges.create( + source=source, + target=target, + strength=data.get("strength", 1), + hash=self.hippocampus.calculate_edge_hash(source, target), + created_time=data.get("created_time", datetime.datetime.now().timestamp()), + last_modified=data.get("last_modified", datetime.datetime.now().timestamp()), + ) edge_end = time.time() logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒") diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index bd7a2d319..bf192ca6a 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -275,6 +275,35 @@ class RecalledMessages(BaseModel): table_name = "recalled_messages" +class GraphNodes(BaseModel): + """ + 用于存储记忆图节点的模型 + """ + concept = TextField(unique=True, index=True) # 节点概念 + memory_items = TextField() # JSON格式存储的记忆列表 + hash = TextField() # 节点哈希值 + created_time = FloatField() # 创建时间戳 + last_modified = FloatField() # 最后修改时间戳 + + class Meta: + table_name = "graph_nodes" + + +class GraphEdges(BaseModel): + """ + 用于存储记忆图边的模型 + """ + source = TextField(index=True) # 源节点 + target = TextField(index=True) # 目标节点 + strength = IntegerField() # 连接强度 + hash = TextField() # 边哈希值 + created_time = FloatField() # 创建时间戳 + last_modified = FloatField() # 最后修改时间戳 + + class Meta: + table_name = "graph_edges" + + def create_tables(): """ 创建所有在模型中定义的数据库表。 @@ -293,6 +322,8 @@ def create_tables(): Knowledges, ThinkingLog, RecalledMessages, # 添加新模型 + GraphNodes, # 添加图节点表 + GraphEdges, # 添加图边表 ] ) @@ -315,7 +346,9 @@ def initialize_database(): PersonInfo, Knowledges, ThinkingLog, - RecalledMessages, # 添加新模型 + RecalledMessages, + GraphNodes, # 添加图节点表 + GraphEdges, # 添加图边表 ] needs_creation = False