三次修改

This commit is contained in:
tt-P607
2025-09-20 02:21:53 +08:00
parent 6a98ae6208
commit 0cc4f5bb27
20 changed files with 478 additions and 481 deletions

View File

@@ -201,7 +201,7 @@ class Hippocampus:
self.entorhinal_cortex = EntorhinalCortex(self)
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
# self.entorhinal_cortex.sync_memory_from_db() # 改为异步启动
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.small")
def get_all_node_names(self) -> list:
@@ -789,7 +789,7 @@ class EntorhinalCortex:
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
def get_memory_sample(self):
async def get_memory_sample(self):
"""从数据库获取记忆样本"""
# 硬编码:每条消息最大记忆次数
max_memorized_time_per_msg = 2
@@ -812,7 +812,7 @@ class EntorhinalCortex:
logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = []
for timestamp in timestamps:
if messages := self.random_get_msg_snippet(
if messages := await self.random_get_msg_snippet(
timestamp,
global_config.memory.memory_build_sample_length,
max_memorized_time_per_msg,
@@ -826,7 +826,9 @@ class EntorhinalCortex:
return chat_samples
@staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
async def random_get_msg_snippet(
target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int
) -> list | None:
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟
@@ -864,13 +866,13 @@ class EntorhinalCortex:
for message in messages:
# 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
with get_db_session() as session:
session.execute(
async with get_db_session() as session:
await session.execute(
update(Messages)
.where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1)
)
session.commit()
await session.commit()
return messages # 直接返回原始的消息列表
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
@@ -884,8 +886,8 @@ class EntorhinalCortex:
current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点
with get_db_session() as session:
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
async with get_db_session() as session:
db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据
@@ -954,24 +956,24 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
await session.execute(insert(GraphNodes), batch)
if nodes_to_update:
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i : i + batch_size]
for node_data in batch:
session.execute(
await session.execute(
update(GraphNodes)
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
)
if nodes_to_delete:
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
# 处理边的信息
db_edges = list(session.execute(select(GraphEdges)).scalars())
db_edges = list((await session.execute(select(GraphEdges))).scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -1023,14 +1025,14 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
await session.execute(insert(GraphEdges), batch)
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i : i + batch_size]
for edge_data in batch:
session.execute(
await session.execute(
update(GraphEdges)
.where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
@@ -1040,12 +1042,12 @@ class EntorhinalCortex:
if edges_to_delete:
for source, target in edges_to_delete:
session.execute(
await session.execute(
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
)
# 提交事务
session.commit()
await session.commit()
end_time = time.time()
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}")
@@ -1057,10 +1059,10 @@ class EntorhinalCortex:
logger.info("[数据库] 开始重新同步所有记忆数据...")
# 清空数据库
with get_db_session() as session:
async with get_db_session() as session:
clear_start = time.time()
session.execute(delete(GraphNodes))
session.execute(delete(GraphEdges))
await session.execute(delete(GraphNodes))
await session.execute(delete(GraphEdges))
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1119,7 +1121,7 @@ class EntorhinalCortex:
batch_size = 500 # 增加批量大小
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
await session.execute(insert(GraphNodes), batch)
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
@@ -1130,8 +1132,8 @@ class EntorhinalCortex:
batch_size = 500 # 增加批量大小
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
session.commit()
await session.execute(insert(GraphEdges), batch)
await session.commit()
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")
@@ -1140,7 +1142,7 @@ class EntorhinalCortex:
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}")
logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边")
def sync_memory_from_db(self):
async def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
current_time = datetime.datetime.now().timestamp()
need_update = False
@@ -1149,8 +1151,8 @@ class EntorhinalCortex:
self.memory_graph.G.clear()
# 从数据库加载所有节点
with get_db_session() as session:
nodes = list(session.execute(select(GraphNodes)).scalars())
async with get_db_session() as session:
nodes = list((await session.execute(select(GraphNodes))).scalars())
for node in nodes:
concept = node.concept
try:
@@ -1168,7 +1170,9 @@ class EntorhinalCortex:
if not node.last_modified:
update_data["last_modified"] = current_time
session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data))
await session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
)
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time
@@ -1183,7 +1187,7 @@ class EntorhinalCortex:
continue
# 从数据库加载所有边
edges = list(session.execute(select(GraphEdges)).scalars())
edges = list((await session.execute(select(GraphEdges))).scalars())
for edge in edges:
source = edge.source
target = edge.target
@@ -1199,7 +1203,7 @@ class EntorhinalCortex:
if not edge.last_modified:
update_data["last_modified"] = current_time
session.execute(
await session.execute(
update(GraphEdges)
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
@@ -1214,7 +1218,7 @@ class EntorhinalCortex:
self.memory_graph.G.add_edge(
source, target, strength=strength, created_time=created_time, last_modified=last_modified
)
session.commit()
await session.commit()
if need_update:
logger.info("[数据库] 已为缺失的时间字段进行补充")
@@ -1254,7 +1258,7 @@ class ParahippocampalGyrus:
# 1. 使用 build_readable_messages 生成格式化文本
# build_readable_messages 只返回一个字符串,不需要解包
input_text = build_readable_messages(
input_text = await build_readable_messages(
messages,
merge_messages=True, # 合并连续消息
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
@@ -1342,7 +1346,7 @@ class ParahippocampalGyrus:
# sourcery skip: merge-list-appends-into-extend
logger.info("------------------------------------开始构建记忆--------------------------------------")
start_time = time.time()
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample()
all_added_nodes = []
all_connected_nodes = []
all_added_edges = []
@@ -1620,7 +1624,7 @@ class HippocampusManager:
return self._hippocampus
self._hippocampus = Hippocampus()
self._hippocampus.initialize()
# self._hippocampus.initialize() # 改为异步启动
self._initialized = True
# 输出记忆图统计信息
@@ -1639,6 +1643,13 @@ class HippocampusManager:
return self._hippocampus
async def initialize_async(self):
"""异步初始化海马体实例"""
if not self._initialized:
self.initialize() # 先进行同步部分的初始化
self._hippocampus.initialize()
await self._hippocampus.entorhinal_cortex.sync_memory_from_db()
def get_hippocampus(self):
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")