Merge branch 'main' of https://github.com/wangw571/MaiBot-Fiao-Edition
This commit is contained in:
@@ -15,10 +15,10 @@ logger = get_module_logger("draw_memory")
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.common.database import db # 使用正确的导入语法
|
||||
from src.common.database import db # noqa: E402
|
||||
|
||||
# 加载.env.dev文件
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), ".env.dev")
|
||||
load_dotenv(env_path)
|
||||
|
||||
|
||||
@@ -32,13 +32,13 @@ class Memory_graph:
|
||||
def add_dot(self, concept, memory):
|
||||
if concept in self.G:
|
||||
# 如果节点已存在,将新记忆添加到现有列表中
|
||||
if 'memory_items' in self.G.nodes[concept]:
|
||||
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||
if "memory_items" in self.G.nodes[concept]:
|
||||
if not isinstance(self.G.nodes[concept]["memory_items"], list):
|
||||
# 如果当前不是列表,将其转换为列表
|
||||
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||
self.G.nodes[concept]['memory_items'].append(memory)
|
||||
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
|
||||
self.G.nodes[concept]["memory_items"].append(memory)
|
||||
else:
|
||||
self.G.nodes[concept]['memory_items'] = [memory]
|
||||
self.G.nodes[concept]["memory_items"] = [memory]
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆列表
|
||||
self.G.add_node(concept, memory_items=[memory])
|
||||
@@ -68,8 +68,8 @@ class Memory_graph:
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
if isinstance(memory_items, list):
|
||||
first_layer_items.extend(memory_items)
|
||||
else:
|
||||
@@ -83,8 +83,8 @@ class Memory_graph:
|
||||
node_data = self.get_dot(neighbor)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
if isinstance(memory_items, list):
|
||||
second_layer_items.extend(memory_items)
|
||||
else:
|
||||
@@ -94,9 +94,7 @@ class Memory_graph:
|
||||
|
||||
def store_memory(self):
|
||||
for node in self.G.nodes():
|
||||
dot_data = {
|
||||
"concept": node
|
||||
}
|
||||
dot_data = {"concept": node}
|
||||
db.store_memory_dots.insert_one(dot_data)
|
||||
|
||||
@property
|
||||
@@ -106,25 +104,27 @@ class Memory_graph:
|
||||
|
||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||
chat_text = ''
|
||||
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
chat_text = ""
|
||||
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)]) # 调试输出
|
||||
logger.info(
|
||||
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}"
|
||||
)
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
closest_time = closest_record["time"]
|
||||
group_id = closest_record["group_id"] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(
|
||||
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||
length))
|
||||
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort("time", 1).limit(length)
|
||||
)
|
||||
for record in chat_record:
|
||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(record["time"])))
|
||||
try:
|
||||
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
|
||||
except:
|
||||
displayname = record["user_nickname"] or "用户" + str(record["user_id"])
|
||||
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||
except (KeyError, TypeError):
|
||||
# 处理缺少键或类型错误的情况
|
||||
displayname = record.get("user_nickname", "") or "用户" + str(record.get("user_id", "未知"))
|
||||
chat_text += f"[{time_str}] {displayname}: {record['processed_plain_text']}\n" # 添加发送者和时间信息
|
||||
return chat_text
|
||||
|
||||
return [] # 如果没有找到记录,返回空列表
|
||||
@@ -135,16 +135,13 @@ class Memory_graph:
|
||||
# 保存节点
|
||||
for node in self.G.nodes(data=True):
|
||||
node_data = {
|
||||
'concept': node[0],
|
||||
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
||||
"concept": node[0],
|
||||
"memory_items": node[1].get("memory_items", []), # 默认为空列表
|
||||
}
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
# 保存边
|
||||
for edge in self.G.edges():
|
||||
edge_data = {
|
||||
'source': edge[0],
|
||||
'target': edge[1]
|
||||
}
|
||||
edge_data = {"source": edge[0], "target": edge[1]}
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
|
||||
def load_graph_from_db(self):
|
||||
@@ -153,14 +150,14 @@ class Memory_graph:
|
||||
# 加载节点
|
||||
nodes = db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
memory_items = node.get('memory_items', [])
|
||||
memory_items = node.get("memory_items", [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||
self.G.add_node(node["concept"], memory_items=memory_items)
|
||||
# 加载边
|
||||
edges = db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
self.G.add_edge(edge['source'], edge['target'])
|
||||
self.G.add_edge(edge["source"], edge["target"])
|
||||
|
||||
|
||||
def main():
|
||||
@@ -172,7 +169,7 @@ def main():
|
||||
|
||||
while True:
|
||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||
if query.lower() == '退出':
|
||||
if query.lower() == "退出":
|
||||
break
|
||||
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
|
||||
if first_layer_items or second_layer_items:
|
||||
@@ -192,19 +189,25 @@ def segment_text(text):
|
||||
|
||||
|
||||
def find_topic(text, topic_num):
|
||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||
prompt = (
|
||||
f"这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。"
|
||||
f"只需要列举{topic_num}个话题就好,不要告诉我其他内容。"
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
def topic_what(text, topic):
|
||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||
prompt = (
|
||||
f"这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。"
|
||||
f"只输出这句话就好"
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签
|
||||
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
|
||||
|
||||
G = memory_graph.G
|
||||
|
||||
@@ -214,7 +217,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||
nodes_to_remove = []
|
||||
for node in H.nodes():
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_items = H.nodes[node].get("memory_items", [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
|
||||
@@ -239,7 +242,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
max_memories = 1
|
||||
max_degree = 1
|
||||
for node in nodes:
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_items = H.nodes[node].get("memory_items", [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
max_memories = max(max_memories, memory_count)
|
||||
@@ -248,7 +251,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
# 计算每个节点的大小和颜色
|
||||
for node in nodes:
|
||||
# 计算节点大小(基于记忆数量)
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_items = H.nodes[node].get("memory_items", [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
# 使用指数函数使变化更明显
|
||||
ratio = memory_count / max_memories
|
||||
@@ -269,19 +272,22 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=node_sizes,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold',
|
||||
edge_color='gray',
|
||||
width=0.5,
|
||||
alpha=0.9)
|
||||
nx.draw(
|
||||
H,
|
||||
pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=node_sizes,
|
||||
font_size=10,
|
||||
font_family="SimHei",
|
||||
font_weight="bold",
|
||||
edge_color="gray",
|
||||
width=0.5,
|
||||
alpha=0.9,
|
||||
)
|
||||
|
||||
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
title = "记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数"
|
||||
plt.title(title, fontsize=16, fontfamily="SimHei")
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
@@ -5,17 +5,18 @@ import time
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
from rich.console import Console
|
||||
from memory_manual_build import Memory_graph, Hippocampus # 海马体和记忆图
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
我想 总有那么一个瞬间
|
||||
你会想和某天才变态少女助手一样
|
||||
往Bot的海马体里插上几个电极 不是吗
|
||||
|
||||
Let's do some dirty job.
|
||||
'''
|
||||
"""
|
||||
|
||||
# 获取当前文件的目录
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
@@ -28,11 +29,10 @@ env_path = project_root / ".env.dev"
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.database import db
|
||||
from src.plugins.memory_system.offline_llm import LLMModel
|
||||
from src.common.logger import get_module_logger # noqa E402
|
||||
from src.common.database import db # noqa E402
|
||||
|
||||
logger = get_module_logger('mem_alter')
|
||||
logger = get_module_logger("mem_alter")
|
||||
console = Console()
|
||||
|
||||
# 加载环境变量
|
||||
@@ -43,13 +43,12 @@ else:
|
||||
logger.warning(f"未找到环境变量文件: {env_path}")
|
||||
logger.info("将使用默认配置")
|
||||
|
||||
from memory_manual_build import Memory_graph, Hippocampus #海马体和记忆图
|
||||
|
||||
# 查询节点信息
|
||||
def query_mem_info(memory_graph: Memory_graph):
|
||||
while True:
|
||||
query = input("\n请输入新的查询概念(输入'退出'以结束):")
|
||||
if query.lower() == '退出':
|
||||
if query.lower() == "退出":
|
||||
break
|
||||
|
||||
items_list = memory_graph.get_related_item(query)
|
||||
@@ -71,42 +70,40 @@ def query_mem_info(memory_graph: Memory_graph):
|
||||
else:
|
||||
print("未找到相关记忆。")
|
||||
|
||||
|
||||
# 增加概念节点
|
||||
def add_mem_node(hippocampus: Hippocampus):
|
||||
while True:
|
||||
concept = input("请输入节点概念名:\n")
|
||||
result = db.graph_data.nodes.count_documents({'concept': concept})
|
||||
result = db.graph_data.nodes.count_documents({"concept": concept})
|
||||
|
||||
if result != 0:
|
||||
console.print("[yellow]已存在名为“{concept}”的节点,行为已取消[/yellow]")
|
||||
continue
|
||||
|
||||
|
||||
memory_items = list()
|
||||
while True:
|
||||
context = input("请输入节点描述信息(输入'终止'以结束)")
|
||||
if context.lower() == "终止": break
|
||||
if context.lower() == "终止":
|
||||
break
|
||||
memory_items.append(context)
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
hippocampus.memory_graph.G.add_node(concept,
|
||||
memory_items=memory_items,
|
||||
created_time=current_time,
|
||||
last_modified=current_time)
|
||||
hippocampus.memory_graph.G.add_node(
|
||||
concept, memory_items=memory_items, created_time=current_time, last_modified=current_time
|
||||
)
|
||||
|
||||
|
||||
# 删除概念节点(及连接到它的边)
|
||||
def remove_mem_node(hippocampus: Hippocampus):
|
||||
concept = input("请输入节点概念名:\n")
|
||||
result = db.graph_data.nodes.count_documents({'concept': concept})
|
||||
result = db.graph_data.nodes.count_documents({"concept": concept})
|
||||
|
||||
if result == 0:
|
||||
console.print(f"[red]不存在名为“{concept}”的节点[/red]")
|
||||
|
||||
edges = db.graph_data.edges.find({
|
||||
'$or': [
|
||||
{'source': concept},
|
||||
{'target': concept}
|
||||
]
|
||||
})
|
||||
|
||||
edges = db.graph_data.edges.find({"$or": [{"source": concept}, {"target": concept}]})
|
||||
|
||||
for edge in edges:
|
||||
console.print(f"[yellow]存在边“{edge['source']} -> {edge['target']}”, 请慎重考虑[/yellow]")
|
||||
|
||||
@@ -116,41 +113,50 @@ def remove_mem_node(hippocampus: Hippocampus):
|
||||
hippocampus.memory_graph.G.remove_node(concept)
|
||||
else:
|
||||
logger.info("[green]删除操作已取消[/green]")
|
||||
|
||||
|
||||
# 增加节点间边
|
||||
def add_mem_edge(hippocampus: Hippocampus):
|
||||
while True:
|
||||
source = input("请输入 **第一个节点** 名称(输入'退出'以结束):\n")
|
||||
if source.lower() == "退出": break
|
||||
if db.graph_data.nodes.count_documents({'concept': source}) == 0:
|
||||
if source.lower() == "退出":
|
||||
break
|
||||
if db.graph_data.nodes.count_documents({"concept": source}) == 0:
|
||||
console.print(f"[yellow]“{source}”节点不存在,操作已取消。[/yellow]")
|
||||
continue
|
||||
|
||||
target = input("请输入 **第二个节点** 名称:\n")
|
||||
if db.graph_data.nodes.count_documents({'concept': target}) == 0:
|
||||
if db.graph_data.nodes.count_documents({"concept": target}) == 0:
|
||||
console.print(f"[yellow]“{target}”节点不存在,操作已取消。[/yellow]")
|
||||
continue
|
||||
|
||||
|
||||
if source == target:
|
||||
console.print(f"[yellow]试图创建“{source} <-> {target}”自环,操作已取消。[/yellow]")
|
||||
continue
|
||||
|
||||
hippocampus.memory_graph.connect_dot(source, target)
|
||||
edge = hippocampus.memory_graph.G.get_edge_data(source, target)
|
||||
if edge['strength'] == 1:
|
||||
if edge["strength"] == 1:
|
||||
console.print(f"[green]成功创建边“{source} <-> {target}”,默认权重1[/green]")
|
||||
else:
|
||||
console.print(f"[yellow]边“{source} <-> {target}”已存在,更新权重: {edge['strength']-1} <-> {edge['strength']}[/yellow]")
|
||||
console.print(
|
||||
f"[yellow]边“{source} <-> {target}”已存在,"
|
||||
f"更新权重: {edge['strength'] - 1} <-> {edge['strength']}[/yellow]"
|
||||
)
|
||||
|
||||
|
||||
# 删除节点间边
|
||||
def remove_mem_edge(hippocampus: Hippocampus):
|
||||
while True:
|
||||
source = input("请输入 **第一个节点** 名称(输入'退出'以结束):\n")
|
||||
if source.lower() == "退出": break
|
||||
if db.graph_data.nodes.count_documents({'concept': source}) == 0:
|
||||
if source.lower() == "退出":
|
||||
break
|
||||
if db.graph_data.nodes.count_documents({"concept": source}) == 0:
|
||||
console.print("[yellow]“{source}”节点不存在,操作已取消。[/yellow]")
|
||||
continue
|
||||
|
||||
target = input("请输入 **第二个节点** 名称:\n")
|
||||
if db.graph_data.nodes.count_documents({'concept': target}) == 0:
|
||||
if db.graph_data.nodes.count_documents({"concept": target}) == 0:
|
||||
console.print("[yellow]“{target}”节点不存在,操作已取消。[/yellow]")
|
||||
continue
|
||||
|
||||
@@ -168,12 +174,14 @@ def remove_mem_edge(hippocampus: Hippocampus):
|
||||
hippocampus.memory_graph.G.remove_edge(source, target)
|
||||
console.print(f"[green]边“{source} <-> {target}”已删除。[green]")
|
||||
|
||||
|
||||
# 修改节点信息
|
||||
def alter_mem_node(hippocampus: Hippocampus):
|
||||
batchEnviroment = dict()
|
||||
while True:
|
||||
concept = input("请输入节点概念名(输入'终止'以结束):\n")
|
||||
if concept.lower() == "终止": break
|
||||
if concept.lower() == "终止":
|
||||
break
|
||||
_, node = hippocampus.memory_graph.get_dot(concept)
|
||||
if node is None:
|
||||
console.print(f"[yellow]“{concept}”节点不存在,操作已取消。[/yellow]")
|
||||
@@ -182,43 +190,60 @@ def alter_mem_node(hippocampus: Hippocampus):
|
||||
console.print("[yellow]注意,请确保你知道自己在做什么[/yellow]")
|
||||
console.print("[yellow]你将获得一个执行任意代码的环境[/yellow]")
|
||||
console.print("[red]你已经被警告过了。[/red]\n")
|
||||
|
||||
nodeEnviroment = {"concept": '<节点名>', 'memory_items': '<记忆文本数组>'}
|
||||
console.print("[green]环境变量中会有env与batchEnv两个dict, env在切换节点时会清空, batchEnv在操作终止时才会清空[/green]")
|
||||
console.print(f"[green] env 会被初始化为[/green]\n{nodeEnviroment}\n[green]且会在用户代码执行完毕后被提交 [/green]")
|
||||
console.print("[yellow]为便于书写临时脚本,请手动在输入代码通过Ctrl+C等方式触发KeyboardInterrupt来结束代码执行[/yellow]")
|
||||
|
||||
|
||||
node_environment = {"concept": "<节点名>", "memory_items": "<记忆文本数组>"}
|
||||
console.print(
|
||||
"[green]环境变量中会有env与batchEnv两个dict, env在切换节点时会清空, batchEnv在操作终止时才会清空[/green]"
|
||||
)
|
||||
console.print(
|
||||
f"[green] env 会被初始化为[/green]\n{node_environment}\n[green]且会在用户代码执行完毕后被提交 [/green]"
|
||||
)
|
||||
console.print(
|
||||
"[yellow]为便于书写临时脚本,请手动在输入代码通过Ctrl+C等方式触发KeyboardInterrupt来结束代码执行[/yellow]"
|
||||
)
|
||||
|
||||
# 拷贝数据以防操作炸了
|
||||
nodeEnviroment = dict(node)
|
||||
nodeEnviroment['concept'] = concept
|
||||
node_environment = dict(node)
|
||||
node_environment["concept"] = concept
|
||||
|
||||
while True:
|
||||
userexec = lambda script, env, batchEnv: eval(script)
|
||||
|
||||
def user_exec(script, env, batch_env):
|
||||
return eval(script, env, batch_env)
|
||||
|
||||
try:
|
||||
command = console.input()
|
||||
except KeyboardInterrupt:
|
||||
# 稍微防一下小天才
|
||||
try:
|
||||
if isinstance(nodeEnviroment['memory_items'], list):
|
||||
node['memory_items'] = nodeEnviroment['memory_items']
|
||||
if isinstance(node_environment["memory_items"], list):
|
||||
node["memory_items"] = node_environment["memory_items"]
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
except:
|
||||
console.print("[red]我不知道你做了什么,但显然nodeEnviroment['memory_items']已经不是个数组了,操作已取消[/red]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[red]我不知道你做了什么,但显然nodeEnviroment['memory_items']已经不是个数组了,"
|
||||
f"操作已取消: {str(e)}[/red]"
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
userexec(command, nodeEnviroment, batchEnviroment)
|
||||
user_exec(command, node_environment, batchEnviroment)
|
||||
except Exception as e:
|
||||
console.print(e)
|
||||
console.print("[red]自定义代码执行时发生异常,已捕获,请重试(可通过 console.print(locals()) 检查环境状态)[/red]")
|
||||
console.print(
|
||||
"[red]自定义代码执行时发生异常,已捕获,请重试(可通过 console.print(locals()) 检查环境状态)[/red]"
|
||||
)
|
||||
|
||||
|
||||
# 修改边信息
|
||||
def alter_mem_edge(hippocampus: Hippocampus):
|
||||
batchEnviroment = dict()
|
||||
while True:
|
||||
source = input("请输入 **第一个节点** 名称(输入'终止'以结束):\n")
|
||||
if source.lower() == "终止": break
|
||||
if source.lower() == "终止":
|
||||
break
|
||||
if hippocampus.memory_graph.get_dot(source) is None:
|
||||
console.print(f"[yellow]“{source}”节点不存在,操作已取消。[/yellow]")
|
||||
continue
|
||||
@@ -237,38 +262,51 @@ def alter_mem_edge(hippocampus: Hippocampus):
|
||||
console.print("[yellow]你将获得一个执行任意代码的环境[/yellow]")
|
||||
console.print("[red]你已经被警告过了。[/red]\n")
|
||||
|
||||
edgeEnviroment = {"source": '<节点名>', "target": '<节点名>', 'strength': '<强度值,装在一个list里>'}
|
||||
console.print("[green]环境变量中会有env与batchEnv两个dict, env在切换节点时会清空, batchEnv在操作终止时才会清空[/green]")
|
||||
console.print(f"[green] env 会被初始化为[/green]\n{edgeEnviroment}\n[green]且会在用户代码执行完毕后被提交 [/green]")
|
||||
console.print("[yellow]为便于书写临时脚本,请手动在输入代码通过Ctrl+C等方式触发KeyboardInterrupt来结束代码执行[/yellow]")
|
||||
|
||||
edgeEnviroment = {"source": "<节点名>", "target": "<节点名>", "strength": "<强度值,装在一个list里>"}
|
||||
console.print(
|
||||
"[green]环境变量中会有env与batchEnv两个dict, env在切换节点时会清空, batchEnv在操作终止时才会清空[/green]"
|
||||
)
|
||||
console.print(
|
||||
f"[green] env 会被初始化为[/green]\n{edgeEnviroment}\n[green]且会在用户代码执行完毕后被提交 [/green]"
|
||||
)
|
||||
console.print(
|
||||
"[yellow]为便于书写临时脚本,请手动在输入代码通过Ctrl+C等方式触发KeyboardInterrupt来结束代码执行[/yellow]"
|
||||
)
|
||||
|
||||
# 拷贝数据以防操作炸了
|
||||
edgeEnviroment['strength'] = [edge["strength"]]
|
||||
edgeEnviroment['source'] = source
|
||||
edgeEnviroment['target'] = target
|
||||
edgeEnviroment["strength"] = [edge["strength"]]
|
||||
edgeEnviroment["source"] = source
|
||||
edgeEnviroment["target"] = target
|
||||
|
||||
while True:
|
||||
userexec = lambda script, env, batchEnv: eval(script)
|
||||
|
||||
def user_exec(script, env, batch_env):
|
||||
return eval(script, env, batch_env)
|
||||
|
||||
try:
|
||||
command = console.input()
|
||||
except KeyboardInterrupt:
|
||||
# 稍微防一下小天才
|
||||
try:
|
||||
if isinstance(edgeEnviroment['strength'][0], int):
|
||||
edge['strength'] = edgeEnviroment['strength'][0]
|
||||
if isinstance(edgeEnviroment["strength"][0], int):
|
||||
edge["strength"] = edgeEnviroment["strength"][0]
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
except:
|
||||
console.print("[red]我不知道你做了什么,但显然edgeEnviroment['strength']已经不是个int了,操作已取消[/red]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[red]我不知道你做了什么,但显然edgeEnviroment['strength']已经不是个int了,"
|
||||
f"操作已取消: {str(e)}[/red]"
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
userexec(command, edgeEnviroment, batchEnviroment)
|
||||
user_exec(command, edgeEnviroment, batchEnviroment)
|
||||
except Exception as e:
|
||||
console.print(e)
|
||||
console.print("[red]自定义代码执行时发生异常,已捕获,请重试(可通过 console.print(locals()) 检查环境状态)[/red]")
|
||||
|
||||
console.print(
|
||||
"[red]自定义代码执行时发生异常,已捕获,请重试(可通过 console.print(locals()) 检查环境状态)[/red]"
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -288,10 +326,17 @@ async def main():
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = int(input("请输入操作类型\n0 -> 查询节点; 1 -> 增加节点; 2 -> 移除节点; 3 -> 增加边; 4 -> 移除边;\n5 -> 修改节点; 6 -> 修改边; 其他任意输入 -> 退出\n"))
|
||||
except:
|
||||
query = int(
|
||||
input(
|
||||
"""请输入操作类型
|
||||
0 -> 查询节点; 1 -> 增加节点; 2 -> 移除节点; 3 -> 增加边; 4 -> 移除边;
|
||||
5 -> 修改节点; 6 -> 修改边; 其他任意输入 -> 退出
|
||||
"""
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
query = -1
|
||||
|
||||
|
||||
if query == 0:
|
||||
query_mem_info(memory_graph)
|
||||
elif query == 1:
|
||||
@@ -308,12 +353,12 @@ async def main():
|
||||
alter_mem_edge(hippocampus)
|
||||
else:
|
||||
print("已结束操作")
|
||||
break
|
||||
break
|
||||
|
||||
hippocampus.sync_memory_to_db()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -23,7 +23,7 @@ from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
|
||||
memory_config = LogConfig(
|
||||
# 使用海马体专用样式
|
||||
console_format=MEMORY_STYLE_CONFIG["console_format"],
|
||||
file_format=MEMORY_STYLE_CONFIG["file_format"]
|
||||
file_format=MEMORY_STYLE_CONFIG["file_format"],
|
||||
)
|
||||
|
||||
logger = get_module_logger("memory_system", config=memory_config)
|
||||
@@ -42,38 +42,43 @@ class Memory_graph:
|
||||
|
||||
# 如果边已存在,增加 strength
|
||||
if self.G.has_edge(concept1, concept2):
|
||||
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
|
||||
self.G[concept1][concept2]["strength"] = self.G[concept1][concept2].get("strength", 1) + 1
|
||||
# 更新最后修改时间
|
||||
self.G[concept1][concept2]['last_modified'] = current_time
|
||||
self.G[concept1][concept2]["last_modified"] = current_time
|
||||
else:
|
||||
# 如果是新边,初始化 strength 为 1
|
||||
self.G.add_edge(concept1, concept2,
|
||||
strength=1,
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time) # 添加最后修改时间
|
||||
self.G.add_edge(
|
||||
concept1,
|
||||
concept2,
|
||||
strength=1,
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time,
|
||||
) # 添加最后修改时间
|
||||
|
||||
def add_dot(self, concept, memory):
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
if concept in self.G:
|
||||
if 'memory_items' in self.G.nodes[concept]:
|
||||
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||
self.G.nodes[concept]['memory_items'].append(memory)
|
||||
if "memory_items" in self.G.nodes[concept]:
|
||||
if not isinstance(self.G.nodes[concept]["memory_items"], list):
|
||||
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
|
||||
self.G.nodes[concept]["memory_items"].append(memory)
|
||||
# 更新最后修改时间
|
||||
self.G.nodes[concept]['last_modified'] = current_time
|
||||
self.G.nodes[concept]["last_modified"] = current_time
|
||||
else:
|
||||
self.G.nodes[concept]['memory_items'] = [memory]
|
||||
self.G.nodes[concept]["memory_items"] = [memory]
|
||||
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
|
||||
if 'created_time' not in self.G.nodes[concept]:
|
||||
self.G.nodes[concept]['created_time'] = current_time
|
||||
self.G.nodes[concept]['last_modified'] = current_time
|
||||
if "created_time" not in self.G.nodes[concept]:
|
||||
self.G.nodes[concept]["created_time"] = current_time
|
||||
self.G.nodes[concept]["last_modified"] = current_time
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆列表
|
||||
self.G.add_node(concept,
|
||||
memory_items=[memory],
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time) # 添加最后修改时间
|
||||
self.G.add_node(
|
||||
concept,
|
||||
memory_items=[memory],
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time,
|
||||
) # 添加最后修改时间
|
||||
|
||||
def get_dot(self, concept):
|
||||
# 检查节点是否存在于图中
|
||||
@@ -97,8 +102,8 @@ class Memory_graph:
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
if isinstance(memory_items, list):
|
||||
first_layer_items.extend(memory_items)
|
||||
else:
|
||||
@@ -111,8 +116,8 @@ class Memory_graph:
|
||||
node_data = self.get_dot(neighbor)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
if isinstance(memory_items, list):
|
||||
second_layer_items.extend(memory_items)
|
||||
else:
|
||||
@@ -134,8 +139,8 @@ class Memory_graph:
|
||||
node_data = self.G.nodes[topic]
|
||||
|
||||
# 如果节点存在memory_items
|
||||
if 'memory_items' in node_data:
|
||||
memory_items = node_data['memory_items']
|
||||
if "memory_items" in node_data:
|
||||
memory_items = node_data["memory_items"]
|
||||
|
||||
# 确保memory_items是列表
|
||||
if not isinstance(memory_items, list):
|
||||
@@ -149,7 +154,7 @@ class Memory_graph:
|
||||
|
||||
# 更新节点的记忆项
|
||||
if memory_items:
|
||||
self.G.nodes[topic]['memory_items'] = memory_items
|
||||
self.G.nodes[topic]["memory_items"] = memory_items
|
||||
else:
|
||||
# 如果没有记忆项了,删除整个节点
|
||||
self.G.remove_node(topic)
|
||||
@@ -163,12 +168,14 @@ class Memory_graph:
|
||||
class Hippocampus:
|
||||
def __init__(self, memory_graph: Memory_graph):
|
||||
self.memory_graph = memory_graph
|
||||
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5,request_type = 'topic')
|
||||
self.llm_summary_by_topic = LLM_request(model=global_config.llm_summary_by_topic, temperature=0.5,request_type = 'topic')
|
||||
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5, request_type="topic")
|
||||
self.llm_summary_by_topic = LLM_request(
|
||||
model=global_config.llm_summary_by_topic, temperature=0.5, request_type="topic"
|
||||
)
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表
|
||||
|
||||
|
||||
Returns:
|
||||
list: 包含所有节点名字的列表
|
||||
"""
|
||||
@@ -193,10 +200,10 @@ class Hippocampus:
|
||||
- target_timestamp: 目标时间戳
|
||||
- chat_size: 抽取的消息数量
|
||||
- max_memorized_time_per_msg: 每条消息的最大记忆次数
|
||||
|
||||
|
||||
Returns:
|
||||
- list: 抽取出的消息记录列表
|
||||
|
||||
|
||||
"""
|
||||
try_count = 0
|
||||
# 最多尝试三次抽取
|
||||
@@ -212,29 +219,32 @@ class Hippocampus:
|
||||
# 成功抽取短期消息样本
|
||||
# 数据写回:增加记忆次数
|
||||
for message in messages:
|
||||
db.messages.update_one({"_id": message["_id"]},
|
||||
{"$set": {"memorized_times": message["memorized_times"] + 1}})
|
||||
db.messages.update_one(
|
||||
{"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}}
|
||||
)
|
||||
return messages
|
||||
try_count += 1
|
||||
# 三次尝试均失败
|
||||
return None
|
||||
|
||||
def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}):
|
||||
def get_memory_sample(self, chat_size=20, time_frequency=None):
|
||||
"""获取记忆样本
|
||||
|
||||
|
||||
Returns:
|
||||
list: 消息记录列表,每个元素是一个消息记录字典列表
|
||||
"""
|
||||
# 硬编码:每条消息最大记忆次数
|
||||
# 如有需求可写入global_config
|
||||
if time_frequency is None:
|
||||
time_frequency = {"near": 2, "mid": 4, "far": 3}
|
||||
max_memorized_time_per_msg = 3
|
||||
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
chat_samples = []
|
||||
|
||||
# 短期:1h 中期:4h 长期:24h
|
||||
logger.debug(f"正在抽取短期消息样本")
|
||||
for i in range(time_frequency.get('near')):
|
||||
logger.debug("正在抽取短期消息样本")
|
||||
for i in range(time_frequency.get("near")):
|
||||
random_time = current_timestamp - random.randint(1, 3600)
|
||||
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||
if messages:
|
||||
@@ -243,8 +253,8 @@ class Hippocampus:
|
||||
else:
|
||||
logger.warning(f"第{i}次短期消息样本抽取失败")
|
||||
|
||||
logger.debug(f"正在抽取中期消息样本")
|
||||
for i in range(time_frequency.get('mid')):
|
||||
logger.debug("正在抽取中期消息样本")
|
||||
for i in range(time_frequency.get("mid")):
|
||||
random_time = current_timestamp - random.randint(3600, 3600 * 4)
|
||||
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||
if messages:
|
||||
@@ -253,8 +263,8 @@ class Hippocampus:
|
||||
else:
|
||||
logger.warning(f"第{i}次中期消息样本抽取失败")
|
||||
|
||||
logger.debug(f"正在抽取长期消息样本")
|
||||
for i in range(time_frequency.get('far')):
|
||||
logger.debug("正在抽取长期消息样本")
|
||||
for i in range(time_frequency.get("far")):
|
||||
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
|
||||
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||
if messages:
|
||||
@@ -267,7 +277,7 @@ class Hippocampus:
|
||||
|
||||
async def memory_compress(self, messages: list, compress_rate=0.1):
|
||||
"""压缩消息记录为记忆
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (压缩记忆集合, 相似主题字典)
|
||||
"""
|
||||
@@ -278,8 +288,8 @@ class Hippocampus:
|
||||
input_text = ""
|
||||
time_info = ""
|
||||
# 计算最早和最晚时间
|
||||
earliest_time = min(msg['time'] for msg in messages)
|
||||
latest_time = max(msg['time'] for msg in messages)
|
||||
earliest_time = min(msg["time"] for msg in messages)
|
||||
latest_time = max(msg["time"] for msg in messages)
|
||||
|
||||
earliest_dt = datetime.datetime.fromtimestamp(earliest_time)
|
||||
latest_dt = datetime.datetime.fromtimestamp(latest_time)
|
||||
@@ -304,8 +314,11 @@ class Hippocampus:
|
||||
|
||||
# 过滤topics
|
||||
filter_keywords = global_config.memory_ban_words
|
||||
topics = [topic.strip() for topic in
|
||||
topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
||||
topics = [
|
||||
topic.strip()
|
||||
for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
if topic.strip()
|
||||
]
|
||||
filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)]
|
||||
|
||||
logger.info(f"过滤后话题: {filtered_topics}")
|
||||
@@ -350,16 +363,17 @@ class Hippocampus:
|
||||
def calculate_topic_num(self, text, compress_rate):
|
||||
"""计算文本的话题数量"""
|
||||
information_content = calculate_information_content(text)
|
||||
topic_by_length = text.count('\n') * compress_rate
|
||||
topic_by_length = text.count("\n") * compress_rate
|
||||
topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
|
||||
topic_num = int((topic_by_length + topic_by_information_content) / 2)
|
||||
logger.debug(
|
||||
f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, "
|
||||
f"topic_num: {topic_num}")
|
||||
f"topic_num: {topic_num}"
|
||||
)
|
||||
return topic_num
|
||||
|
||||
async def operation_build_memory(self, chat_size=20):
|
||||
time_frequency = {'near': 1, 'mid': 4, 'far': 4}
|
||||
time_frequency = {"near": 1, "mid": 4, "far": 4}
|
||||
memory_samples = self.get_memory_sample(chat_size, time_frequency)
|
||||
|
||||
for i, messages in enumerate(memory_samples, 1):
|
||||
@@ -368,7 +382,7 @@ class Hippocampus:
|
||||
progress = (i / len(memory_samples)) * 100
|
||||
bar_length = 30
|
||||
filled_length = int(bar_length * i // len(memory_samples))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
||||
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
||||
|
||||
compress_rate = global_config.memory_compress_rate
|
||||
@@ -389,10 +403,13 @@ class Hippocampus:
|
||||
if topic != similar_topic:
|
||||
strength = int(similarity * 10)
|
||||
logger.info(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})")
|
||||
self.memory_graph.G.add_edge(topic, similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time)
|
||||
self.memory_graph.G.add_edge(
|
||||
topic,
|
||||
similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time,
|
||||
)
|
||||
|
||||
# 连接同批次的相关话题
|
||||
for i in range(len(all_topics)):
|
||||
@@ -409,11 +426,11 @@ class Hippocampus:
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
db_nodes_dict = {node['concept']: node for node in db_nodes}
|
||||
db_nodes_dict = {node["concept"]: node for node in db_nodes}
|
||||
|
||||
# 检查并更新节点
|
||||
for concept, data in memory_nodes:
|
||||
memory_items = data.get('memory_items', [])
|
||||
memory_items = data.get("memory_items", [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
@@ -421,34 +438,36 @@ class Hippocampus:
|
||||
memory_hash = self.calculate_node_hash(concept, memory_items)
|
||||
|
||||
# 获取时间信息
|
||||
created_time = data.get('created_time', datetime.datetime.now().timestamp())
|
||||
last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
|
||||
created_time = data.get("created_time", datetime.datetime.now().timestamp())
|
||||
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
|
||||
|
||||
if concept not in db_nodes_dict:
|
||||
# 数据库中缺少的节点,添加
|
||||
node_data = {
|
||||
'concept': concept,
|
||||
'memory_items': memory_items,
|
||||
'hash': memory_hash,
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
"concept": concept,
|
||||
"memory_items": memory_items,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
db_hash = db_node.get('hash', None)
|
||||
db_hash = db_node.get("hash", None)
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
'hash': memory_hash,
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
}}
|
||||
{"concept": concept},
|
||||
{
|
||||
"$set": {
|
||||
"memory_items": memory_items,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# 处理边的信息
|
||||
@@ -458,44 +477,43 @@ class Hippocampus:
|
||||
# 创建边的哈希值字典
|
||||
db_edge_dict = {}
|
||||
for edge in db_edges:
|
||||
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
|
||||
db_edge_dict[(edge['source'], edge['target'])] = {
|
||||
'hash': edge_hash,
|
||||
'strength': edge.get('strength', 1)
|
||||
}
|
||||
edge_hash = self.calculate_edge_hash(edge["source"], edge["target"])
|
||||
db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)}
|
||||
|
||||
# 检查并更新边
|
||||
for source, target, data in memory_edges:
|
||||
edge_hash = self.calculate_edge_hash(source, target)
|
||||
edge_key = (source, target)
|
||||
strength = data.get('strength', 1)
|
||||
strength = data.get("strength", 1)
|
||||
|
||||
# 获取边的时间信息
|
||||
created_time = data.get('created_time', datetime.datetime.now().timestamp())
|
||||
last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
|
||||
created_time = data.get("created_time", datetime.datetime.now().timestamp())
|
||||
last_modified = data.get("last_modified", datetime.datetime.now().timestamp())
|
||||
|
||||
if edge_key not in db_edge_dict:
|
||||
# 添加新边
|
||||
edge_data = {
|
||||
'source': source,
|
||||
'target': target,
|
||||
'strength': strength,
|
||||
'hash': edge_hash,
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
if db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'hash': edge_hash,
|
||||
'strength': strength,
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
}}
|
||||
{"source": source, "target": target},
|
||||
{
|
||||
"$set": {
|
||||
"hash": edge_hash,
|
||||
"strength": strength,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def sync_memory_from_db(self):
|
||||
@@ -509,70 +527,62 @@ class Hippocampus:
|
||||
# 从数据库加载所有节点
|
||||
nodes = list(db.graph_data.nodes.find())
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
concept = node["concept"]
|
||||
memory_items = node.get("memory_items", [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
# 检查时间字段是否存在
|
||||
if 'created_time' not in node or 'last_modified' not in node:
|
||||
if "created_time" not in node or "last_modified" not in node:
|
||||
need_update = True
|
||||
# 更新数据库中的节点
|
||||
update_data = {}
|
||||
if 'created_time' not in node:
|
||||
update_data['created_time'] = current_time
|
||||
if 'last_modified' not in node:
|
||||
update_data['last_modified'] = current_time
|
||||
if "created_time" not in node:
|
||||
update_data["created_time"] = current_time
|
||||
if "last_modified" not in node:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': update_data}
|
||||
)
|
||||
db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data})
|
||||
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.get('created_time', current_time)
|
||||
last_modified = node.get('last_modified', current_time)
|
||||
created_time = node.get("created_time", current_time)
|
||||
last_modified = node.get("last_modified", current_time)
|
||||
|
||||
# 添加节点到图中
|
||||
self.memory_graph.G.add_node(concept,
|
||||
memory_items=memory_items,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified)
|
||||
self.memory_graph.G.add_node(
|
||||
concept, memory_items=memory_items, created_time=created_time, last_modified=last_modified
|
||||
)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(db.graph_data.edges.find())
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
strength = edge.get('strength', 1)
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
strength = edge.get("strength", 1)
|
||||
|
||||
# 检查时间字段是否存在
|
||||
if 'created_time' not in edge or 'last_modified' not in edge:
|
||||
if "created_time" not in edge or "last_modified" not in edge:
|
||||
need_update = True
|
||||
# 更新数据库中的边
|
||||
update_data = {}
|
||||
if 'created_time' not in edge:
|
||||
update_data['created_time'] = current_time
|
||||
if 'last_modified' not in edge:
|
||||
update_data['last_modified'] = current_time
|
||||
if "created_time" not in edge:
|
||||
update_data["created_time"] = current_time
|
||||
if "last_modified" not in edge:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': update_data}
|
||||
)
|
||||
db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data})
|
||||
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = edge.get('created_time', current_time)
|
||||
last_modified = edge.get('last_modified', current_time)
|
||||
created_time = edge.get("created_time", current_time)
|
||||
last_modified = edge.get("last_modified", current_time)
|
||||
|
||||
# 只有当源节点和目标节点都存在时才添加边
|
||||
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||
self.memory_graph.G.add_edge(source, target,
|
||||
strength=strength,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified)
|
||||
self.memory_graph.G.add_edge(
|
||||
source, target, strength=strength, created_time=created_time, last_modified=last_modified
|
||||
)
|
||||
|
||||
if need_update:
|
||||
logger.success("[数据库] 已为缺失的时间字段进行补充")
|
||||
@@ -582,7 +592,7 @@ class Hippocampus:
|
||||
# 检查数据库是否为空
|
||||
# logger.remove()
|
||||
|
||||
logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
|
||||
logger.info("[遗忘] 开始检查数据库... 当前Logger信息:")
|
||||
# logger.info(f"- Logger名称: {logger.name}")
|
||||
logger.info(f"- Logger等级: {logger.level}")
|
||||
# logger.info(f"- Logger处理器: {[handler.__class__.__name__ for handler in logger.handlers]}")
|
||||
@@ -604,8 +614,8 @@ class Hippocampus:
|
||||
nodes_to_check = random.sample(all_nodes, check_nodes_count)
|
||||
edges_to_check = random.sample(all_edges, check_edges_count)
|
||||
|
||||
edge_changes = {'weakened': 0, 'removed': 0}
|
||||
node_changes = {'reduced': 0, 'removed': 0}
|
||||
edge_changes = {"weakened": 0, "removed": 0}
|
||||
node_changes = {"reduced": 0, "removed": 0}
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
@@ -613,30 +623,30 @@ class Hippocampus:
|
||||
logger.info("[遗忘] 开始检查连接...")
|
||||
for source, target in edges_to_check:
|
||||
edge_data = self.memory_graph.G[source][target]
|
||||
last_modified = edge_data.get('last_modified')
|
||||
last_modified = edge_data.get("last_modified")
|
||||
|
||||
if current_time - last_modified > 3600 * global_config.memory_forget_time:
|
||||
current_strength = edge_data.get('strength', 1)
|
||||
current_strength = edge_data.get("strength", 1)
|
||||
new_strength = current_strength - 1
|
||||
|
||||
if new_strength <= 0:
|
||||
self.memory_graph.G.remove_edge(source, target)
|
||||
edge_changes['removed'] += 1
|
||||
edge_changes["removed"] += 1
|
||||
logger.info(f"[遗忘] 连接移除: {source} -> {target}")
|
||||
else:
|
||||
edge_data['strength'] = new_strength
|
||||
edge_data['last_modified'] = current_time
|
||||
edge_changes['weakened'] += 1
|
||||
edge_data["strength"] = new_strength
|
||||
edge_data["last_modified"] = current_time
|
||||
edge_changes["weakened"] += 1
|
||||
logger.info(f"[遗忘] 连接减弱: {source} -> {target} (强度: {current_strength} -> {new_strength})")
|
||||
|
||||
# 检查并遗忘话题
|
||||
logger.info("[遗忘] 开始检查节点...")
|
||||
for node in nodes_to_check:
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
last_modified = node_data.get('last_modified', current_time)
|
||||
last_modified = node_data.get("last_modified", current_time)
|
||||
|
||||
if current_time - last_modified > 3600 * 24:
|
||||
memory_items = node_data.get('memory_items', [])
|
||||
memory_items = node_data.get("memory_items", [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
@@ -646,13 +656,13 @@ class Hippocampus:
|
||||
memory_items.remove(removed_item)
|
||||
|
||||
if memory_items:
|
||||
self.memory_graph.G.nodes[node]['memory_items'] = memory_items
|
||||
self.memory_graph.G.nodes[node]['last_modified'] = current_time
|
||||
node_changes['reduced'] += 1
|
||||
self.memory_graph.G.nodes[node]["memory_items"] = memory_items
|
||||
self.memory_graph.G.nodes[node]["last_modified"] = current_time
|
||||
node_changes["reduced"] += 1
|
||||
logger.info(f"[遗忘] 记忆减少: {node} (数量: {current_count} -> {len(memory_items)})")
|
||||
else:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes['removed'] += 1
|
||||
node_changes["removed"] += 1
|
||||
logger.info(f"[遗忘] 节点移除: {node}")
|
||||
|
||||
if any(count > 0 for count in edge_changes.values()) or any(count > 0 for count in node_changes.values()):
|
||||
@@ -666,7 +676,7 @@ class Hippocampus:
|
||||
async def merge_memory(self, topic):
|
||||
"""对指定话题的记忆进行合并压缩"""
|
||||
# 获取节点的记忆项
|
||||
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
|
||||
memory_items = self.memory_graph.G.nodes[topic].get("memory_items", [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
@@ -695,13 +705,13 @@ class Hippocampus:
|
||||
logger.info(f"[合并] 添加压缩记忆: {compressed_memory}")
|
||||
|
||||
# 更新节点的记忆项
|
||||
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
||||
self.memory_graph.G.nodes[topic]["memory_items"] = memory_items
|
||||
logger.debug(f"[合并] 完成记忆合并,当前记忆数量: {len(memory_items)}")
|
||||
|
||||
async def operation_merge_memory(self, percentage=0.1):
|
||||
"""
|
||||
随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并
|
||||
|
||||
|
||||
Args:
|
||||
percentage: 要检查的节点比例,默认为0.1(10%)
|
||||
"""
|
||||
@@ -715,7 +725,7 @@ class Hippocampus:
|
||||
merged_nodes = []
|
||||
for node in nodes_to_check:
|
||||
# 获取节点的内容条数
|
||||
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||
memory_items = self.memory_graph.G.nodes[node].get("memory_items", [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
content_count = len(memory_items)
|
||||
@@ -734,38 +744,47 @@ class Hippocampus:
|
||||
logger.debug("本次检查没有需要合并的节点")
|
||||
|
||||
def find_topic_llm(self, text, topic_num):
|
||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
|
||||
prompt = (
|
||||
f"这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
||||
f"用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
)
|
||||
return prompt
|
||||
|
||||
def topic_what(self, text, topic, time_info):
|
||||
prompt = f'这是一段文字,{time_info}:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
|
||||
prompt = (
|
||||
f'这是一段文字,{time_info}:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
|
||||
f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
|
||||
)
|
||||
return prompt
|
||||
|
||||
async def _identify_topics(self, text: str) -> list:
|
||||
"""从文本中识别可能的主题
|
||||
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
|
||||
Returns:
|
||||
list: 识别出的主题列表
|
||||
"""
|
||||
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 5))
|
||||
# print(f"话题: {topics_response[0]}")
|
||||
topics = [topic.strip() for topic in
|
||||
topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
||||
topics = [
|
||||
topic.strip()
|
||||
for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
if topic.strip()
|
||||
]
|
||||
# print(f"话题: {topics}")
|
||||
|
||||
return topics
|
||||
|
||||
def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list:
|
||||
"""查找与给定主题相似的记忆主题
|
||||
|
||||
|
||||
Args:
|
||||
topics: 主题列表
|
||||
similarity_threshold: 相似度阈值
|
||||
debug_info: 调试信息前缀
|
||||
|
||||
|
||||
Returns:
|
||||
list: (主题, 相似度) 元组列表
|
||||
"""
|
||||
@@ -794,7 +813,6 @@ class Hippocampus:
|
||||
if similarity >= similarity_threshold:
|
||||
has_similar_topic = True
|
||||
if debug_info:
|
||||
# print(f"\033[1;32m[{debug_info}]\033[0m 找到相似主题: {topic} -> {memory_topic} (相似度: {similarity:.2f})")
|
||||
pass
|
||||
all_similar_topics.append((memory_topic, similarity))
|
||||
|
||||
@@ -806,11 +824,11 @@ class Hippocampus:
|
||||
|
||||
def _get_top_topics(self, similar_topics: list, max_topics: int = 5) -> list:
|
||||
"""获取相似度最高的主题
|
||||
|
||||
|
||||
Args:
|
||||
similar_topics: (主题, 相似度) 元组列表
|
||||
max_topics: 最大主题数量
|
||||
|
||||
|
||||
Returns:
|
||||
list: (主题, 相似度) 元组列表
|
||||
"""
|
||||
@@ -826,7 +844,7 @@ class Hippocampus:
|
||||
|
||||
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
|
||||
"""计算输入文本对记忆的激活程度"""
|
||||
logger.info(f"[激活] 识别主题: {await self._identify_topics(text)}")
|
||||
logger.info(f"识别主题: {await self._identify_topics(text)}")
|
||||
|
||||
# 识别主题
|
||||
identified_topics = await self._identify_topics(text)
|
||||
@@ -835,9 +853,7 @@ class Hippocampus:
|
||||
|
||||
# 查找相似主题
|
||||
all_similar_topics = self._find_similar_topics(
|
||||
identified_topics,
|
||||
similarity_threshold=similarity_threshold,
|
||||
debug_info="激活"
|
||||
identified_topics, similarity_threshold=similarity_threshold, debug_info="激活"
|
||||
)
|
||||
|
||||
if not all_similar_topics:
|
||||
@@ -850,24 +866,23 @@ class Hippocampus:
|
||||
if len(top_topics) == 1:
|
||||
topic, score = top_topics[0]
|
||||
# 获取主题内容数量并计算惩罚系数
|
||||
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
|
||||
memory_items = self.memory_graph.G.nodes[topic].get("memory_items", [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
content_count = len(memory_items)
|
||||
penalty = 1.0 / (1 + math.log(content_count + 1))
|
||||
|
||||
activation = int(score * 50 * penalty)
|
||||
logger.info(
|
||||
f"[激活] 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}")
|
||||
logger.info(f"单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}")
|
||||
return activation
|
||||
|
||||
# 计算关键词匹配率,同时考虑内容数量
|
||||
matched_topics = set()
|
||||
topic_similarities = {}
|
||||
|
||||
for memory_topic, similarity in top_topics:
|
||||
for memory_topic, _similarity in top_topics:
|
||||
# 计算内容数量惩罚
|
||||
memory_items = self.memory_graph.G.nodes[memory_topic].get('memory_items', [])
|
||||
memory_items = self.memory_graph.G.nodes[memory_topic].get("memory_items", [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
content_count = len(memory_items)
|
||||
@@ -886,7 +901,6 @@ class Hippocampus:
|
||||
adjusted_sim = sim * penalty
|
||||
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
|
||||
# logger.debug(
|
||||
# f"[激活] 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})")
|
||||
|
||||
# 计算主题匹配率和平均相似度
|
||||
topic_match = len(matched_topics) / len(identified_topics)
|
||||
@@ -894,22 +908,20 @@ class Hippocampus:
|
||||
|
||||
# 计算最终激活值
|
||||
activation = int((topic_match + average_similarities) / 2 * 100)
|
||||
logger.info(
|
||||
f"[激活] 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}")
|
||||
logger.info(f"匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}")
|
||||
|
||||
return activation
|
||||
|
||||
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
|
||||
max_memory_num: int = 5) -> list:
|
||||
async def get_relevant_memories(
|
||||
self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, max_memory_num: int = 5
|
||||
) -> list:
|
||||
"""根据输入文本获取相关的记忆内容"""
|
||||
# 识别主题
|
||||
identified_topics = await self._identify_topics(text)
|
||||
|
||||
# 查找相似主题
|
||||
all_similar_topics = self._find_similar_topics(
|
||||
identified_topics,
|
||||
similarity_threshold=similarity_threshold,
|
||||
debug_info="记忆检索"
|
||||
identified_topics, similarity_threshold=similarity_threshold, debug_info="记忆检索"
|
||||
)
|
||||
|
||||
# 获取最相关的主题
|
||||
@@ -926,15 +938,11 @@ class Hippocampus:
|
||||
first_layer = random.sample(first_layer, max_memory_num // 2)
|
||||
# 为每条记忆添加来源主题和相似度信息
|
||||
for memory in first_layer:
|
||||
relevant_memories.append({
|
||||
'topic': topic,
|
||||
'similarity': score,
|
||||
'content': memory
|
||||
})
|
||||
relevant_memories.append({"topic": topic, "similarity": score, "content": memory})
|
||||
|
||||
# 如果记忆数量超过5个,随机选择5个
|
||||
# 按相似度排序
|
||||
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
relevant_memories.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
if len(relevant_memories) > max_memory_num:
|
||||
relevant_memories = random.sample(relevant_memories, max_memory_num)
|
||||
@@ -961,4 +969,3 @@ hippocampus.sync_memory_from_db()
|
||||
|
||||
end_time = time.time()
|
||||
logger.success(f"加载海马体耗时: {end_time - start_time:.2f} 秒")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -9,120 +9,115 @@ from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("offline_llm")
|
||||
|
||||
|
||||
class LLMModel:
|
||||
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
|
||||
if not self.api_key or not self.base_url:
|
||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||
|
||||
|
||||
logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url
|
||||
|
||||
def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||
"""根据输入的提示生成模型的响应"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
**self.params
|
||||
**self.params,
|
||||
}
|
||||
|
||||
|
||||
# 发送请求到完整的 chat/completions 端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15 # 基础等待时间(秒)
|
||||
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
|
||||
|
||||
if response.status_code == 429:
|
||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||
wait_time = base_wait_time * (2**retry) # 指数退避
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
|
||||
result = response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
wait_time = base_wait_time * (2**retry)
|
||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"请求失败: {str(e)}")
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
**self.params
|
||||
**self.params,
|
||||
}
|
||||
|
||||
|
||||
# 发送请求到完整的 chat/completions 端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status == 429:
|
||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||
wait_time = base_wait_time * (2**retry) # 指数退避
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
wait_time = base_wait_time * (2**retry)
|
||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"请求失败: {str(e)}")
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
Reference in New Issue
Block a user