From 50c1765b81331afad227b1a0e4c965cb798e36c8 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sun, 2 Mar 2025 00:14:25 +0800 Subject: [PATCH] =?UTF-8?q?v0.3.1=20=E5=AE=9E=E8=A3=85=E4=BA=86=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E7=B3=BB=E7=BB=9F=E5=92=8C=E8=87=AA=E5=8A=A8=E5=8F=91?= =?UTF-8?q?=E8=A8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 哈哈哈 --- .gitignore | 2 +- README.md | 14 +- bot_config.toml => config/bot_config_toml | 5 + env.example => config/env.example | 0 src/plugins/chat/__init__.py | 43 +- src/plugins/chat/bot.py | 17 +- src/plugins/chat/config.py | 53 ++- src/plugins/chat/llm_generator.py | 4 +- src/plugins/chat/message.py | 10 +- src/plugins/chat/message_send_control.py | 10 +- src/plugins/chat/prompt_builder.py | 10 +- src/plugins/chat/utils.py | 35 ++ src/plugins/chat/utils_image.py | 16 +- src/plugins/chat/willing_manager.py | 27 +- .../knowledege/knowledge_library.py | 0 src/plugins/memory_system/draw_memory.py | 264 ++++++++++++ .../memory_system/llm_module_memory_make.py | 82 ++++ src/plugins/memory_system/memory.py | 377 +++++++----------- .../{memory copy.py => memory_make.py} | 90 ++++- 19 files changed, 732 insertions(+), 327 deletions(-) rename bot_config.toml => config/bot_config_toml (95%) rename env.example => config/env.example (100%) rename src/plugins/{chat => }/knowledege/knowledge_library.py (100%) create mode 100644 src/plugins/memory_system/draw_memory.py create mode 100644 src/plugins/memory_system/llm_module_memory_make.py rename src/plugins/memory_system/{memory copy.py => memory_make.py} (82%) diff --git a/.gitignore b/.gitignore index a70c66cdf..265108181 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ mongodb/ NapCat.Framework.Windows.Once/ log/ src/plugins/memory -src/plugins/chat/bot_config.toml +config/bot_config.toml /test message_queue_content.txt message_queue_content.bat diff --git a/README.md b/README.md index 2366fe87b..a85fcc4e8 100644 --- a/README.md +++ b/README.md @@ -16,11 +16,19 @@ 基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot +
+ + 麦麦演示视频 +
+ 👆 点击观看麦麦演示视频 👆 +
+
+ > ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本 > ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次 > ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语( -关于麦麦的开发和部署相关的讨论群(不建议发布无关消息)这里不会有麦麦发言! +关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言! ## 开发计划TODO:LIST @@ -29,6 +37,10 @@ - 对思考链长度限制 - 修复已知bug - 完善文档 +- 修复转发 +- config自动生成和检测 +- log别用print +- 给发送消息写专门的类
diff --git a/bot_config.toml b/config/bot_config_toml similarity index 95% rename from bot_config.toml rename to config/bot_config_toml index 6730f0481..b5011c7f9 100644 --- a/bot_config.toml +++ b/config/bot_config_toml @@ -29,6 +29,11 @@ model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率 model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率 model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率 +[memory] +build_memory_interval = 300 # 记忆构建间隔 + + + [others] enable_advance_output = true # 开启后输出更多日志,false关闭true开启 diff --git a/env.example b/config/env.example similarity index 100% rename from env.example rename to config/env.example diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index a2b54eaa5..1c25a24f1 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -1,3 +1,4 @@ +from loguru import logger from nonebot import on_message, on_command, require, get_driver from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment from nonebot.typing import T_State @@ -10,9 +11,6 @@ from .relationship_manager import relationship_manager from ..schedule.schedule_generator import bot_schedule from .willing_manager import willing_manager -from ..memory_system.memory import memory_graph - - # 获取驱动器 driver = get_driver() @@ -21,10 +19,7 @@ Database.initialize( global_config.MONGODB_PORT, global_config.DATABASE_NAME ) - -print("\033[1;32m[初始化配置和数据库完成]\033[0m") - - +print("\033[1;32m[初始化数据库完成]\033[0m") # 导入其他模块 @@ -32,6 +27,7 @@ from .bot import ChatBot from .emoji_manager import emoji_manager from .message_send_control import message_sender from .relationship_manager import relationship_manager +from ..memory_system.memory import memory_graph,hippocampus # 初始化表情管理器 emoji_manager.initialize() @@ -39,21 +35,26 @@ emoji_manager.initialize() print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m") # 创建机器人实例 chat_bot = ChatBot(global_config) - # 注册消息处理器 group_msg = on_message() - # 创建定时任务 scheduler = require("nonebot_plugin_apscheduler").scheduler -# 启动后台任务 + + @driver.on_startup async def start_background_tasks(): """启动后台任务""" # 只启动表情包管理任务 asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL)) - bot_schedule.print_schedule() + +@driver.on_startup +async def init_relationships(): + """在 NoneBot2 启动时初始化关系管理器""" + print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...") + await relationship_manager.load_all_relationships() + asyncio.create_task(relationship_manager._start_relationship_manager()) @driver.on_bot_connect async def _(bot: Bot): @@ -68,19 +69,23 @@ async def _(bot: Bot): print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") # 启动消息发送控制任务 -@driver.on_startup -async def init_relationships(): - """在 NoneBot2 启动时初始化关系管理器""" - print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...") - await relationship_manager.load_all_relationships() - asyncio.create_task(relationship_manager._start_relationship_manager()) - @group_msg.handle() async def _(bot: Bot, event: GroupMessageEvent, state: T_State): await chat_bot.handle_message(event, bot) - + +''' @scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships") async def monitor_relationships(): """每15秒打印一次关系数据""" relationship_manager.print_all_relationships() +''' +# 添加build_memory定时任务 +@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") +async def build_memory_task(): + """每30秒执行一次记忆构建""" + print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...") + hippocampus.build_memory(chat_size=12) + print("\033[1;32m[记忆构建]\033[0m 记忆构建完成") + + diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 09ee2f063..1b5201645 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -83,7 +83,7 @@ class ChatBot: await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info) await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5) - print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}") + # print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}") message = Message( @@ -100,14 +100,19 @@ class ChatBot: topic = topic_identifier.identify_topic_jieba(message.processed_plain_text) print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}") + all_num = 0 + interested_num = 0 if topic: for current_topic in topic: + all_num += 1 first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) if first_layer_items: - print(f"\033[1;32m[记忆检索-bot]\033[0m 有印象:{current_topic}") + interested_num += 1 + print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象") + interested_rate = interested_num / all_num if all_num > 0 else 0 + await self.storage.store_message(message, topic[0] if topic else None) - is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text) @@ -117,7 +122,8 @@ class ChatBot: is_mentioned, self.config, event.user_id, - message.is_emoji + message.is_emoji, + interested_rate ) current_willing = willing_manager.get_willing(event.group_id) @@ -188,7 +194,8 @@ class ChatBot: user_nickname=global_config.BOT_NICKNAME, group_name=message.group_name, time=bot_response_time, - is_emoji=True + is_emoji=True, + translate_cq=False ) message_sender.send_temp_container.add_message(bot_message) diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index 69e59ed5b..05d492789 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -6,6 +6,8 @@ import logging import configparser import tomli import sys +from loguru import logger +from dotenv import load_dotenv @@ -21,7 +23,7 @@ class BotConfig: MONGODB_PASSWORD: Optional[str] = None # 默认空值 MONGODB_AUTH_SOURCE: Optional[str] = None # 默认空值 - BOT_QQ: Optional[int] = None + BOT_QQ: Optional[int] = 1 BOT_NICKNAME: Optional[str] = None # 消息处理相关配置 @@ -35,6 +37,7 @@ class BotConfig: talk_frequency_down_groups = set() ban_user_id = set() + build_memory_interval: int = 60 # 记忆构建间隔(秒) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) @@ -45,9 +48,21 @@ class BotConfig: enable_advance_output: bool = False # 是否启用高级输出 + @staticmethod + def get_default_config_path() -> str: + """获取默认配置文件路径""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) + config_dir = os.path.join(root_dir, 'config') + return os.path.join(config_dir, 'bot_config.toml') + @classmethod - def load_config(cls, config_path: str = "bot_config.toml") -> "BotConfig": + def load_config(cls, config_path: str = None) -> "BotConfig": """从TOML配置文件加载配置""" + if config_path is None: + config_path = cls.get_default_config_path() + logger.info(f"使用默认配置文件路径: {config_path}") + config = cls() if os.path.exists(config_path): with open(config_path, "rb") as f: @@ -93,6 +108,10 @@ class BotConfig: config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE) config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance) + if "memory" in toml_dict: + memory_config = toml_dict["memory"] + config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval) + # 群组配置 if "groups" in toml_dict: groups_config = toml_dict["groups"] @@ -104,16 +123,26 @@ class BotConfig: others_config = toml_dict["others"] config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output) - print(f"\033[1;32m成功加载配置文件: {config_path}\033[0m") + logger.success(f"成功加载配置文件: {config_path}") return config -global_config = BotConfig.load_config(".bot_config.toml") +# 获取配置文件路径 +bot_config_path = BotConfig.get_default_config_path() +config_dir = os.path.dirname(bot_config_path) +env_path = os.path.join(config_dir, '.env') -from dotenv import load_dotenv -current_dir = os.path.dirname(os.path.abspath(__file__)) -root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) -load_dotenv(os.path.join(root_dir, '.env')) +logger.info(f"尝试从 {bot_config_path} 加载机器人配置") +global_config = BotConfig.load_config(config_path=bot_config_path) + +# 加载环境变量 + +logger.info(f"尝试从 {env_path} 加载环境变量配置") +if os.path.exists(env_path): + load_dotenv(env_path) + logger.success("成功加载环境变量配置") +else: + logger.error(f"环境变量配置文件不存在: {env_path}") @dataclass class LLMConfig: @@ -132,9 +161,5 @@ llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL') if not global_config.enable_advance_output: - # 只降低日志级别而不是完全移除 - logger.remove() - logger.add(sys.stderr, level="WARNING") # 添加一个只输出 WARNING 及以上级别的处理器 - - # 设置 nonebot 的日志级别 - logging.getLogger('nonebot').setLevel(logging.WARNING) + # logger.remove() + pass diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index bb68d3618..2ea4d7f24 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -4,7 +4,7 @@ import asyncio import requests from functools import partial from .message import Message -from .config import BotConfig +from .config import BotConfig, global_config from ...common.database import Database import random import time @@ -255,4 +255,4 @@ class LLMResponseGenerator: return processed_response, emotion_tags # 创建全局实例 -llm_response = LLMResponseGenerator(config=BotConfig()) \ No newline at end of file +llm_response = LLMResponseGenerator(global_config) \ No newline at end of file diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 2e91f530e..f5ea0db0d 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -6,17 +6,13 @@ import os from datetime import datetime from ...common.database import Database from PIL import Image -from .config import BotConfig, global_config +from .config import global_config import urllib3 from .utils_user import get_user_nickname from .utils_cq import parse_cq_code from .cq_code import cq_code_tool,CQCode Message = ForwardRef('Message') # 添加这行 - -# 加载配置 -bot_config = BotConfig.load_config() - # 禁用SSL警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -48,6 +44,8 @@ class Message: is_emoji: bool = False # 是否是表情包 has_emoji: bool = False # 是否包含表情包 + + translate_cq: bool = True # 是否翻译cq码 reply_benefits: float = 0.0 @@ -99,7 +97,7 @@ class Message: - cq_code_list:分割出的聊天对象,包括文本和CQ码 - trans_list:翻译后的对象列表 """ - print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}") + # print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}") cq_code_dict_list = [] trans_list = [] diff --git a/src/plugins/chat/message_send_control.py b/src/plugins/chat/message_send_control.py index cb45b3132..0ddb79c5f 100644 --- a/src/plugins/chat/message_send_control.py +++ b/src/plugins/chat/message_send_control.py @@ -208,7 +208,15 @@ class MessageSendControl: print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}秒") current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}") - await self.storage.store_message(message, None) + + if message.is_emoji: + message.processed_plain_text = "[表情包]" + await self.storage.store_message(message, None) + else: + await self.storage.store_message(message, None) + + + queue.update_send_time() if queue.has_messages(): await asyncio.sleep( diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index 4e72c6304..da9037cfa 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -53,8 +53,8 @@ class PromptBuilder: # 遍历所有topic for current_topic in topic: first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) - if first_layer_items: - print(f"\033[1;32m[pb记忆检索]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}") + # if first_layer_items: + # print(f"\033[1;32m[前额叶]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}") # 记录第一层数据 all_first_layer_items.extend(first_layer_items) @@ -68,14 +68,14 @@ class PromptBuilder: # 找到重叠的记忆 overlap = set(second_layer_items) & set(other_second_layer) if overlap: - print(f"\033[1;32m[pb记忆检索]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}") + # print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}") overlapping_second_layer.update(overlap) # 合并所有需要的记忆 if all_first_layer_items: - print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") + print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") if overlapping_second_layer: - print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") + print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") all_memories = all_first_layer_items + list(overlapping_second_layer) diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 58e2280cc..4e2235805 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -7,6 +7,8 @@ import numpy as np from .config import llm_config, global_config import re from typing import Dict +from collections import Counter +import math def combine_messages(messages: List[Message]) -> str: @@ -81,6 +83,39 @@ def cosine_similarity(v1, v2): norm2 = np.linalg.norm(v2) return dot_product / (norm1 * norm2) +def calculate_information_content(text): + """计算文本的信息量(熵)""" + # 统计字符频率 + char_count = Counter(text) + total_chars = len(text) + + # 计算熵 + entropy = 0 + for count in char_count.values(): + probability = count / total_chars + entropy -= probability * math.log2(probability) + + return entropy + +def get_cloest_chat_from_db(db, length: int, timestamp: str): + # 从数据库中根据时间戳获取离其最近的聊天记录 + chat_text = '' + closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 + # print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") + + if closest_record: + closest_time = closest_record['time'] + group_id = closest_record['group_id'] # 获取groupid + # 获取该时间戳之后的length条消息,且groupid相同 + chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) + for record in chat_record: + time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) + chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 + return chat_text + + return [] # 如果没有找到记录,返回空列表 + + def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: """从数据库获取群组最近的消息记录 diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index e1a882341..9fe2c40cc 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -4,11 +4,9 @@ import hashlib import time import os from ...common.database import Database -from .config import BotConfig import zlib # 用于 CRC32 import base64 - -bot_config = BotConfig.load_config() +from .config import global_config def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes: @@ -39,12 +37,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes: # 连接数据库 db = Database( - host=bot_config.MONGODB_HOST, - port=bot_config.MONGODB_PORT, - db_name=bot_config.DATABASE_NAME, - username=bot_config.MONGODB_USERNAME, - password=bot_config.MONGODB_PASSWORD, - auth_source=bot_config.MONGODB_AUTH_SOURCE + host=global_config.MONGODB_HOST, + port=global_config.MONGODB_PORT, + db_name=global_config.DATABASE_NAME, + username=global_config.MONGODB_USERNAME, + password=global_config.MONGODB_PASSWORD, + auth_source=global_config.MONGODB_AUTH_SOURCE ) # 检查是否已存在相同哈希值的图片 diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py index df41ba42f..037c2d517 100644 --- a/src/plugins/chat/willing_manager.py +++ b/src/plugins/chat/willing_manager.py @@ -22,22 +22,31 @@ class WillingManager: """设置指定群组的回复意愿""" self.group_reply_willing[group_id] = willing - def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False) -> float: + def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float: """改变指定群组的回复意愿并返回回复概率""" current_willing = self.group_reply_willing.get(group_id, 0) - if topic and current_willing < 1: - current_willing += 0.2 - elif topic: - current_willing += 0.05 + print(f"初始意愿: {current_willing}") + + # if topic and current_willing < 1: + # current_willing += 0.2 + # elif topic: + # current_willing += 0.05 if is_mentioned_bot and current_willing < 1.0: current_willing += 0.9 + print(f"被提及, 当前意愿: {current_willing}") elif is_mentioned_bot: current_willing += 0.05 + print(f"被重复提及, 当前意愿: {current_willing}") if is_emoji: - current_willing *= 0.2 + current_willing *= 0.15 + print(f"表情包, 当前意愿: {current_willing}") + + if interested_rate > 0.6: + print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") + current_willing += interested_rate-0.45 self.group_reply_willing[group_id] = min(current_willing, 3.0) @@ -55,15 +64,15 @@ class WillingManager: return reply_probability def change_reply_willing_sent(self, group_id: int): - """发送消息后降低群组的回复意愿""" + """开始思考后降低群组的回复意愿""" current_willing = self.group_reply_willing.get(group_id, 0) - self.group_reply_willing[group_id] = max(0, current_willing - 1.8) + self.group_reply_willing[group_id] = max(0, current_willing - 2) def change_reply_willing_after_sent(self, group_id: int): """发送消息后提高群组的回复意愿""" current_willing = self.group_reply_willing.get(group_id, 0) if current_willing < 1: - self.group_reply_willing[group_id] = min(1, current_willing + 0.4) + self.group_reply_willing[group_id] = min(1, current_willing + 0.3) async def ensure_started(self): """确保衰减任务已启动""" diff --git a/src/plugins/chat/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py similarity index 100% rename from src/plugins/chat/knowledege/knowledge_library.py rename to src/plugins/knowledege/knowledge_library.py diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py new file mode 100644 index 000000000..651d5fbca --- /dev/null +++ b/src/plugins/memory_system/draw_memory.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +import sys +import jieba +from llm_module import LLMModel +import networkx as nx +import matplotlib.pyplot as plt +import math +from collections import Counter +import datetime +import random +import time +# from chat.config import global_config +import sys +sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 +from src.common.database import Database # 使用正确的导入语法 + +class Memory_graph: + def __init__(self): + self.G = nx.Graph() # 使用 networkx 的图结构 + self.db = Database.get_instance() + + def connect_dot(self, concept1, concept2): + self.G.add_edge(concept1, concept2) + + def add_dot(self, concept, memory): + if concept in self.G: + # 如果节点已存在,将新记忆添加到现有列表中 + if 'memory_items' in self.G.nodes[concept]: + if not isinstance(self.G.nodes[concept]['memory_items'], list): + # 如果当前不是列表,将其转换为列表 + self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']] + self.G.nodes[concept]['memory_items'].append(memory) + else: + self.G.nodes[concept]['memory_items'] = [memory] + else: + # 如果是新节点,创建新的记忆列表 + self.G.add_node(concept, memory_items=[memory]) + + def get_dot(self, concept): + # 检查节点是否存在于图中 + if concept in self.G: + # 从图中获取节点数据 + node_data = self.G.nodes[concept] + # print(node_data) + # 创建新的Memory_dot对象 + return concept,node_data + return None + + def get_related_item(self, topic, depth=1): + if topic not in self.G: + return [], [] + + first_layer_items = [] + second_layer_items = [] + + # 获取相邻节点 + neighbors = list(self.G.neighbors(topic)) + # print(f"第一层: {topic}") + + # 获取当前节点的记忆项 + node_data = self.get_dot(topic) + if node_data: + concept, data = node_data + if 'memory_items' in data: + memory_items = data['memory_items'] + if isinstance(memory_items, list): + first_layer_items.extend(memory_items) + else: + first_layer_items.append(memory_items) + + # 只在depth=2时获取第二层记忆 + if depth >= 2: + # 获取相邻节点的记忆项 + for neighbor in neighbors: + # print(f"第二层: {neighbor}") + node_data = self.get_dot(neighbor) + if node_data: + concept, data = node_data + if 'memory_items' in data: + memory_items = data['memory_items'] + if isinstance(memory_items, list): + second_layer_items.extend(memory_items) + else: + second_layer_items.append(memory_items) + + return first_layer_items, second_layer_items + + def store_memory(self): + for node in self.G.nodes(): + dot_data = { + "concept": node + } + self.db.db.store_memory_dots.insert_one(dot_data) + + @property + def dots(self): + # 返回所有节点对应的 Memory_dot 对象 + return [self.get_dot(node) for node in self.G.nodes()] + + + def get_random_chat_from_db(self, length: int, timestamp: str): + # 从数据库中根据时间戳获取离其最近的聊天记录 + chat_text = '' + closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 + print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") + + if closest_record: + closest_time = closest_record['time'] + group_id = closest_record['group_id'] # 获取groupid + # 获取该时间戳之后的length条消息,且groupid相同 + chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) + for record in chat_record: + time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) + chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 + return chat_text + + return [] # 如果没有找到记录,返回空列表 + + def save_graph_to_db(self): + # 清空现有的图数据 + self.db.db.graph_data.delete_many({}) + # 保存节点 + for node in self.G.nodes(data=True): + node_data = { + 'concept': node[0], + 'memory_items': node[1].get('memory_items', []) # 默认为空列表 + } + self.db.db.graph_data.nodes.insert_one(node_data) + # 保存边 + for edge in self.G.edges(): + edge_data = { + 'source': edge[0], + 'target': edge[1] + } + self.db.db.graph_data.edges.insert_one(edge_data) + + def load_graph_from_db(self): + # 清空当前图 + self.G.clear() + # 加载节点 + nodes = self.db.db.graph_data.nodes.find() + for node in nodes: + memory_items = node.get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + self.G.add_node(node['concept'], memory_items=memory_items) + # 加载边 + edges = self.db.db.graph_data.edges.find() + for edge in edges: + self.G.add_edge(edge['source'], edge['target']) + + +def main(): + # 初始化数据库 + Database.initialize( + "127.0.0.1", + 27017, + "MegBot" + ) + + memory_graph = Memory_graph() + # 创建LLM模型实例 + + memory_graph.load_graph_from_db() + # 展示两种不同的可视化方式 + print("\n按连接数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=False) + + print("\n按记忆数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=True) + + # memory_graph.save_graph_to_db() + + while True: + query = input("请输入新的查询概念(输入'退出'以结束):") + if query.lower() == '退出': + break + items_list = memory_graph.get_related_item(query) + if items_list: + # print(items_list) + for memory_item in items_list: + print(memory_item) + else: + print("未找到相关记忆。") + + +def segment_text(text): + seg_text = list(jieba.cut(text)) + return seg_text + +def find_topic(text, topic_num): + prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。' + return prompt + +def topic_what(text, topic): + prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好' + return prompt + +def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): + # 设置中文字体 + plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 + plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + + G = memory_graph.G + + # 保存图到本地 + nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式 + + # 根据连接条数或记忆数量设置节点颜色 + node_colors = [] + nodes = list(G.nodes()) # 获取图中实际的节点列表 + + if color_by_memory: + # 计算每个节点的记忆数量 + memory_counts = [] + for node in nodes: + memory_items = G.nodes[node].get('memory_items', []) + if isinstance(memory_items, list): + count = len(memory_items) + else: + count = 1 if memory_items else 0 + memory_counts.append(count) + max_memories = max(memory_counts) if memory_counts else 1 + + for count in memory_counts: + # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少 + if max_memories > 0: + intensity = min(1.0, count / max_memories) + color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色 + else: + color = (0, 0, 1) # 如果没有记忆,则为蓝色 + node_colors.append(color) + else: + # 使用原来的连接数量着色方案 + max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 + for node in nodes: + degree = G.degree(node) + if max_degree > 0: + red = min(1.0, degree / max_degree) + blue = 1.0 - red + color = (red, 0, blue) + else: + color = (0, 0, 1) + node_colors.append(color) + + # 绘制图形 + plt.figure(figsize=(12, 8)) + pos = nx.spring_layout(G, k=1, iterations=50) + nx.draw(G, pos, + with_labels=True, + node_color=node_colors, + node_size=2000, + font_size=10, + font_family='SimHei', + font_weight='bold') + + title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色') + plt.title(title, fontsize=16, fontfamily='SimHei') + plt.show() + +if __name__ == "__main__": + main() + + diff --git a/src/plugins/memory_system/llm_module_memory_make.py b/src/plugins/memory_system/llm_module_memory_make.py new file mode 100644 index 000000000..1abfdb2c6 --- /dev/null +++ b/src/plugins/memory_system/llm_module_memory_make.py @@ -0,0 +1,82 @@ +import os +import requests +from dotenv import load_dotenv +from typing import Tuple, Union +import time +from ..chat.config import BotConfig + +# 获取当前文件的绝对路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) +env_path = os.path.join(root_dir, 'config', '.env') + +# 加载环境变量 +print(f"尝试从 {env_path} 加载环境变量配置") +if os.path.exists(env_path): + load_dotenv(env_path) + print("成功加载环境变量配置") +else: + print(f"环境变量配置文件不存在: {env_path}") + +class LLMModel: + # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): + def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): + self.model_name = model_name + self.params = kwargs + self.api_key = os.getenv("SILICONFLOW_KEY") + self.base_url = os.getenv("SILICONFLOW_BASE_URL") + + if not self.api_key or not self.base_url: + raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") + + print(f"API URL: {self.base_url}") # 打印 base_url 用于调试 + + def generate_response(self, prompt: str) -> Tuple[str, str]: + """根据输入的提示生成模型的响应""" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # 构建请求体 + data = { + "model": self.model_name, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.5, + **self.params + } + + # 发送请求到完整的chat/completions端点 + api_url = f"{self.base_url.rstrip('/')}/chat/completions" + + max_retries = 3 + base_wait_time = 15 # 基础等待时间(秒) + + for retry in range(max_retries): + try: + response = requests.post(api_url, headers=headers, json=data) + + if response.status_code == 429: + wait_time = base_wait_time * (2 ** retry) # 指数退避 + print(f"遇到请求限制(429),等待{wait_time}秒后重试...") + time.sleep(wait_time) + continue + + response.raise_for_status() # 检查其他响应状态 + + result = response.json() + if "choices" in result and len(result["choices"]) > 0: + content = result["choices"][0]["message"]["content"] + reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") + return content, reasoning_content + return "没有返回结果", "" + + except requests.exceptions.RequestException as e: + if retry < max_retries - 1: # 如果还有重试机会 + wait_time = base_wait_time * (2 ** retry) + print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + time.sleep(wait_time) + else: + return f"请求失败: {str(e)}", "" + + return "达到最大重试次数,请求仍然失败", "" \ No newline at end of file diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 3f216997f..af6aab39a 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import sys import jieba from .llm_module import LLMModel import networkx as nx @@ -11,8 +10,8 @@ import random import time from ..chat.config import global_config import sys -sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 -from src.common.database import Database # 使用正确的导入语法 +from ...common.database import Database # 使用正确的导入语法 +from ..chat.utils import calculate_information_content, get_cloest_chat_from_db class Memory_graph: def __init__(self): @@ -85,54 +84,66 @@ class Memory_graph: return first_layer_items, second_layer_items - def store_memory(self): - for node in self.G.nodes(): - dot_data = { - "concept": node - } - self.db.db.store_memory_dots.insert_one(dot_data) - @property def dots(self): # 返回所有节点对应的 Memory_dot 对象 return [self.get_dot(node) for node in self.G.nodes()] - - - def get_random_chat_from_db(self, length: int, timestamp: str): - # 从数据库中根据时间戳获取离其最近的聊天记录 - chat_text = '' - closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 - print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") - - if closest_record: - closest_time = closest_record['time'] - group_id = closest_record['group_id'] # 获取groupid - # 获取该时间戳之后的length条消息,且groupid相同 - chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) - for record in chat_record: - time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) - chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 - return chat_text - - return [] # 如果没有找到记录,返回空列表 def save_graph_to_db(self): - # 清空现有的图数据 - self.db.db.graph_data.delete_many({}) # 保存节点 for node in self.G.nodes(data=True): - node_data = { - 'concept': node[0], - 'memory_items': node[1].get('memory_items', []) # 默认为空列表 - } - self.db.db.graph_data.nodes.insert_one(node_data) + concept = node[0] + memory_items = node[1].get('memory_items', []) + + # 查找是否存在同名节点 + existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept}) + if existing_node: + # 如果存在,合并memory_items并去重 + existing_items = existing_node.get('memory_items', []) + if not isinstance(existing_items, list): + existing_items = [existing_items] if existing_items else [] + + # 合并并去重 + all_items = list(set(existing_items + memory_items)) + + # 更新节点 + self.db.db.graph_data.nodes.update_one( + {'concept': concept}, + {'$set': {'memory_items': all_items}} + ) + else: + # 如果不存在,创建新节点 + node_data = { + 'concept': concept, + 'memory_items': memory_items + } + self.db.db.graph_data.nodes.insert_one(node_data) + # 保存边 for edge in self.G.edges(): - edge_data = { - 'source': edge[0], - 'target': edge[1] - } - self.db.db.graph_data.edges.insert_one(edge_data) + source, target = edge + + # 查找是否存在同样的边 + existing_edge = self.db.db.graph_data.edges.find_one({ + 'source': source, + 'target': target + }) + + if existing_edge: + # 如果存在,增加num属性 + num = existing_edge.get('num', 1) + 1 + self.db.db.graph_data.edges.update_one( + {'source': source, 'target': target}, + {'$set': {'num': num}} + ) + else: + # 如果不存在,创建新边 + edge_data = { + 'source': source, + 'target': target, + 'num': 1 + } + self.db.db.graph_data.edges.insert_one(edge_data) def load_graph_from_db(self): # 清空当前图 @@ -147,150 +158,92 @@ class Memory_graph: # 加载边 edges = self.db.db.graph_data.edges.find() for edge in edges: - self.G.add_edge(edge['source'], edge['target']) - -def calculate_information_content(text): - - """计算文本的信息量(熵)""" - # 统计字符频率 - char_count = Counter(text) - total_chars = len(text) - - # 计算熵 - entropy = 0 - for count in char_count.values(): - probability = count / total_chars - entropy -= probability * math.log2(probability) - - return entropy - - -start_time = time.time() - -Database.initialize( - global_config.MONGODB_HOST, - global_config.MONGODB_PORT, - global_config.DATABASE_NAME -) -memory_graph = Memory_graph() - -llm_model = LLMModel() -llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - -memory_graph.load_graph_from_db() - -end_time = time.time() -print(f"加载海马体耗时: {end_time - start_time:.2f} 秒") + self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1)) -def main(): - # 初始化数据库 - Database.initialize( - "127.0.0.1", - 27017, - "MegBot" - ) - - memory_graph = Memory_graph() - # 创建LLM模型实例 - llm_model = LLMModel() - llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - - # 使用当前时间戳进行测试 - current_timestamp = datetime.datetime.now().timestamp() - chat_text = [] - - chat_size =40 - - for _ in range(100): # 循环10次 - random_time = current_timestamp - random.randint(1, 3600*39) # 随机时间 - print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") - chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time) - chat_text.append(chat_) # 拼接所有text - time.sleep(5) - - for input_text in chat_text: - print(input_text) - first_memory = set() - first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5) +# 海马体 +class Hippocampus: + def __init__(self,memory_graph:Memory_graph): + self.memory_graph = memory_graph + self.llm_model = LLMModel() + self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - #将记忆加入到图谱中 - for topic, memory in first_memory: - topics = segment_text(topic) - print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") - for split_topic in topics: - memory_graph.add_dot(split_topic,memory) - for split_topic in topics: - for other_split_topic in topics: - if split_topic != other_split_topic: - memory_graph.connect_dot(split_topic, other_split_topic) + def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}): + current_timestamp = datetime.datetime.now().timestamp() + chat_text = [] + #短期:1h 中期:4h 长期:24h + for _ in range(time_frequency.get('near')): # 循环10次 + random_time = current_timestamp - random.randint(1, 3600) # 随机时间 + # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + for _ in range(time_frequency.get('mid')): # 循环10次 + random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间 + # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + for _ in range(time_frequency.get('far')): # 循环10次 + random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间 + # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + return chat_text - # memory_graph.store_memory() - - # 展示两种不同的可视化方式 - print("\n按连接数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=False) - - print("\n按记忆数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=True) - - memory_graph.save_graph_to_db() - # memory_graph.load_graph_from_db() - - while True: - query = input("请输入新的查询概念(输入'退出'以结束):") - if query.lower() == '退出': - break - items_list = memory_graph.get_related_item(query) - if items_list: - # print(items_list) - for memory_item in items_list: - print(memory_item) - else: - print("未找到相关记忆。") + def build_memory(self,chat_size=12): + #最近消息获取频率 + time_frequency = {'near':1,'mid':2,'far':2} + memory_sample = self.get_memory_sample(chat_size,time_frequency) + # print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}") + + + for i, input_text in enumerate(memory_sample, 1): + #加载进度可视化 + progress = (i / len(memory_sample)) * 100 + bar_length = 30 + filled_length = int(bar_length * i // len(memory_sample)) + bar = '█' * filled_length + '-' * (bar_length - filled_length) + print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") - while True: - query = input("请输入问题:") - - if query.lower() == '退出': - break - - topic_prompt = find_topic(query, 3) - topic_response = llm_model.generate_response(topic_prompt) + # 生成压缩后记忆 + first_memory = set() + first_memory = self.memory_compress(input_text, 2.5) + # 延时防止访问超频 + # time.sleep(5) + #将记忆加入到图谱中 + for topic, memory in first_memory: + topics = segment_text(topic) + print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") + for split_topic in topics: + self.memory_graph.add_dot(split_topic,memory) + for split_topic in topics: + for other_split_topic in topics: + if split_topic != other_split_topic: + self.memory_graph.connect_dot(split_topic, other_split_topic) + + self.memory_graph.save_graph_to_db() + + def memory_compress(self, input_text, rate=1): + information_content = calculate_information_content(input_text) + print(f"文本的信息量(熵): {information_content:.4f} bits") + topic_num = max(1, min(5, int(information_content * rate / 4))) + # print(topic_num) + topic_prompt = find_topic(input_text, topic_num) + topic_response = self.llm_model.generate_response(topic_prompt) # 检查 topic_response 是否为元组 if isinstance(topic_response, tuple): topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 else: topics = topic_response.split(",") - print(topics) - - for keyword in topics: - items_list = memory_graph.get_related_item(keyword) - if items_list: - print(items_list) - -def memory_compress(input_text, llm_model, llm_model_small, rate=1): - information_content = calculate_information_content(input_text) - print(f"文本的信息量(熵): {information_content:.4f} bits") - topic_num = max(1, min(5, int(information_content * rate / 4))) - print(topic_num) - topic_prompt = find_topic(input_text, topic_num) - topic_response = llm_model.generate_response(topic_prompt) - # 检查 topic_response 是否为元组 - if isinstance(topic_response, tuple): - topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 - else: - topics = topic_response.split(",") - print(topics) - compressed_memory = set() - for topic in topics: - topic_what_prompt = topic_what(input_text,topic) - topic_what_response = llm_model_small.generate_response(topic_what_prompt) - compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 - return compressed_memory + # print(topics) + compressed_memory = set() + for topic in topics: + topic_what_prompt = topic_what(input_text,topic) + topic_what_response = self.llm_model_small.generate_response(topic_what_prompt) + compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 + return compressed_memory def segment_text(text): @@ -305,69 +258,21 @@ def topic_what(text, topic): prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好' return prompt -def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): - # 设置中文字体 - plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 - plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 - - G = memory_graph.G - - # 保存图到本地 - nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式 - - # 根据连接条数或记忆数量设置节点颜色 - node_colors = [] - nodes = list(G.nodes()) # 获取图中实际的节点列表 - - if color_by_memory: - # 计算每个节点的记忆数量 - memory_counts = [] - for node in nodes: - memory_items = G.nodes[node].get('memory_items', []) - if isinstance(memory_items, list): - count = len(memory_items) - else: - count = 1 if memory_items else 0 - memory_counts.append(count) - max_memories = max(memory_counts) if memory_counts else 1 - - for count in memory_counts: - # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少 - if max_memories > 0: - intensity = min(1.0, count / max_memories) - color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色 - else: - color = (0, 0, 1) # 如果没有记忆,则为蓝色 - node_colors.append(color) - else: - # 使用原来的连接数量着色方案 - max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 - for node in nodes: - degree = G.degree(node) - if max_degree > 0: - red = min(1.0, degree / max_degree) - blue = 1.0 - red - color = (red, 0, blue) - else: - color = (0, 0, 1) - node_colors.append(color) - - # 绘制图形 - plt.figure(figsize=(12, 8)) - pos = nx.spring_layout(G, k=1, iterations=50) - nx.draw(G, pos, - with_labels=True, - node_color=node_colors, - node_size=2000, - font_size=10, - font_family='SimHei', - font_weight='bold') - - title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色') - plt.title(title, fontsize=16, fontfamily='SimHei') - plt.show() - -if __name__ == "__main__": - main() +start_time = time.time() + +Database.initialize( + global_config.MONGODB_HOST, + global_config.MONGODB_PORT, + global_config.DATABASE_NAME +) +#创建记忆图 +memory_graph = Memory_graph() +#加载数据库中存储的记忆图 +memory_graph.load_graph_from_db() +#创建海马体 +hippocampus = Hippocampus(memory_graph) + +end_time = time.time() +print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m") \ No newline at end of file diff --git a/src/plugins/memory_system/memory copy.py b/src/plugins/memory_system/memory_make.py similarity index 82% rename from src/plugins/memory_system/memory copy.py rename to src/plugins/memory_system/memory_make.py index 07dea2a8b..244838e21 100644 --- a/src/plugins/memory_system/memory copy.py +++ b/src/plugins/memory_system/memory_make.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import sys import jieba -from llm_module import LLMModel import networkx as nx import matplotlib.pyplot as plt import math @@ -9,10 +8,12 @@ from collections import Counter import datetime import random import time +import os +from dotenv import load_dotenv # from chat.config import global_config -import sys sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 from src.common.database import Database # 使用正确的导入语法 +from src.plugins.memory_system.llm_module import LLMModel class Memory_graph: def __init__(self): @@ -117,22 +118,60 @@ class Memory_graph: return [] # 如果没有找到记录,返回空列表 def save_graph_to_db(self): - # 清空现有的图数据 - self.db.db.graph_data.delete_many({}) # 保存节点 for node in self.G.nodes(data=True): - node_data = { - 'concept': node[0], - 'memory_items': node[1].get('memory_items', []) # 默认为空列表 - } - self.db.db.graph_data.nodes.insert_one(node_data) + concept = node[0] + memory_items = node[1].get('memory_items', []) + + # 查找是否存在同名节点 + existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept}) + if existing_node: + # 如果存在,合并memory_items并去重 + existing_items = existing_node.get('memory_items', []) + if not isinstance(existing_items, list): + existing_items = [existing_items] if existing_items else [] + + # 合并并去重 + all_items = list(set(existing_items + memory_items)) + + # 更新节点 + self.db.db.graph_data.nodes.update_one( + {'concept': concept}, + {'$set': {'memory_items': all_items}} + ) + else: + # 如果不存在,创建新节点 + node_data = { + 'concept': concept, + 'memory_items': memory_items + } + self.db.db.graph_data.nodes.insert_one(node_data) + # 保存边 for edge in self.G.edges(): - edge_data = { - 'source': edge[0], - 'target': edge[1] - } - self.db.db.graph_data.edges.insert_one(edge_data) + source, target = edge + + # 查找是否存在同样的边 + existing_edge = self.db.db.graph_data.edges.find_one({ + 'source': source, + 'target': target + }) + + if existing_edge: + # 如果存在,增加num属性 + num = existing_edge.get('num', 1) + 1 + self.db.db.graph_data.edges.update_one( + {'source': source, 'target': target}, + {'$set': {'num': num}} + ) + else: + # 如果不存在,创建新边 + edge_data = { + 'source': source, + 'target': target, + 'num': 1 + } + self.db.db.graph_data.edges.insert_one(edge_data) def load_graph_from_db(self): # 清空当前图 @@ -147,7 +186,7 @@ class Memory_graph: # 加载边 edges = self.db.db.graph_data.edges.find() for edge in edges: - self.G.add_edge(edge['source'], edge['target']) + self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1)) def calculate_information_content(text): @@ -180,6 +219,19 @@ def calculate_information_content(text): def main(): + # 获取当前文件的绝对路径 + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) + env_path = os.path.join(root_dir, 'config', '.env') + + # 加载环境变量 + print(f"尝试从 {env_path} 加载环境变量配置") + if os.path.exists(env_path): + load_dotenv(env_path) + print("成功加载环境变量配置") + else: + print(f"环境变量配置文件不存在: {env_path}") + # 初始化数据库 Database.initialize( "127.0.0.1", @@ -196,10 +248,10 @@ def main(): current_timestamp = datetime.datetime.now().timestamp() chat_text = [] - chat_size =20 + chat_size =25 - for _ in range(10): # 循环10次 - random_time = current_timestamp - random.randint(1, 3600*3) # 随机时间 + for _ in range(30): # 循环10次 + random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间 print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time) chat_text.append(chat_) # 拼接所有text @@ -218,7 +270,7 @@ def main(): # print(input_text) first_memory = set() first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5) - time.sleep(5) + # time.sleep(5) #将记忆加入到图谱中 for topic, memory in first_memory: