feat:添加海马体支持

This commit is contained in:
SengokuCola
2025-05-17 14:46:01 +08:00
parent 06a3479c0f
commit e6cd2a8e8f
2 changed files with 112 additions and 89 deletions

View File

@@ -4,6 +4,7 @@ import math
import random import random
import time import time
import re import re
import json
from itertools import combinations from itertools import combinations
import jieba import jieba
@@ -22,7 +23,7 @@ from ..utils.utils import translate_timestamp_to_human_readable
from rich.traceback import install from rich.traceback import install
from ...config.config import global_config from ...config.config import global_config
from src.common.database.database_model import Messages # Peewee Messages 模型导入 from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
install(extra_lines=3) install(extra_lines=3)
@@ -877,12 +878,9 @@ class EntorhinalCortex:
async def sync_memory_to_db(self): async def sync_memory_to_db(self):
"""将记忆图同步到数据库""" """将记忆图同步到数据库"""
# 获取数据库中所有节点和内存中所有节点 # 获取数据库中所有节点和内存中所有节点
db_nodes = list(db.graph_data.nodes.find()) db_nodes = {node.concept: node for node in GraphNodes.select()}
memory_nodes = list(self.memory_graph.G.nodes(data=True)) memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node["concept"]: node for node in db_nodes}
# 检查并更新节点 # 检查并更新节点
for concept, data in memory_nodes: for concept, data in memory_nodes:
memory_items = data.get("memory_items", []) memory_items = data.get("memory_items", [])
@@ -896,44 +894,39 @@ class EntorhinalCortex:
created_time = data.get("created_time", datetime.datetime.now().timestamp()) created_time = data.get("created_time", datetime.datetime.now().timestamp())
last_modified = data.get("last_modified", datetime.datetime.now().timestamp()) last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
if concept not in db_nodes_dict: # 将memory_items转换为JSON字符串
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
if concept not in db_nodes:
# 数据库中缺少的节点,添加 # 数据库中缺少的节点,添加
node_data = { GraphNodes.create(
"concept": concept, concept=concept,
"memory_items": memory_items, memory_items=memory_items_json,
"hash": memory_hash, hash=memory_hash,
"created_time": created_time, created_time=created_time,
"last_modified": last_modified, last_modified=last_modified,
} )
db.graph_data.nodes.insert_one(node_data)
else: else:
# 获取数据库中节点的特征值 # 获取数据库中节点的特征值
db_node = db_nodes_dict[concept] db_node = db_nodes[concept]
db_hash = db_node.get("hash", None) db_hash = db_node.hash
# 如果特征值不同,则更新节点 # 如果特征值不同,则更新节点
if db_hash != memory_hash: if db_hash != memory_hash:
db.graph_data.nodes.update_one( db_node.memory_items = memory_items_json
{"concept": concept}, db_node.hash = memory_hash
{ db_node.last_modified = last_modified
"$set": { db_node.save()
"memory_items": memory_items,
"hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
}
},
).execute()
# 处理边的信息 # 处理边的信息
db_edges = list(db.graph_data.edges.find()) db_edges = list(GraphEdges.select())
memory_edges = list(self.memory_graph.G.edges(data=True)) memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典 # 创建边的哈希值字典
db_edge_dict = {} db_edge_dict = {}
for edge in db_edges: for edge in db_edges:
edge_hash = self.hippocampus.calculate_edge_hash(edge["source"], edge["target"]) edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)} db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
# 检查并更新边 # 检查并更新边
for source, target, data in memory_edges: for source, target, data in memory_edges:
@@ -947,29 +940,22 @@ class EntorhinalCortex:
if edge_key not in db_edge_dict: if edge_key not in db_edge_dict:
# 添加新边 # 添加新边
edge_data = { GraphEdges.create(
"source": source, source=source,
"target": target, target=target,
"strength": strength, strength=strength,
"hash": edge_hash, hash=edge_hash,
"created_time": created_time, created_time=created_time,
"last_modified": last_modified, last_modified=last_modified,
} )
db.graph_data.edges.insert_one(edge_data)
else: else:
# 检查边的特征值是否变化 # 检查边的特征值是否变化
if db_edge_dict[edge_key]["hash"] != edge_hash: if db_edge_dict[edge_key]["hash"] != edge_hash:
db.graph_data.edges.update_one( edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target)
{"source": source, "target": target}, edge.hash = edge_hash
{ edge.strength = strength
"$set": { edge.last_modified = last_modified
"hash": edge_hash, edge.save()
"strength": strength,
"created_time": created_time,
"last_modified": last_modified,
}
},
).execute()
def sync_memory_from_db(self): def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构""" """从数据库同步数据到内存中的图结构"""
@@ -980,29 +966,31 @@ class EntorhinalCortex:
self.memory_graph.G.clear() self.memory_graph.G.clear()
# 从数据库加载所有节点 # 从数据库加载所有节点
nodes = list(db.graph_data.nodes.find()) nodes = list(GraphNodes.select())
for node in nodes: for node in nodes:
concept = node["concept"] concept = node.concept
memory_items = node.get("memory_items", []) memory_items = json.loads(node.memory_items)
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
# 检查时间字段是否存在 # 检查时间字段是否存在
if "created_time" not in node or "last_modified" not in node: if not node.created_time or not node.last_modified:
need_update = True need_update = True
# 更新数据库中的节点 # 更新数据库中的节点
update_data = {} update_data = {}
if "created_time" not in node: if not node.created_time:
update_data["created_time"] = current_time update_data["created_time"] = current_time
if "last_modified" not in node: if not node.last_modified:
update_data["last_modified"] = current_time update_data["last_modified"] = current_time
db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data}).execute() GraphNodes.update(
**update_data
).where(GraphNodes.concept == concept).execute()
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间) # 获取时间信息(如果不存在则使用当前时间)
created_time = node.get("created_time", current_time) created_time = node.created_time or current_time
last_modified = node.get("last_modified", current_time) last_modified = node.last_modified or current_time
# 添加节点到图中 # 添加节点到图中
self.memory_graph.G.add_node( self.memory_graph.G.add_node(
@@ -1010,28 +998,32 @@ class EntorhinalCortex:
) )
# 从数据库加载所有边 # 从数据库加载所有边
edges = list(db.graph_data.edges.find()) edges = list(GraphEdges.select())
for edge in edges: for edge in edges:
source = edge["source"] source = edge.source
target = edge["target"] target = edge.target
strength = edge.get("strength", 1) strength = edge.strength
# 检查时间字段是否存在 # 检查时间字段是否存在
if "created_time" not in edge or "last_modified" not in edge: if not edge.created_time or not edge.last_modified:
need_update = True need_update = True
# 更新数据库中的边 # 更新数据库中的边
update_data = {} update_data = {}
if "created_time" not in edge: if not edge.created_time:
update_data["created_time"] = current_time update_data["created_time"] = current_time
if "last_modified" not in edge: if not edge.last_modified:
update_data["last_modified"] = current_time update_data["last_modified"] = current_time
db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data}).execute() GraphEdges.update(
**update_data
).where(
(GraphEdges.source == source) & (GraphEdges.target == target)
).execute()
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间) # 获取时间信息(如果不存在则使用当前时间)
created_time = edge.get("created_time", current_time) created_time = edge.created_time or current_time
last_modified = edge.get("last_modified", current_time) last_modified = edge.last_modified or current_time
# 只有当源节点和目标节点都存在时才添加边 # 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G: if source in self.memory_graph.G and target in self.memory_graph.G:
@@ -1049,8 +1041,8 @@ class EntorhinalCortex:
# 清空数据库 # 清空数据库
clear_start = time.time() clear_start = time.time()
db.graph_data.nodes.delete_many({}) GraphNodes.delete().execute()
db.graph_data.edges.delete_many({}) GraphEdges.delete().execute()
clear_end = time.time() clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}") logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1065,29 +1057,27 @@ class EntorhinalCortex:
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
node_data = { GraphNodes.create(
"concept": concept, concept=concept,
"memory_items": memory_items, memory_items=json.dumps(memory_items),
"hash": self.hippocampus.calculate_node_hash(concept, memory_items), hash=self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", datetime.datetime.now().timestamp()), created_time=data.get("created_time", datetime.datetime.now().timestamp()),
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()), last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
} )
db.graph_data.nodes.insert_one(node_data)
node_end = time.time() node_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}") logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}")
# 重新写入边 # 重新写入边
edge_start = time.time() edge_start = time.time()
for source, target, data in memory_edges: for source, target, data in memory_edges:
edge_data = { GraphEdges.create(
"source": source, source=source,
"target": target, target=target,
"strength": data.get("strength", 1), strength=data.get("strength", 1),
"hash": self.hippocampus.calculate_edge_hash(source, target), hash=self.hippocampus.calculate_edge_hash(source, target),
"created_time": data.get("created_time", datetime.datetime.now().timestamp()), created_time=data.get("created_time", datetime.datetime.now().timestamp()),
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()), last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
} )
db.graph_data.edges.insert_one(edge_data)
edge_end = time.time() edge_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}") logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}")

View File

@@ -275,6 +275,35 @@ class RecalledMessages(BaseModel):
table_name = "recalled_messages" table_name = "recalled_messages"
class GraphNodes(BaseModel):
"""
用于存储记忆图节点的模型
"""
concept = TextField(unique=True, index=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表
hash = TextField() # 节点哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = "graph_nodes"
class GraphEdges(BaseModel):
"""
用于存储记忆图边的模型
"""
source = TextField(index=True) # 源节点
target = TextField(index=True) # 目标节点
strength = IntegerField() # 连接强度
hash = TextField() # 边哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = "graph_edges"
def create_tables(): def create_tables():
""" """
创建所有在模型中定义的数据库表。 创建所有在模型中定义的数据库表。
@@ -293,6 +322,8 @@ def create_tables():
Knowledges, Knowledges,
ThinkingLog, ThinkingLog,
RecalledMessages, # 添加新模型 RecalledMessages, # 添加新模型
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
] ]
) )
@@ -315,7 +346,9 @@ def initialize_database():
PersonInfo, PersonInfo,
Knowledges, Knowledges,
ThinkingLog, ThinkingLog,
RecalledMessages, # 添加新模型 RecalledMessages,
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
] ]
needs_creation = False needs_creation = False