三次修改
This commit is contained in:
@@ -201,7 +201,7 @@ class Hippocampus:
|
||||
self.entorhinal_cortex = EntorhinalCortex(self)
|
||||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||
# 从数据库加载记忆图
|
||||
self.entorhinal_cortex.sync_memory_from_db()
|
||||
# self.entorhinal_cortex.sync_memory_from_db() # 改为异步启动
|
||||
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.small")
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
@@ -789,7 +789,7 @@ class EntorhinalCortex:
|
||||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
def get_memory_sample(self):
|
||||
async def get_memory_sample(self):
|
||||
"""从数据库获取记忆样本"""
|
||||
# 硬编码:每条消息最大记忆次数
|
||||
max_memorized_time_per_msg = 2
|
||||
@@ -812,7 +812,7 @@ class EntorhinalCortex:
|
||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||
chat_samples = []
|
||||
for timestamp in timestamps:
|
||||
if messages := self.random_get_msg_snippet(
|
||||
if messages := await self.random_get_msg_snippet(
|
||||
timestamp,
|
||||
global_config.memory.memory_build_sample_length,
|
||||
max_memorized_time_per_msg,
|
||||
@@ -826,7 +826,9 @@ class EntorhinalCortex:
|
||||
return chat_samples
|
||||
|
||||
@staticmethod
|
||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
||||
async def random_get_msg_snippet(
|
||||
target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int
|
||||
) -> list | None:
|
||||
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||
@@ -864,13 +866,13 @@ class EntorhinalCortex:
|
||||
for message in messages:
|
||||
# 确保在更新前获取最新的 memorized_times
|
||||
current_memorized_times = message.get("memorized_times", 0)
|
||||
with get_db_session() as session:
|
||||
session.execute(
|
||||
async with get_db_session() as session:
|
||||
await session.execute(
|
||||
update(Messages)
|
||||
.where(Messages.message_id == message["message_id"])
|
||||
.values(memorized_times=current_memorized_times + 1)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return messages # 直接返回原始的消息列表
|
||||
|
||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||
@@ -884,8 +886,8 @@ class EntorhinalCortex:
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
with get_db_session() as session:
|
||||
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
|
||||
async with get_db_session() as session:
|
||||
db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 批量准备节点数据
|
||||
@@ -954,24 +956,24 @@ class EntorhinalCortex:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
await session.execute(insert(GraphNodes), batch)
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_update), batch_size):
|
||||
batch = nodes_to_update[i : i + batch_size]
|
||||
for node_data in batch:
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(GraphNodes)
|
||||
.where(GraphNodes.concept == node_data["concept"])
|
||||
.values(**{k: v for k, v in node_data.items() if k != "concept"})
|
||||
)
|
||||
|
||||
if nodes_to_delete:
|
||||
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
|
||||
await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
db_edges = list((await session.execute(select(GraphEdges))).scalars())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -1023,14 +1025,14 @@ class EntorhinalCortex:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
await session.execute(insert(GraphEdges), batch)
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_update), batch_size):
|
||||
batch = edges_to_update[i : i + batch_size]
|
||||
for edge_data in batch:
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(GraphEdges)
|
||||
.where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
@@ -1040,12 +1042,12 @@ class EntorhinalCortex:
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
session.execute(
|
||||
await session.execute(
|
||||
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
)
|
||||
|
||||
# 提交事务
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||
@@ -1057,10 +1059,10 @@ class EntorhinalCortex:
|
||||
logger.info("[数据库] 开始重新同步所有记忆数据...")
|
||||
|
||||
# 清空数据库
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
clear_start = time.time()
|
||||
session.execute(delete(GraphNodes))
|
||||
session.execute(delete(GraphEdges))
|
||||
await session.execute(delete(GraphNodes))
|
||||
await session.execute(delete(GraphEdges))
|
||||
|
||||
clear_end = time.time()
|
||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||
@@ -1119,7 +1121,7 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
await session.execute(insert(GraphNodes), batch)
|
||||
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
@@ -1130,8 +1132,8 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i : i + batch_size]
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
session.commit()
|
||||
await session.execute(insert(GraphEdges), batch)
|
||||
await session.commit()
|
||||
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
@@ -1140,7 +1142,7 @@ class EntorhinalCortex:
|
||||
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边")
|
||||
|
||||
def sync_memory_from_db(self):
|
||||
async def sync_memory_from_db(self):
|
||||
"""从数据库同步数据到内存中的图结构"""
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
need_update = False
|
||||
@@ -1149,8 +1151,8 @@ class EntorhinalCortex:
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
with get_db_session() as session:
|
||||
nodes = list(session.execute(select(GraphNodes)).scalars())
|
||||
async with get_db_session() as session:
|
||||
nodes = list((await session.execute(select(GraphNodes))).scalars())
|
||||
for node in nodes:
|
||||
concept = node.concept
|
||||
try:
|
||||
@@ -1168,7 +1170,9 @@ class EntorhinalCortex:
|
||||
if not node.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data))
|
||||
await session.execute(
|
||||
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
|
||||
)
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.created_time or current_time
|
||||
@@ -1183,7 +1187,7 @@ class EntorhinalCortex:
|
||||
continue
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
edges = list((await session.execute(select(GraphEdges))).scalars())
|
||||
for edge in edges:
|
||||
source = edge.source
|
||||
target = edge.target
|
||||
@@ -1199,7 +1203,7 @@ class EntorhinalCortex:
|
||||
if not edge.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(GraphEdges)
|
||||
.where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
.values(**update_data)
|
||||
@@ -1214,7 +1218,7 @@ class EntorhinalCortex:
|
||||
self.memory_graph.G.add_edge(
|
||||
source, target, strength=strength, created_time=created_time, last_modified=last_modified
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
if need_update:
|
||||
logger.info("[数据库] 已为缺失的时间字段进行补充")
|
||||
@@ -1254,7 +1258,7 @@ class ParahippocampalGyrus:
|
||||
|
||||
# 1. 使用 build_readable_messages 生成格式化文本
|
||||
# build_readable_messages 只返回一个字符串,不需要解包
|
||||
input_text = build_readable_messages(
|
||||
input_text = await build_readable_messages(
|
||||
messages,
|
||||
merge_messages=True, # 合并连续消息
|
||||
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
|
||||
@@ -1342,7 +1346,7 @@ class ParahippocampalGyrus:
|
||||
# sourcery skip: merge-list-appends-into-extend
|
||||
logger.info("------------------------------------开始构建记忆--------------------------------------")
|
||||
start_time = time.time()
|
||||
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||
memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||
all_added_nodes = []
|
||||
all_connected_nodes = []
|
||||
all_added_edges = []
|
||||
@@ -1620,7 +1624,7 @@ class HippocampusManager:
|
||||
return self._hippocampus
|
||||
|
||||
self._hippocampus = Hippocampus()
|
||||
self._hippocampus.initialize()
|
||||
# self._hippocampus.initialize() # 改为异步启动
|
||||
self._initialized = True
|
||||
|
||||
# 输出记忆图统计信息
|
||||
@@ -1639,6 +1643,13 @@ class HippocampusManager:
|
||||
|
||||
return self._hippocampus
|
||||
|
||||
async def initialize_async(self):
|
||||
"""异步初始化海马体实例"""
|
||||
if not self._initialized:
|
||||
self.initialize() # 先进行同步部分的初始化
|
||||
self._hippocampus.initialize()
|
||||
await self._hippocampus.entorhinal_cortex.sync_memory_from_db()
|
||||
|
||||
def get_hippocampus(self):
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
|
||||
Reference in New Issue
Block a user