fix: 构建记忆时重复读取同一段消息,导致token消耗暴增
This commit is contained in:
@@ -27,6 +27,7 @@ class Message(MessageBase):
|
|||||||
reply: Optional["Message"] = None
|
reply: Optional["Message"] = None
|
||||||
detailed_plain_text: str = ""
|
detailed_plain_text: str = ""
|
||||||
processed_plain_text: str = ""
|
processed_plain_text: str = ""
|
||||||
|
memorized_times: int = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class MessageStorage:
|
|||||||
"processed_plain_text": message.processed_plain_text,
|
"processed_plain_text": message.processed_plain_text,
|
||||||
"detailed_plain_text": message.detailed_plain_text,
|
"detailed_plain_text": message.detailed_plain_text,
|
||||||
"topic": topic,
|
"topic": topic,
|
||||||
|
"memorized_times": message.memorized_times,
|
||||||
}
|
}
|
||||||
db.messages.insert_one(message_data)
|
db.messages.insert_one(message_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -104,10 +104,13 @@ def get_closest_chat_from_db(length: int, timestamp: str):
|
|||||||
# 转换记录格式
|
# 转换记录格式
|
||||||
formatted_records = []
|
formatted_records = []
|
||||||
for record in chat_records:
|
for record in chat_records:
|
||||||
|
# 兼容行为,前向兼容老数据
|
||||||
formatted_records.append({
|
formatted_records.append({
|
||||||
|
'_id': record["_id"],
|
||||||
'time': record["time"],
|
'time': record["time"],
|
||||||
'chat_id': record["chat_id"],
|
'chat_id': record["chat_id"],
|
||||||
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
|
'detailed_plain_text': record.get("detailed_plain_text", ""), # 添加文本内容
|
||||||
|
'memorized_times': record.get("memorized_times", 0) # 添加记忆次数
|
||||||
})
|
})
|
||||||
|
|
||||||
return formatted_records
|
return formatted_records
|
||||||
|
|||||||
@@ -178,33 +178,80 @@ class Hippocampus:
|
|||||||
nodes = sorted([source, target])
|
nodes = sorted([source, target])
|
||||||
return hash(f"{nodes[0]}:{nodes[1]}")
|
return hash(f"{nodes[0]}:{nodes[1]}")
|
||||||
|
|
||||||
|
def random_get_msg_snippet(self, target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
|
||||||
|
"""随机抽取一段时间内的消息片段
|
||||||
|
Args:
|
||||||
|
- target_timestamp: 目标时间戳
|
||||||
|
- chat_size: 抽取的消息数量
|
||||||
|
- max_memorized_time_per_msg: 每条消息的最大记忆次数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- list: 抽取出的消息记录列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
try_count = 0
|
||||||
|
# 最多尝试三次抽取
|
||||||
|
while try_count < 3:
|
||||||
|
messages = get_closest_chat_from_db(length=chat_size, timestamp=target_timestamp)
|
||||||
|
if messages:
|
||||||
|
# 检查messages是否均没有达到记忆次数限制
|
||||||
|
for message in messages:
|
||||||
|
if message["memorized_times"] >= max_memorized_time_per_msg:
|
||||||
|
messages = None
|
||||||
|
break
|
||||||
|
if messages:
|
||||||
|
# 成功抽取短期消息样本
|
||||||
|
# 数据写回:增加记忆次数
|
||||||
|
for message in messages:
|
||||||
|
db.messages.update_one({"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}})
|
||||||
|
return messages
|
||||||
|
try_count += 1
|
||||||
|
# 三次尝试均失败
|
||||||
|
return None
|
||||||
|
|
||||||
def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}):
|
def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}):
|
||||||
"""获取记忆样本
|
"""获取记忆样本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: 消息记录列表,每个元素是一个消息记录字典列表
|
list: 消息记录列表,每个元素是一个消息记录字典列表
|
||||||
"""
|
"""
|
||||||
|
# 硬编码:每条消息最大记忆次数
|
||||||
|
# 如有需求可写入global_config
|
||||||
|
max_memorized_time_per_msg = 3
|
||||||
|
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
|
|
||||||
# 短期:1h 中期:4h 长期:24h
|
# 短期:1h 中期:4h 长期:24h
|
||||||
for _ in range(time_frequency.get('near')):
|
logger.debug(f"正在抽取短期消息样本")
|
||||||
|
for i in range(time_frequency.get('near')):
|
||||||
random_time = current_timestamp - random.randint(1, 3600)
|
random_time = current_timestamp - random.randint(1, 3600)
|
||||||
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||||
if messages:
|
if messages:
|
||||||
|
logger.debug(f"成功抽取短期消息样本{len(messages)}条")
|
||||||
chat_samples.append(messages)
|
chat_samples.append(messages)
|
||||||
|
else:
|
||||||
|
logger.warning(f"第{i}次短期消息样本抽取失败")
|
||||||
|
|
||||||
for _ in range(time_frequency.get('mid')):
|
logger.debug(f"正在抽取中期消息样本")
|
||||||
|
for i in range(time_frequency.get('mid')):
|
||||||
random_time = current_timestamp - random.randint(3600, 3600 * 4)
|
random_time = current_timestamp - random.randint(3600, 3600 * 4)
|
||||||
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||||
if messages:
|
if messages:
|
||||||
|
logger.debug(f"成功抽取中期消息样本{len(messages)}条")
|
||||||
chat_samples.append(messages)
|
chat_samples.append(messages)
|
||||||
|
else:
|
||||||
|
logger.warning(f"第{i}次中期消息样本抽取失败")
|
||||||
|
|
||||||
for _ in range(time_frequency.get('far')):
|
logger.debug(f"正在抽取长期消息样本")
|
||||||
|
for i in range(time_frequency.get('far')):
|
||||||
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
|
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
|
||||||
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
|
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||||
if messages:
|
if messages:
|
||||||
|
logger.debug(f"成功抽取长期消息样本{len(messages)}条")
|
||||||
chat_samples.append(messages)
|
chat_samples.append(messages)
|
||||||
|
else:
|
||||||
|
logger.warning(f"第{i}次长期消息样本抽取失败")
|
||||||
|
|
||||||
return chat_samples
|
return chat_samples
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user