reformat: 格式化memory.py

This commit is contained in:
AL76
2025-03-15 02:46:38 +08:00
parent f3fef69968
commit 77df50e666

View File

@@ -27,6 +27,7 @@ logger = log_module.setup_logger(LogClassification.MEMORY)
logger.info("初始化记忆系统") logger.info("初始化记忆系统")
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
@@ -210,7 +211,8 @@ class Hippocampus:
# 成功抽取短期消息样本 # 成功抽取短期消息样本
# 数据写回:增加记忆次数 # 数据写回:增加记忆次数
for message in messages: for message in messages:
db.messages.update_one({"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}}) db.messages.update_one({"_id": message["_id"]},
{"$set": {"memorized_times": message["memorized_times"] + 1}})
return messages return messages
try_count += 1 try_count += 1
# 三次尝试均失败 # 三次尝试均失败
@@ -612,7 +614,7 @@ class Hippocampus:
edge_data = self.memory_graph.G[source][target] edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get('last_modified') last_modified = edge_data.get('last_modified')
if current_time - last_modified > 3600*global_config.memory_forget_time: if current_time - last_modified > 3600 * global_config.memory_forget_time:
current_strength = edge_data.get('strength', 1) current_strength = edge_data.get('strength', 1)
new_strength = current_strength - 1 new_strength = current_strength - 1
@@ -632,7 +634,7 @@ class Hippocampus:
node_data = self.memory_graph.G.nodes[node] node_data = self.memory_graph.G.nodes[node]
last_modified = node_data.get('last_modified', current_time) last_modified = node_data.get('last_modified', current_time)
if current_time - last_modified > 3600*24: if current_time - last_modified > 3600 * 24:
memory_items = node_data.get('memory_items', []) memory_items = node_data.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
@@ -943,6 +945,7 @@ def segment_text(text):
seg_text = list(jieba.cut(text)) seg_text = list(jieba.cut(text))
return seg_text return seg_text
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config