fix: 构建记忆时重复读取同一段消息,导致token消耗暴增

This commit is contained in:
Oct-autumn
2025-03-14 13:37:23 +08:00
parent cccea98ba2
commit 33df5981b4
4 changed files with 59 additions and 7 deletions

View File

@@ -27,6 +27,7 @@ class Message(MessageBase):
reply: Optional["Message"] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
memorized_times: int = 0
def __init__(
self,

View File

@@ -19,6 +19,7 @@ class MessageStorage:
"processed_plain_text": message.processed_plain_text,
"detailed_plain_text": message.detailed_plain_text,
"topic": topic,
"memorized_times": message.memorized_times,
}
db.messages.insert_one(message_data)
except Exception:

View File

@@ -104,10 +104,13 @@ def get_closest_chat_from_db(length: int, timestamp: str):
# 转换记录格式
formatted_records = []
for record in chat_records:
# 兼容行为,前向兼容老数据
formatted_records.append({
'_id': record["_id"],
'time': record["time"],
'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
'detailed_plain_text': record.get("detailed_plain_text", ""), # 添加文本内容
'memorized_times': record.get("memorized_times", 0) # 添加记忆次数
})
return formatted_records

View File

@@ -178,33 +178,80 @@ class Hippocampus:
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def random_get_msg_snippet(self, target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
"""随机抽取一段时间内的消息片段
Args:
- target_timestamp: 目标时间戳
- chat_size: 抽取的消息数量
- max_memorized_time_per_msg: 每条消息的最大记忆次数
Returns:
- list: 抽取出的消息记录列表
"""
try_count = 0
# 最多尝试三次抽取
while try_count < 3:
messages = get_closest_chat_from_db(length=chat_size, timestamp=target_timestamp)
if messages:
# 检查messages是否均没有达到记忆次数限制
for message in messages:
if message["memorized_times"] >= max_memorized_time_per_msg:
messages = None
break
if messages:
# 成功抽取短期消息样本
# 数据写回:增加记忆次数
for message in messages:
db.messages.update_one({"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}})
return messages
try_count += 1
# 三次尝试均失败
return None
def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}):
"""获取记忆样本
Returns:
list: 消息记录列表,每个元素是一个消息记录字典列表
"""
# 硬编码:每条消息最大记忆次数
# 如有需求可写入global_config
max_memorized_time_per_msg = 3
current_timestamp = datetime.datetime.now().timestamp()
chat_samples = []
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
logger.debug(f"正在抽取短期消息样本")
for i in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
if messages:
logger.debug(f"成功抽取短期消息样本{len(messages)}")
chat_samples.append(messages)
else:
logger.warning(f"{i}次短期消息样本抽取失败")
for _ in range(time_frequency.get('mid')):
logger.debug(f"正在抽取中期消息样本")
for i in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600, 3600 * 4)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
if messages:
logger.debug(f"成功抽取中期消息样本{len(messages)}")
chat_samples.append(messages)
else:
logger.warning(f"{i}次中期消息样本抽取失败")
for _ in range(time_frequency.get('far')):
logger.debug(f"正在抽取长期消息样本")
for i in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
if messages:
logger.debug(f"成功抽取长期消息样本{len(messages)}")
chat_samples.append(messages)
else:
logger.warning(f"{i}次长期消息样本抽取失败")
return chat_samples