fix:优化记忆同步算法,修复记忆构建没有chat_id的问题

This commit is contained in:
SengokuCola
2025-06-03 13:10:48 +08:00
parent 94c7072cc9
commit 691172766e
2 changed files with 373 additions and 75 deletions

View File

@@ -17,12 +17,14 @@ from src.chat.memory_system.sample_distribution import MemoryBuildScheduler #
from ..utils.chat_message_builder import (
get_raw_msg_by_timestamp,
build_readable_messages,
get_raw_msg_by_timestamp_with_chat,
) # 导入 build_readable_messages
from ..utils.utils import translate_timestamp_to_human_readable
from rich.traceback import install
from ...config.config import global_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from peewee import Case
install(extra_lines=3)
@@ -215,15 +217,18 @@ class Hippocampus:
"""计算节点的特征值"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
sorted_items = sorted(memory_items)
content = f"{concept}:{'|'.join(sorted_items)}"
# 使用集合来去重,避免排序
unique_items = set(str(item) for item in memory_items)
# 使用frozenset来保证顺序一致性
content = f"{concept}:{frozenset(unique_items)}"
return hash(content)
@staticmethod
def calculate_edge_hash(source, target) -> int:
"""计算边的特征值"""
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
# 直接使用元组,保证顺序一致性
return hash((source, target))
@staticmethod
def find_topic_llm(text, topic_num):
@@ -811,7 +816,8 @@ class EntorhinalCortex:
timestamps = sample_scheduler.get_timestamp_array()
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
logger.info(f"回忆往事: {readable_timestamps}")
for timestamp, readable_timestamp in zip(timestamps, readable_timestamps):
logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = []
for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet
@@ -820,10 +826,10 @@ class EntorhinalCortex:
)
if messages:
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
logger.debug(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}")
logger.success(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}")
chat_samples.append(messages)
else:
logger.debug(f"时间戳 {timestamp} 的消息样本抽取失败")
logger.debug(f"时间戳 {timestamp} 的消息无需记忆")
return chat_samples
@@ -838,10 +844,15 @@ class EntorhinalCortex:
timestamp_start = target_timestamp
timestamp_end = target_timestamp + time_window_seconds
# 使用 chat_message_builder 的函数获取消息
# limit_mode='earliest' 获取这个时间窗口内最早的 chat_size 条消息
messages = get_raw_msg_by_timestamp(
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest"
chosen_message = get_raw_msg_by_timestamp(
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
)
if chosen_message:
chat_id = chosen_message[0].get("chat_id")
messages = get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id
)
if messages:
@@ -873,85 +884,361 @@ class EntorhinalCortex:
async def sync_memory_to_db(self):
"""将记忆图同步到数据库"""
start_time = time.time()
# 获取数据库中所有节点和内存中所有节点
db_load_start = time.time()
db_nodes = {node.concept: node for node in GraphNodes.select()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
db_load_end = time.time()
logger.info(f"[同步] 加载数据库耗时: {db_load_end - db_load_start:.2f}")
# 批量准备节点数据
nodes_to_create = []
nodes_to_update = []
current_time = datetime.datetime.now().timestamp()
# 检查并更新节点
node_process_start = time.time()
for concept, data in memory_nodes:
# 检查概念是否有效
if not concept or not isinstance(concept, str):
logger.warning(f"[同步] 发现无效概念,将移除节点: {concept}")
# 从图中移除节点(这会自动移除相关的边)
self.memory_graph.G.remove_node(concept)
continue
memory_items = data.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 检查记忆项是否为空
if not memory_items:
logger.warning(f"[同步] 发现空记忆节点,将移除节点: {concept}")
# 从图中移除节点(这会自动移除相关的边)
self.memory_graph.G.remove_node(concept)
continue
# 计算内存中节点的特征值
memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items)
# 获取时间信息
created_time = data.get("created_time", datetime.datetime.now().timestamp())
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
created_time = data.get("created_time", current_time)
last_modified = data.get("last_modified", current_time)
# 将memory_items转换为JSON字符串
try:
# 确保memory_items中的每个项都是字符串
memory_items = [str(item) for item in memory_items]
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
if not memory_items_json: # 确保JSON字符串不为空
raise ValueError("序列化后的JSON字符串为空")
# 验证JSON字符串是否有效
json.loads(memory_items_json)
except Exception as e:
logger.error(f"[同步] 序列化记忆项失败,将移除节点: {concept}, 错误: {e}")
# 从图中移除节点(这会自动移除相关的边)
self.memory_graph.G.remove_node(concept)
continue
if concept not in db_nodes:
# 数据库中缺少的节点,添加
GraphNodes.create(
concept=concept,
memory_items=memory_items_json,
hash=memory_hash,
created_time=created_time,
last_modified=last_modified,
)
# 数据库中缺少的节点,添加到创建列表
nodes_to_create.append({
'concept': concept,
'memory_items': memory_items_json,
'hash': memory_hash,
'created_time': created_time,
'last_modified': last_modified
})
logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}")
else:
# 获取数据库中节点的特征值
db_node = db_nodes[concept]
db_hash = db_node.hash
# 如果特征值不同,则更新节点
# 如果特征值不同,则添加到更新列表
if db_hash != memory_hash:
db_node.memory_items = memory_items_json
db_node.hash = memory_hash
db_node.last_modified = last_modified
db_node.save()
nodes_to_update.append({
'concept': concept,
'memory_items': memory_items_json,
'hash': memory_hash,
'last_modified': last_modified
})
# 检查需要删除的节点
memory_concepts = {concept for concept, _ in memory_nodes}
db_concepts = set(db_nodes.keys())
nodes_to_delete = db_concepts - memory_concepts
node_process_end = time.time()
logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}")
logger.info(f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点")
# 异步批量创建新节点
node_create_start = time.time()
if nodes_to_create:
try:
# 验证所有要创建的节点数据
valid_nodes_to_create = []
for node_data in nodes_to_create:
if not node_data.get('memory_items'):
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
continue
try:
# 验证 JSON 字符串
json.loads(node_data['memory_items'])
valid_nodes_to_create.append(node_data)
except json.JSONDecodeError:
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
continue
if valid_nodes_to_create:
# 使用异步批量插入
batch_size = 100
for i in range(0, len(valid_nodes_to_create), batch_size):
batch = valid_nodes_to_create[i:i + batch_size]
await self._async_batch_create_nodes(batch)
logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点")
else:
logger.warning("[同步] 没有有效的节点可以创建")
except Exception as e:
logger.error(f"[同步] 创建节点失败: {e}")
# 尝试逐个创建以找出问题节点
for node_data in nodes_to_create:
try:
if not node_data.get('memory_items'):
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
continue
try:
json.loads(node_data['memory_items'])
except json.JSONDecodeError:
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
continue
await self._async_create_node(node_data)
except Exception as e:
logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}")
# 从图中移除问题节点
self.memory_graph.G.remove_node(node_data['concept'])
node_create_end = time.time()
logger.info(f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)")
# 异步批量更新节点
node_update_start = time.time()
if nodes_to_update:
# 按批次更新节点每批100个
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i:i + batch_size]
try:
# 验证批次中的每个节点数据
valid_batch = []
for node_data in batch:
# 确保 memory_items 不为空且是有效的 JSON 字符串
if not node_data.get('memory_items'):
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空")
continue
try:
# 验证 JSON 字符串是否有效
json.loads(node_data['memory_items'])
valid_batch.append(node_data)
except json.JSONDecodeError:
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
continue
if not valid_batch:
logger.warning(f"[同步] 批次 {i//batch_size + 1} 没有有效的节点可以更新")
continue
# 异步批量更新节点
await self._async_batch_update_nodes(valid_batch)
logger.debug(f"[同步] 成功更新批次 {i//batch_size + 1} 中的 {len(valid_batch)} 个节点")
except Exception as e:
logger.error(f"[同步] 批量更新节点失败: {e}")
# 如果批量更新失败,尝试逐个更新
for node_data in valid_batch:
try:
await self._async_update_node(node_data)
except Exception as e:
logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}")
# 从图中移除问题节点
self.memory_graph.G.remove_node(node_data['concept'])
node_update_end = time.time()
logger.info(f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)")
# 异步删除不存在的节点
node_delete_start = time.time()
if nodes_to_delete:
await self._async_delete_nodes(nodes_to_delete)
node_delete_end = time.time()
logger.info(f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)")
# 处理边的信息
edge_load_start = time.time()
db_edges = list(GraphEdges.select())
memory_edges = list(self.memory_graph.G.edges(data=True))
edge_load_end = time.time()
logger.info(f"[同步] 加载边数据耗时: {edge_load_end - edge_load_start:.2f}")
# 创建边的哈希值字典
edge_dict_start = time.time()
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
edge_dict_end = time.time()
logger.info(f"[同步] 创建边字典耗时: {edge_dict_end - edge_dict_start:.2f}")
# 批量准备边数据
edges_to_create = []
edges_to_update = []
# 检查并更新边
edge_process_start = time.time()
for source, target, data in memory_edges:
edge_hash = self.hippocampus.calculate_edge_hash(source, target)
edge_key = (source, target)
strength = data.get("strength", 1)
# 获取边的时间信息
created_time = data.get("created_time", datetime.datetime.now().timestamp())
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
created_time = data.get("created_time", current_time)
last_modified = data.get("last_modified", current_time)
if edge_key not in db_edge_dict:
# 添加新边
GraphEdges.create(
source=source,
target=target,
strength=strength,
hash=edge_hash,
created_time=created_time,
last_modified=last_modified,
)
# 添加新边到创建列表
edges_to_create.append({
'source': source,
'target': target,
'strength': strength,
'hash': edge_hash,
'created_time': created_time,
'last_modified': last_modified
})
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]["hash"] != edge_hash:
edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target)
edge.hash = edge_hash
edge.strength = strength
edge.last_modified = last_modified
edge.save()
edges_to_update.append({
'source': source,
'target': target,
'strength': strength,
'hash': edge_hash,
'last_modified': last_modified
})
edge_process_end = time.time()
logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}")
# 异步批量创建新边
edge_create_start = time.time()
if edges_to_create:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i:i + batch_size]
await self._async_batch_create_edges(batch)
edge_create_end = time.time()
logger.info(f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)")
# 异步批量更新边
edge_update_start = time.time()
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i:i + batch_size]
await self._async_batch_update_edges(batch)
edge_update_end = time.time()
logger.info(f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)")
# 检查需要删除的边
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
db_edge_keys = {(edge.source, edge.target) for edge in db_edges}
edges_to_delete = db_edge_keys - memory_edge_keys
# 异步删除不存在的边
edge_delete_start = time.time()
if edges_to_delete:
await self._async_delete_edges(edges_to_delete)
edge_delete_end = time.time()
logger.info(f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)")
end_time = time.time()
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}")
logger.success(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
async def _async_batch_create_nodes(self, nodes_data):
"""异步批量创建节点"""
try:
GraphNodes.insert_many(nodes_data).execute()
except Exception as e:
logger.error(f"[同步] 批量创建节点失败: {e}")
raise
async def _async_create_node(self, node_data):
"""异步创建单个节点"""
try:
GraphNodes.create(**node_data)
except Exception as e:
logger.error(f"[同步] 创建节点失败: {e}")
raise
async def _async_batch_update_nodes(self, nodes_data):
"""异步批量更新节点"""
try:
for node_data in nodes_data:
GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where(
GraphNodes.concept == node_data['concept']
).execute()
except Exception as e:
logger.error(f"[同步] 批量更新节点失败: {e}")
raise
async def _async_update_node(self, node_data):
"""异步更新单个节点"""
try:
GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where(
GraphNodes.concept == node_data['concept']
).execute()
except Exception as e:
logger.error(f"[同步] 更新节点失败: {e}")
raise
async def _async_delete_nodes(self, concepts):
"""异步删除节点"""
try:
GraphNodes.delete().where(GraphNodes.concept.in_(concepts)).execute()
except Exception as e:
logger.error(f"[同步] 删除节点失败: {e}")
raise
async def _async_batch_create_edges(self, edges_data):
"""异步批量创建边"""
try:
GraphEdges.insert_many(edges_data).execute()
except Exception as e:
logger.error(f"[同步] 批量创建边失败: {e}")
raise
async def _async_batch_update_edges(self, edges_data):
"""异步批量更新边"""
try:
for edge_data in edges_data:
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ['source', 'target']}).where(
(GraphEdges.source == edge_data['source']) &
(GraphEdges.target == edge_data['target'])
).execute()
except Exception as e:
logger.error(f"[同步] 批量更新边失败: {e}")
raise
async def _async_delete_edges(self, edge_keys):
"""异步删除边"""
try:
for source, target in edge_keys:
GraphEdges.delete().where(
(GraphEdges.source == source) &
(GraphEdges.target == target)
).execute()
except Exception as e:
logger.error(f"[同步] 删除边失败: {e}")
raise
def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
@@ -1111,7 +1398,7 @@ class ParahippocampalGyrus:
input_text = await build_readable_messages(
messages,
merge_messages=True, # 合并连续消息
timestamp_mode="normal", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
replace_bot_name=False, # 保留原始用户名
)
@@ -1120,7 +1407,11 @@ class ParahippocampalGyrus:
logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。")
return set(), {}
logger.debug(f"用于压缩的格式化文本:\n{input_text}")
current_YMD_time = datetime.datetime.now().strftime("%Y-%m-%d")
current_YMD_time_str = f"当前日期: {current_YMD_time}"
input_text = f"{current_YMD_time_str}\n{input_text}"
logger.debug(f"记忆来源:\n{input_text}")
# 2. 使用LLM提取关键主题
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
@@ -1191,7 +1482,7 @@ class ParahippocampalGyrus:
return compressed_memory, similar_topics_dict
async def operation_build_memory(self):
logger.debug("------------------------------------开始构建记忆--------------------------------------")
logger.info("------------------------------------开始构建记忆--------------------------------------")
start_time = time.time()
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
all_added_nodes = []
@@ -1199,19 +1490,16 @@ class ParahippocampalGyrus:
all_added_edges = []
for i, messages in enumerate(memory_samples, 1):
all_topics = []
progress = (i / len(memory_samples)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_samples))
bar = "" * filled_length + "-" * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
compress_rate = global_config.memory.memory_compress_rate
try:
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
except Exception as e:
logger.error(f"压缩记忆时发生错误: {e}")
continue
logger.debug(f"压缩后记忆数量: {compressed_memory},似曾相识的话题: {similar_topics_dict}")
for topic, memory in compressed_memory:
logger.info(f"取得记忆: {topic} - {memory}")
for topic, similar_topics in similar_topics_dict.items():
logger.debug(f"相似话题: {topic} - {similar_topics}")
current_time = datetime.datetime.now().timestamp()
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
@@ -1246,8 +1534,18 @@ class ParahippocampalGyrus:
all_added_edges.append(f"{topic1}-{topic2}")
self.memory_graph.connect_dot(topic1, topic2)
progress = (i / len(memory_samples)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_samples))
bar = "" * filled_length + "-" * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
if all_added_nodes:
logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
if all_added_edges:
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
if all_connected_nodes:
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()

View File

@@ -342,7 +342,7 @@ async def _build_readable_messages_internal(
# 使用指定的 timestamp_mode 格式化时间
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
header = f"{readable_time}{merged['name']} :"
header = f"{readable_time}, {merged['name']} :"
output_lines.append(header)
# 将内容合并,并添加缩进
for line in merged["content"]: