Merge remote-tracking branch 'upstream/debug'
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
import jieba
|
||||
from llm_module import LLMModel
|
||||
@@ -157,9 +158,12 @@ class Memory_graph:
|
||||
def main():
|
||||
# 初始化数据库
|
||||
Database.initialize(
|
||||
"127.0.0.1",
|
||||
27017,
|
||||
"MegBot"
|
||||
host= os.getenv("MONGODB_HOST"),
|
||||
port= int(os.getenv("MONGODB_PORT")),
|
||||
db_name= os.getenv("DATABASE_NAME"),
|
||||
username= os.getenv("MONGODB_USERNAME"),
|
||||
password= os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
)
|
||||
|
||||
memory_graph = Memory_graph()
|
||||
@@ -168,10 +172,12 @@ def main():
|
||||
memory_graph.load_graph_from_db()
|
||||
# 展示两种不同的可视化方式
|
||||
print("\n按连接数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=False)
|
||||
# visualize_graph(memory_graph, color_by_memory=False)
|
||||
visualize_graph_lite(memory_graph, color_by_memory=False)
|
||||
|
||||
print("\n按记忆数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=True)
|
||||
# visualize_graph(memory_graph, color_by_memory=True)
|
||||
visualize_graph_lite(memory_graph, color_by_memory=True)
|
||||
|
||||
# memory_graph.save_graph_to_db()
|
||||
|
||||
@@ -262,7 +268,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||
|
||||
G = memory_graph.G
|
||||
|
||||
# 创建一个新图用于可视化
|
||||
H = G.copy()
|
||||
|
||||
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||
nodes_to_remove = []
|
||||
for node in H.nodes():
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
if memory_count <= 2 or degree <= 2:
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
H.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
# 如果过滤后没有节点,则返回
|
||||
if len(H.nodes()) == 0:
|
||||
print("过滤后没有符合条件的节点可显示")
|
||||
return
|
||||
|
||||
# 保存图到本地
|
||||
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||
|
||||
# 根据连接条数或记忆数量设置节点颜色
|
||||
node_colors = []
|
||||
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||
|
||||
if color_by_memory:
|
||||
# 计算每个节点的记忆数量
|
||||
memory_counts = []
|
||||
for node in nodes:
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
if isinstance(memory_items, list):
|
||||
count = len(memory_items)
|
||||
else:
|
||||
count = 1 if memory_items else 0
|
||||
memory_counts.append(count)
|
||||
max_memories = max(memory_counts) if memory_counts else 1
|
||||
|
||||
for count in memory_counts:
|
||||
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||
if max_memories > 0:
|
||||
intensity = min(1.0, count / max_memories)
|
||||
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||
else:
|
||||
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||
node_colors.append(color)
|
||||
else:
|
||||
# 使用原来的连接数量着色方案
|
||||
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||
for node in nodes:
|
||||
degree = H.degree(node)
|
||||
if max_degree > 0:
|
||||
red = min(1.0, degree / max_degree)
|
||||
blue = 1.0 - red
|
||||
color = (red, 0, blue)
|
||||
else:
|
||||
color = (0, 0, 1)
|
||||
node_colors.append(color)
|
||||
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=2000,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold')
|
||||
|
||||
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,19 +1,19 @@
|
||||
import os
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from typing import Tuple, Union
|
||||
import time
|
||||
from nonebot import get_driver
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
class LLMModel:
|
||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
self.api_key = config.siliconflow_key
|
||||
self.base_url = config.siliconflow_base_url
|
||||
|
||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||
"""根据输入的提示生成模型的响应"""
|
||||
|
||||
@@ -1,30 +1,20 @@
|
||||
import os
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from typing import Tuple, Union
|
||||
import time
|
||||
from ..chat.config import BotConfig
|
||||
from nonebot import get_driver
|
||||
|
||||
# 获取当前文件的绝对路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||
env_path = os.path.join(root_dir, 'config', '.env')
|
||||
|
||||
# 加载环境变量
|
||||
print(f"尝试从 {env_path} 加载环境变量配置")
|
||||
if os.path.exists(env_path):
|
||||
load_dotenv(env_path)
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
print(f"环境变量配置文件不存在: {env_path}")
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
class LLMModel:
|
||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
self.api_key = config.siliconflow_key
|
||||
self.base_url = config.siliconflow_base_url
|
||||
|
||||
if not self.api_key or not self.base_url:
|
||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import jieba
|
||||
from .llm_module import LLMModel
|
||||
import networkx as nx
|
||||
@@ -197,8 +198,6 @@ class Hippocampus:
|
||||
time_frequency = {'near':1,'mid':2,'far':2}
|
||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
|
||||
|
||||
|
||||
for i, input_text in enumerate(memory_sample, 1):
|
||||
#加载进度可视化
|
||||
progress = (i / len(memory_sample)) * 100
|
||||
@@ -206,26 +205,25 @@ class Hippocampus:
|
||||
filled_length = int(bar_length * i // len(memory_sample))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||
|
||||
# 生成压缩后记忆
|
||||
first_memory = set()
|
||||
first_memory = self.memory_compress(input_text, 2.5)
|
||||
# 延时防止访问超频
|
||||
# time.sleep(60)
|
||||
#将记忆加入到图谱中
|
||||
for topic, memory in first_memory:
|
||||
topics = segment_text(topic)
|
||||
if '[' in topic or topic=='':
|
||||
continue
|
||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||
for split_topic in topics:
|
||||
self.memory_graph.add_dot(split_topic,memory)
|
||||
for split_topic in topics:
|
||||
for other_split_topic in topics:
|
||||
if split_topic != other_split_topic:
|
||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||
|
||||
self.memory_graph.save_graph_to_db()
|
||||
if input_text:
|
||||
# 生成压缩后记忆
|
||||
first_memory = set()
|
||||
first_memory = self.memory_compress(input_text, 2.5)
|
||||
# 延时防止访问超频
|
||||
# time.sleep(5)
|
||||
#将记忆加入到图谱中
|
||||
for topic, memory in first_memory:
|
||||
topics = segment_text(topic)
|
||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||
for split_topic in topics:
|
||||
self.memory_graph.add_dot(split_topic,memory)
|
||||
for split_topic in topics:
|
||||
for other_split_topic in topics:
|
||||
if split_topic != other_split_topic:
|
||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||
else:
|
||||
print(f"空消息 跳过")
|
||||
self.memory_graph.save_graph_to_db()
|
||||
|
||||
def memory_compress(self, input_text, rate=1):
|
||||
information_content = calculate_information_content(input_text)
|
||||
@@ -263,13 +261,19 @@ def topic_what(text, topic):
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
from nonebot import get_driver
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
Database.initialize(
|
||||
global_config.MONGODB_HOST,
|
||||
global_config.MONGODB_PORT,
|
||||
global_config.DATABASE_NAME
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
auth_source=config.mongodb_auth_source
|
||||
)
|
||||
#创建记忆图
|
||||
memory_graph = Memory_graph()
|
||||
|
||||
@@ -9,12 +9,42 @@ import datetime
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
# from chat.config import global_config
|
||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||
from src.common.database import Database # 使用正确的导入语法
|
||||
from src.plugins.memory_system.llm_module import LLMModel
|
||||
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
# 统计字符频率
|
||||
char_count = Counter(text)
|
||||
total_chars = len(text)
|
||||
|
||||
# 计算熵
|
||||
entropy = 0
|
||||
for count in char_count.values():
|
||||
probability = count / total_chars
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
return entropy
|
||||
|
||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
"""从数据库中获取最接近指定时间戳的聊天记录"""
|
||||
chat_text = ''
|
||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||
for record in chat_record:
|
||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
|
||||
return chat_text
|
||||
|
||||
return ''
|
||||
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
@@ -103,7 +133,8 @@ class Memory_graph:
|
||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||
chat_text = ''
|
||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
|
||||
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
@@ -192,166 +223,80 @@ class Memory_graph:
|
||||
for edge in edges:
|
||||
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
||||
|
||||
def calculate_information_content(text):
|
||||
|
||||
"""计算文本的信息量(熵)"""
|
||||
# 统计字符频率
|
||||
char_count = Counter(text)
|
||||
total_chars = len(text)
|
||||
|
||||
# 计算熵
|
||||
entropy = 0
|
||||
for count in char_count.values():
|
||||
probability = count / total_chars
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
return entropy
|
||||
|
||||
|
||||
# Database.initialize(
|
||||
# global_config.MONGODB_HOST,
|
||||
# global_config.MONGODB_PORT,
|
||||
# global_config.DATABASE_NAME
|
||||
# )
|
||||
# memory_graph = Memory_graph()
|
||||
|
||||
# llm_model = LLMModel()
|
||||
# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||
|
||||
# memory_graph.load_graph_from_db()
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
# 获取当前文件的绝对路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||
env_path = os.path.join(root_dir, 'config', '.env')
|
||||
|
||||
# 加载环境变量
|
||||
print(f"尝试从 {env_path} 加载环境变量配置")
|
||||
if os.path.exists(env_path):
|
||||
load_dotenv(env_path)
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
print(f"环境变量配置文件不存在: {env_path}")
|
||||
|
||||
# 初始化数据库
|
||||
Database.initialize(
|
||||
"127.0.0.1",
|
||||
27017,
|
||||
"MegBot"
|
||||
)
|
||||
|
||||
memory_graph = Memory_graph()
|
||||
# 创建LLM模型实例
|
||||
llm_model = LLMModel()
|
||||
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||
|
||||
# 使用当前时间戳进行测试
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
chat_text = []
|
||||
|
||||
chat_size =25
|
||||
|
||||
for _ in range(30): # 循环10次
|
||||
random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
|
||||
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
||||
chat_text.append(chat_) # 拼接所有text
|
||||
# time.sleep(1)
|
||||
|
||||
|
||||
|
||||
for i, input_text in enumerate(chat_text, 1):
|
||||
# 海马体
|
||||
class Hippocampus:
|
||||
def __init__(self,memory_graph:Memory_graph):
|
||||
self.memory_graph = memory_graph
|
||||
self.llm_model = LLMModel()
|
||||
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||
|
||||
progress = (i / len(chat_text)) * 100
|
||||
bar_length = 30
|
||||
filled_length = int(bar_length * i // len(chat_text))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})")
|
||||
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
chat_text = []
|
||||
#短期:1h 中期:4h 长期:24h
|
||||
for _ in range(time_frequency.get('near')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
for _ in range(time_frequency.get('far')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
return chat_text
|
||||
|
||||
def build_memory(self,chat_size=12):
|
||||
#最近消息获取频率
|
||||
time_frequency = {'near':1,'mid':2,'far':2}
|
||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||
|
||||
# print(input_text)
|
||||
first_memory = set()
|
||||
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
|
||||
# time.sleep(5)
|
||||
|
||||
#将记忆加入到图谱中
|
||||
for topic, memory in first_memory:
|
||||
# continue
|
||||
topics = segment_text(topic)
|
||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||
for split_topic in topics:
|
||||
memory_graph.add_dot(split_topic,memory)
|
||||
for split_topic in topics:
|
||||
for other_split_topic in topics:
|
||||
if split_topic != other_split_topic:
|
||||
memory_graph.connect_dot(split_topic, other_split_topic)
|
||||
|
||||
# memory_graph.store_memory()
|
||||
|
||||
# 展示两种不同的可视化方式
|
||||
print("\n按连接数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=False)
|
||||
|
||||
print("\n按记忆数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=True)
|
||||
|
||||
memory_graph.save_graph_to_db()
|
||||
# memory_graph.load_graph_from_db()
|
||||
|
||||
while True:
|
||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
items_list = memory_graph.get_related_item(query)
|
||||
if items_list:
|
||||
# print(items_list)
|
||||
for memory_item in items_list:
|
||||
print(memory_item)
|
||||
else:
|
||||
print("未找到相关记忆。")
|
||||
#加载进度可视化
|
||||
for i, input_text in enumerate(memory_sample, 1):
|
||||
progress = (i / len(memory_sample)) * 100
|
||||
bar_length = 30
|
||||
filled_length = int(bar_length * i // len(memory_sample))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||
# print(f"第{i}条消息: {input_text}")
|
||||
if input_text:
|
||||
# 生成压缩后记忆
|
||||
first_memory = set()
|
||||
first_memory = self.memory_compress(input_text, 2.5)
|
||||
#将记忆加入到图谱中
|
||||
for topic, memory in first_memory:
|
||||
topics = segment_text(topic)
|
||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||
for split_topic in topics:
|
||||
self.memory_graph.add_dot(split_topic,memory)
|
||||
for split_topic in topics:
|
||||
for other_split_topic in topics:
|
||||
if split_topic != other_split_topic:
|
||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||
else:
|
||||
print(f"空消息 跳过")
|
||||
|
||||
while True:
|
||||
query = input("请输入问题:")
|
||||
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
|
||||
topic_prompt = find_topic(query, 3)
|
||||
topic_response = llm_model.generate_response(topic_prompt)
|
||||
self.memory_graph.save_graph_to_db()
|
||||
|
||||
def memory_compress(self, input_text, rate=1):
|
||||
information_content = calculate_information_content(input_text)
|
||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||
topic_prompt = find_topic(input_text, topic_num)
|
||||
topic_response = self.llm_model.generate_response(topic_prompt)
|
||||
# 检查 topic_response 是否为元组
|
||||
if isinstance(topic_response, tuple):
|
||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||
else:
|
||||
topics = topic_response.split(",")
|
||||
print(topics)
|
||||
|
||||
for keyword in topics:
|
||||
items_list = memory_graph.get_related_item(keyword)
|
||||
if items_list:
|
||||
print(items_list)
|
||||
|
||||
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
|
||||
information_content = calculate_information_content(input_text)
|
||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||
print(topic_num)
|
||||
topic_prompt = find_topic(input_text, topic_num)
|
||||
topic_response = llm_model.generate_response(topic_prompt)
|
||||
# 检查 topic_response 是否为元组
|
||||
if isinstance(topic_response, tuple):
|
||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||
else:
|
||||
topics = topic_response.split(",")
|
||||
print(topics)
|
||||
compressed_memory = set()
|
||||
for topic in topics:
|
||||
topic_what_prompt = topic_what(input_text,topic)
|
||||
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
|
||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||
return compressed_memory
|
||||
|
||||
compressed_memory = set()
|
||||
for topic in topics:
|
||||
topic_what_prompt = topic_what(input_text,topic)
|
||||
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
|
||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||
return compressed_memory
|
||||
|
||||
def segment_text(text):
|
||||
seg_text = list(jieba.cut(text))
|
||||
@@ -372,18 +317,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
|
||||
G = memory_graph.G
|
||||
|
||||
# 创建一个新图用于可视化
|
||||
H = G.copy()
|
||||
|
||||
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||
nodes_to_remove = []
|
||||
for node in H.nodes():
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
if memory_count <= 1 or degree <= 2:
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
H.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
# 如果过滤后没有节点,则返回
|
||||
if len(H.nodes()) == 0:
|
||||
print("过滤后没有符合条件的节点可显示")
|
||||
return
|
||||
|
||||
# 保存图到本地
|
||||
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
||||
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||
|
||||
# 根据连接条数或记忆数量设置节点颜色
|
||||
node_colors = []
|
||||
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
||||
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||
|
||||
if color_by_memory:
|
||||
# 计算每个节点的记忆数量
|
||||
memory_counts = []
|
||||
for node in nodes:
|
||||
memory_items = G.nodes[node].get('memory_items', [])
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
if isinstance(memory_items, list):
|
||||
count = len(memory_items)
|
||||
else:
|
||||
@@ -401,9 +365,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
node_colors.append(color)
|
||||
else:
|
||||
# 使用原来的连接数量着色方案
|
||||
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
|
||||
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||
for node in nodes:
|
||||
degree = G.degree(node)
|
||||
degree = H.degree(node)
|
||||
if max_degree > 0:
|
||||
red = min(1.0, degree / max_degree)
|
||||
blue = 1.0 - red
|
||||
@@ -414,8 +378,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(G, k=1, iterations=50)
|
||||
nx.draw(G, pos,
|
||||
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=2000,
|
||||
@@ -427,6 +391,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
def main():
|
||||
# 初始化数据库
|
||||
Database.initialize(
|
||||
host= os.getenv("MONGODB_HOST"),
|
||||
port= int(os.getenv("MONGODB_PORT")),
|
||||
db_name= os.getenv("DATABASE_NAME"),
|
||||
username= os.getenv("MONGODB_USERNAME"),
|
||||
password= os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建记忆图
|
||||
memory_graph = Memory_graph()
|
||||
# 加载数据库中存储的记忆图
|
||||
memory_graph.load_graph_from_db()
|
||||
# 创建海马体
|
||||
hippocampus = Hippocampus(memory_graph)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||
|
||||
# 构建记忆
|
||||
hippocampus.build_memory(chat_size=25)
|
||||
|
||||
# 展示两种不同的可视化方式
|
||||
print("\n按连接数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=False)
|
||||
|
||||
print("\n按记忆数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=True)
|
||||
|
||||
# 交互式查询
|
||||
while True:
|
||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
items_list = memory_graph.get_related_item(query)
|
||||
if items_list:
|
||||
for memory_item in items_list:
|
||||
print(memory_item)
|
||||
else:
|
||||
print("未找到相关记忆。")
|
||||
|
||||
while True:
|
||||
query = input("请输入问题:")
|
||||
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
|
||||
topic_prompt = find_topic(query, 3)
|
||||
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
|
||||
# 检查 topic_response 是否为元组
|
||||
if isinstance(topic_response, tuple):
|
||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||
else:
|
||||
topics = topic_response.split(",")
|
||||
print(topics)
|
||||
|
||||
for keyword in topics:
|
||||
items_list = memory_graph.get_related_item(keyword)
|
||||
if items_list:
|
||||
print(items_list)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user