diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 96308c50b..f05139279 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -27,6 +27,7 @@ class Message(MessageBase): reply: Optional["Message"] = None detailed_plain_text: str = "" processed_plain_text: str = "" + memorized_times: int = 0 def __init__( self, diff --git a/src/plugins/chat/storage.py b/src/plugins/chat/storage.py index ad6662f2b..33099d6b6 100644 --- a/src/plugins/chat/storage.py +++ b/src/plugins/chat/storage.py @@ -19,6 +19,7 @@ class MessageStorage: "processed_plain_text": message.processed_plain_text, "detailed_plain_text": message.detailed_plain_text, "topic": topic, + "memorized_times": message.memorized_times, } db.messages.insert_one(message_data) except Exception: diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index f28d0e192..28e6b7f36 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -104,10 +104,13 @@ def get_closest_chat_from_db(length: int, timestamp: str): # 转换记录格式 formatted_records = [] for record in chat_records: + # 兼容行为,前向兼容老数据 formatted_records.append({ + '_id': record["_id"], 'time': record["time"], 'chat_id': record["chat_id"], - 'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容 + 'detailed_plain_text': record.get("detailed_plain_text", ""), # 添加文本内容 + 'memorized_times': record.get("memorized_times", 0) # 添加记忆次数 }) return formatted_records diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index f87f037d5..c5ec2ddcb 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -178,33 +178,80 @@ class Hippocampus: nodes = sorted([source, target]) return hash(f"{nodes[0]}:{nodes[1]}") + def random_get_msg_snippet(self, target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list: + """随机抽取一段时间内的消息片段 + Args: + - target_timestamp: 目标时间戳 + - chat_size: 抽取的消息数量 + - max_memorized_time_per_msg: 每条消息的最大记忆次数 + + Returns: + - list: 抽取出的消息记录列表 + + """ + try_count = 0 + # 最多尝试三次抽取 + while try_count < 3: + messages = get_closest_chat_from_db(length=chat_size, timestamp=target_timestamp) + if messages: + # 检查messages是否均没有达到记忆次数限制 + for message in messages: + if message["memorized_times"] >= max_memorized_time_per_msg: + messages = None + break + if messages: + # 成功抽取短期消息样本 + # 数据写回:增加记忆次数 + for message in messages: + db.messages.update_one({"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}}) + return messages + try_count += 1 + # 三次尝试均失败 + return None + def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}): """获取记忆样本 Returns: list: 消息记录列表,每个元素是一个消息记录字典列表 """ + # 硬编码:每条消息最大记忆次数 + # 如有需求可写入global_config + max_memorized_time_per_msg = 3 + current_timestamp = datetime.datetime.now().timestamp() chat_samples = [] # 短期:1h 中期:4h 长期:24h - for _ in range(time_frequency.get('near')): + logger.debug(f"正在抽取短期消息样本") + for i in range(time_frequency.get('near')): random_time = current_timestamp - random.randint(1, 3600) - messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) + messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg) if messages: + logger.debug(f"成功抽取短期消息样本{len(messages)}条") chat_samples.append(messages) + else: + logger.warning(f"第{i}次短期消息样本抽取失败") - for _ in range(time_frequency.get('mid')): + logger.debug(f"正在抽取中期消息样本") + for i in range(time_frequency.get('mid')): random_time = current_timestamp - random.randint(3600, 3600 * 4) - messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) + messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg) if messages: + logger.debug(f"成功抽取中期消息样本{len(messages)}条") chat_samples.append(messages) + else: + logger.warning(f"第{i}次中期消息样本抽取失败") - for _ in range(time_frequency.get('far')): + logger.debug(f"正在抽取长期消息样本") + for i in range(time_frequency.get('far')): random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24) - messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) + messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg) if messages: + logger.debug(f"成功抽取长期消息样本{len(messages)}条") chat_samples.append(messages) + else: + logger.warning(f"第{i}次长期消息样本抽取失败") return chat_samples