feat:添加海马体支持
This commit is contained in:
@@ -4,6 +4,7 @@ import math
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
|
import json
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
|
|
||||||
import jieba
|
import jieba
|
||||||
@@ -22,7 +23,7 @@ from ..utils.utils import translate_timestamp_to_human_readable
|
|||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from src.common.database.database_model import Messages # Peewee Messages 模型导入
|
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -877,12 +878,9 @@ class EntorhinalCortex:
|
|||||||
async def sync_memory_to_db(self):
|
async def sync_memory_to_db(self):
|
||||||
"""将记忆图同步到数据库"""
|
"""将记忆图同步到数据库"""
|
||||||
# 获取数据库中所有节点和内存中所有节点
|
# 获取数据库中所有节点和内存中所有节点
|
||||||
db_nodes = list(db.graph_data.nodes.find())
|
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||||
|
|
||||||
# 转换数据库节点为字典格式,方便查找
|
|
||||||
db_nodes_dict = {node["concept"]: node for node in db_nodes}
|
|
||||||
|
|
||||||
# 检查并更新节点
|
# 检查并更新节点
|
||||||
for concept, data in memory_nodes:
|
for concept, data in memory_nodes:
|
||||||
memory_items = data.get("memory_items", [])
|
memory_items = data.get("memory_items", [])
|
||||||
@@ -896,44 +894,39 @@ class EntorhinalCortex:
|
|||||||
created_time = data.get("created_time", datetime.datetime.now().timestamp())
|
created_time = data.get("created_time", datetime.datetime.now().timestamp())
|
||||||
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
|
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
|
||||||
|
|
||||||
if concept not in db_nodes_dict:
|
# 将memory_items转换为JSON字符串
|
||||||
|
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
|
||||||
|
|
||||||
|
if concept not in db_nodes:
|
||||||
# 数据库中缺少的节点,添加
|
# 数据库中缺少的节点,添加
|
||||||
node_data = {
|
GraphNodes.create(
|
||||||
"concept": concept,
|
concept=concept,
|
||||||
"memory_items": memory_items,
|
memory_items=memory_items_json,
|
||||||
"hash": memory_hash,
|
hash=memory_hash,
|
||||||
"created_time": created_time,
|
created_time=created_time,
|
||||||
"last_modified": last_modified,
|
last_modified=last_modified,
|
||||||
}
|
)
|
||||||
db.graph_data.nodes.insert_one(node_data)
|
|
||||||
else:
|
else:
|
||||||
# 获取数据库中节点的特征值
|
# 获取数据库中节点的特征值
|
||||||
db_node = db_nodes_dict[concept]
|
db_node = db_nodes[concept]
|
||||||
db_hash = db_node.get("hash", None)
|
db_hash = db_node.hash
|
||||||
|
|
||||||
# 如果特征值不同,则更新节点
|
# 如果特征值不同,则更新节点
|
||||||
if db_hash != memory_hash:
|
if db_hash != memory_hash:
|
||||||
db.graph_data.nodes.update_one(
|
db_node.memory_items = memory_items_json
|
||||||
{"concept": concept},
|
db_node.hash = memory_hash
|
||||||
{
|
db_node.last_modified = last_modified
|
||||||
"$set": {
|
db_node.save()
|
||||||
"memory_items": memory_items,
|
|
||||||
"hash": memory_hash,
|
|
||||||
"created_time": created_time,
|
|
||||||
"last_modified": last_modified,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
# 处理边的信息
|
# 处理边的信息
|
||||||
db_edges = list(db.graph_data.edges.find())
|
db_edges = list(GraphEdges.select())
|
||||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||||
|
|
||||||
# 创建边的哈希值字典
|
# 创建边的哈希值字典
|
||||||
db_edge_dict = {}
|
db_edge_dict = {}
|
||||||
for edge in db_edges:
|
for edge in db_edges:
|
||||||
edge_hash = self.hippocampus.calculate_edge_hash(edge["source"], edge["target"])
|
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
|
||||||
db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)}
|
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
|
||||||
|
|
||||||
# 检查并更新边
|
# 检查并更新边
|
||||||
for source, target, data in memory_edges:
|
for source, target, data in memory_edges:
|
||||||
@@ -947,29 +940,22 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
if edge_key not in db_edge_dict:
|
if edge_key not in db_edge_dict:
|
||||||
# 添加新边
|
# 添加新边
|
||||||
edge_data = {
|
GraphEdges.create(
|
||||||
"source": source,
|
source=source,
|
||||||
"target": target,
|
target=target,
|
||||||
"strength": strength,
|
strength=strength,
|
||||||
"hash": edge_hash,
|
hash=edge_hash,
|
||||||
"created_time": created_time,
|
created_time=created_time,
|
||||||
"last_modified": last_modified,
|
last_modified=last_modified,
|
||||||
}
|
)
|
||||||
db.graph_data.edges.insert_one(edge_data)
|
|
||||||
else:
|
else:
|
||||||
# 检查边的特征值是否变化
|
# 检查边的特征值是否变化
|
||||||
if db_edge_dict[edge_key]["hash"] != edge_hash:
|
if db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||||
db.graph_data.edges.update_one(
|
edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target)
|
||||||
{"source": source, "target": target},
|
edge.hash = edge_hash
|
||||||
{
|
edge.strength = strength
|
||||||
"$set": {
|
edge.last_modified = last_modified
|
||||||
"hash": edge_hash,
|
edge.save()
|
||||||
"strength": strength,
|
|
||||||
"created_time": created_time,
|
|
||||||
"last_modified": last_modified,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
def sync_memory_from_db(self):
|
def sync_memory_from_db(self):
|
||||||
"""从数据库同步数据到内存中的图结构"""
|
"""从数据库同步数据到内存中的图结构"""
|
||||||
@@ -980,29 +966,31 @@ class EntorhinalCortex:
|
|||||||
self.memory_graph.G.clear()
|
self.memory_graph.G.clear()
|
||||||
|
|
||||||
# 从数据库加载所有节点
|
# 从数据库加载所有节点
|
||||||
nodes = list(db.graph_data.nodes.find())
|
nodes = list(GraphNodes.select())
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
concept = node["concept"]
|
concept = node.concept
|
||||||
memory_items = node.get("memory_items", [])
|
memory_items = json.loads(node.memory_items)
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
# 检查时间字段是否存在
|
# 检查时间字段是否存在
|
||||||
if "created_time" not in node or "last_modified" not in node:
|
if not node.created_time or not node.last_modified:
|
||||||
need_update = True
|
need_update = True
|
||||||
# 更新数据库中的节点
|
# 更新数据库中的节点
|
||||||
update_data = {}
|
update_data = {}
|
||||||
if "created_time" not in node:
|
if not node.created_time:
|
||||||
update_data["created_time"] = current_time
|
update_data["created_time"] = current_time
|
||||||
if "last_modified" not in node:
|
if not node.last_modified:
|
||||||
update_data["last_modified"] = current_time
|
update_data["last_modified"] = current_time
|
||||||
|
|
||||||
db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data}).execute()
|
GraphNodes.update(
|
||||||
|
**update_data
|
||||||
|
).where(GraphNodes.concept == concept).execute()
|
||||||
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
|
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
|
||||||
|
|
||||||
# 获取时间信息(如果不存在则使用当前时间)
|
# 获取时间信息(如果不存在则使用当前时间)
|
||||||
created_time = node.get("created_time", current_time)
|
created_time = node.created_time or current_time
|
||||||
last_modified = node.get("last_modified", current_time)
|
last_modified = node.last_modified or current_time
|
||||||
|
|
||||||
# 添加节点到图中
|
# 添加节点到图中
|
||||||
self.memory_graph.G.add_node(
|
self.memory_graph.G.add_node(
|
||||||
@@ -1010,28 +998,32 @@ class EntorhinalCortex:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 从数据库加载所有边
|
# 从数据库加载所有边
|
||||||
edges = list(db.graph_data.edges.find())
|
edges = list(GraphEdges.select())
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
source = edge["source"]
|
source = edge.source
|
||||||
target = edge["target"]
|
target = edge.target
|
||||||
strength = edge.get("strength", 1)
|
strength = edge.strength
|
||||||
|
|
||||||
# 检查时间字段是否存在
|
# 检查时间字段是否存在
|
||||||
if "created_time" not in edge or "last_modified" not in edge:
|
if not edge.created_time or not edge.last_modified:
|
||||||
need_update = True
|
need_update = True
|
||||||
# 更新数据库中的边
|
# 更新数据库中的边
|
||||||
update_data = {}
|
update_data = {}
|
||||||
if "created_time" not in edge:
|
if not edge.created_time:
|
||||||
update_data["created_time"] = current_time
|
update_data["created_time"] = current_time
|
||||||
if "last_modified" not in edge:
|
if not edge.last_modified:
|
||||||
update_data["last_modified"] = current_time
|
update_data["last_modified"] = current_time
|
||||||
|
|
||||||
db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data}).execute()
|
GraphEdges.update(
|
||||||
|
**update_data
|
||||||
|
).where(
|
||||||
|
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||||
|
).execute()
|
||||||
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
|
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
|
||||||
|
|
||||||
# 获取时间信息(如果不存在则使用当前时间)
|
# 获取时间信息(如果不存在则使用当前时间)
|
||||||
created_time = edge.get("created_time", current_time)
|
created_time = edge.created_time or current_time
|
||||||
last_modified = edge.get("last_modified", current_time)
|
last_modified = edge.last_modified or current_time
|
||||||
|
|
||||||
# 只有当源节点和目标节点都存在时才添加边
|
# 只有当源节点和目标节点都存在时才添加边
|
||||||
if source in self.memory_graph.G and target in self.memory_graph.G:
|
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||||
@@ -1049,8 +1041,8 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
# 清空数据库
|
# 清空数据库
|
||||||
clear_start = time.time()
|
clear_start = time.time()
|
||||||
db.graph_data.nodes.delete_many({})
|
GraphNodes.delete().execute()
|
||||||
db.graph_data.edges.delete_many({})
|
GraphEdges.delete().execute()
|
||||||
clear_end = time.time()
|
clear_end = time.time()
|
||||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||||
|
|
||||||
@@ -1065,29 +1057,27 @@ class EntorhinalCortex:
|
|||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
node_data = {
|
GraphNodes.create(
|
||||||
"concept": concept,
|
concept=concept,
|
||||||
"memory_items": memory_items,
|
memory_items=json.dumps(memory_items),
|
||||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
hash=self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||||
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
|
created_time=data.get("created_time", datetime.datetime.now().timestamp()),
|
||||||
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
|
last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
|
||||||
}
|
)
|
||||||
db.graph_data.nodes.insert_one(node_data)
|
|
||||||
node_end = time.time()
|
node_end = time.time()
|
||||||
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒")
|
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||||
|
|
||||||
# 重新写入边
|
# 重新写入边
|
||||||
edge_start = time.time()
|
edge_start = time.time()
|
||||||
for source, target, data in memory_edges:
|
for source, target, data in memory_edges:
|
||||||
edge_data = {
|
GraphEdges.create(
|
||||||
"source": source,
|
source=source,
|
||||||
"target": target,
|
target=target,
|
||||||
"strength": data.get("strength", 1),
|
strength=data.get("strength", 1),
|
||||||
"hash": self.hippocampus.calculate_edge_hash(source, target),
|
hash=self.hippocampus.calculate_edge_hash(source, target),
|
||||||
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
|
created_time=data.get("created_time", datetime.datetime.now().timestamp()),
|
||||||
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
|
last_modified=data.get("last_modified", datetime.datetime.now().timestamp()),
|
||||||
}
|
)
|
||||||
db.graph_data.edges.insert_one(edge_data)
|
|
||||||
edge_end = time.time()
|
edge_end = time.time()
|
||||||
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||||
|
|
||||||
|
|||||||
@@ -275,6 +275,35 @@ class RecalledMessages(BaseModel):
|
|||||||
table_name = "recalled_messages"
|
table_name = "recalled_messages"
|
||||||
|
|
||||||
|
|
||||||
|
class GraphNodes(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储记忆图节点的模型
|
||||||
|
"""
|
||||||
|
concept = TextField(unique=True, index=True) # 节点概念
|
||||||
|
memory_items = TextField() # JSON格式存储的记忆列表
|
||||||
|
hash = TextField() # 节点哈希值
|
||||||
|
created_time = FloatField() # 创建时间戳
|
||||||
|
last_modified = FloatField() # 最后修改时间戳
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "graph_nodes"
|
||||||
|
|
||||||
|
|
||||||
|
class GraphEdges(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储记忆图边的模型
|
||||||
|
"""
|
||||||
|
source = TextField(index=True) # 源节点
|
||||||
|
target = TextField(index=True) # 目标节点
|
||||||
|
strength = IntegerField() # 连接强度
|
||||||
|
hash = TextField() # 边哈希值
|
||||||
|
created_time = FloatField() # 创建时间戳
|
||||||
|
last_modified = FloatField() # 最后修改时间戳
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "graph_edges"
|
||||||
|
|
||||||
|
|
||||||
def create_tables():
|
def create_tables():
|
||||||
"""
|
"""
|
||||||
创建所有在模型中定义的数据库表。
|
创建所有在模型中定义的数据库表。
|
||||||
@@ -293,6 +322,8 @@ def create_tables():
|
|||||||
Knowledges,
|
Knowledges,
|
||||||
ThinkingLog,
|
ThinkingLog,
|
||||||
RecalledMessages, # 添加新模型
|
RecalledMessages, # 添加新模型
|
||||||
|
GraphNodes, # 添加图节点表
|
||||||
|
GraphEdges, # 添加图边表
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -315,7 +346,9 @@ def initialize_database():
|
|||||||
PersonInfo,
|
PersonInfo,
|
||||||
Knowledges,
|
Knowledges,
|
||||||
ThinkingLog,
|
ThinkingLog,
|
||||||
RecalledMessages, # 添加新模型
|
RecalledMessages,
|
||||||
|
GraphNodes, # 添加图节点表
|
||||||
|
GraphEdges, # 添加图边表
|
||||||
]
|
]
|
||||||
|
|
||||||
needs_creation = False
|
needs_creation = False
|
||||||
|
|||||||
Reference in New Issue
Block a user