diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index 7a3e2c758..a2b54eaa5 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -10,6 +10,8 @@ from .relationship_manager import relationship_manager from ..schedule.schedule_generator import bot_schedule from .willing_manager import willing_manager +from ..memory_system.memory import memory_graph + # 获取驱动器 driver = get_driver() @@ -23,6 +25,8 @@ Database.initialize( print("\033[1;32m[初始化配置和数据库完成]\033[0m") + + # 导入其他模块 from .bot import ChatBot from .emoji_manager import emoji_manager diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index efa8e1014..09ee2f063 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -5,7 +5,7 @@ from .storage import MessageStorage from .llm_generator import LLMResponseGenerator from .message_stream import MessageStream, MessageStreamContainer from .topic_identifier import topic_identifier -from random import random +from random import random, choice from .emoji_manager import emoji_manager # 导入表情包管理器 import time import os @@ -15,6 +15,7 @@ from .message import Message_Thinking # 导入 Message_Thinking 类 from .relationship_manager import relationship_manager from .willing_manager import willing_manager # 导入意愿管理器 from .utils import is_mentioned_bot_in_txt, calculate_typing_time +from ..memory_system.memory import memory_graph class ChatBot: def __init__(self, config: BotConfig): @@ -99,6 +100,11 @@ class ChatBot: topic = topic_identifier.identify_topic_jieba(message.processed_plain_text) print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}") + if topic: + for current_topic in topic: + first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) + if first_layer_items: + print(f"\033[1;32m[记忆检索-bot]\033[0m 有印象:{current_topic}") await self.storage.store_message(message, topic[0] if topic else None) diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index b9965470c..f34317c92 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -133,8 +133,8 @@ llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL') if not global_config.enable_advance_output: logger.remove() - logging.getLogger('nonebot').handlers.clear() - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.WARNING) # 只输出 WARNING 及以上级别 - logging.getLogger('nonebot').addHandler(console_handler) - logging.getLogger('nonebot').setLevel(logging.WARNING) + # logging.getLogger('nonebot').handlers.clear() + # console_handler = logging.StreamHandler() + # console_handler.setLevel(logging.WARNING) # 只输出 WARNING 及以上级别 + # logging.getLogger('nonebot').addHandler(console_handler) + # logging.getLogger('nonebot').setLevel(logging.WARNING) diff --git a/src/plugins/chat/knowledege/knowledge_library.py b/src/plugins/chat/knowledege/knowledge_library.py new file mode 100644 index 000000000..40756b413 --- /dev/null +++ b/src/plugins/chat/knowledege/knowledge_library.py @@ -0,0 +1,186 @@ +import os +import sys +import numpy as np +import requests +import time + +# 添加项目根目录到 Python 路径 +root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +sys.path.append(root_path) + +from src.common.database import Database +from src.plugins.chat.config import llm_config + +# 直接配置数据库连接信息 +Database.initialize( + "127.0.0.1", # MongoDB 主机 + 27017, # MongoDB 端口 + "MegBot" # 数据库名称 +) + +class KnowledgeLibrary: + def __init__(self): + self.db = Database.get_instance() + self.raw_info_dir = "data/raw_info" + self._ensure_dirs() + + def _ensure_dirs(self): + """确保必要的目录存在""" + os.makedirs(self.raw_info_dir, exist_ok=True) + + def get_embedding(self, text: str) -> list: + """获取文本的embedding向量""" + url = "https://api.siliconflow.cn/v1/embeddings" + payload = { + "model": "BAAI/bge-m3", + "input": text, + "encoding_format": "float" + } + headers = { + "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}", + "Content-Type": "application/json" + } + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + print(f"获取embedding失败: {response.text}") + return None + + return response.json()['data'][0]['embedding'] + + def process_files(self): + """处理raw_info目录下的所有txt文件""" + for filename in os.listdir(self.raw_info_dir): + if filename.endswith('.txt'): + file_path = os.path.join(self.raw_info_dir, filename) + self.process_single_file(file_path) + + def process_single_file(self, file_path: str): + """处理单个文件""" + try: + # 检查文件是否已处理 + if self.db.db.processed_files.find_one({"file_path": file_path}): + print(f"文件已处理过,跳过: {file_path}") + return + + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # 按1024字符分段 + segments = [content[i:i+300] for i in range(0, len(content), 300)] + + # 处理每个分段 + for segment in segments: + if not segment.strip(): # 跳过空段 + continue + + # 获取embedding + embedding = self.get_embedding(segment) + if not embedding: + continue + + # 存储到数据库 + doc = { + "content": segment, + "embedding": embedding, + "file_path": file_path, + "segment_length": len(segment) + } + + # 使用文本内容的哈希值作为唯一标识 + content_hash = hash(segment) + + # 更新或插入文档 + self.db.db.knowledges.update_one( + {"content_hash": content_hash}, + {"$set": doc}, + upsert=True + ) + + # 记录文件已处理 + self.db.db.processed_files.insert_one({ + "file_path": file_path, + "processed_time": time.time() + }) + + print(f"成功处理文件: {file_path}") + + except Exception as e: + print(f"处理文件 {file_path} 时出错: {str(e)}") + + def search_similar_segments(self, query: str, limit: int = 5) -> list: + """搜索与查询文本相似的片段""" + query_embedding = self.get_embedding(query) + if not query_embedding: + return [] + + # 使用余弦相似度计算 + pipeline = [ + { + "$addFields": { + "dotProduct": { + "$reduce": { + "input": {"$range": [0, {"$size": "$embedding"}]}, + "initialValue": 0, + "in": { + "$add": [ + "$$value", + {"$multiply": [ + {"$arrayElemAt": ["$embedding", "$$this"]}, + {"$arrayElemAt": [query_embedding, "$$this"]} + ]} + ] + } + } + }, + "magnitude1": { + "$sqrt": { + "$reduce": { + "input": "$embedding", + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} + } + } + }, + "magnitude2": { + "$sqrt": { + "$reduce": { + "input": query_embedding, + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} + } + } + } + } + }, + { + "$addFields": { + "similarity": { + "$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}] + } + } + }, + {"$sort": {"similarity": -1}}, + {"$limit": limit}, + {"$project": {"content": 1, "similarity": 1, "file_path": 1}} + ] + + results = list(self.db.db.knowledges.aggregate(pipeline)) + return results + +# 创建单例实例 +knowledge_library = KnowledgeLibrary() + +if __name__ == "__main__": + # 测试知识库功能 + print("开始处理知识库文件...") + knowledge_library.process_files() + + # 测试搜索功能 + test_query = "麦麦评价一下僕と花" + print(f"\n搜索与'{test_query}'相似的内容:") + results = knowledge_library.search_similar_segments(test_query) + for result in results: + print(f"相似度: {result['similarity']:.4f}") + print(f"内容: {result['content'][:100]}...") + print("-" * 50) diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index ac865d9ef..0116969a7 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -6,6 +6,9 @@ import os from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text from ...common.database import Database from .config import global_config +from .topic_identifier import topic_identifier +from ..memory_system.memory import memory_graph +from random import choice # 获取当前文件的绝对路径 current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -35,6 +38,59 @@ class PromptBuilder: Returns: str: 构建好的prompt """ + + + memory_prompt = '' + start_time = time.time() # 记录开始时间 + topic = topic_identifier.identify_topic_jieba(message_txt) + # print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}") + + all_first_layer_items = [] # 存储所有第一层记忆 + all_second_layer_items = {} # 用字典存储每个topic的第二层记忆 + overlapping_second_layer = set() # 存储重叠的第二层记忆 + + if topic: + # 遍历所有topic + for current_topic in topic: + first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) + if first_layer_items: + print(f"\033[1;32m[pb记忆检索]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}") + + # 记录第一层数据 + all_first_layer_items.extend(first_layer_items) + + # 记录第二层数据 + all_second_layer_items[current_topic] = second_layer_items + + # 检查是否有重叠的第二层数据 + for other_topic, other_second_layer in all_second_layer_items.items(): + if other_topic != current_topic: + # 找到重叠的记忆 + overlap = set(second_layer_items) & set(other_second_layer) + if overlap: + print(f"\033[1;32m[pb记忆检索]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}") + overlapping_second_layer.update(overlap) + + # 合并所有需要的记忆 + if all_first_layer_items: + print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") + if overlapping_second_layer: + print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") + + all_memories = all_first_layer_items + list(overlapping_second_layer) + + if all_memories: # 只在列表非空时选择随机项 + random_item = choice(all_memories) + memory_prompt = f"看到这些聊天,你想起来{random_item}\n" + else: + memory_prompt = "" # 如果没有记忆,则返回空字符串 + + end_time = time.time() # 记录结束时间 + print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}秒") # 输出耗时 + + + + #先禁用关系 if 0 > 30: relation_prompt = "关系特别特别好,你很喜欢喜欢他" @@ -55,12 +111,17 @@ class PromptBuilder: prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' #知识构建 + start_time = time.time() + prompt_info = '' promt_info_prompt = '' prompt_info = self.get_prompt_info(message_txt,threshold=0.5) if prompt_info: prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n''' promt_info_prompt = '你有一些[知识],在上面可以参考。' + + end_time = time.time() + print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒") # print(f"\033[1;34m[调试]\033[0m 获取知识库内容结果: {prompt_info}") @@ -69,11 +130,13 @@ class PromptBuilder: chat_talking_prompt = '' if group_id: chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) + + chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") #激活prompt构建 activate_prompt = '' - activate_prompt = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2}。" + activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2}。" #检测机器人相关词汇 bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人'] @@ -87,13 +150,12 @@ class PromptBuilder: prompt_personality = '' personality_choice = random.random() if personality_choice < 4/6: # 第一种人格 - prompt_personality = f'''你的网名叫{global_config.BOT_NICKNAME},是一个学习地质的女大学生,喜欢摄影,你会刷贴吧,你正在浏览qq群,{promt_info_prompt}, - {activate_prompt} + prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},是一个学习地质的女大学生,喜欢摄影,你会刷贴吧,你正在浏览qq群,{promt_info_prompt}, 现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt} 请注意把握群里的聊天内容,不要回复的太有条理,可以有个性。''' elif personality_choice < 1: # 第二种人格 - prompt_personality = f'''你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt}, - {activate_prompt} + prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt}, + 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt} 请你表达自己的见解和观点。可以有个性。''' @@ -108,7 +170,7 @@ class PromptBuilder: #额外信息要求 - extra_info = '''但是记得回复平淡一些,简短一些,不要过多提及自身的背景, 记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' + extra_info = '''但是记得回复平淡一些,简短一些,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' @@ -116,7 +178,10 @@ class PromptBuilder: prompt = "" prompt += f"{prompt_info}\n" prompt += f"{prompt_date}\n" - prompt += f"{chat_talking_prompt}\n" + prompt += f"{chat_talking_prompt}\n" + + # prompt += f"{memory_prompt}\n" + # prompt += f"{activate_prompt}\n" prompt += f"{prompt_personality}\n" prompt += f"{prompt_ger}\n" diff --git a/src/plugins/memory_system/llm_module.py b/src/plugins/memory_system/llm_module.py index a5516012f..fa879afdc 100644 --- a/src/plugins/memory_system/llm_module.py +++ b/src/plugins/memory_system/llm_module.py @@ -2,6 +2,7 @@ import os import requests from dotenv import load_dotenv from typing import Tuple, Union +import time # 加载环境变量 load_dotenv() @@ -32,16 +33,34 @@ class LLMModel: # 发送请求到完整的chat/completions端点 api_url = f"{self.base_url.rstrip('/')}/chat/completions" - try: - response = requests.post(api_url, headers=headers, json=data) - response.raise_for_status() # 检查响应状态 - - result = response.json() - if "choices" in result and len(result["choices"]) > 0: - content = result["choices"][0]["message"]["content"] - reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") - return content, reasoning_content # 返回内容和推理内容 - return "没有返回结果", "" # 返回两个值 - - except requests.exceptions.RequestException as e: - return f"请求失败: {str(e)}", "" # 返回错误信息和空字符串 \ No newline at end of file + max_retries = 3 + base_wait_time = 15 # 基础等待时间(秒) + + for retry in range(max_retries): + try: + response = requests.post(api_url, headers=headers, json=data) + + if response.status_code == 429: + wait_time = base_wait_time * (2 ** retry) # 指数退避 + print(f"遇到请求限制(429),等待{wait_time}秒后重试...") + time.sleep(wait_time) + continue + + response.raise_for_status() # 检查其他响应状态 + + result = response.json() + if "choices" in result and len(result["choices"]) > 0: + content = result["choices"][0]["message"]["content"] + reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") + return content, reasoning_content + return "没有返回结果", "" + + except requests.exceptions.RequestException as e: + if retry < max_retries - 1: # 如果还有重试机会 + wait_time = base_wait_time * (2 ** retry) + print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + time.sleep(wait_time) + else: + return f"请求失败: {str(e)}", "" + + return "达到最大重试次数,请求仍然失败", "" \ No newline at end of file diff --git a/src/plugins/memory_system/memory copy.py b/src/plugins/memory_system/memory copy.py new file mode 100644 index 000000000..074a95b19 --- /dev/null +++ b/src/plugins/memory_system/memory copy.py @@ -0,0 +1,376 @@ +# -*- coding: utf-8 -*- +import sys +import jieba +from llm_module import LLMModel +import networkx as nx +import matplotlib.pyplot as plt +import math +from collections import Counter +import datetime +import random +import time +# from chat.config import global_config +import sys +sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 +from src.common.database import Database # 使用正确的导入语法 + +class Memory_graph: + def __init__(self): + self.G = nx.Graph() # 使用 networkx 的图结构 + self.db = Database.get_instance() + + def connect_dot(self, concept1, concept2): + self.G.add_edge(concept1, concept2) + + def add_dot(self, concept, memory): + if concept in self.G: + # 如果节点已存在,将新记忆添加到现有列表中 + if 'memory_items' in self.G.nodes[concept]: + if not isinstance(self.G.nodes[concept]['memory_items'], list): + # 如果当前不是列表,将其转换为列表 + self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']] + self.G.nodes[concept]['memory_items'].append(memory) + else: + self.G.nodes[concept]['memory_items'] = [memory] + else: + # 如果是新节点,创建新的记忆列表 + self.G.add_node(concept, memory_items=[memory]) + + def get_dot(self, concept): + # 检查节点是否存在于图中 + if concept in self.G: + # 从图中获取节点数据 + node_data = self.G.nodes[concept] + # print(node_data) + # 创建新的Memory_dot对象 + return concept,node_data + return None + + def get_related_item(self, topic, depth=1): + if topic not in self.G: + return [], [] + + first_layer_items = [] + second_layer_items = [] + + # 获取相邻节点 + neighbors = list(self.G.neighbors(topic)) + # print(f"第一层: {topic}") + + # 获取当前节点的记忆项 + node_data = self.get_dot(topic) + if node_data: + concept, data = node_data + if 'memory_items' in data: + memory_items = data['memory_items'] + if isinstance(memory_items, list): + first_layer_items.extend(memory_items) + else: + first_layer_items.append(memory_items) + + # 只在depth=2时获取第二层记忆 + if depth >= 2: + # 获取相邻节点的记忆项 + for neighbor in neighbors: + # print(f"第二层: {neighbor}") + node_data = self.get_dot(neighbor) + if node_data: + concept, data = node_data + if 'memory_items' in data: + memory_items = data['memory_items'] + if isinstance(memory_items, list): + second_layer_items.extend(memory_items) + else: + second_layer_items.append(memory_items) + + return first_layer_items, second_layer_items + + def store_memory(self): + for node in self.G.nodes(): + dot_data = { + "concept": node + } + self.db.db.store_memory_dots.insert_one(dot_data) + + @property + def dots(self): + # 返回所有节点对应的 Memory_dot 对象 + return [self.get_dot(node) for node in self.G.nodes()] + + + def get_random_chat_from_db(self, length: int, timestamp: str): + # 从数据库中根据时间戳获取离其最近的聊天记录 + chat_text = '' + closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 + print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") + + if closest_record: + closest_time = closest_record['time'] + group_id = closest_record['group_id'] # 获取groupid + # 获取该时间戳之后的length条消息,且groupid相同 + chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) + for record in chat_record: + time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) + chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 + return chat_text + + return [] # 如果没有找到记录,返回空列表 + + def save_graph_to_db(self): + # 清空现有的图数据 + self.db.db.graph_data.delete_many({}) + # 保存节点 + for node in self.G.nodes(data=True): + node_data = { + 'concept': node[0], + 'memory_items': node[1].get('memory_items', []) # 默认为空列表 + } + self.db.db.graph_data.nodes.insert_one(node_data) + # 保存边 + for edge in self.G.edges(): + edge_data = { + 'source': edge[0], + 'target': edge[1] + } + self.db.db.graph_data.edges.insert_one(edge_data) + + def load_graph_from_db(self): + # 清空当前图 + self.G.clear() + # 加载节点 + nodes = self.db.db.graph_data.nodes.find() + for node in nodes: + memory_items = node.get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + self.G.add_node(node['concept'], memory_items=memory_items) + # 加载边 + edges = self.db.db.graph_data.edges.find() + for edge in edges: + self.G.add_edge(edge['source'], edge['target']) + +def calculate_information_content(text): + + """计算文本的信息量(熵)""" + # 统计字符频率 + char_count = Counter(text) + total_chars = len(text) + + # 计算熵 + entropy = 0 + for count in char_count.values(): + probability = count / total_chars + entropy -= probability * math.log2(probability) + + return entropy + + +# Database.initialize( +# global_config.MONGODB_HOST, +# global_config.MONGODB_PORT, +# global_config.DATABASE_NAME +# ) +# memory_graph = Memory_graph() + +# llm_model = LLMModel() +# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") + +# memory_graph.load_graph_from_db() + + + +def main(): + # 初始化数据库 + Database.initialize( + "127.0.0.1", + 27017, + "MegBot" + ) + + memory_graph = Memory_graph() + # 创建LLM模型实例 + llm_model = LLMModel() + llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") + + # 使用当前时间戳进行测试 + current_timestamp = datetime.datetime.now().timestamp() + chat_text = [] + + chat_size =40 + + for _ in range(100): # 循环10次 + random_time = current_timestamp - random.randint(1, 3600*39) # 随机时间 + print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") + chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time) + chat_text.append(chat_) # 拼接所有text + # time.sleep(1) + + + + for i, input_text in enumerate(chat_text, 1): + + progress = (i / len(chat_text)) * 100 + bar_length = 30 + filled_length = int(bar_length * i // len(chat_text)) + bar = '█' * filled_length + '-' * (bar_length - filled_length) + print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})") + + # print(input_text) + first_memory = set() + first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5) + time.sleep(5) + + #将记忆加入到图谱中 + for topic, memory in first_memory: + topics = segment_text(topic) + print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") + for split_topic in topics: + memory_graph.add_dot(split_topic,memory) + for split_topic in topics: + for other_split_topic in topics: + if split_topic != other_split_topic: + memory_graph.connect_dot(split_topic, other_split_topic) + + # memory_graph.store_memory() + + # 展示两种不同的可视化方式 + print("\n按连接数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=False) + + print("\n按记忆数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=True) + + memory_graph.save_graph_to_db() + # memory_graph.load_graph_from_db() + + while True: + query = input("请输入新的查询概念(输入'退出'以结束):") + if query.lower() == '退出': + break + items_list = memory_graph.get_related_item(query) + if items_list: + # print(items_list) + for memory_item in items_list: + print(memory_item) + else: + print("未找到相关记忆。") + + while True: + query = input("请输入问题:") + + if query.lower() == '退出': + break + + topic_prompt = find_topic(query, 3) + topic_response = llm_model.generate_response(topic_prompt) + # 检查 topic_response 是否为元组 + if isinstance(topic_response, tuple): + topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 + else: + topics = topic_response.split(",") + print(topics) + + for keyword in topics: + items_list = memory_graph.get_related_item(keyword) + if items_list: + print(items_list) + +def memory_compress(input_text, llm_model, llm_model_small, rate=1): + information_content = calculate_information_content(input_text) + print(f"文本的信息量(熵): {information_content:.4f} bits") + topic_num = max(1, min(5, int(information_content * rate / 4))) + print(topic_num) + topic_prompt = find_topic(input_text, topic_num) + topic_response = llm_model.generate_response(topic_prompt) + # 检查 topic_response 是否为元组 + if isinstance(topic_response, tuple): + topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 + else: + topics = topic_response.split(",") + print(topics) + compressed_memory = set() + for topic in topics: + topic_what_prompt = topic_what(input_text,topic) + topic_what_response = llm_model_small.generate_response(topic_what_prompt) + compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 + return compressed_memory + + +def segment_text(text): + seg_text = list(jieba.cut(text)) + return seg_text + +def find_topic(text, topic_num): + prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。' + return prompt + +def topic_what(text, topic): + prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好' + return prompt + +def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): + # 设置中文字体 + plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 + plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + + G = memory_graph.G + + # 保存图到本地 + nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式 + + # 根据连接条数或记忆数量设置节点颜色 + node_colors = [] + nodes = list(G.nodes()) # 获取图中实际的节点列表 + + if color_by_memory: + # 计算每个节点的记忆数量 + memory_counts = [] + for node in nodes: + memory_items = G.nodes[node].get('memory_items', []) + if isinstance(memory_items, list): + count = len(memory_items) + else: + count = 1 if memory_items else 0 + memory_counts.append(count) + max_memories = max(memory_counts) if memory_counts else 1 + + for count in memory_counts: + # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少 + if max_memories > 0: + intensity = min(1.0, count / max_memories) + color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色 + else: + color = (0, 0, 1) # 如果没有记忆,则为蓝色 + node_colors.append(color) + else: + # 使用原来的连接数量着色方案 + max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 + for node in nodes: + degree = G.degree(node) + if max_degree > 0: + red = min(1.0, degree / max_degree) + blue = 1.0 - red + color = (red, 0, blue) + else: + color = (0, 0, 1) + node_colors.append(color) + + # 绘制图形 + plt.figure(figsize=(12, 8)) + pos = nx.spring_layout(G, k=1, iterations=50) + nx.draw(G, pos, + with_labels=True, + node_color=node_colors, + node_size=2000, + font_size=10, + font_family='SimHei', + font_weight='bold') + + title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色') + plt.title(title, fontsize=16, fontfamily='SimHei') + plt.show() + +if __name__ == "__main__": + main() + + diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index d8f644d7c..3f216997f 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import sys import jieba -from llm_module import LLMModel +from .llm_module import LLMModel import networkx as nx import matplotlib.pyplot as plt import math @@ -9,9 +9,9 @@ from collections import Counter import datetime import random import time - +from ..chat.config import global_config import sys -sys.path.append("C:/GitHub/MegMeg-bot") # 添加项目根目录到 Python 路径 +sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 from src.common.database import Database # 使用正确的导入语法 class Memory_graph: @@ -23,44 +23,67 @@ class Memory_graph: self.G.add_edge(concept1, concept2) def add_dot(self, concept, memory): - self.G.add_node(concept, memory_items=memory) + if concept in self.G: + # 如果节点已存在,将新记忆添加到现有列表中 + if 'memory_items' in self.G.nodes[concept]: + if not isinstance(self.G.nodes[concept]['memory_items'], list): + # 如果当前不是列表,将其转换为列表 + self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']] + self.G.nodes[concept]['memory_items'].append(memory) + else: + self.G.nodes[concept]['memory_items'] = [memory] + else: + # 如果是新节点,创建新的记忆列表 + self.G.add_node(concept, memory_items=[memory]) def get_dot(self, concept): # 检查节点是否存在于图中 if concept in self.G: # 从图中获取节点数据 node_data = self.G.nodes[concept] - print(node_data) + # print(node_data) # 创建新的Memory_dot对象 return concept,node_data return None def get_related_item(self, topic, depth=1): if topic not in self.G: - return set() + return [], [] - items_set = set() + first_layer_items = [] + second_layer_items = [] + # 获取相邻节点 neighbors = list(self.G.neighbors(topic)) - print(f"第一层: {topic}") + # print(f"第一层: {topic}") # 获取当前节点的记忆项 node_data = self.get_dot(topic) if node_data: concept, data = node_data if 'memory_items' in data: - items_set.add(data['memory_items']) + memory_items = data['memory_items'] + if isinstance(memory_items, list): + first_layer_items.extend(memory_items) + else: + first_layer_items.append(memory_items) - # 获取相邻节点的记忆项 - for neighbor in neighbors: - print(f"第二层: {neighbor}") - node_data = self.get_dot(neighbor) - if node_data: - concept, data = node_data - if 'memory_items' in data: - items_set.add(data['memory_items']) + # 只在depth=2时获取第二层记忆 + if depth >= 2: + # 获取相邻节点的记忆项 + for neighbor in neighbors: + # print(f"第二层: {neighbor}") + node_data = self.get_dot(neighbor) + if node_data: + concept, data = node_data + if 'memory_items' in data: + memory_items = data['memory_items'] + if isinstance(memory_items, list): + second_layer_items.extend(memory_items) + else: + second_layer_items.append(memory_items) - return items_set + return first_layer_items, second_layer_items def store_memory(self): for node in self.G.nodes(): @@ -100,7 +123,7 @@ class Memory_graph: for node in self.G.nodes(data=True): node_data = { 'concept': node[0], - 'memory_items': node[1].get('memory_items', None) + 'memory_items': node[1].get('memory_items', []) # 默认为空列表 } self.db.db.graph_data.nodes.insert_one(node_data) # 保存边 @@ -117,7 +140,10 @@ class Memory_graph: # 加载节点 nodes = self.db.db.graph_data.nodes.find() for node in nodes: - self.G.add_node(node['concept'], memory_items=node['memory_items']) + memory_items = node.get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + self.G.add_node(node['concept'], memory_items=memory_items) # 加载边 edges = self.db.db.graph_data.edges.find() for edge in edges: @@ -138,6 +164,26 @@ def calculate_information_content(text): return entropy + +start_time = time.time() + +Database.initialize( + global_config.MONGODB_HOST, + global_config.MONGODB_PORT, + global_config.DATABASE_NAME +) +memory_graph = Memory_graph() + +llm_model = LLMModel() +llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") + +memory_graph.load_graph_from_db() + +end_time = time.time() +print(f"加载海马体耗时: {end_time - start_time:.2f} 秒") + + + def main(): # 初始化数据库 Database.initialize( @@ -155,13 +201,14 @@ def main(): current_timestamp = datetime.datetime.now().timestamp() chat_text = [] - chat_size =30 + chat_size =40 - for _ in range(60): # 循环10次 - random_time = current_timestamp - random.randint(1, 3600*3) # 随机时间 + for _ in range(100): # 循环10次 + random_time = current_timestamp - random.randint(1, 3600*39) # 随机时间 print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time) chat_text.append(chat_) # 拼接所有text + time.sleep(5) @@ -173,7 +220,7 @@ def main(): #将记忆加入到图谱中 for topic, memory in first_memory: topics = segment_text(topic) - print(f"话题: {topic},节点: {topics}, 记忆: {memory}") + print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") for split_topic in topics: memory_graph.add_dot(split_topic,memory) for split_topic in topics: @@ -182,7 +229,13 @@ def main(): memory_graph.connect_dot(split_topic, other_split_topic) # memory_graph.store_memory() - visualize_graph(memory_graph) + + # 展示两种不同的可视化方式 + print("\n按连接数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=False) + + print("\n按记忆数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=True) memory_graph.save_graph_to_db() # memory_graph.load_graph_from_db() @@ -252,45 +305,66 @@ def topic_what(text, topic): prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好' return prompt -def visualize_graph(memory_graph: Memory_graph): +def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 G = memory_graph.G - # 保存图到本地 nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式 - # 根据连接条数设置节点颜色 + # 根据连接条数或记忆数量设置节点颜色 node_colors = [] nodes = list(G.nodes()) # 获取图中实际的节点列表 - max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 # 获取最大连接数 - for node in nodes: - degree = G.degree(node) # 获取节点的度 - # 计算颜色,使用渐变效果 - if max_degree > 0: - red = min(1.0, degree / max_degree) # 红色分量随连接数增加而增加 - blue = 1.0 - red # 蓝色分量随连接数增加而减少 - color = (red, 0, blue) - else: - color = (0, 0, 1) # 如果没有连接,则为蓝色 - node_colors.append(color) + if color_by_memory: + # 计算每个节点的记忆数量 + memory_counts = [] + for node in nodes: + memory_items = G.nodes[node].get('memory_items', []) + if isinstance(memory_items, list): + count = len(memory_items) + else: + count = 1 if memory_items else 0 + memory_counts.append(count) + max_memories = max(memory_counts) if memory_counts else 1 + + for count in memory_counts: + # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少 + if max_memories > 0: + intensity = min(1.0, count / max_memories) + color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色 + else: + color = (0, 0, 1) # 如果没有记忆,则为蓝色 + node_colors.append(color) + else: + # 使用原来的连接数量着色方案 + max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 + for node in nodes: + degree = G.degree(node) + if max_degree > 0: + red = min(1.0, degree / max_degree) + blue = 1.0 - red + color = (red, 0, blue) + else: + color = (0, 0, 1) + node_colors.append(color) # 绘制图形 plt.figure(figsize=(12, 8)) - pos = nx.spring_layout(G, k=1, iterations=50) # 使用弹簧布局,调整参数使布局更合理 + pos = nx.spring_layout(G, k=1, iterations=50) nx.draw(G, pos, with_labels=True, node_color=node_colors, node_size=2000, font_size=10, - font_family='SimHei', # 设置节点标签的字体 + font_family='SimHei', font_weight='bold') - plt.title('记忆图谱可视化', fontsize=16, fontfamily='SimHei') + title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色') + plt.title(title, fontsize=16, fontfamily='SimHei') plt.show() if __name__ == "__main__":