feat:添加海马体支持

This commit is contained in:
SengokuCola
2025-05-17 14:46:01 +08:00
parent 06a3479c0f
commit e6cd2a8e8f
2 changed files with 112 additions and 89 deletions

View File

@@ -4,6 +4,7 @@ import math
import random
import time
import re
import json
from itertools import combinations
import jieba
@@ -22,7 +23,7 @@ from ..utils.utils import translate_timestamp_to_human_readable
from rich.traceback import install
from ...config.config import global_config
from src.common.database.database_model import Messages # Peewee Messages 模型导入
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
install(extra_lines=3)
@@ -877,12 +878,9 @@ class EntorhinalCortex:
async def sync_memory_to_db(self):
"""将记忆图同步到数据库"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(db.graph_data.nodes.find())
db_nodes = {node.concept: node for node in GraphNodes.select()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node["concept"]: node for node in db_nodes}
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get("memory_items", [])
@@ -896,44 +894,39 @@ class EntorhinalCortex:
created_time = data.get("created_time", datetime.datetime.now().timestamp())
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
if concept not in db_nodes_dict:
# 将memory_items转换为JSON字符串
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
if concept not in db_nodes:
# 数据库中缺少的节点,添加
node_data = {
"concept": concept,
"memory_items": memory_items,
"hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
}
db.graph_data.nodes.insert_one(node_data)
GraphNodes.create(
concept=concept,
memory_items=memory_items_json,
hash=memory_hash,
created_time=created_time,
last_modified=last_modified,
)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
db_hash = db_node.get("hash", None)
db_node = db_nodes[concept]
db_hash = db_node.hash
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
db.graph_data.nodes.update_one(
{"concept": concept},
{
"$set": {
"memory_items": memory_items,
"hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
}
},
).execute()
db_node.memory_items = memory_items_json
db_node.hash = memory_hash
db_node.last_modified = last_modified
db_node.save()
# 处理边的信息
db_edges = list(db.graph_data.edges.find())
db_edges = list(GraphEdges.select())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.hippocampus.calculate_edge_hash(edge["source"], edge["target"])
db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)}
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
# 检查并更新边
for source, target, data in memory_edges:
@@ -947,29 +940,22 @@ class EntorhinalCortex:
if edge_key not in db_edge_dict:
# 添加新边
edge_data = {
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"created_time": created_time,
"last_modified": last_modified,
}
db.graph_data.edges.insert_one(edge_data)
GraphEdges.create(
source=source,
target=target,
strength=strength,
hash=edge_hash,
created_time=created_time,
last_modified=last_modified,
)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]["hash"] != edge_hash:
db.graph_data.edges.update_one(
{"source": source, "target": target},
{
"$set": {
"hash": edge_hash,
"strength": strength,
"created_time": created_time,
"last_modified": last_modified,
}
},
).execute()
edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target)
edge.hash = edge_hash
edge.strength = strength
edge.last_modified = last_modified
edge.save()
def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
@@ -980,29 +966,31 @@ class EntorhinalCortex:
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = list(db.graph_data.nodes.find())
nodes = list(GraphNodes.select())
for node in nodes:
concept = node["concept"]
memory_items = node.get("memory_items", [])
concept = node.concept
memory_items = json.loads(node.memory_items)
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 检查时间字段是否存在
if "created_time" not in node or "last_modified" not in node:
if not node.created_time or not node.last_modified:
need_update = True
# 更新数据库中的节点
update_data = {}
if "created_time" not in node:
if not node.created_time:
update_data["created_time"] = current_time
if "last_modified" not in node:
if not node.last_modified:
update_data["last_modified"] = current_time
db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data}).execute()
GraphNodes.update(
**update_data
).where(GraphNodes.concept == concept).execute()
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.get("created_time", current_time)
last_modified = node.get("last_modified", current_time)
created_time = node.created_time or current_time
last_modified = node.last_modified or current_time
# 添加节点到图中
self.memory_graph.G.add_node(
@@ -1010,28 +998,32 @@ class EntorhinalCortex:
)
# 从数据库加载所有边
edges = list(db.graph_data.edges.find())
edges = list(GraphEdges.select())
for edge in edges:
source = edge["source"]
target = edge["target"]
strength = edge.get("strength", 1)
source = edge.source
target = edge.target
strength = edge.strength
# 检查时间字段是否存在
if "created_time" not in edge or "last_modified" not in edge:
if not edge.created_time or not edge.last_modified:
need_update = True
# 更新数据库中的边
update_data = {}
if "created_time" not in edge:
if not edge.created_time:
update_data["created_time"] = current_time
if "last_modified" not in edge:
if not edge.last_modified:
update_data["last_modified"] = current_time
db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data}).execute()
GraphEdges.update(
**update_data
).where(
(GraphEdges.source == source) & (GraphEdges.target == target)
).execute()
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.get("created_time", current_time)
last_modified = edge.get("last_modified", current_time)
created_time = edge.created_time or current_time
last_modified = edge.last_modified or current_time
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
@@ -1049,8 +1041,8 @@ class EntorhinalCortex:
# 清空数据库
clear_start = time.time()
db.graph_data.nodes.delete_many({})
db.graph_data.edges.delete_many({})
GraphNodes.delete().execute()
GraphEdges.delete().execute()
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1065,29 +1057,27 @@ class EntorhinalCortex:
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
node_data = {
"concept": concept,
"memory_items": memory_items,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
}
db.graph_data.nodes.insert_one(node_data)
GraphNodes.create(
concept=concept,
memory_items=json.dumps(memory_items),
hash=self.hippocampus.calculate_node_hash(concept, memory_items),
created_time=data.get("created_time", datetime.datetime.now().timestamp()),
last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
)
node_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}")
# 重新写入边
edge_start = time.time()
for source, target, data in memory_edges:
edge_data = {
"source": source,
"target": target,
"strength": data.get("strength", 1),
"hash": self.hippocampus.calculate_edge_hash(source, target),
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
}
db.graph_data.edges.insert_one(edge_data)
GraphEdges.create(
source=source,
target=target,
strength=data.get("strength", 1),
hash=self.hippocampus.calculate_edge_hash(source, target),
created_time=data.get("created_time", datetime.datetime.now().timestamp()),
last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
)
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}")

View File

@@ -275,6 +275,35 @@ class RecalledMessages(BaseModel):
table_name = "recalled_messages"
class GraphNodes(BaseModel):
"""
用于存储记忆图节点的模型
"""
concept = TextField(unique=True, index=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表
hash = TextField() # 节点哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = "graph_nodes"
class GraphEdges(BaseModel):
"""
用于存储记忆图边的模型
"""
source = TextField(index=True) # 源节点
target = TextField(index=True) # 目标节点
strength = IntegerField() # 连接强度
hash = TextField() # 边哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = "graph_edges"
def create_tables():
"""
创建所有在模型中定义的数据库表。
@@ -293,6 +322,8 @@ def create_tables():
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
]
)
@@ -315,7 +346,9 @@ def initialize_database():
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
RecalledMessages,
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
]
needs_creation = False