v0.3.0 记忆和知识库
beta
This commit is contained in:
@@ -10,6 +10,8 @@ from .relationship_manager import relationship_manager
|
|||||||
from ..schedule.schedule_generator import bot_schedule
|
from ..schedule.schedule_generator import bot_schedule
|
||||||
from .willing_manager import willing_manager
|
from .willing_manager import willing_manager
|
||||||
|
|
||||||
|
from ..memory_system.memory import memory_graph
|
||||||
|
|
||||||
|
|
||||||
# 获取驱动器
|
# 获取驱动器
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
@@ -23,6 +25,8 @@ Database.initialize(
|
|||||||
print("\033[1;32m[初始化配置和数据库完成]\033[0m")
|
print("\033[1;32m[初始化配置和数据库完成]\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 导入其他模块
|
# 导入其他模块
|
||||||
from .bot import ChatBot
|
from .bot import ChatBot
|
||||||
from .emoji_manager import emoji_manager
|
from .emoji_manager import emoji_manager
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from .storage import MessageStorage
|
|||||||
from .llm_generator import LLMResponseGenerator
|
from .llm_generator import LLMResponseGenerator
|
||||||
from .message_stream import MessageStream, MessageStreamContainer
|
from .message_stream import MessageStream, MessageStreamContainer
|
||||||
from .topic_identifier import topic_identifier
|
from .topic_identifier import topic_identifier
|
||||||
from random import random
|
from random import random, choice
|
||||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
@@ -15,6 +15,7 @@ from .message import Message_Thinking # 导入 Message_Thinking 类
|
|||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from .willing_manager import willing_manager # 导入意愿管理器
|
from .willing_manager import willing_manager # 导入意愿管理器
|
||||||
from .utils import is_mentioned_bot_in_txt, calculate_typing_time
|
from .utils import is_mentioned_bot_in_txt, calculate_typing_time
|
||||||
|
from ..memory_system.memory import memory_graph
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self, config: BotConfig):
|
def __init__(self, config: BotConfig):
|
||||||
@@ -99,6 +100,11 @@ class ChatBot:
|
|||||||
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
||||||
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
|
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
|
||||||
|
|
||||||
|
if topic:
|
||||||
|
for current_topic in topic:
|
||||||
|
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
|
||||||
|
if first_layer_items:
|
||||||
|
print(f"\033[1;32m[记忆检索-bot]\033[0m 有印象:{current_topic}")
|
||||||
|
|
||||||
await self.storage.store_message(message, topic[0] if topic else None)
|
await self.storage.store_message(message, topic[0] if topic else None)
|
||||||
|
|
||||||
|
|||||||
@@ -133,8 +133,8 @@ llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
|
|||||||
if not global_config.enable_advance_output:
|
if not global_config.enable_advance_output:
|
||||||
logger.remove()
|
logger.remove()
|
||||||
|
|
||||||
logging.getLogger('nonebot').handlers.clear()
|
# logging.getLogger('nonebot').handlers.clear()
|
||||||
console_handler = logging.StreamHandler()
|
# console_handler = logging.StreamHandler()
|
||||||
console_handler.setLevel(logging.WARNING) # 只输出 WARNING 及以上级别
|
# console_handler.setLevel(logging.WARNING) # 只输出 WARNING 及以上级别
|
||||||
logging.getLogger('nonebot').addHandler(console_handler)
|
# logging.getLogger('nonebot').addHandler(console_handler)
|
||||||
logging.getLogger('nonebot').setLevel(logging.WARNING)
|
# logging.getLogger('nonebot').setLevel(logging.WARNING)
|
||||||
|
|||||||
186
src/plugins/chat/knowledege/knowledge_library.py
Normal file
186
src/plugins/chat/knowledege/knowledge_library.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 添加项目根目录到 Python 路径
|
||||||
|
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
|
sys.path.append(root_path)
|
||||||
|
|
||||||
|
from src.common.database import Database
|
||||||
|
from src.plugins.chat.config import llm_config
|
||||||
|
|
||||||
|
# 直接配置数据库连接信息
|
||||||
|
Database.initialize(
|
||||||
|
"127.0.0.1", # MongoDB 主机
|
||||||
|
27017, # MongoDB 端口
|
||||||
|
"MegBot" # 数据库名称
|
||||||
|
)
|
||||||
|
|
||||||
|
class KnowledgeLibrary:
|
||||||
|
def __init__(self):
|
||||||
|
self.db = Database.get_instance()
|
||||||
|
self.raw_info_dir = "data/raw_info"
|
||||||
|
self._ensure_dirs()
|
||||||
|
|
||||||
|
def _ensure_dirs(self):
|
||||||
|
"""确保必要的目录存在"""
|
||||||
|
os.makedirs(self.raw_info_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def get_embedding(self, text: str) -> list:
|
||||||
|
"""获取文本的embedding向量"""
|
||||||
|
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||||
|
payload = {
|
||||||
|
"model": "BAAI/bge-m3",
|
||||||
|
"input": text,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, json=payload, headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"获取embedding失败: {response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return response.json()['data'][0]['embedding']
|
||||||
|
|
||||||
|
def process_files(self):
|
||||||
|
"""处理raw_info目录下的所有txt文件"""
|
||||||
|
for filename in os.listdir(self.raw_info_dir):
|
||||||
|
if filename.endswith('.txt'):
|
||||||
|
file_path = os.path.join(self.raw_info_dir, filename)
|
||||||
|
self.process_single_file(file_path)
|
||||||
|
|
||||||
|
def process_single_file(self, file_path: str):
|
||||||
|
"""处理单个文件"""
|
||||||
|
try:
|
||||||
|
# 检查文件是否已处理
|
||||||
|
if self.db.db.processed_files.find_one({"file_path": file_path}):
|
||||||
|
print(f"文件已处理过,跳过: {file_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# 按1024字符分段
|
||||||
|
segments = [content[i:i+300] for i in range(0, len(content), 300)]
|
||||||
|
|
||||||
|
# 处理每个分段
|
||||||
|
for segment in segments:
|
||||||
|
if not segment.strip(): # 跳过空段
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取embedding
|
||||||
|
embedding = self.get_embedding(segment)
|
||||||
|
if not embedding:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 存储到数据库
|
||||||
|
doc = {
|
||||||
|
"content": segment,
|
||||||
|
"embedding": embedding,
|
||||||
|
"file_path": file_path,
|
||||||
|
"segment_length": len(segment)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用文本内容的哈希值作为唯一标识
|
||||||
|
content_hash = hash(segment)
|
||||||
|
|
||||||
|
# 更新或插入文档
|
||||||
|
self.db.db.knowledges.update_one(
|
||||||
|
{"content_hash": content_hash},
|
||||||
|
{"$set": doc},
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录文件已处理
|
||||||
|
self.db.db.processed_files.insert_one({
|
||||||
|
"file_path": file_path,
|
||||||
|
"processed_time": time.time()
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"成功处理文件: {file_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理文件 {file_path} 时出错: {str(e)}")
|
||||||
|
|
||||||
|
def search_similar_segments(self, query: str, limit: int = 5) -> list:
|
||||||
|
"""搜索与查询文本相似的片段"""
|
||||||
|
query_embedding = self.get_embedding(query)
|
||||||
|
if not query_embedding:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 使用余弦相似度计算
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$addFields": {
|
||||||
|
"dotProduct": {
|
||||||
|
"$reduce": {
|
||||||
|
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||||
|
"initialValue": 0,
|
||||||
|
"in": {
|
||||||
|
"$add": [
|
||||||
|
"$$value",
|
||||||
|
{"$multiply": [
|
||||||
|
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||||
|
{"$arrayElemAt": [query_embedding, "$$this"]}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"magnitude1": {
|
||||||
|
"$sqrt": {
|
||||||
|
"$reduce": {
|
||||||
|
"input": "$embedding",
|
||||||
|
"initialValue": 0,
|
||||||
|
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"magnitude2": {
|
||||||
|
"$sqrt": {
|
||||||
|
"$reduce": {
|
||||||
|
"input": query_embedding,
|
||||||
|
"initialValue": 0,
|
||||||
|
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$addFields": {
|
||||||
|
"similarity": {
|
||||||
|
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$sort": {"similarity": -1}},
|
||||||
|
{"$limit": limit},
|
||||||
|
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(self.db.db.knowledges.aggregate(pipeline))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 创建单例实例
|
||||||
|
knowledge_library = KnowledgeLibrary()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试知识库功能
|
||||||
|
print("开始处理知识库文件...")
|
||||||
|
knowledge_library.process_files()
|
||||||
|
|
||||||
|
# 测试搜索功能
|
||||||
|
test_query = "麦麦评价一下僕と花"
|
||||||
|
print(f"\n搜索与'{test_query}'相似的内容:")
|
||||||
|
results = knowledge_library.search_similar_segments(test_query)
|
||||||
|
for result in results:
|
||||||
|
print(f"相似度: {result['similarity']:.4f}")
|
||||||
|
print(f"内容: {result['content'][:100]}...")
|
||||||
|
print("-" * 50)
|
||||||
@@ -6,6 +6,9 @@ import os
|
|||||||
from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text
|
from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from .topic_identifier import topic_identifier
|
||||||
|
from ..memory_system.memory import memory_graph
|
||||||
|
from random import choice
|
||||||
|
|
||||||
# 获取当前文件的绝对路径
|
# 获取当前文件的绝对路径
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
@@ -35,6 +38,59 @@ class PromptBuilder:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 构建好的prompt
|
str: 构建好的prompt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
memory_prompt = ''
|
||||||
|
start_time = time.time() # 记录开始时间
|
||||||
|
topic = topic_identifier.identify_topic_jieba(message_txt)
|
||||||
|
# print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}")
|
||||||
|
|
||||||
|
all_first_layer_items = [] # 存储所有第一层记忆
|
||||||
|
all_second_layer_items = {} # 用字典存储每个topic的第二层记忆
|
||||||
|
overlapping_second_layer = set() # 存储重叠的第二层记忆
|
||||||
|
|
||||||
|
if topic:
|
||||||
|
# 遍历所有topic
|
||||||
|
for current_topic in topic:
|
||||||
|
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
|
||||||
|
if first_layer_items:
|
||||||
|
print(f"\033[1;32m[pb记忆检索]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}")
|
||||||
|
|
||||||
|
# 记录第一层数据
|
||||||
|
all_first_layer_items.extend(first_layer_items)
|
||||||
|
|
||||||
|
# 记录第二层数据
|
||||||
|
all_second_layer_items[current_topic] = second_layer_items
|
||||||
|
|
||||||
|
# 检查是否有重叠的第二层数据
|
||||||
|
for other_topic, other_second_layer in all_second_layer_items.items():
|
||||||
|
if other_topic != current_topic:
|
||||||
|
# 找到重叠的记忆
|
||||||
|
overlap = set(second_layer_items) & set(other_second_layer)
|
||||||
|
if overlap:
|
||||||
|
print(f"\033[1;32m[pb记忆检索]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
|
||||||
|
overlapping_second_layer.update(overlap)
|
||||||
|
|
||||||
|
# 合并所有需要的记忆
|
||||||
|
if all_first_layer_items:
|
||||||
|
print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
||||||
|
if overlapping_second_layer:
|
||||||
|
print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
||||||
|
|
||||||
|
all_memories = all_first_layer_items + list(overlapping_second_layer)
|
||||||
|
|
||||||
|
if all_memories: # 只在列表非空时选择随机项
|
||||||
|
random_item = choice(all_memories)
|
||||||
|
memory_prompt = f"看到这些聊天,你想起来{random_item}\n"
|
||||||
|
else:
|
||||||
|
memory_prompt = "" # 如果没有记忆,则返回空字符串
|
||||||
|
|
||||||
|
end_time = time.time() # 记录结束时间
|
||||||
|
print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}秒") # 输出耗时
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#先禁用关系
|
#先禁用关系
|
||||||
if 0 > 30:
|
if 0 > 30:
|
||||||
relation_prompt = "关系特别特别好,你很喜欢喜欢他"
|
relation_prompt = "关系特别特别好,你很喜欢喜欢他"
|
||||||
@@ -55,12 +111,17 @@ class PromptBuilder:
|
|||||||
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
|
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
|
||||||
|
|
||||||
#知识构建
|
#知识构建
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
prompt_info = ''
|
prompt_info = ''
|
||||||
promt_info_prompt = ''
|
promt_info_prompt = ''
|
||||||
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
|
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
|
||||||
if prompt_info:
|
if prompt_info:
|
||||||
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
||||||
promt_info_prompt = '你有一些[知识],在上面可以参考。'
|
promt_info_prompt = '你有一些[知识],在上面可以参考。'
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
||||||
# print(f"\033[1;34m[调试]\033[0m 获取知识库内容结果: {prompt_info}")
|
# print(f"\033[1;34m[调试]\033[0m 获取知识库内容结果: {prompt_info}")
|
||||||
|
|
||||||
|
|
||||||
@@ -69,11 +130,13 @@ class PromptBuilder:
|
|||||||
chat_talking_prompt = ''
|
chat_talking_prompt = ''
|
||||||
if group_id:
|
if group_id:
|
||||||
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
||||||
|
|
||||||
|
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||||
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
||||||
|
|
||||||
#激活prompt构建
|
#激活prompt构建
|
||||||
activate_prompt = ''
|
activate_prompt = ''
|
||||||
activate_prompt = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2}。"
|
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2}。"
|
||||||
|
|
||||||
#检测机器人相关词汇
|
#检测机器人相关词汇
|
||||||
bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
|
bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
|
||||||
@@ -87,13 +150,12 @@ class PromptBuilder:
|
|||||||
prompt_personality = ''
|
prompt_personality = ''
|
||||||
personality_choice = random.random()
|
personality_choice = random.random()
|
||||||
if personality_choice < 4/6: # 第一种人格
|
if personality_choice < 4/6: # 第一种人格
|
||||||
prompt_personality = f'''你的网名叫{global_config.BOT_NICKNAME},是一个学习地质的女大学生,喜欢摄影,你会刷贴吧,你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},是一个学习地质的女大学生,喜欢摄影,你会刷贴吧,你正在浏览qq群,{promt_info_prompt},
|
||||||
{activate_prompt}
|
|
||||||
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
||||||
请注意把握群里的聊天内容,不要回复的太有条理,可以有个性。'''
|
请注意把握群里的聊天内容,不要回复的太有条理,可以有个性。'''
|
||||||
elif personality_choice < 1: # 第二种人格
|
elif personality_choice < 1: # 第二种人格
|
||||||
prompt_personality = f'''你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt},
|
||||||
{activate_prompt}
|
|
||||||
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
|
||||||
请你表达自己的见解和观点。可以有个性。'''
|
请你表达自己的见解和观点。可以有个性。'''
|
||||||
|
|
||||||
@@ -108,7 +170,7 @@ class PromptBuilder:
|
|||||||
|
|
||||||
|
|
||||||
#额外信息要求
|
#额外信息要求
|
||||||
extra_info = '''但是记得回复平淡一些,简短一些,不要过多提及自身的背景, 记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
|
extra_info = '''但是记得回复平淡一些,简短一些,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -117,6 +179,9 @@ class PromptBuilder:
|
|||||||
prompt += f"{prompt_info}\n"
|
prompt += f"{prompt_info}\n"
|
||||||
prompt += f"{prompt_date}\n"
|
prompt += f"{prompt_date}\n"
|
||||||
prompt += f"{chat_talking_prompt}\n"
|
prompt += f"{chat_talking_prompt}\n"
|
||||||
|
|
||||||
|
# prompt += f"{memory_prompt}\n"
|
||||||
|
|
||||||
# prompt += f"{activate_prompt}\n"
|
# prompt += f"{activate_prompt}\n"
|
||||||
prompt += f"{prompt_personality}\n"
|
prompt += f"{prompt_personality}\n"
|
||||||
prompt += f"{prompt_ger}\n"
|
prompt += f"{prompt_ger}\n"
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
import requests
|
import requests
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
import time
|
||||||
|
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -32,16 +33,34 @@ class LLMModel:
|
|||||||
# 发送请求到完整的chat/completions端点
|
# 发送请求到完整的chat/completions端点
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
try:
|
max_retries = 3
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
base_wait_time = 15 # 基础等待时间(秒)
|
||||||
response.raise_for_status() # 检查响应状态
|
|
||||||
|
|
||||||
result = response.json()
|
for retry in range(max_retries):
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
try:
|
||||||
content = result["choices"][0]["message"]["content"]
|
response = requests.post(api_url, headers=headers, json=data)
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
|
||||||
return content, reasoning_content # 返回内容和推理内容
|
|
||||||
return "没有返回结果", "" # 返回两个值
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
if response.status_code == 429:
|
||||||
return f"请求失败: {str(e)}", "" # 返回错误信息和空字符串
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
376
src/plugins/memory_system/memory copy.py
Normal file
376
src/plugins/memory_system/memory copy.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import sys
|
||||||
|
import jieba
|
||||||
|
from llm_module import LLMModel
|
||||||
|
import networkx as nx
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import math
|
||||||
|
from collections import Counter
|
||||||
|
import datetime
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
# from chat.config import global_config
|
||||||
|
import sys
|
||||||
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
|
|
||||||
|
class Memory_graph:
|
||||||
|
def __init__(self):
|
||||||
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
|
def connect_dot(self, concept1, concept2):
|
||||||
|
self.G.add_edge(concept1, concept2)
|
||||||
|
|
||||||
|
def add_dot(self, concept, memory):
|
||||||
|
if concept in self.G:
|
||||||
|
# 如果节点已存在,将新记忆添加到现有列表中
|
||||||
|
if 'memory_items' in self.G.nodes[concept]:
|
||||||
|
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||||
|
# 如果当前不是列表,将其转换为列表
|
||||||
|
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||||
|
self.G.nodes[concept]['memory_items'].append(memory)
|
||||||
|
else:
|
||||||
|
self.G.nodes[concept]['memory_items'] = [memory]
|
||||||
|
else:
|
||||||
|
# 如果是新节点,创建新的记忆列表
|
||||||
|
self.G.add_node(concept, memory_items=[memory])
|
||||||
|
|
||||||
|
def get_dot(self, concept):
|
||||||
|
# 检查节点是否存在于图中
|
||||||
|
if concept in self.G:
|
||||||
|
# 从图中获取节点数据
|
||||||
|
node_data = self.G.nodes[concept]
|
||||||
|
# print(node_data)
|
||||||
|
# 创建新的Memory_dot对象
|
||||||
|
return concept,node_data
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_related_item(self, topic, depth=1):
|
||||||
|
if topic not in self.G:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
first_layer_items = []
|
||||||
|
second_layer_items = []
|
||||||
|
|
||||||
|
# 获取相邻节点
|
||||||
|
neighbors = list(self.G.neighbors(topic))
|
||||||
|
# print(f"第一层: {topic}")
|
||||||
|
|
||||||
|
# 获取当前节点的记忆项
|
||||||
|
node_data = self.get_dot(topic)
|
||||||
|
if node_data:
|
||||||
|
concept, data = node_data
|
||||||
|
if 'memory_items' in data:
|
||||||
|
memory_items = data['memory_items']
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
first_layer_items.extend(memory_items)
|
||||||
|
else:
|
||||||
|
first_layer_items.append(memory_items)
|
||||||
|
|
||||||
|
# 只在depth=2时获取第二层记忆
|
||||||
|
if depth >= 2:
|
||||||
|
# 获取相邻节点的记忆项
|
||||||
|
for neighbor in neighbors:
|
||||||
|
# print(f"第二层: {neighbor}")
|
||||||
|
node_data = self.get_dot(neighbor)
|
||||||
|
if node_data:
|
||||||
|
concept, data = node_data
|
||||||
|
if 'memory_items' in data:
|
||||||
|
memory_items = data['memory_items']
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
second_layer_items.extend(memory_items)
|
||||||
|
else:
|
||||||
|
second_layer_items.append(memory_items)
|
||||||
|
|
||||||
|
return first_layer_items, second_layer_items
|
||||||
|
|
||||||
|
def store_memory(self):
|
||||||
|
for node in self.G.nodes():
|
||||||
|
dot_data = {
|
||||||
|
"concept": node
|
||||||
|
}
|
||||||
|
self.db.db.store_memory_dots.insert_one(dot_data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dots(self):
|
||||||
|
# 返回所有节点对应的 Memory_dot 对象
|
||||||
|
return [self.get_dot(node) for node in self.G.nodes()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||||
|
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||||
|
chat_text = ''
|
||||||
|
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||||
|
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||||
|
|
||||||
|
if closest_record:
|
||||||
|
closest_time = closest_record['time']
|
||||||
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
|
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||||
|
for record in chat_record:
|
||||||
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
|
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||||
|
return chat_text
|
||||||
|
|
||||||
|
return [] # 如果没有找到记录,返回空列表
|
||||||
|
|
||||||
|
def save_graph_to_db(self):
|
||||||
|
# 清空现有的图数据
|
||||||
|
self.db.db.graph_data.delete_many({})
|
||||||
|
# 保存节点
|
||||||
|
for node in self.G.nodes(data=True):
|
||||||
|
node_data = {
|
||||||
|
'concept': node[0],
|
||||||
|
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
||||||
|
}
|
||||||
|
self.db.db.graph_data.nodes.insert_one(node_data)
|
||||||
|
# 保存边
|
||||||
|
for edge in self.G.edges():
|
||||||
|
edge_data = {
|
||||||
|
'source': edge[0],
|
||||||
|
'target': edge[1]
|
||||||
|
}
|
||||||
|
self.db.db.graph_data.edges.insert_one(edge_data)
|
||||||
|
|
||||||
|
def load_graph_from_db(self):
|
||||||
|
# 清空当前图
|
||||||
|
self.G.clear()
|
||||||
|
# 加载节点
|
||||||
|
nodes = self.db.db.graph_data.nodes.find()
|
||||||
|
for node in nodes:
|
||||||
|
memory_items = node.get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||||
|
# 加载边
|
||||||
|
edges = self.db.db.graph_data.edges.find()
|
||||||
|
for edge in edges:
|
||||||
|
self.G.add_edge(edge['source'], edge['target'])
|
||||||
|
|
||||||
|
def calculate_information_content(text):
|
||||||
|
|
||||||
|
"""计算文本的信息量(熵)"""
|
||||||
|
# 统计字符频率
|
||||||
|
char_count = Counter(text)
|
||||||
|
total_chars = len(text)
|
||||||
|
|
||||||
|
# 计算熵
|
||||||
|
entropy = 0
|
||||||
|
for count in char_count.values():
|
||||||
|
probability = count / total_chars
|
||||||
|
entropy -= probability * math.log2(probability)
|
||||||
|
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
# Database.initialize(
|
||||||
|
# global_config.MONGODB_HOST,
|
||||||
|
# global_config.MONGODB_PORT,
|
||||||
|
# global_config.DATABASE_NAME
|
||||||
|
# )
|
||||||
|
# memory_graph = Memory_graph()
|
||||||
|
|
||||||
|
# llm_model = LLMModel()
|
||||||
|
# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||||
|
|
||||||
|
# memory_graph.load_graph_from_db()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 初始化数据库
|
||||||
|
Database.initialize(
|
||||||
|
"127.0.0.1",
|
||||||
|
27017,
|
||||||
|
"MegBot"
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_graph = Memory_graph()
|
||||||
|
# 创建LLM模型实例
|
||||||
|
llm_model = LLMModel()
|
||||||
|
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||||
|
|
||||||
|
# 使用当前时间戳进行测试
|
||||||
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
|
chat_text = []
|
||||||
|
|
||||||
|
chat_size =40
|
||||||
|
|
||||||
|
for _ in range(100): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(1, 3600*39) # 随机时间
|
||||||
|
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||||
|
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
||||||
|
chat_text.append(chat_) # 拼接所有text
|
||||||
|
# time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for i, input_text in enumerate(chat_text, 1):
|
||||||
|
|
||||||
|
progress = (i / len(chat_text)) * 100
|
||||||
|
bar_length = 30
|
||||||
|
filled_length = int(bar_length * i // len(chat_text))
|
||||||
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
|
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})")
|
||||||
|
|
||||||
|
# print(input_text)
|
||||||
|
first_memory = set()
|
||||||
|
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
#将记忆加入到图谱中
|
||||||
|
for topic, memory in first_memory:
|
||||||
|
topics = segment_text(topic)
|
||||||
|
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||||
|
for split_topic in topics:
|
||||||
|
memory_graph.add_dot(split_topic,memory)
|
||||||
|
for split_topic in topics:
|
||||||
|
for other_split_topic in topics:
|
||||||
|
if split_topic != other_split_topic:
|
||||||
|
memory_graph.connect_dot(split_topic, other_split_topic)
|
||||||
|
|
||||||
|
# memory_graph.store_memory()
|
||||||
|
|
||||||
|
# 展示两种不同的可视化方式
|
||||||
|
print("\n按连接数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=False)
|
||||||
|
|
||||||
|
print("\n按记忆数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=True)
|
||||||
|
|
||||||
|
memory_graph.save_graph_to_db()
|
||||||
|
# memory_graph.load_graph_from_db()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||||
|
if query.lower() == '退出':
|
||||||
|
break
|
||||||
|
items_list = memory_graph.get_related_item(query)
|
||||||
|
if items_list:
|
||||||
|
# print(items_list)
|
||||||
|
for memory_item in items_list:
|
||||||
|
print(memory_item)
|
||||||
|
else:
|
||||||
|
print("未找到相关记忆。")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input("请输入问题:")
|
||||||
|
|
||||||
|
if query.lower() == '退出':
|
||||||
|
break
|
||||||
|
|
||||||
|
topic_prompt = find_topic(query, 3)
|
||||||
|
topic_response = llm_model.generate_response(topic_prompt)
|
||||||
|
# 检查 topic_response 是否为元组
|
||||||
|
if isinstance(topic_response, tuple):
|
||||||
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
|
else:
|
||||||
|
topics = topic_response.split(",")
|
||||||
|
print(topics)
|
||||||
|
|
||||||
|
for keyword in topics:
|
||||||
|
items_list = memory_graph.get_related_item(keyword)
|
||||||
|
if items_list:
|
||||||
|
print(items_list)
|
||||||
|
|
||||||
|
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
|
||||||
|
information_content = calculate_information_content(input_text)
|
||||||
|
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||||
|
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||||
|
print(topic_num)
|
||||||
|
topic_prompt = find_topic(input_text, topic_num)
|
||||||
|
topic_response = llm_model.generate_response(topic_prompt)
|
||||||
|
# 检查 topic_response 是否为元组
|
||||||
|
if isinstance(topic_response, tuple):
|
||||||
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
|
else:
|
||||||
|
topics = topic_response.split(",")
|
||||||
|
print(topics)
|
||||||
|
compressed_memory = set()
|
||||||
|
for topic in topics:
|
||||||
|
topic_what_prompt = topic_what(input_text,topic)
|
||||||
|
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
|
||||||
|
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||||
|
return compressed_memory
|
||||||
|
|
||||||
|
|
||||||
|
def segment_text(text):
|
||||||
|
seg_text = list(jieba.cut(text))
|
||||||
|
return seg_text
|
||||||
|
|
||||||
|
def find_topic(text, topic_num):
|
||||||
|
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def topic_what(text, topic):
|
||||||
|
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||||
|
# 设置中文字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||||
|
|
||||||
|
G = memory_graph.G
|
||||||
|
|
||||||
|
# 保存图到本地
|
||||||
|
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
||||||
|
|
||||||
|
# 根据连接条数或记忆数量设置节点颜色
|
||||||
|
node_colors = []
|
||||||
|
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
||||||
|
|
||||||
|
if color_by_memory:
|
||||||
|
# 计算每个节点的记忆数量
|
||||||
|
memory_counts = []
|
||||||
|
for node in nodes:
|
||||||
|
memory_items = G.nodes[node].get('memory_items', [])
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
count = len(memory_items)
|
||||||
|
else:
|
||||||
|
count = 1 if memory_items else 0
|
||||||
|
memory_counts.append(count)
|
||||||
|
max_memories = max(memory_counts) if memory_counts else 1
|
||||||
|
|
||||||
|
for count in memory_counts:
|
||||||
|
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||||
|
if max_memories > 0:
|
||||||
|
intensity = min(1.0, count / max_memories)
|
||||||
|
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||||
|
node_colors.append(color)
|
||||||
|
else:
|
||||||
|
# 使用原来的连接数量着色方案
|
||||||
|
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
|
||||||
|
for node in nodes:
|
||||||
|
degree = G.degree(node)
|
||||||
|
if max_degree > 0:
|
||||||
|
red = min(1.0, degree / max_degree)
|
||||||
|
blue = 1.0 - red
|
||||||
|
color = (red, 0, blue)
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1)
|
||||||
|
node_colors.append(color)
|
||||||
|
|
||||||
|
# 绘制图形
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
pos = nx.spring_layout(G, k=1, iterations=50)
|
||||||
|
nx.draw(G, pos,
|
||||||
|
with_labels=True,
|
||||||
|
node_color=node_colors,
|
||||||
|
node_size=2000,
|
||||||
|
font_size=10,
|
||||||
|
font_family='SimHei',
|
||||||
|
font_weight='bold')
|
||||||
|
|
||||||
|
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||||
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
import sys
|
||||||
import jieba
|
import jieba
|
||||||
from llm_module import LLMModel
|
from .llm_module import LLMModel
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import math
|
import math
|
||||||
@@ -9,9 +9,9 @@ from collections import Counter
|
|||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from ..chat.config import global_config
|
||||||
import sys
|
import sys
|
||||||
sys.path.append("C:/GitHub/MegMeg-bot") # 添加项目根目录到 Python 路径
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
from src.common.database import Database # 使用正确的导入语法
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
@@ -23,44 +23,67 @@ class Memory_graph:
|
|||||||
self.G.add_edge(concept1, concept2)
|
self.G.add_edge(concept1, concept2)
|
||||||
|
|
||||||
def add_dot(self, concept, memory):
|
def add_dot(self, concept, memory):
|
||||||
self.G.add_node(concept, memory_items=memory)
|
if concept in self.G:
|
||||||
|
# 如果节点已存在,将新记忆添加到现有列表中
|
||||||
|
if 'memory_items' in self.G.nodes[concept]:
|
||||||
|
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||||
|
# 如果当前不是列表,将其转换为列表
|
||||||
|
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||||
|
self.G.nodes[concept]['memory_items'].append(memory)
|
||||||
|
else:
|
||||||
|
self.G.nodes[concept]['memory_items'] = [memory]
|
||||||
|
else:
|
||||||
|
# 如果是新节点,创建新的记忆列表
|
||||||
|
self.G.add_node(concept, memory_items=[memory])
|
||||||
|
|
||||||
def get_dot(self, concept):
|
def get_dot(self, concept):
|
||||||
# 检查节点是否存在于图中
|
# 检查节点是否存在于图中
|
||||||
if concept in self.G:
|
if concept in self.G:
|
||||||
# 从图中获取节点数据
|
# 从图中获取节点数据
|
||||||
node_data = self.G.nodes[concept]
|
node_data = self.G.nodes[concept]
|
||||||
print(node_data)
|
# print(node_data)
|
||||||
# 创建新的Memory_dot对象
|
# 创建新的Memory_dot对象
|
||||||
return concept,node_data
|
return concept,node_data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_related_item(self, topic, depth=1):
|
def get_related_item(self, topic, depth=1):
|
||||||
if topic not in self.G:
|
if topic not in self.G:
|
||||||
return set()
|
return [], []
|
||||||
|
|
||||||
|
first_layer_items = []
|
||||||
|
second_layer_items = []
|
||||||
|
|
||||||
items_set = set()
|
|
||||||
# 获取相邻节点
|
# 获取相邻节点
|
||||||
neighbors = list(self.G.neighbors(topic))
|
neighbors = list(self.G.neighbors(topic))
|
||||||
print(f"第一层: {topic}")
|
# print(f"第一层: {topic}")
|
||||||
|
|
||||||
# 获取当前节点的记忆项
|
# 获取当前节点的记忆项
|
||||||
node_data = self.get_dot(topic)
|
node_data = self.get_dot(topic)
|
||||||
if node_data:
|
if node_data:
|
||||||
concept, data = node_data
|
concept, data = node_data
|
||||||
if 'memory_items' in data:
|
if 'memory_items' in data:
|
||||||
items_set.add(data['memory_items'])
|
memory_items = data['memory_items']
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
first_layer_items.extend(memory_items)
|
||||||
|
else:
|
||||||
|
first_layer_items.append(memory_items)
|
||||||
|
|
||||||
# 获取相邻节点的记忆项
|
# 只在depth=2时获取第二层记忆
|
||||||
for neighbor in neighbors:
|
if depth >= 2:
|
||||||
print(f"第二层: {neighbor}")
|
# 获取相邻节点的记忆项
|
||||||
node_data = self.get_dot(neighbor)
|
for neighbor in neighbors:
|
||||||
if node_data:
|
# print(f"第二层: {neighbor}")
|
||||||
concept, data = node_data
|
node_data = self.get_dot(neighbor)
|
||||||
if 'memory_items' in data:
|
if node_data:
|
||||||
items_set.add(data['memory_items'])
|
concept, data = node_data
|
||||||
|
if 'memory_items' in data:
|
||||||
|
memory_items = data['memory_items']
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
second_layer_items.extend(memory_items)
|
||||||
|
else:
|
||||||
|
second_layer_items.append(memory_items)
|
||||||
|
|
||||||
return items_set
|
return first_layer_items, second_layer_items
|
||||||
|
|
||||||
def store_memory(self):
|
def store_memory(self):
|
||||||
for node in self.G.nodes():
|
for node in self.G.nodes():
|
||||||
@@ -100,7 +123,7 @@ class Memory_graph:
|
|||||||
for node in self.G.nodes(data=True):
|
for node in self.G.nodes(data=True):
|
||||||
node_data = {
|
node_data = {
|
||||||
'concept': node[0],
|
'concept': node[0],
|
||||||
'memory_items': node[1].get('memory_items', None)
|
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
||||||
}
|
}
|
||||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
self.db.db.graph_data.nodes.insert_one(node_data)
|
||||||
# 保存边
|
# 保存边
|
||||||
@@ -117,7 +140,10 @@ class Memory_graph:
|
|||||||
# 加载节点
|
# 加载节点
|
||||||
nodes = self.db.db.graph_data.nodes.find()
|
nodes = self.db.db.graph_data.nodes.find()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
self.G.add_node(node['concept'], memory_items=node['memory_items'])
|
memory_items = node.get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||||
# 加载边
|
# 加载边
|
||||||
edges = self.db.db.graph_data.edges.find()
|
edges = self.db.db.graph_data.edges.find()
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
@@ -138,6 +164,26 @@ def calculate_information_content(text):
|
|||||||
|
|
||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
Database.initialize(
|
||||||
|
global_config.MONGODB_HOST,
|
||||||
|
global_config.MONGODB_PORT,
|
||||||
|
global_config.DATABASE_NAME
|
||||||
|
)
|
||||||
|
memory_graph = Memory_graph()
|
||||||
|
|
||||||
|
llm_model = LLMModel()
|
||||||
|
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||||
|
|
||||||
|
memory_graph.load_graph_from_db()
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"加载海马体耗时: {end_time - start_time:.2f} 秒")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
@@ -155,13 +201,14 @@ def main():
|
|||||||
current_timestamp = datetime.datetime.now().timestamp()
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
chat_text = []
|
chat_text = []
|
||||||
|
|
||||||
chat_size =30
|
chat_size =40
|
||||||
|
|
||||||
for _ in range(60): # 循环10次
|
for _ in range(100): # 循环10次
|
||||||
random_time = current_timestamp - random.randint(1, 3600*3) # 随机时间
|
random_time = current_timestamp - random.randint(1, 3600*39) # 随机时间
|
||||||
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||||
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
||||||
chat_text.append(chat_) # 拼接所有text
|
chat_text.append(chat_) # 拼接所有text
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -173,7 +220,7 @@ def main():
|
|||||||
#将记忆加入到图谱中
|
#将记忆加入到图谱中
|
||||||
for topic, memory in first_memory:
|
for topic, memory in first_memory:
|
||||||
topics = segment_text(topic)
|
topics = segment_text(topic)
|
||||||
print(f"话题: {topic},节点: {topics}, 记忆: {memory}")
|
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||||
for split_topic in topics:
|
for split_topic in topics:
|
||||||
memory_graph.add_dot(split_topic,memory)
|
memory_graph.add_dot(split_topic,memory)
|
||||||
for split_topic in topics:
|
for split_topic in topics:
|
||||||
@@ -182,7 +229,13 @@ def main():
|
|||||||
memory_graph.connect_dot(split_topic, other_split_topic)
|
memory_graph.connect_dot(split_topic, other_split_topic)
|
||||||
|
|
||||||
# memory_graph.store_memory()
|
# memory_graph.store_memory()
|
||||||
visualize_graph(memory_graph)
|
|
||||||
|
# 展示两种不同的可视化方式
|
||||||
|
print("\n按连接数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=False)
|
||||||
|
|
||||||
|
print("\n按记忆数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=True)
|
||||||
|
|
||||||
memory_graph.save_graph_to_db()
|
memory_graph.save_graph_to_db()
|
||||||
# memory_graph.load_graph_from_db()
|
# memory_graph.load_graph_from_db()
|
||||||
@@ -252,45 +305,66 @@ def topic_what(text, topic):
|
|||||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def visualize_graph(memory_graph: Memory_graph):
|
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||||
# 设置中文字体
|
# 设置中文字体
|
||||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||||
|
|
||||||
G = memory_graph.G
|
G = memory_graph.G
|
||||||
|
|
||||||
|
|
||||||
# 保存图到本地
|
# 保存图到本地
|
||||||
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
||||||
|
|
||||||
# 根据连接条数设置节点颜色
|
# 根据连接条数或记忆数量设置节点颜色
|
||||||
node_colors = []
|
node_colors = []
|
||||||
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
||||||
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 # 获取最大连接数
|
|
||||||
|
|
||||||
for node in nodes:
|
if color_by_memory:
|
||||||
degree = G.degree(node) # 获取节点的度
|
# 计算每个节点的记忆数量
|
||||||
# 计算颜色,使用渐变效果
|
memory_counts = []
|
||||||
if max_degree > 0:
|
for node in nodes:
|
||||||
red = min(1.0, degree / max_degree) # 红色分量随连接数增加而增加
|
memory_items = G.nodes[node].get('memory_items', [])
|
||||||
blue = 1.0 - red # 蓝色分量随连接数增加而减少
|
if isinstance(memory_items, list):
|
||||||
color = (red, 0, blue)
|
count = len(memory_items)
|
||||||
else:
|
else:
|
||||||
color = (0, 0, 1) # 如果没有连接,则为蓝色
|
count = 1 if memory_items else 0
|
||||||
node_colors.append(color)
|
memory_counts.append(count)
|
||||||
|
max_memories = max(memory_counts) if memory_counts else 1
|
||||||
|
|
||||||
|
for count in memory_counts:
|
||||||
|
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||||
|
if max_memories > 0:
|
||||||
|
intensity = min(1.0, count / max_memories)
|
||||||
|
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||||
|
node_colors.append(color)
|
||||||
|
else:
|
||||||
|
# 使用原来的连接数量着色方案
|
||||||
|
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
|
||||||
|
for node in nodes:
|
||||||
|
degree = G.degree(node)
|
||||||
|
if max_degree > 0:
|
||||||
|
red = min(1.0, degree / max_degree)
|
||||||
|
blue = 1.0 - red
|
||||||
|
color = (red, 0, blue)
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1)
|
||||||
|
node_colors.append(color)
|
||||||
|
|
||||||
# 绘制图形
|
# 绘制图形
|
||||||
plt.figure(figsize=(12, 8))
|
plt.figure(figsize=(12, 8))
|
||||||
pos = nx.spring_layout(G, k=1, iterations=50) # 使用弹簧布局,调整参数使布局更合理
|
pos = nx.spring_layout(G, k=1, iterations=50)
|
||||||
nx.draw(G, pos,
|
nx.draw(G, pos,
|
||||||
with_labels=True,
|
with_labels=True,
|
||||||
node_color=node_colors,
|
node_color=node_colors,
|
||||||
node_size=2000,
|
node_size=2000,
|
||||||
font_size=10,
|
font_size=10,
|
||||||
font_family='SimHei', # 设置节点标签的字体
|
font_family='SimHei',
|
||||||
font_weight='bold')
|
font_weight='bold')
|
||||||
|
|
||||||
plt.title('记忆图谱可视化', fontsize=16, fontfamily='SimHei')
|
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||||
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user