feat:添加海马体支持
This commit is contained in:
@@ -4,6 +4,7 @@ import math
|
||||
import random
|
||||
import time
|
||||
import re
|
||||
import json
|
||||
from itertools import combinations
|
||||
|
||||
import jieba
|
||||
@@ -22,7 +23,7 @@ from ..utils.utils import translate_timestamp_to_human_readable
|
||||
from rich.traceback import install
|
||||
|
||||
from ...config.config import global_config
|
||||
from src.common.database.database_model import Messages # Peewee Messages 模型导入
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -877,12 +878,9 @@ class EntorhinalCortex:
|
||||
async def sync_memory_to_db(self):
|
||||
"""将记忆图同步到数据库"""
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(db.graph_data.nodes.find())
|
||||
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
db_nodes_dict = {node["concept"]: node for node in db_nodes}
|
||||
|
||||
# 检查并更新节点
|
||||
for concept, data in memory_nodes:
|
||||
memory_items = data.get("memory_items", [])
|
||||
@@ -896,44 +894,39 @@ class EntorhinalCortex:
|
||||
created_time = data.get("created_time", datetime.datetime.now().timestamp())
|
||||
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
|
||||
|
||||
if concept not in db_nodes_dict:
|
||||
# 将memory_items转换为JSON字符串
|
||||
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
|
||||
|
||||
if concept not in db_nodes:
|
||||
# 数据库中缺少的节点,添加
|
||||
node_data = {
|
||||
"concept": concept,
|
||||
"memory_items": memory_items,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
GraphNodes.create(
|
||||
concept=concept,
|
||||
memory_items=memory_items_json,
|
||||
hash=memory_hash,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified,
|
||||
)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
db_hash = db_node.get("hash", None)
|
||||
db_node = db_nodes[concept]
|
||||
db_hash = db_node.hash
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
db.graph_data.nodes.update_one(
|
||||
{"concept": concept},
|
||||
{
|
||||
"$set": {
|
||||
"memory_items": memory_items,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
},
|
||||
).execute()
|
||||
db_node.memory_items = memory_items_json
|
||||
db_node.hash = memory_hash
|
||||
db_node.last_modified = last_modified
|
||||
db_node.save()
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(db.graph_data.edges.find())
|
||||
db_edges = list(GraphEdges.select())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
db_edge_dict = {}
|
||||
for edge in db_edges:
|
||||
edge_hash = self.hippocampus.calculate_edge_hash(edge["source"], edge["target"])
|
||||
db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)}
|
||||
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
|
||||
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
|
||||
|
||||
# 检查并更新边
|
||||
for source, target, data in memory_edges:
|
||||
@@ -947,29 +940,22 @@ class EntorhinalCortex:
|
||||
|
||||
if edge_key not in db_edge_dict:
|
||||
# 添加新边
|
||||
edge_data = {
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
GraphEdges.create(
|
||||
source=source,
|
||||
target=target,
|
||||
strength=strength,
|
||||
hash=edge_hash,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified,
|
||||
)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||
db.graph_data.edges.update_one(
|
||||
{"source": source, "target": target},
|
||||
{
|
||||
"$set": {
|
||||
"hash": edge_hash,
|
||||
"strength": strength,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
},
|
||||
).execute()
|
||||
edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target)
|
||||
edge.hash = edge_hash
|
||||
edge.strength = strength
|
||||
edge.last_modified = last_modified
|
||||
edge.save()
|
||||
|
||||
def sync_memory_from_db(self):
|
||||
"""从数据库同步数据到内存中的图结构"""
|
||||
@@ -980,29 +966,31 @@ class EntorhinalCortex:
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = list(db.graph_data.nodes.find())
|
||||
nodes = list(GraphNodes.select())
|
||||
for node in nodes:
|
||||
concept = node["concept"]
|
||||
memory_items = node.get("memory_items", [])
|
||||
concept = node.concept
|
||||
memory_items = json.loads(node.memory_items)
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
# 检查时间字段是否存在
|
||||
if "created_time" not in node or "last_modified" not in node:
|
||||
if not node.created_time or not node.last_modified:
|
||||
need_update = True
|
||||
# 更新数据库中的节点
|
||||
update_data = {}
|
||||
if "created_time" not in node:
|
||||
if not node.created_time:
|
||||
update_data["created_time"] = current_time
|
||||
if "last_modified" not in node:
|
||||
if not node.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data}).execute()
|
||||
GraphNodes.update(
|
||||
**update_data
|
||||
).where(GraphNodes.concept == concept).execute()
|
||||
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.get("created_time", current_time)
|
||||
last_modified = node.get("last_modified", current_time)
|
||||
created_time = node.created_time or current_time
|
||||
last_modified = node.last_modified or current_time
|
||||
|
||||
# 添加节点到图中
|
||||
self.memory_graph.G.add_node(
|
||||
@@ -1010,28 +998,32 @@ class EntorhinalCortex:
|
||||
)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(db.graph_data.edges.find())
|
||||
edges = list(GraphEdges.select())
|
||||
for edge in edges:
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
strength = edge.get("strength", 1)
|
||||
source = edge.source
|
||||
target = edge.target
|
||||
strength = edge.strength
|
||||
|
||||
# 检查时间字段是否存在
|
||||
if "created_time" not in edge or "last_modified" not in edge:
|
||||
if not edge.created_time or not edge.last_modified:
|
||||
need_update = True
|
||||
# 更新数据库中的边
|
||||
update_data = {}
|
||||
if "created_time" not in edge:
|
||||
if not edge.created_time:
|
||||
update_data["created_time"] = current_time
|
||||
if "last_modified" not in edge:
|
||||
if not edge.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data}).execute()
|
||||
GraphEdges.update(
|
||||
**update_data
|
||||
).where(
|
||||
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||
).execute()
|
||||
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = edge.get("created_time", current_time)
|
||||
last_modified = edge.get("last_modified", current_time)
|
||||
created_time = edge.created_time or current_time
|
||||
last_modified = edge.last_modified or current_time
|
||||
|
||||
# 只有当源节点和目标节点都存在时才添加边
|
||||
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||
@@ -1049,8 +1041,8 @@ class EntorhinalCortex:
|
||||
|
||||
# 清空数据库
|
||||
clear_start = time.time()
|
||||
db.graph_data.nodes.delete_many({})
|
||||
db.graph_data.edges.delete_many({})
|
||||
GraphNodes.delete().execute()
|
||||
GraphEdges.delete().execute()
|
||||
clear_end = time.time()
|
||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||
|
||||
@@ -1065,29 +1057,27 @@ class EntorhinalCortex:
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
node_data = {
|
||||
"concept": concept,
|
||||
"memory_items": memory_items,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
|
||||
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
|
||||
}
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
GraphNodes.create(
|
||||
concept=concept,
|
||||
memory_items=json.dumps(memory_items),
|
||||
hash=self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
created_time=data.get("created_time", datetime.datetime.now().timestamp()),
|
||||
last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
|
||||
)
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
|
||||
# 重新写入边
|
||||
edge_start = time.time()
|
||||
for source, target, data in memory_edges:
|
||||
edge_data = {
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": data.get("strength", 1),
|
||||
"hash": self.hippocampus.calculate_edge_hash(source, target),
|
||||
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
|
||||
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
|
||||
}
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
GraphEdges.create(
|
||||
source=source,
|
||||
target=target,
|
||||
strength=data.get("strength", 1),
|
||||
hash=self.hippocampus.calculate_edge_hash(source, target),
|
||||
created_time=data.get("created_time", datetime.datetime.now().timestamp()),
|
||||
last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
|
||||
)
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user