fix:优化记忆同步算法,修复记忆构建没有chat_id的问题
This commit is contained in:
@@ -17,12 +17,14 @@ from src.chat.memory_system.sample_distribution import MemoryBuildScheduler #
|
|||||||
from ..utils.chat_message_builder import (
|
from ..utils.chat_message_builder import (
|
||||||
get_raw_msg_by_timestamp,
|
get_raw_msg_by_timestamp,
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
|
get_raw_msg_by_timestamp_with_chat,
|
||||||
) # 导入 build_readable_messages
|
) # 导入 build_readable_messages
|
||||||
from ..utils.utils import translate_timestamp_to_human_readable
|
from ..utils.utils import translate_timestamp_to_human_readable
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||||
|
from peewee import Case
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -215,15 +217,18 @@ class Hippocampus:
|
|||||||
"""计算节点的特征值"""
|
"""计算节点的特征值"""
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
sorted_items = sorted(memory_items)
|
|
||||||
content = f"{concept}:{'|'.join(sorted_items)}"
|
# 使用集合来去重,避免排序
|
||||||
|
unique_items = set(str(item) for item in memory_items)
|
||||||
|
# 使用frozenset来保证顺序一致性
|
||||||
|
content = f"{concept}:{frozenset(unique_items)}"
|
||||||
return hash(content)
|
return hash(content)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calculate_edge_hash(source, target) -> int:
|
def calculate_edge_hash(source, target) -> int:
|
||||||
"""计算边的特征值"""
|
"""计算边的特征值"""
|
||||||
nodes = sorted([source, target])
|
# 直接使用元组,保证顺序一致性
|
||||||
return hash(f"{nodes[0]}:{nodes[1]}")
|
return hash((source, target))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_topic_llm(text, topic_num):
|
def find_topic_llm(text, topic_num):
|
||||||
@@ -811,7 +816,8 @@ class EntorhinalCortex:
|
|||||||
timestamps = sample_scheduler.get_timestamp_array()
|
timestamps = sample_scheduler.get_timestamp_array()
|
||||||
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
||||||
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
||||||
logger.info(f"回忆往事: {readable_timestamps}")
|
for timestamp, readable_timestamp in zip(timestamps, readable_timestamps):
|
||||||
|
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
# 调用修改后的 random_get_msg_snippet
|
# 调用修改后的 random_get_msg_snippet
|
||||||
@@ -820,10 +826,10 @@ class EntorhinalCortex:
|
|||||||
)
|
)
|
||||||
if messages:
|
if messages:
|
||||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||||
logger.debug(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
logger.success(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||||
chat_samples.append(messages)
|
chat_samples.append(messages)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"时间戳 {timestamp} 的消息样本抽取失败")
|
logger.debug(f"时间戳 {timestamp} 的消息无需记忆")
|
||||||
|
|
||||||
return chat_samples
|
return chat_samples
|
||||||
|
|
||||||
@@ -838,10 +844,15 @@ class EntorhinalCortex:
|
|||||||
timestamp_start = target_timestamp
|
timestamp_start = target_timestamp
|
||||||
timestamp_end = target_timestamp + time_window_seconds
|
timestamp_end = target_timestamp + time_window_seconds
|
||||||
|
|
||||||
# 使用 chat_message_builder 的函数获取消息
|
chosen_message = get_raw_msg_by_timestamp(
|
||||||
# limit_mode='earliest' 获取这个时间窗口内最早的 chat_size 条消息
|
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
|
||||||
messages = get_raw_msg_by_timestamp(
|
)
|
||||||
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest"
|
|
||||||
|
if chosen_message:
|
||||||
|
chat_id = chosen_message[0].get("chat_id")
|
||||||
|
|
||||||
|
messages = get_raw_msg_by_timestamp_with_chat(
|
||||||
|
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if messages:
|
if messages:
|
||||||
@@ -873,85 +884,361 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
async def sync_memory_to_db(self):
|
async def sync_memory_to_db(self):
|
||||||
"""将记忆图同步到数据库"""
|
"""将记忆图同步到数据库"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
# 获取数据库中所有节点和内存中所有节点
|
# 获取数据库中所有节点和内存中所有节点
|
||||||
|
db_load_start = time.time()
|
||||||
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||||
|
db_load_end = time.time()
|
||||||
|
logger.info(f"[同步] 加载数据库耗时: {db_load_end - db_load_start:.2f}秒")
|
||||||
|
|
||||||
|
# 批量准备节点数据
|
||||||
|
nodes_to_create = []
|
||||||
|
nodes_to_update = []
|
||||||
|
current_time = datetime.datetime.now().timestamp()
|
||||||
|
|
||||||
# 检查并更新节点
|
# 检查并更新节点
|
||||||
|
node_process_start = time.time()
|
||||||
for concept, data in memory_nodes:
|
for concept, data in memory_nodes:
|
||||||
|
# 检查概念是否有效
|
||||||
|
if not concept or not isinstance(concept, str):
|
||||||
|
logger.warning(f"[同步] 发现无效概念,将移除节点: {concept}")
|
||||||
|
# 从图中移除节点(这会自动移除相关的边)
|
||||||
|
self.memory_graph.G.remove_node(concept)
|
||||||
|
continue
|
||||||
|
|
||||||
memory_items = data.get("memory_items", [])
|
memory_items = data.get("memory_items", [])
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
|
# 检查记忆项是否为空
|
||||||
|
if not memory_items:
|
||||||
|
logger.warning(f"[同步] 发现空记忆节点,将移除节点: {concept}")
|
||||||
|
# 从图中移除节点(这会自动移除相关的边)
|
||||||
|
self.memory_graph.G.remove_node(concept)
|
||||||
|
continue
|
||||||
|
|
||||||
# 计算内存中节点的特征值
|
# 计算内存中节点的特征值
|
||||||
memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items)
|
memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items)
|
||||||
|
|
||||||
# 获取时间信息
|
# 获取时间信息
|
||||||
created_time = data.get("created_time", datetime.datetime.now().timestamp())
|
created_time = data.get("created_time", current_time)
|
||||||
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
|
last_modified = data.get("last_modified", current_time)
|
||||||
|
|
||||||
# 将memory_items转换为JSON字符串
|
# 将memory_items转换为JSON字符串
|
||||||
|
try:
|
||||||
|
# 确保memory_items中的每个项都是字符串
|
||||||
|
memory_items = [str(item) for item in memory_items]
|
||||||
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
|
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
|
||||||
|
if not memory_items_json: # 确保JSON字符串不为空
|
||||||
|
raise ValueError("序列化后的JSON字符串为空")
|
||||||
|
# 验证JSON字符串是否有效
|
||||||
|
json.loads(memory_items_json)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 序列化记忆项失败,将移除节点: {concept}, 错误: {e}")
|
||||||
|
# 从图中移除节点(这会自动移除相关的边)
|
||||||
|
self.memory_graph.G.remove_node(concept)
|
||||||
|
continue
|
||||||
|
|
||||||
if concept not in db_nodes:
|
if concept not in db_nodes:
|
||||||
# 数据库中缺少的节点,添加
|
# 数据库中缺少的节点,添加到创建列表
|
||||||
GraphNodes.create(
|
nodes_to_create.append({
|
||||||
concept=concept,
|
'concept': concept,
|
||||||
memory_items=memory_items_json,
|
'memory_items': memory_items_json,
|
||||||
hash=memory_hash,
|
'hash': memory_hash,
|
||||||
created_time=created_time,
|
'created_time': created_time,
|
||||||
last_modified=last_modified,
|
'last_modified': last_modified
|
||||||
)
|
})
|
||||||
|
logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}")
|
||||||
else:
|
else:
|
||||||
# 获取数据库中节点的特征值
|
# 获取数据库中节点的特征值
|
||||||
db_node = db_nodes[concept]
|
db_node = db_nodes[concept]
|
||||||
db_hash = db_node.hash
|
db_hash = db_node.hash
|
||||||
|
|
||||||
# 如果特征值不同,则更新节点
|
# 如果特征值不同,则添加到更新列表
|
||||||
if db_hash != memory_hash:
|
if db_hash != memory_hash:
|
||||||
db_node.memory_items = memory_items_json
|
nodes_to_update.append({
|
||||||
db_node.hash = memory_hash
|
'concept': concept,
|
||||||
db_node.last_modified = last_modified
|
'memory_items': memory_items_json,
|
||||||
db_node.save()
|
'hash': memory_hash,
|
||||||
|
'last_modified': last_modified
|
||||||
|
})
|
||||||
|
|
||||||
|
# 检查需要删除的节点
|
||||||
|
memory_concepts = {concept for concept, _ in memory_nodes}
|
||||||
|
db_concepts = set(db_nodes.keys())
|
||||||
|
nodes_to_delete = db_concepts - memory_concepts
|
||||||
|
|
||||||
|
node_process_end = time.time()
|
||||||
|
logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}秒")
|
||||||
|
logger.info(f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点")
|
||||||
|
|
||||||
|
# 异步批量创建新节点
|
||||||
|
node_create_start = time.time()
|
||||||
|
if nodes_to_create:
|
||||||
|
try:
|
||||||
|
# 验证所有要创建的节点数据
|
||||||
|
valid_nodes_to_create = []
|
||||||
|
for node_data in nodes_to_create:
|
||||||
|
if not node_data.get('memory_items'):
|
||||||
|
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
# 验证 JSON 字符串
|
||||||
|
json.loads(node_data['memory_items'])
|
||||||
|
valid_nodes_to_create.append(node_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if valid_nodes_to_create:
|
||||||
|
# 使用异步批量插入
|
||||||
|
batch_size = 100
|
||||||
|
for i in range(0, len(valid_nodes_to_create), batch_size):
|
||||||
|
batch = valid_nodes_to_create[i:i + batch_size]
|
||||||
|
await self._async_batch_create_nodes(batch)
|
||||||
|
logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点")
|
||||||
|
else:
|
||||||
|
logger.warning("[同步] 没有有效的节点可以创建")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 创建节点失败: {e}")
|
||||||
|
# 尝试逐个创建以找出问题节点
|
||||||
|
for node_data in nodes_to_create:
|
||||||
|
try:
|
||||||
|
if not node_data.get('memory_items'):
|
||||||
|
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
json.loads(node_data['memory_items'])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
|
||||||
|
continue
|
||||||
|
await self._async_create_node(node_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}")
|
||||||
|
# 从图中移除问题节点
|
||||||
|
self.memory_graph.G.remove_node(node_data['concept'])
|
||||||
|
node_create_end = time.time()
|
||||||
|
logger.info(f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)")
|
||||||
|
|
||||||
|
# 异步批量更新节点
|
||||||
|
node_update_start = time.time()
|
||||||
|
if nodes_to_update:
|
||||||
|
# 按批次更新节点,每批100个
|
||||||
|
batch_size = 100
|
||||||
|
for i in range(0, len(nodes_to_update), batch_size):
|
||||||
|
batch = nodes_to_update[i:i + batch_size]
|
||||||
|
try:
|
||||||
|
# 验证批次中的每个节点数据
|
||||||
|
valid_batch = []
|
||||||
|
for node_data in batch:
|
||||||
|
# 确保 memory_items 不为空且是有效的 JSON 字符串
|
||||||
|
if not node_data.get('memory_items'):
|
||||||
|
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
# 验证 JSON 字符串是否有效
|
||||||
|
json.loads(node_data['memory_items'])
|
||||||
|
valid_batch.append(node_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not valid_batch:
|
||||||
|
logger.warning(f"[同步] 批次 {i//batch_size + 1} 没有有效的节点可以更新")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 异步批量更新节点
|
||||||
|
await self._async_batch_update_nodes(valid_batch)
|
||||||
|
logger.debug(f"[同步] 成功更新批次 {i//batch_size + 1} 中的 {len(valid_batch)} 个节点")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 批量更新节点失败: {e}")
|
||||||
|
# 如果批量更新失败,尝试逐个更新
|
||||||
|
for node_data in valid_batch:
|
||||||
|
try:
|
||||||
|
await self._async_update_node(node_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}")
|
||||||
|
# 从图中移除问题节点
|
||||||
|
self.memory_graph.G.remove_node(node_data['concept'])
|
||||||
|
|
||||||
|
node_update_end = time.time()
|
||||||
|
logger.info(f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)")
|
||||||
|
|
||||||
|
# 异步删除不存在的节点
|
||||||
|
node_delete_start = time.time()
|
||||||
|
if nodes_to_delete:
|
||||||
|
await self._async_delete_nodes(nodes_to_delete)
|
||||||
|
node_delete_end = time.time()
|
||||||
|
logger.info(f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)")
|
||||||
|
|
||||||
# 处理边的信息
|
# 处理边的信息
|
||||||
|
edge_load_start = time.time()
|
||||||
db_edges = list(GraphEdges.select())
|
db_edges = list(GraphEdges.select())
|
||||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||||
|
edge_load_end = time.time()
|
||||||
|
logger.info(f"[同步] 加载边数据耗时: {edge_load_end - edge_load_start:.2f}秒")
|
||||||
|
|
||||||
# 创建边的哈希值字典
|
# 创建边的哈希值字典
|
||||||
|
edge_dict_start = time.time()
|
||||||
db_edge_dict = {}
|
db_edge_dict = {}
|
||||||
for edge in db_edges:
|
for edge in db_edges:
|
||||||
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
|
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
|
||||||
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
|
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
|
||||||
|
edge_dict_end = time.time()
|
||||||
|
logger.info(f"[同步] 创建边字典耗时: {edge_dict_end - edge_dict_start:.2f}秒")
|
||||||
|
|
||||||
|
# 批量准备边数据
|
||||||
|
edges_to_create = []
|
||||||
|
edges_to_update = []
|
||||||
|
|
||||||
# 检查并更新边
|
# 检查并更新边
|
||||||
|
edge_process_start = time.time()
|
||||||
for source, target, data in memory_edges:
|
for source, target, data in memory_edges:
|
||||||
edge_hash = self.hippocampus.calculate_edge_hash(source, target)
|
edge_hash = self.hippocampus.calculate_edge_hash(source, target)
|
||||||
edge_key = (source, target)
|
edge_key = (source, target)
|
||||||
strength = data.get("strength", 1)
|
strength = data.get("strength", 1)
|
||||||
|
|
||||||
# 获取边的时间信息
|
# 获取边的时间信息
|
||||||
created_time = data.get("created_time", datetime.datetime.now().timestamp())
|
created_time = data.get("created_time", current_time)
|
||||||
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
|
last_modified = data.get("last_modified", current_time)
|
||||||
|
|
||||||
if edge_key not in db_edge_dict:
|
if edge_key not in db_edge_dict:
|
||||||
# 添加新边
|
# 添加新边到创建列表
|
||||||
GraphEdges.create(
|
edges_to_create.append({
|
||||||
source=source,
|
'source': source,
|
||||||
target=target,
|
'target': target,
|
||||||
strength=strength,
|
'strength': strength,
|
||||||
hash=edge_hash,
|
'hash': edge_hash,
|
||||||
created_time=created_time,
|
'created_time': created_time,
|
||||||
last_modified=last_modified,
|
'last_modified': last_modified
|
||||||
)
|
})
|
||||||
else:
|
else:
|
||||||
# 检查边的特征值是否变化
|
# 检查边的特征值是否变化
|
||||||
if db_edge_dict[edge_key]["hash"] != edge_hash:
|
if db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||||
edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target)
|
edges_to_update.append({
|
||||||
edge.hash = edge_hash
|
'source': source,
|
||||||
edge.strength = strength
|
'target': target,
|
||||||
edge.last_modified = last_modified
|
'strength': strength,
|
||||||
edge.save()
|
'hash': edge_hash,
|
||||||
|
'last_modified': last_modified
|
||||||
|
})
|
||||||
|
edge_process_end = time.time()
|
||||||
|
logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}秒")
|
||||||
|
|
||||||
|
# 异步批量创建新边
|
||||||
|
edge_create_start = time.time()
|
||||||
|
if edges_to_create:
|
||||||
|
batch_size = 100
|
||||||
|
for i in range(0, len(edges_to_create), batch_size):
|
||||||
|
batch = edges_to_create[i:i + batch_size]
|
||||||
|
await self._async_batch_create_edges(batch)
|
||||||
|
edge_create_end = time.time()
|
||||||
|
logger.info(f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)")
|
||||||
|
|
||||||
|
# 异步批量更新边
|
||||||
|
edge_update_start = time.time()
|
||||||
|
if edges_to_update:
|
||||||
|
batch_size = 100
|
||||||
|
for i in range(0, len(edges_to_update), batch_size):
|
||||||
|
batch = edges_to_update[i:i + batch_size]
|
||||||
|
await self._async_batch_update_edges(batch)
|
||||||
|
edge_update_end = time.time()
|
||||||
|
logger.info(f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)")
|
||||||
|
|
||||||
|
# 检查需要删除的边
|
||||||
|
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
|
||||||
|
db_edge_keys = {(edge.source, edge.target) for edge in db_edges}
|
||||||
|
edges_to_delete = db_edge_keys - memory_edge_keys
|
||||||
|
|
||||||
|
# 异步删除不存在的边
|
||||||
|
edge_delete_start = time.time()
|
||||||
|
if edges_to_delete:
|
||||||
|
await self._async_delete_edges(edges_to_delete)
|
||||||
|
edge_delete_end = time.time()
|
||||||
|
logger.info(f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)")
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||||
|
logger.success(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
|
||||||
|
|
||||||
|
async def _async_batch_create_nodes(self, nodes_data):
|
||||||
|
"""异步批量创建节点"""
|
||||||
|
try:
|
||||||
|
GraphNodes.insert_many(nodes_data).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 批量创建节点失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _async_create_node(self, node_data):
|
||||||
|
"""异步创建单个节点"""
|
||||||
|
try:
|
||||||
|
GraphNodes.create(**node_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 创建节点失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _async_batch_update_nodes(self, nodes_data):
|
||||||
|
"""异步批量更新节点"""
|
||||||
|
try:
|
||||||
|
for node_data in nodes_data:
|
||||||
|
GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where(
|
||||||
|
GraphNodes.concept == node_data['concept']
|
||||||
|
).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 批量更新节点失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _async_update_node(self, node_data):
|
||||||
|
"""异步更新单个节点"""
|
||||||
|
try:
|
||||||
|
GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where(
|
||||||
|
GraphNodes.concept == node_data['concept']
|
||||||
|
).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 更新节点失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _async_delete_nodes(self, concepts):
|
||||||
|
"""异步删除节点"""
|
||||||
|
try:
|
||||||
|
GraphNodes.delete().where(GraphNodes.concept.in_(concepts)).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 删除节点失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _async_batch_create_edges(self, edges_data):
|
||||||
|
"""异步批量创建边"""
|
||||||
|
try:
|
||||||
|
GraphEdges.insert_many(edges_data).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 批量创建边失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _async_batch_update_edges(self, edges_data):
|
||||||
|
"""异步批量更新边"""
|
||||||
|
try:
|
||||||
|
for edge_data in edges_data:
|
||||||
|
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ['source', 'target']}).where(
|
||||||
|
(GraphEdges.source == edge_data['source']) &
|
||||||
|
(GraphEdges.target == edge_data['target'])
|
||||||
|
).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 批量更新边失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _async_delete_edges(self, edge_keys):
|
||||||
|
"""异步删除边"""
|
||||||
|
try:
|
||||||
|
for source, target in edge_keys:
|
||||||
|
GraphEdges.delete().where(
|
||||||
|
(GraphEdges.source == source) &
|
||||||
|
(GraphEdges.target == target)
|
||||||
|
).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[同步] 删除边失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def sync_memory_from_db(self):
|
def sync_memory_from_db(self):
|
||||||
"""从数据库同步数据到内存中的图结构"""
|
"""从数据库同步数据到内存中的图结构"""
|
||||||
@@ -1111,7 +1398,7 @@ class ParahippocampalGyrus:
|
|||||||
input_text = await build_readable_messages(
|
input_text = await build_readable_messages(
|
||||||
messages,
|
messages,
|
||||||
merge_messages=True, # 合并连续消息
|
merge_messages=True, # 合并连续消息
|
||||||
timestamp_mode="normal", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
|
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
|
||||||
replace_bot_name=False, # 保留原始用户名
|
replace_bot_name=False, # 保留原始用户名
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1120,7 +1407,11 @@ class ParahippocampalGyrus:
|
|||||||
logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。")
|
logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。")
|
||||||
return set(), {}
|
return set(), {}
|
||||||
|
|
||||||
logger.debug(f"用于压缩的格式化文本:\n{input_text}")
|
current_YMD_time = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
|
current_YMD_time_str = f"当前日期: {current_YMD_time}"
|
||||||
|
input_text = f"{current_YMD_time_str}\n{input_text}"
|
||||||
|
|
||||||
|
logger.debug(f"记忆来源:\n{input_text}")
|
||||||
|
|
||||||
# 2. 使用LLM提取关键主题
|
# 2. 使用LLM提取关键主题
|
||||||
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
|
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
|
||||||
@@ -1191,7 +1482,7 @@ class ParahippocampalGyrus:
|
|||||||
return compressed_memory, similar_topics_dict
|
return compressed_memory, similar_topics_dict
|
||||||
|
|
||||||
async def operation_build_memory(self):
|
async def operation_build_memory(self):
|
||||||
logger.debug("------------------------------------开始构建记忆--------------------------------------")
|
logger.info("------------------------------------开始构建记忆--------------------------------------")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
|
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||||
all_added_nodes = []
|
all_added_nodes = []
|
||||||
@@ -1199,19 +1490,16 @@ class ParahippocampalGyrus:
|
|||||||
all_added_edges = []
|
all_added_edges = []
|
||||||
for i, messages in enumerate(memory_samples, 1):
|
for i, messages in enumerate(memory_samples, 1):
|
||||||
all_topics = []
|
all_topics = []
|
||||||
progress = (i / len(memory_samples)) * 100
|
|
||||||
bar_length = 30
|
|
||||||
filled_length = int(bar_length * i // len(memory_samples))
|
|
||||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
|
||||||
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
|
||||||
|
|
||||||
compress_rate = global_config.memory.memory_compress_rate
|
compress_rate = global_config.memory.memory_compress_rate
|
||||||
try:
|
try:
|
||||||
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"压缩记忆时发生错误: {e}")
|
logger.error(f"压缩记忆时发生错误: {e}")
|
||||||
continue
|
continue
|
||||||
logger.debug(f"压缩后记忆数量: {compressed_memory},似曾相识的话题: {similar_topics_dict}")
|
for topic, memory in compressed_memory:
|
||||||
|
logger.info(f"取得记忆: {topic} - {memory}")
|
||||||
|
for topic, similar_topics in similar_topics_dict.items():
|
||||||
|
logger.debug(f"相似话题: {topic} - {similar_topics}")
|
||||||
|
|
||||||
current_time = datetime.datetime.now().timestamp()
|
current_time = datetime.datetime.now().timestamp()
|
||||||
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
|
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
|
||||||
@@ -1246,8 +1534,18 @@ class ParahippocampalGyrus:
|
|||||||
all_added_edges.append(f"{topic1}-{topic2}")
|
all_added_edges.append(f"{topic1}-{topic2}")
|
||||||
self.memory_graph.connect_dot(topic1, topic2)
|
self.memory_graph.connect_dot(topic1, topic2)
|
||||||
|
|
||||||
|
|
||||||
|
progress = (i / len(memory_samples)) * 100
|
||||||
|
bar_length = 30
|
||||||
|
filled_length = int(bar_length * i // len(memory_samples))
|
||||||
|
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
||||||
|
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
||||||
|
|
||||||
|
if all_added_nodes:
|
||||||
logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
|
logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
|
||||||
|
if all_added_edges:
|
||||||
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
|
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
|
||||||
|
if all_connected_nodes:
|
||||||
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
|
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
|
||||||
|
|
||||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||||
|
|||||||
@@ -342,7 +342,7 @@ async def _build_readable_messages_internal(
|
|||||||
# 使用指定的 timestamp_mode 格式化时间
|
# 使用指定的 timestamp_mode 格式化时间
|
||||||
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
||||||
|
|
||||||
header = f"{readable_time}{merged['name']} 说:"
|
header = f"{readable_time}, {merged['name']} :"
|
||||||
output_lines.append(header)
|
output_lines.append(header)
|
||||||
# 将内容合并,并添加缩进
|
# 将内容合并,并添加缩进
|
||||||
for line in merged["content"]:
|
for line in merged["content"]:
|
||||||
|
|||||||
Reference in New Issue
Block a user