reformat: 格式化memory.py

This commit is contained in:
AL76
2025-03-15 02:46:38 +08:00
parent f3fef69968
commit 77df50e666

View File

@@ -27,6 +27,7 @@ logger = log_module.setup_logger(LogClassification.MEMORY)
logger.info("初始化记忆系统") logger.info("初始化记忆系统")
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
@@ -35,9 +36,9 @@ class Memory_graph:
# 避免自连接 # 避免自连接
if concept1 == concept2: if concept1 == concept2:
return return
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 如果边已存在,增加 strength # 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2): if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1 self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
@@ -45,14 +46,14 @@ class Memory_graph:
self.G[concept1][concept2]['last_modified'] = current_time self.G[concept1][concept2]['last_modified'] = current_time
else: else:
# 如果是新边,初始化 strength 为 1 # 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, self.G.add_edge(concept1, concept2,
strength=1, strength=1,
created_time=current_time, # 添加创建时间 created_time=current_time, # 添加创建时间
last_modified=current_time) # 添加最后修改时间 last_modified=current_time) # 添加最后修改时间
def add_dot(self, concept, memory): def add_dot(self, concept, memory):
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
if concept in self.G: if concept in self.G:
if 'memory_items' in self.G.nodes[concept]: if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list): if not isinstance(self.G.nodes[concept]['memory_items'], list):
@@ -68,10 +69,10 @@ class Memory_graph:
self.G.nodes[concept]['last_modified'] = current_time self.G.nodes[concept]['last_modified'] = current_time
else: else:
# 如果是新节点,创建新的记忆列表 # 如果是新节点,创建新的记忆列表
self.G.add_node(concept, self.G.add_node(concept,
memory_items=[memory], memory_items=[memory],
created_time=current_time, # 添加创建时间 created_time=current_time, # 添加创建时间
last_modified=current_time) # 添加最后修改时间 last_modified=current_time) # 添加最后修改时间
def get_dot(self, concept): def get_dot(self, concept):
# 检查节点是否存在于图中 # 检查节点是否存在于图中
@@ -210,12 +211,13 @@ class Hippocampus:
# 成功抽取短期消息样本 # 成功抽取短期消息样本
# 数据写回:增加记忆次数 # 数据写回:增加记忆次数
for message in messages: for message in messages:
db.messages.update_one({"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}}) db.messages.update_one({"_id": message["_id"]},
{"$set": {"memorized_times": message["memorized_times"] + 1}})
return messages return messages
try_count += 1 try_count += 1
# 三次尝试均失败 # 三次尝试均失败
return None return None
def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}): def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}):
"""获取记忆样本 """获取记忆样本
@@ -225,7 +227,7 @@ class Hippocampus:
# 硬编码:每条消息最大记忆次数 # 硬编码:每条消息最大记忆次数
# 如有需求可写入global_config # 如有需求可写入global_config
max_memorized_time_per_msg = 3 max_memorized_time_per_msg = 3
current_timestamp = datetime.datetime.now().timestamp() current_timestamp = datetime.datetime.now().timestamp()
chat_samples = [] chat_samples = []
@@ -324,20 +326,20 @@ class Hippocampus:
# 为每个话题查找相似的已存在主题 # 为每个话题查找相似的已存在主题
existing_topics = list(self.memory_graph.G.nodes()) existing_topics = list(self.memory_graph.G.nodes())
similar_topics = [] similar_topics = []
for existing_topic in existing_topics: for existing_topic in existing_topics:
topic_words = set(jieba.cut(topic)) topic_words = set(jieba.cut(topic))
existing_words = set(jieba.cut(existing_topic)) existing_words = set(jieba.cut(existing_topic))
all_words = topic_words | existing_words all_words = topic_words | existing_words
v1 = [1 if word in topic_words else 0 for word in all_words] v1 = [1 if word in topic_words else 0 for word in all_words]
v2 = [1 if word in existing_words else 0 for word in all_words] v2 = [1 if word in existing_words else 0 for word in all_words]
similarity = cosine_similarity(v1, v2) similarity = cosine_similarity(v1, v2)
if similarity >= 0.6: if similarity >= 0.6:
similar_topics.append((existing_topic, similarity)) similar_topics.append((existing_topic, similarity))
similar_topics.sort(key=lambda x: x[1], reverse=True) similar_topics.sort(key=lambda x: x[1], reverse=True)
similar_topics = similar_topics[:5] similar_topics = similar_topics[:5]
similar_topics_dict[topic] = similar_topics similar_topics_dict[topic] = similar_topics
@@ -358,7 +360,7 @@ class Hippocampus:
async def operation_build_memory(self, chat_size=20): async def operation_build_memory(self, chat_size=20):
time_frequency = {'near': 1, 'mid': 4, 'far': 4} time_frequency = {'near': 1, 'mid': 4, 'far': 4}
memory_samples = self.get_memory_sample(chat_size, time_frequency) memory_samples = self.get_memory_sample(chat_size, time_frequency)
for i, messages in enumerate(memory_samples, 1): for i, messages in enumerate(memory_samples, 1):
all_topics = [] all_topics = []
# 加载进度可视化 # 加载进度可视化
@@ -371,14 +373,14 @@ class Hippocampus:
compress_rate = global_config.memory_compress_rate compress_rate = global_config.memory_compress_rate
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}") logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
for topic, memory in compressed_memory: for topic, memory in compressed_memory:
logger.info(f"添加节点: {topic}") logger.info(f"添加节点: {topic}")
self.memory_graph.add_dot(topic, memory) self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) all_topics.append(topic)
# 连接相似的已存在主题 # 连接相似的已存在主题
if topic in similar_topics_dict: if topic in similar_topics_dict:
similar_topics = similar_topics_dict[topic] similar_topics = similar_topics_dict[topic]
@@ -386,11 +388,11 @@ class Hippocampus:
if topic != similar_topic: if topic != similar_topic:
strength = int(similarity * 10) strength = int(similarity * 10)
logger.info(f"连接相似节点: {topic}{similar_topic} (强度: {strength})") logger.info(f"连接相似节点: {topic}{similar_topic} (强度: {strength})")
self.memory_graph.G.add_edge(topic, similar_topic, self.memory_graph.G.add_edge(topic, similar_topic,
strength=strength, strength=strength,
created_time=current_time, created_time=current_time,
last_modified=current_time) last_modified=current_time)
# 连接同批次的相关话题 # 连接同批次的相关话题
for i in range(len(all_topics)): for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)): for j in range(i + 1, len(all_topics)):
@@ -416,7 +418,7 @@ class Hippocampus:
# 计算内存中节点的特征值 # 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items) memory_hash = self.calculate_node_hash(concept, memory_items)
# 获取时间信息 # 获取时间信息
created_time = data.get('created_time', datetime.datetime.now().timestamp()) created_time = data.get('created_time', datetime.datetime.now().timestamp())
last_modified = data.get('last_modified', datetime.datetime.now().timestamp()) last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
@@ -466,7 +468,7 @@ class Hippocampus:
edge_hash = self.calculate_edge_hash(source, target) edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target) edge_key = (source, target)
strength = data.get('strength', 1) strength = data.get('strength', 1)
# 获取边的时间信息 # 获取边的时间信息
created_time = data.get('created_time', datetime.datetime.now().timestamp()) created_time = data.get('created_time', datetime.datetime.now().timestamp())
last_modified = data.get('last_modified', datetime.datetime.now().timestamp()) last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
@@ -499,7 +501,7 @@ class Hippocampus:
"""从数据库同步数据到内存中的图结构""" """从数据库同步数据到内存中的图结构"""
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
need_update = False need_update = False
# 清空当前图 # 清空当前图
self.memory_graph.G.clear() self.memory_graph.G.clear()
@@ -510,7 +512,7 @@ class Hippocampus:
memory_items = node.get('memory_items', []) memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
# 检查时间字段是否存在 # 检查时间字段是否存在
if 'created_time' not in node or 'last_modified' not in node: if 'created_time' not in node or 'last_modified' not in node:
need_update = True need_update = True
@@ -520,22 +522,22 @@ class Hippocampus:
update_data['created_time'] = current_time update_data['created_time'] = current_time
if 'last_modified' not in node: if 'last_modified' not in node:
update_data['last_modified'] = current_time update_data['last_modified'] = current_time
db.graph_data.nodes.update_one( db.graph_data.nodes.update_one(
{'concept': concept}, {'concept': concept},
{'$set': update_data} {'$set': update_data}
) )
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间) # 获取时间信息(如果不存在则使用当前时间)
created_time = node.get('created_time', current_time) created_time = node.get('created_time', current_time)
last_modified = node.get('last_modified', current_time) last_modified = node.get('last_modified', current_time)
# 添加节点到图中 # 添加节点到图中
self.memory_graph.G.add_node(concept, self.memory_graph.G.add_node(concept,
memory_items=memory_items, memory_items=memory_items,
created_time=created_time, created_time=created_time,
last_modified=last_modified) last_modified=last_modified)
# 从数据库加载所有边 # 从数据库加载所有边
edges = list(db.graph_data.edges.find()) edges = list(db.graph_data.edges.find())
@@ -543,7 +545,7 @@ class Hippocampus:
source = edge['source'] source = edge['source']
target = edge['target'] target = edge['target']
strength = edge.get('strength', 1) strength = edge.get('strength', 1)
# 检查时间字段是否存在 # 检查时间字段是否存在
if 'created_time' not in edge or 'last_modified' not in edge: if 'created_time' not in edge or 'last_modified' not in edge:
need_update = True need_update = True
@@ -553,24 +555,24 @@ class Hippocampus:
update_data['created_time'] = current_time update_data['created_time'] = current_time
if 'last_modified' not in edge: if 'last_modified' not in edge:
update_data['last_modified'] = current_time update_data['last_modified'] = current_time
db.graph_data.edges.update_one( db.graph_data.edges.update_one(
{'source': source, 'target': target}, {'source': source, 'target': target},
{'$set': update_data} {'$set': update_data}
) )
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间) # 获取时间信息(如果不存在则使用当前时间)
created_time = edge.get('created_time', current_time) created_time = edge.get('created_time', current_time)
last_modified = edge.get('last_modified', current_time) last_modified = edge.get('last_modified', current_time)
# 只有当源节点和目标节点都存在时才添加边 # 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G: if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, self.memory_graph.G.add_edge(source, target,
strength=strength, strength=strength,
created_time=created_time, created_time=created_time,
last_modified=last_modified) last_modified=last_modified)
if need_update: if need_update:
logger.success("[数据库] 已为缺失的时间字段进行补充") logger.success("[数据库] 已为缺失的时间字段进行补充")
@@ -578,44 +580,44 @@ class Hippocampus:
"""随机选择图中一定比例的节点和边进行检查,根据时间条件决定是否遗忘""" """随机选择图中一定比例的节点和边进行检查,根据时间条件决定是否遗忘"""
# 检查数据库是否为空 # 检查数据库是否为空
# logger.remove() # logger.remove()
logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:") logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
# logger.info(f"- Logger名称: {logger.name}") # logger.info(f"- Logger名称: {logger.name}")
logger.info(f"- Logger等级: {logger.level}") logger.info(f"- Logger等级: {logger.level}")
# logger.info(f"- Logger处理器: {[handler.__class__.__name__ for handler in logger.handlers]}") # logger.info(f"- Logger处理器: {[handler.__class__.__name__ for handler in logger.handlers]}")
# logger2 = setup_logger(LogModule.MEMORY) # logger2 = setup_logger(LogModule.MEMORY)
# logger2.info(f"[遗忘] 开始检查数据库... 当前Logger信息:") # logger2.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
# logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:") # logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
all_nodes = list(self.memory_graph.G.nodes()) all_nodes = list(self.memory_graph.G.nodes())
all_edges = list(self.memory_graph.G.edges()) all_edges = list(self.memory_graph.G.edges())
if not all_nodes and not all_edges: if not all_nodes and not all_edges:
logger.info("[遗忘] 记忆图为空,无需进行遗忘操作") logger.info("[遗忘] 记忆图为空,无需进行遗忘操作")
return return
check_nodes_count = max(1, int(len(all_nodes) * percentage)) check_nodes_count = max(1, int(len(all_nodes) * percentage))
check_edges_count = max(1, int(len(all_edges) * percentage)) check_edges_count = max(1, int(len(all_edges) * percentage))
nodes_to_check = random.sample(all_nodes, check_nodes_count) nodes_to_check = random.sample(all_nodes, check_nodes_count)
edges_to_check = random.sample(all_edges, check_edges_count) edges_to_check = random.sample(all_edges, check_edges_count)
edge_changes = {'weakened': 0, 'removed': 0} edge_changes = {'weakened': 0, 'removed': 0}
node_changes = {'reduced': 0, 'removed': 0} node_changes = {'reduced': 0, 'removed': 0}
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 检查并遗忘连接 # 检查并遗忘连接
logger.info("[遗忘] 开始检查连接...") logger.info("[遗忘] 开始检查连接...")
for source, target in edges_to_check: for source, target in edges_to_check:
edge_data = self.memory_graph.G[source][target] edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get('last_modified') last_modified = edge_data.get('last_modified')
if current_time - last_modified > 3600*global_config.memory_forget_time: if current_time - last_modified > 3600 * global_config.memory_forget_time:
current_strength = edge_data.get('strength', 1) current_strength = edge_data.get('strength', 1)
new_strength = current_strength - 1 new_strength = current_strength - 1
if new_strength <= 0: if new_strength <= 0:
self.memory_graph.G.remove_edge(source, target) self.memory_graph.G.remove_edge(source, target)
edge_changes['removed'] += 1 edge_changes['removed'] += 1
@@ -625,23 +627,23 @@ class Hippocampus:
edge_data['last_modified'] = current_time edge_data['last_modified'] = current_time
edge_changes['weakened'] += 1 edge_changes['weakened'] += 1
logger.info(f"[遗忘] 连接减弱: {source} -> {target} (强度: {current_strength} -> {new_strength})") logger.info(f"[遗忘] 连接减弱: {source} -> {target} (强度: {current_strength} -> {new_strength})")
# 检查并遗忘话题 # 检查并遗忘话题
logger.info("[遗忘] 开始检查节点...") logger.info("[遗忘] 开始检查节点...")
for node in nodes_to_check: for node in nodes_to_check:
node_data = self.memory_graph.G.nodes[node] node_data = self.memory_graph.G.nodes[node]
last_modified = node_data.get('last_modified', current_time) last_modified = node_data.get('last_modified', current_time)
if current_time - last_modified > 3600*24: if current_time - last_modified > 3600 * 24:
memory_items = node_data.get('memory_items', []) memory_items = node_data.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
if memory_items: if memory_items:
current_count = len(memory_items) current_count = len(memory_items)
removed_item = random.choice(memory_items) removed_item = random.choice(memory_items)
memory_items.remove(removed_item) memory_items.remove(removed_item)
if memory_items: if memory_items:
self.memory_graph.G.nodes[node]['memory_items'] = memory_items self.memory_graph.G.nodes[node]['memory_items'] = memory_items
self.memory_graph.G.nodes[node]['last_modified'] = current_time self.memory_graph.G.nodes[node]['last_modified'] = current_time
@@ -651,7 +653,7 @@ class Hippocampus:
self.memory_graph.G.remove_node(node) self.memory_graph.G.remove_node(node)
node_changes['removed'] += 1 node_changes['removed'] += 1
logger.info(f"[遗忘] 节点移除: {node}") logger.info(f"[遗忘] 节点移除: {node}")
if any(count > 0 for count in edge_changes.values()) or any(count > 0 for count in node_changes.values()): if any(count > 0 for count in edge_changes.values()) or any(count > 0 for count in node_changes.values()):
self.sync_memory_to_db() self.sync_memory_to_db()
logger.info("[遗忘] 统计信息:") logger.info("[遗忘] 统计信息:")
@@ -943,6 +945,7 @@ def segment_text(text):
seg_text = list(jieba.cut(text)) seg_text = list(jieba.cut(text))
return seg_text return seg_text
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config