perf(memory): 优化记忆系统数据库操作并修复并发问题
将消息记忆次数的更新方式从单次写入重构为批量更新,在记忆构建任务结束时统一执行,大幅减少数据库写入次数,显著提升性能。 此外,为 `HippocampusManager` 添加了异步锁,以防止记忆巩固和遗忘操作并发执行时产生竞争条件。同时,增加了节点去重逻辑,在插入数据库前检查重复的概念,确保数据一致性。
This commit is contained in:
@@ -3,6 +3,7 @@ import datetime
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import orjson
|
import orjson
|
||||||
import jieba
|
import jieba
|
||||||
@@ -789,7 +790,7 @@ class EntorhinalCortex:
|
|||||||
self.hippocampus = hippocampus
|
self.hippocampus = hippocampus
|
||||||
self.memory_graph = hippocampus.memory_graph
|
self.memory_graph = hippocampus.memory_graph
|
||||||
|
|
||||||
async def get_memory_sample(self):
|
async def get_memory_sample(self) -> tuple[list, list[str]]:
|
||||||
"""从数据库获取记忆样本"""
|
"""从数据库获取记忆样本"""
|
||||||
# 硬编码:每条消息最大记忆次数
|
# 硬编码:每条消息最大记忆次数
|
||||||
max_memorized_time_per_msg = 2
|
max_memorized_time_per_msg = 2
|
||||||
@@ -811,24 +812,27 @@ class EntorhinalCortex:
|
|||||||
for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False):
|
for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False):
|
||||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
|
all_message_ids_to_update = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
if messages := await self.random_get_msg_snippet(
|
if result := await self.random_get_msg_snippet(
|
||||||
timestamp,
|
timestamp,
|
||||||
global_config.memory.memory_build_sample_length,
|
global_config.memory.memory_build_sample_length,
|
||||||
max_memorized_time_per_msg,
|
max_memorized_time_per_msg,
|
||||||
):
|
):
|
||||||
|
messages, message_ids_to_update = result
|
||||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||||
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||||
chat_samples.append(messages)
|
chat_samples.append(messages)
|
||||||
|
all_message_ids_to_update.extend(message_ids_to_update)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"时间戳 {timestamp} 的消息无需记忆")
|
logger.debug(f"时间戳 {timestamp} 的消息无需记忆")
|
||||||
|
|
||||||
return chat_samples
|
return chat_samples, all_message_ids_to_update
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def random_get_msg_snippet(
|
async def random_get_msg_snippet(
|
||||||
target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int
|
target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int
|
||||||
) -> list | None:
|
) -> tuple[list, list[str]] | None:
|
||||||
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||||
@@ -862,18 +866,9 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
# 如果所有消息都有效
|
# 如果所有消息都有效
|
||||||
if all_valid:
|
if all_valid:
|
||||||
# 更新数据库中的记忆次数
|
# 返回消息和需要更新的message_id
|
||||||
for message in messages:
|
message_ids_to_update = [msg["message_id"] for msg in messages]
|
||||||
# 确保在更新前获取最新的 memorized_times
|
return messages, message_ids_to_update
|
||||||
current_memorized_times = message.get("memorized_times", 0)
|
|
||||||
async with get_db_session() as session:
|
|
||||||
await session.execute(
|
|
||||||
update(Messages)
|
|
||||||
.where(Messages.message_id == message["message_id"])
|
|
||||||
.values(memorized_times=current_memorized_times + 1)
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
return messages # 直接返回原始的消息列表
|
|
||||||
|
|
||||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||||
|
|
||||||
@@ -953,10 +948,20 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
# 批量处理节点
|
# 批量处理节点
|
||||||
if nodes_to_create:
|
if nodes_to_create:
|
||||||
batch_size = 100
|
# 在插入前进行去重检查
|
||||||
for i in range(0, len(nodes_to_create), batch_size):
|
unique_nodes_to_create = []
|
||||||
batch = nodes_to_create[i : i + batch_size]
|
seen_concepts = set(db_nodes.keys())
|
||||||
await session.execute(insert(GraphNodes), batch)
|
for node_data in nodes_to_create:
|
||||||
|
concept = node_data["concept"]
|
||||||
|
if concept not in seen_concepts:
|
||||||
|
unique_nodes_to_create.append(node_data)
|
||||||
|
seen_concepts.add(concept)
|
||||||
|
|
||||||
|
if unique_nodes_to_create:
|
||||||
|
batch_size = 100
|
||||||
|
for i in range(0, len(unique_nodes_to_create), batch_size):
|
||||||
|
batch = unique_nodes_to_create[i : i + batch_size]
|
||||||
|
await session.execute(insert(GraphNodes), batch)
|
||||||
|
|
||||||
if nodes_to_update:
|
if nodes_to_update:
|
||||||
batch_size = 100
|
batch_size = 100
|
||||||
@@ -1346,7 +1351,7 @@ class ParahippocampalGyrus:
|
|||||||
# sourcery skip: merge-list-appends-into-extend
|
# sourcery skip: merge-list-appends-into-extend
|
||||||
logger.info("------------------------------------开始构建记忆--------------------------------------")
|
logger.info("------------------------------------开始构建记忆--------------------------------------")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample()
|
memory_samples, all_message_ids_to_update = await self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||||
all_added_nodes = []
|
all_added_nodes = []
|
||||||
all_connected_nodes = []
|
all_connected_nodes = []
|
||||||
all_added_edges = []
|
all_added_edges = []
|
||||||
@@ -1409,8 +1414,21 @@ class ParahippocampalGyrus:
|
|||||||
if all_connected_nodes:
|
if all_connected_nodes:
|
||||||
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
|
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
|
||||||
|
|
||||||
|
# 先同步记忆图
|
||||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||||
|
|
||||||
|
# 最后批量更新消息的记忆次数
|
||||||
|
if all_message_ids_to_update:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
# 使用 in_ 操作符进行批量更新
|
||||||
|
await session.execute(
|
||||||
|
update(Messages)
|
||||||
|
.where(Messages.message_id.in_(all_message_ids_to_update))
|
||||||
|
.values(memorized_times=Messages.memorized_times + 1)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
logger.info(f"批量更新了 {len(all_message_ids_to_update)} 条消息的记忆次数")
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
|
logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
|
||||||
|
|
||||||
@@ -1617,6 +1635,7 @@ class HippocampusManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._hippocampus: Hippocampus = None # type: ignore
|
self._hippocampus: Hippocampus = None # type: ignore
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
self._db_lock = asyncio.Lock()
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""初始化海马体实例"""
|
"""初始化海马体实例"""
|
||||||
@@ -1665,14 +1684,16 @@ class HippocampusManager:
|
|||||||
"""遗忘记忆的公共接口"""
|
"""遗忘记忆的公共接口"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
async with self._db_lock:
|
||||||
|
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
||||||
|
|
||||||
async def consolidate_memory(self):
|
async def consolidate_memory(self):
|
||||||
"""整合记忆的公共接口"""
|
"""整合记忆的公共接口"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
# 使用 operation_build_memory 方法来整合记忆
|
# 使用 operation_build_memory 方法来整合记忆
|
||||||
return await self._hippocampus.parahippocampal_gyrus.operation_build_memory()
|
async with self._db_lock:
|
||||||
|
return await self._hippocampus.parahippocampal_gyrus.operation_build_memory()
|
||||||
|
|
||||||
async def get_memory_from_text(
|
async def get_memory_from_text(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -339,7 +339,7 @@ class NoticeHandler:
|
|||||||
message_id=raw_message.get("message_id",""),
|
message_id=raw_message.get("message_id",""),
|
||||||
emoji_id=like_emoji_id
|
emoji_id=like_emoji_id
|
||||||
)
|
)
|
||||||
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]")
|
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]")
|
||||||
return seg_data, user_info
|
return seg_data, user_info
|
||||||
|
|
||||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||||
|
|||||||
Reference in New Issue
Block a user