diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index d1117e9e6..3721b169e 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -3,6 +3,7 @@ import datetime import math import random import time +import asyncio import re import orjson import jieba @@ -877,7 +878,7 @@ class EntorhinalCortex: self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - async def get_memory_sample(self): + async def get_memory_sample(self) -> tuple[list, list[str]]: """从数据库获取记忆样本""" # 硬编码:每条消息最大记忆次数 max_memorized_time_per_msg = 2 @@ -899,24 +900,27 @@ class EntorhinalCortex: for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False): logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] + all_message_ids_to_update = [] for timestamp in timestamps: - if messages := await self.random_get_msg_snippet( + if result := await self.random_get_msg_snippet( timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg, ): + messages, message_ids_to_update = result time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") chat_samples.append(messages) + all_message_ids_to_update.extend(message_ids_to_update) else: logger.debug(f"时间戳 {timestamp} 的消息无需记忆") - return chat_samples + return chat_samples, all_message_ids_to_update @staticmethod async def random_get_msg_snippet( target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int - ) -> list | None: + ) -> tuple[list, list[str]] | None: # sourcery skip: invert-any-all, use-any, use-named-expression, use-next """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟 @@ -950,18 +954,9 @@ class EntorhinalCortex: # 如果所有消息都有效 if all_valid: - # 更新数据库中的记忆次数 - for message in messages: - # 确保在更新前获取最新的 memorized_times - current_memorized_times = message.get("memorized_times", 0) - async with get_db_session() as session: - await session.execute( - update(Messages) - .where(Messages.message_id == message["message_id"]) - .values(memorized_times=current_memorized_times + 1) - ) - await session.commit() - return messages # 直接返回原始的消息列表 + # 返回消息和需要更新的message_id + message_ids_to_update = [msg["message_id"] for msg in messages] + return messages, message_ids_to_update target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试 @@ -1042,10 +1037,20 @@ class EntorhinalCortex: # 批量处理节点 if nodes_to_create: - batch_size = 100 - for i in range(0, len(nodes_to_create), batch_size): - batch = nodes_to_create[i : i + batch_size] - await session.execute(insert(GraphNodes), batch) + # 在插入前进行去重检查 + unique_nodes_to_create = [] + seen_concepts = set(db_nodes.keys()) + for node_data in nodes_to_create: + concept = node_data["concept"] + if concept not in seen_concepts: + unique_nodes_to_create.append(node_data) + seen_concepts.add(concept) + + if unique_nodes_to_create: + batch_size = 100 + for i in range(0, len(unique_nodes_to_create), batch_size): + batch = unique_nodes_to_create[i : i + batch_size] + await session.execute(insert(GraphNodes), batch) if nodes_to_update: batch_size = 100 @@ -1440,7 +1445,7 @@ class ParahippocampalGyrus: # sourcery skip: merge-list-appends-into-extend logger.info("------------------------------------开始构建记忆--------------------------------------") start_time = time.time() - memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample() + memory_samples, all_message_ids_to_update = await self.hippocampus.entorhinal_cortex.get_memory_sample() all_added_nodes = [] all_connected_nodes = [] all_added_edges = [] @@ -1503,8 +1508,21 @@ class ParahippocampalGyrus: if all_connected_nodes: logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") + # 先同步记忆图 await self.hippocampus.entorhinal_cortex.sync_memory_to_db() + # 最后批量更新消息的记忆次数 + if all_message_ids_to_update: + async with get_db_session() as session: + # 使用 in_ 操作符进行批量更新 + await session.execute( + update(Messages) + .where(Messages.message_id.in_(all_message_ids_to_update)) + .values(memorized_times=Messages.memorized_times + 1) + ) + await session.commit() + logger.info(f"批量更新了 {len(all_message_ids_to_update)} 条消息的记忆次数") + end_time = time.time() logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------") @@ -1694,6 +1712,7 @@ class HippocampusManager: def __init__(self): self._hippocampus: Hippocampus = None # type: ignore self._initialized = False + self._db_lock = asyncio.Lock() def initialize(self): """初始化海马体实例""" @@ -1742,14 +1761,16 @@ class HippocampusManager: """遗忘记忆的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage) + async with self._db_lock: + return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage) async def consolidate_memory(self): """整合记忆的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") # 使用 operation_build_memory 方法来整合记忆 - return await self._hippocampus.parahippocampal_gyrus.operation_build_memory() + async with self._db_lock: + return await self._hippocampus.parahippocampal_gyrus.operation_build_memory() async def get_memory_from_text( self, diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 4a32657a7..58b7f23b9 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -339,7 +339,7 @@ class NoticeHandler: message_id=raw_message.get("message_id",""), emoji_id=like_emoji_id ) - seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]") + seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]") return seg_data, user_info async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: