reformat: 格式化memory.py
This commit is contained in:
@@ -27,6 +27,7 @@ logger = log_module.setup_logger(LogClassification.MEMORY)
|
|||||||
|
|
||||||
logger.info("初始化记忆系统")
|
logger.info("初始化记忆系统")
|
||||||
|
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
@@ -210,7 +211,8 @@ class Hippocampus:
|
|||||||
# 成功抽取短期消息样本
|
# 成功抽取短期消息样本
|
||||||
# 数据写回:增加记忆次数
|
# 数据写回:增加记忆次数
|
||||||
for message in messages:
|
for message in messages:
|
||||||
db.messages.update_one({"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}})
|
db.messages.update_one({"_id": message["_id"]},
|
||||||
|
{"$set": {"memorized_times": message["memorized_times"] + 1}})
|
||||||
return messages
|
return messages
|
||||||
try_count += 1
|
try_count += 1
|
||||||
# 三次尝试均失败
|
# 三次尝试均失败
|
||||||
@@ -612,7 +614,7 @@ class Hippocampus:
|
|||||||
edge_data = self.memory_graph.G[source][target]
|
edge_data = self.memory_graph.G[source][target]
|
||||||
last_modified = edge_data.get('last_modified')
|
last_modified = edge_data.get('last_modified')
|
||||||
|
|
||||||
if current_time - last_modified > 3600*global_config.memory_forget_time:
|
if current_time - last_modified > 3600 * global_config.memory_forget_time:
|
||||||
current_strength = edge_data.get('strength', 1)
|
current_strength = edge_data.get('strength', 1)
|
||||||
new_strength = current_strength - 1
|
new_strength = current_strength - 1
|
||||||
|
|
||||||
@@ -632,7 +634,7 @@ class Hippocampus:
|
|||||||
node_data = self.memory_graph.G.nodes[node]
|
node_data = self.memory_graph.G.nodes[node]
|
||||||
last_modified = node_data.get('last_modified', current_time)
|
last_modified = node_data.get('last_modified', current_time)
|
||||||
|
|
||||||
if current_time - last_modified > 3600*24:
|
if current_time - last_modified > 3600 * 24:
|
||||||
memory_items = node_data.get('memory_items', [])
|
memory_items = node_data.get('memory_items', [])
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
@@ -943,6 +945,7 @@ def segment_text(text):
|
|||||||
seg_text = list(jieba.cut(text))
|
seg_text = list(jieba.cut(text))
|
||||||
return seg_text
|
return seg_text
|
||||||
|
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user