diff --git a/config/env.example b/.env.prod similarity index 73% rename from config/env.example rename to .env.prod index 9988d58f3..3d795978c 100644 --- a/config/env.example +++ b/.env.prod @@ -1,4 +1,3 @@ -ENVIRONMENT=dev HOST=127.0.0.1 PORT=8080 @@ -11,15 +10,15 @@ PLUGINS=["src2.plugins.chat"] MONGODB_HOST=127.0.0.1 MONGODB_PORT=27017 DATABASE_NAME=MegBot + MONGODB_USERNAME = "" # 默认空值 MONGODB_PASSWORD = "" # 默认空值 MONGODB_AUTH_SOURCE = "" # 默认空值 -#api配置项 +#key and url +CHAT_ANY_WHERE_KEY= SILICONFLOW_KEY= +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ DEEP_SEEK_KEY= -DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 - - - +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 265108181..c19b9ce33 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ reasoning_content.bat reasoning_window.bat queue_update.txt memory_graph.gml +.env.dev # Byte-compiled / optimized / DLL files diff --git a/README.md b/README.md index a85fcc4e8..1310d4879 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ > ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次 > ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语( -关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言! +关于麦麦的开发和建议相关的讨论群:766798517(不建议发布无关消息)这里不会有麦麦发言! ## 开发计划TODO:LIST @@ -41,16 +41,13 @@ - config自动生成和检测 - log别用print - 给发送消息写专门的类 +- 改进表情包发送逻辑l -
- -
- ## 📚 详细文档 - [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成) -### 安装方法(还没测试好,现在部署可能遇到未知问题!!!!) +### 安装方法(还没测试好,随时outdated ,现在部署可能遇到未知问题!!!!) #### Linux 使用 Docker Compose 部署 获取项目根目录中的```docker-compose.yml```文件,运行以下命令 diff --git a/bot.py b/bot.py index 8abdbbbe9..8741eca7f 100644 --- a/bot.py +++ b/bot.py @@ -4,27 +4,59 @@ from nonebot.adapters.onebot.v11 import Adapter from dotenv import load_dotenv from loguru import logger - # 加载全局环境变量 -root_dir = os.path.dirname(os.path.abspath(__file__)) -env_path=os.path.join(root_dir, "config",'.env') +'''彩蛋''' +from colorama import init, Fore +init() +text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午" +rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA] +rainbow_text = "" +for i, char in enumerate(text): + rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char +print(rainbow_text) +'''彩蛋''' -logger.info(f"尝试从 {env_path} 加载环境变量配置") -if os.path.exists(env_path): - load_dotenv(env_path) - logger.success("成功加载环境变量配置") +# 首先加载基础环境变量 +if os.path.exists(".env"): + load_dotenv(".env") + logger.success("成功加载基础环境变量配置") else: - logger.error(f"环境变量配置文件不存在: {env_path}") + logger.error("基础环境变量配置文件 .env 不存在") + exit(1) +# 根据 ENVIRONMENT 加载对应的环境配置 +env = os.getenv("ENVIRONMENT") +env_file = f".env.{env}" + +if env_file == ".env.dev" and os.path.exists(env_file): + logger.success("加载开发环境变量配置") + load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量 +elif env_file == ".env.prod" and os.path.exists(env_file): + logger.success("加载环境变量配置") + load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量 +else: + logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.") + exit(1) -# 初始化 NoneBot nonebot.init( - # napcat 默认使用 8080 端口 - websocket_port=8080, - # 设置日志级别 + # 从环境变量中读取配置 + websocket_port=os.getenv("PORT", 8080), + host=os.getenv("HOST", "127.0.0.1"), log_level="INFO", - # 设置超级用户 - superusers={"你的QQ号"}, - # TODO: 这样写会忽略环境变量需要优化 https://nonebot.dev/docs/appendices/config - _env_file=env_path + # 添加自定义配置 + mongodb_host=os.getenv("MONGODB_HOST", "127.0.0.1"), + mongodb_port=os.getenv("MONGODB_PORT", 27017), + database_name=os.getenv("DATABASE_NAME", "MegBot"), + mongodb_username=os.getenv("MONGODB_USERNAME", ""), + mongodb_password=os.getenv("MONGODB_PASSWORD", ""), + mongodb_auth_source=os.getenv("MONGODB_AUTH_SOURCE", ""), + # API相关配置 + chat_any_where_key=os.getenv("CHAT_ANY_WHERE_KEY", ""), + siliconflow_key=os.getenv("SILICONFLOW_KEY", ""), + chat_any_where_base_url=os.getenv("CHAT_ANY_WHERE_BASE_URL", "https://api.chatanywhere.tech/v1"), + siliconflow_base_url=os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1/"), + deep_seek_key=os.getenv("DEEP_SEEK_KEY", ""), + deep_seek_base_url=os.getenv("DEEP_SEEK_BASE_URL", "https://api.deepseek.com/v1"), + # 插件配置 + plugins=os.getenv("PLUGINS", ["src2.plugins.chat"]) ) # 注册适配器 diff --git a/run_maimai.bat b/run_maimai.bat index 0e1bd7eb6..702d39edc 100644 --- a/run_maimai.bat +++ b/run_maimai.bat @@ -2,4 +2,5 @@ call conda activate niuniu cd . REM 执行nb run命令 -nb run \ No newline at end of file +nb run +pause \ No newline at end of file diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index 6fc896ac8..0f81d30d6 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -11,16 +11,18 @@ from .relationship_manager import relationship_manager from ..schedule.schedule_generator import bot_schedule from .willing_manager import willing_manager + # 获取驱动器 driver = get_driver() +config = driver.config Database.initialize( - host= os.getenv("MONGODB_HOST"), - port= int(os.getenv("MONGODB_PORT")), - db_name= os.getenv("DATABASE_NAME"), - username= os.getenv("MONGODB_USERNAME"), - password= os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source= config.mongodb_auth_source ) print("\033[1;32m[初始化数据库完成]\033[0m") @@ -37,7 +39,7 @@ emoji_manager.initialize() print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m") # 创建机器人实例 -chat_bot = ChatBot(global_config) +chat_bot = ChatBot() # 注册消息处理器 group_msg = on_message() # 创建定时任务 diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 1b5201645..e68ae93f3 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time from ..memory_system.memory import memory_graph class ChatBot: - def __init__(self, config: BotConfig): - self.config = config + def __init__(self): self.storage = MessageStorage() - self.gpt = LLMResponseGenerator(config) + self.gpt = LLMResponseGenerator() self.bot = None # bot 实例引用 self._started = False @@ -39,11 +38,11 @@ class ChatBot: async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None: """处理收到的群消息""" - if event.group_id not in self.config.talk_allowed_groups: + if event.group_id not in global_config.talk_allowed_groups: return self.bot = bot # 更新 bot 实例 - if event.user_id in self.config.ban_user_id: + if event.user_id in global_config.ban_user_id: return # 打印原始消息内容 @@ -120,7 +119,7 @@ class ChatBot: event.group_id, topic[0] if topic else None, is_mentioned, - self.config, + global_config, event.user_id, message.is_emoji, interested_rate @@ -143,10 +142,14 @@ class ChatBot: response, emotion = await self.gpt.generate_response(message) # 如果生成了回复,发送并记录 - + + ''' + 生成回复后的内容 + + ''' if response: - message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id) + message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) accu_typing_time = 0 for msg in response: print(f"当前消息: {msg}") @@ -157,7 +160,7 @@ class ChatBot: bot_message = Message( group_id=event.group_id, - user_id=self.config.BOT_QQ, + user_id=global_config.BOT_QQ, message_id=think_id, message_based_id=event.message_id, raw_message=msg, @@ -174,7 +177,7 @@ class ChatBot: bot_response_time = tinking_time_point - if random() < self.config.emoji_chance: + if random() < global_config.emoji_chance: emoji_path = await emoji_manager.get_emoji_for_emotion(emotion) if emoji_path: emoji_cq = CQCode.create_emoji_cq(emoji_path) @@ -186,7 +189,7 @@ class ChatBot: bot_message = Message( group_id=event.group_id, - user_id=self.config.BOT_QQ, + user_id=global_config.BOT_QQ, message_id=0, raw_message=emoji_cq, plain_text=emoji_cq, diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index 8c04e1126..0232219f8 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -7,6 +7,7 @@ import configparser import tomli import sys from loguru import logger +from nonebot import get_driver @@ -111,7 +112,6 @@ class BotConfig: # 获取配置文件路径 bot_config_path = BotConfig.get_default_config_path() config_dir = os.path.dirname(bot_config_path) -env_path = os.path.join(config_dir, '.env') logger.info(f"尝试从 {bot_config_path} 加载机器人配置") global_config = BotConfig.load_config(config_path=bot_config_path) @@ -126,10 +126,11 @@ class LLMConfig: DEEP_SEEK_BASE_URL: str = None llm_config = LLMConfig() -llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY') -llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL') -llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY') -llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL') +config = get_driver().config +llm_config.SILICONFLOW_API_KEY = config.siliconflow_key +llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url +llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key +llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url if not global_config.enable_advance_output: diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 92ca20bd8..ae5d8a257 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -7,7 +7,7 @@ from PIL import Image import os from random import random from nonebot.adapters.onebot.v11 import Bot -from .config import global_config, llm_config +from .config import global_config import time import asyncio from .utils_image import storage_image,storage_emoji @@ -16,6 +16,10 @@ from .utils_user import get_user_nickname #包含CQ码类 import urllib3 from urllib3.util import create_urllib3_context +from nonebot import get_driver + +driver = get_driver() +config = driver.config # TLS1.3特殊处理 https://github.com/psf/requests/issues/6616 ctx = create_urllib3_context() @@ -179,7 +183,7 @@ class CQCode: """调用AI接口获取表情包描述""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + "Authorization": f"Bearer {config.siliconflow_key}" } payload = { @@ -206,7 +210,7 @@ class CQCode: } response = requests.post( - f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + f"{config.siliconflow_base_url}chat/completions", headers=headers, json=payload, timeout=30 @@ -224,7 +228,7 @@ class CQCode: """调用AI接口获取普通图片描述""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + "Authorization": f"Bearer {config.siliconflow_key}" } payload = { @@ -251,7 +255,7 @@ class CQCode: } response = requests.post( - f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + f"{config.siliconflow_base_url}chat/completions", headers=headers, json=payload, timeout=30 diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index a4352758d..c8c9dc814 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -10,10 +10,14 @@ import hashlib from datetime import datetime import base64 import shutil -from .config import global_config, llm_config import asyncio import time +from nonebot import get_driver + +driver = get_driver() +config = driver.config + class EmojiManager: _instance = None @@ -93,7 +97,7 @@ class EmojiManager: # 准备请求数据 headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + "Authorization": f"Bearer {config.siliconflow_key}" } payload = { @@ -115,7 +119,7 @@ class EmojiManager: async with aiohttp.ClientSession() as session: async with session.post( - f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + f"{config.siliconflow_base_url}chat/completions", headers=headers, json=payload ) as response: @@ -249,7 +253,7 @@ class EmojiManager: async with aiohttp.ClientSession() as session: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + "Authorization": f"Bearer {config.siliconflow_key}" } payload = { @@ -276,7 +280,7 @@ class EmojiManager: } async with session.post( - f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + f"{config.siliconflow_base_url}chat/completions", headers=headers, json=payload ) as response: diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 5ed12c9b6..bfff1d474 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -1,34 +1,34 @@ from typing import Dict, Any, List, Optional, Union, Tuple from openai import OpenAI import asyncio -import requests from functools import partial from .message import Message -from .config import BotConfig, global_config +from .config import global_config from ...common.database import Database import random import time -import os import numpy as np from .relationship_manager import relationship_manager -from ..schedule.schedule_generator import bot_schedule from .prompt_builder import prompt_builder -from .config import llm_config, global_config +from .config import global_config from .utils import process_llm_response +from nonebot import get_driver + +driver = get_driver() +config = driver.config class LLMResponseGenerator: - def __init__(self, config: BotConfig): - self.config = config - if self.config.API_USING == "siliconflow": + def __init__(self): + if global_config.API_USING == "siliconflow": self.client = OpenAI( - api_key=llm_config.SILICONFLOW_API_KEY, - base_url=llm_config.SILICONFLOW_BASE_URL + api_key=config.siliconflow_key, + base_url=config.siliconflow_base_url ) - elif self.config.API_USING == "deepseek": + elif global_config.API_USING == "deepseek": self.client = OpenAI( - api_key=llm_config.DEEP_SEEK_API_KEY, - base_url=llm_config.DEEP_SEEK_BASE_URL + api_key=config.deep_seek_key, + base_url=config.deep_seek_base_url ) self.db = Database.get_instance() @@ -52,6 +52,7 @@ class LLMResponseGenerator: else: self.current_model_type = 'r1_distill' # 默认使用 R1-Distill + print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++") if self.current_model_type == 'r1': model_response = await self._generate_r1_response(message) @@ -83,8 +84,9 @@ class LLMResponseGenerator: print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") else: relationship_value = 0.0 + - # 构建prompt + ''' 构建prompt ''' prompt,prompt_check = prompt_builder._build_prompt( message_txt=message.processed_plain_text, sender_name=sender_name, @@ -92,6 +94,7 @@ class LLMResponseGenerator: group_id=message.group_id ) + # 设置默认参数 default_params = { "model": model_name, @@ -113,6 +116,7 @@ class LLMResponseGenerator: if model_params: default_params.update(model_params) + def create_completion(): return self.client.chat.completions.create(**default_params) @@ -122,6 +126,7 @@ class LLMResponseGenerator: loop = asyncio.get_event_loop() # 读空气模块 + air = 0 reasoning_content_check='' content_check='' if global_config.enable_kuuki_read: @@ -135,21 +140,26 @@ class LLMResponseGenerator: content_check = response_check.choices[0].message.content print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}") if 'yes' not in content_check.lower(): - self.db.db.reasoning_logs.insert_one({ - 'time': time.time(), - 'group_id': message.group_id, - 'user': sender_name, - 'message': message.processed_plain_text, - 'model': model_name, - 'reasoning_check': reasoning_content_check, - 'response_check': content_check, - 'reasoning': "", - 'response': "", - 'prompt': prompt, - 'prompt_check': prompt_check, - 'model_params': default_params - }) - return None + air = 1 + #稀释读空气的判定 + if air == 1 and random.random() < 0.3: + self.db.db.reasoning_logs.insert_one({ + 'time': time.time(), + 'group_id': message.group_id, + 'user': sender_name, + 'message': message.processed_plain_text, + 'model': model_name, + 'reasoning_check': reasoning_content_check, + 'response_check': content_check, + 'reasoning': "", + 'response': "", + 'prompt': prompt, + 'prompt_check': prompt_check, + 'model_params': default_params + }) + return None + + @@ -193,7 +203,7 @@ class LLMResponseGenerator: async def _generate_r1_response(self, message: Message) -> Optional[str]: """使用 DeepSeek-R1 模型生成回复""" - if self.config.API_USING == "deepseek": + if global_config.API_USING == "deepseek": return await self._generate_base_response( message, "deepseek-reasoner", @@ -208,7 +218,7 @@ class LLMResponseGenerator: async def _generate_v3_response(self, message: Message) -> Optional[str]: """使用 DeepSeek-V3 模型生成回复""" - if self.config.API_USING == "deepseek": + if global_config.API_USING == "deepseek": return await self._generate_base_response( message, "deepseek-chat", @@ -259,7 +269,7 @@ class LLMResponseGenerator: messages = [{"role": "user", "content": prompt}] loop = asyncio.get_event_loop() - if self.config.API_USING == "deepseek": + if global_config.API_USING == "deepseek": model = "deepseek-chat" else: model = "Pro/deepseek-ai/DeepSeek-V3" @@ -296,4 +306,4 @@ class LLMResponseGenerator: return processed_response, emotion_tags # 创建全局实例 -llm_response = LLMResponseGenerator(global_config) \ No newline at end of file +llm_response = LLMResponseGenerator() \ No newline at end of file diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index 49402dbfc..ba0e9b4cc 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -66,12 +66,15 @@ class PromptBuilder: overlapping_second_layer.update(overlap) # 合并所有需要的记忆 - if all_first_layer_items: - print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") - if overlapping_second_layer: - print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") + # if all_first_layer_items: + # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") + # if overlapping_second_layer: + # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") - all_memories = all_first_layer_items + list(overlapping_second_layer) + # 使用集合去重 + all_memories = list(set(all_first_layer_items) | set(overlapping_second_layer)) + if all_memories: + print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}") if all_memories: # 只在列表非空时选择随机项 random_item = choice(all_memories) @@ -179,7 +182,11 @@ class PromptBuilder: # prompt += f"{activate_prompt}\n" prompt += f"{prompt_personality}\n" prompt += f"{prompt_ger}\n" - prompt += f"{extra_info}\n" + prompt += f"{extra_info}\n" + + + + '''读空气prompt处理''' activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" prompt_personality_check = '' diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py index 75593667b..81956ddc1 100644 --- a/src/plugins/chat/topic_identifier.py +++ b/src/plugins/chat/topic_identifier.py @@ -1,14 +1,17 @@ from typing import Optional, Dict, List from openai import OpenAI from .message import Message -from .config import global_config, llm_config import jieba +from nonebot import get_driver + +driver = get_driver() +config = driver.config class TopicIdentifier: def __init__(self): self.client = OpenAI( - api_key=llm_config.SILICONFLOW_API_KEY, - base_url=llm_config.SILICONFLOW_BASE_URL + api_key=config.siliconflow_key, + base_url=config.siliconflow_base_url ) def identify_topic_llm(self, text: str) -> Optional[str]: diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 4e2235805..91018ee08 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -4,11 +4,15 @@ from typing import List from .message import Message import requests import numpy as np -from .config import llm_config, global_config +from .config import global_config import re from typing import Dict from collections import Counter import math +from nonebot import get_driver + +driver = get_driver() +config = driver.config def combine_messages(messages: List[Message]) -> str: @@ -64,7 +68,7 @@ def get_embedding(text): "encoding_format": "float" } headers = { - "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}", + "Authorization": f"Bearer {config.siliconflow_key}", "Content-Type": "application/json" } diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 097294010..68b2fa7f0 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -7,6 +7,10 @@ from ...common.database import Database import zlib # 用于 CRC32 import base64 from .config import global_config +from nonebot import get_driver + +driver = get_driver() +config = driver.config def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes: @@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes: # 连接数据库 db = Database( - host= os.getenv("MONGODB_HOST"), - port= int(os.getenv("MONGODB_PORT")), - db_name= os.getenv("DATABASE_NAME"), - username= os.getenv("MONGODB_USERNAME"), - password= os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source=config.mongodb_auth_source ) # 检查是否已存在相同哈希值的图片 diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py index 8f3734a4f..e35743577 100644 --- a/src/plugins/chat/willing_manager.py +++ b/src/plugins/chat/willing_manager.py @@ -58,8 +58,8 @@ class WillingManager: if group_id in config.talk_frequency_down_groups: reply_probability = reply_probability / 3.5 - if is_mentioned_bot and user_id == int(964959351): - reply_probability = 1 + # if is_mentioned_bot and user_id == int(1026294844): + # reply_probability = 1 return reply_probability diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py index cdb591dee..d8c2e1482 100644 --- a/src/plugins/knowledege/knowledge_library.py +++ b/src/plugins/knowledege/knowledge_library.py @@ -3,6 +3,10 @@ import sys import numpy as np import requests import time +from nonebot import get_driver + +driver = get_driver() +config = driver.config # 添加项目根目录到 Python 路径 root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) @@ -13,12 +17,12 @@ from src.plugins.chat.config import llm_config # 直接配置数据库连接信息 Database.initialize( - host= os.getenv("MONGODB_HOST"), - port= int(os.getenv("MONGODB_PORT")), - db_name= os.getenv("DATABASE_NAME"), - username= os.getenv("MONGODB_USERNAME"), - password= os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source=config.mongodb_auth_source ) class KnowledgeLibrary: diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py index a24f95a76..6b5dcd716 100644 --- a/src/plugins/memory_system/draw_memory.py +++ b/src/plugins/memory_system/draw_memory.py @@ -168,10 +168,12 @@ def main(): memory_graph.load_graph_from_db() # 展示两种不同的可视化方式 print("\n按连接数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=False) + # visualize_graph(memory_graph, color_by_memory=False) + visualize_graph_lite(memory_graph, color_by_memory=False) print("\n按记忆数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=True) + # visualize_graph(memory_graph, color_by_memory=True) + visualize_graph_lite(memory_graph, color_by_memory=True) # memory_graph.save_graph_to_db() @@ -262,7 +264,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): plt.title(title, fontsize=16, fontfamily='SimHei') plt.show() -if __name__ == "__main__": - main() +def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False): + # 设置中文字体 + plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 + plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + G = memory_graph.G + + # 创建一个新图用于可视化 + H = G.copy() + + # 移除只有一条记忆的节点和连接数少于3的节点 + nodes_to_remove = [] + for node in H.nodes(): + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + degree = H.degree(node) + if memory_count <= 2 or degree <= 2: + nodes_to_remove.append(node) + + H.remove_nodes_from(nodes_to_remove) + + # 如果过滤后没有节点,则返回 + if len(H.nodes()) == 0: + print("过滤后没有符合条件的节点可显示") + return + + # 保存图到本地 + nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式 + + # 根据连接条数或记忆数量设置节点颜色 + node_colors = [] + nodes = list(H.nodes()) # 获取图中实际的节点列表 + + if color_by_memory: + # 计算每个节点的记忆数量 + memory_counts = [] + for node in nodes: + memory_items = H.nodes[node].get('memory_items', []) + if isinstance(memory_items, list): + count = len(memory_items) + else: + count = 1 if memory_items else 0 + memory_counts.append(count) + max_memories = max(memory_counts) if memory_counts else 1 + + for count in memory_counts: + # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少 + if max_memories > 0: + intensity = min(1.0, count / max_memories) + color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色 + else: + color = (0, 0, 1) # 如果没有记忆,则为蓝色 + node_colors.append(color) + else: + # 使用原来的连接数量着色方案 + max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1 + for node in nodes: + degree = H.degree(node) + if max_degree > 0: + red = min(1.0, degree / max_degree) + blue = 1.0 - red + color = (red, 0, blue) + else: + color = (0, 0, 1) + node_colors.append(color) + + # 绘制图形 + plt.figure(figsize=(12, 8)) + pos = nx.spring_layout(H, k=1, iterations=50) + nx.draw(H, pos, + with_labels=True, + node_color=node_colors, + node_size=2000, + font_size=10, + font_family='SimHei', + font_weight='bold') + + title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色') + plt.title(title, fontsize=16, fontfamily='SimHei') + plt.show() + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/plugins/memory_system/llm_module.py b/src/plugins/memory_system/llm_module.py index 4c9174dac..bd7f60dc3 100644 --- a/src/plugins/memory_system/llm_module.py +++ b/src/plugins/memory_system/llm_module.py @@ -2,14 +2,18 @@ import os import requests from typing import Tuple, Union import time +from nonebot import get_driver + +driver = get_driver() +config = driver.config class LLMModel: # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): self.model_name = model_name self.params = kwargs - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") + self.api_key = config.siliconflow_key + self.base_url = config.siliconflow_base_url def generate_response(self, prompt: str) -> Tuple[str, str]: """根据输入的提示生成模型的响应""" diff --git a/src/plugins/memory_system/llm_module_memory_make.py b/src/plugins/memory_system/llm_module_memory_make.py index 07cecae9d..04ab6dbc6 100644 --- a/src/plugins/memory_system/llm_module_memory_make.py +++ b/src/plugins/memory_system/llm_module_memory_make.py @@ -3,14 +3,18 @@ import requests from typing import Tuple, Union import time from ..chat.config import BotConfig +from nonebot import get_driver + +driver = get_driver() +config = driver.config class LLMModel: # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): self.model_name = model_name self.params = kwargs - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") + self.api_key = config.siliconflow_key + self.base_url = config.siliconflow_base_url if not self.api_key or not self.base_url: raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 4deb28d63..f2b162afb 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -198,8 +198,6 @@ class Hippocampus: time_frequency = {'near':1,'mid':2,'far':2} memory_sample = self.get_memory_sample(chat_size,time_frequency) # print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}") - - for i, input_text in enumerate(memory_sample, 1): #加载进度可视化 progress = (i / len(memory_sample)) * 100 @@ -207,24 +205,25 @@ class Hippocampus: filled_length = int(bar_length * i // len(memory_sample)) bar = '█' * filled_length + '-' * (bar_length - filled_length) print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") - - # 生成压缩后记忆 - first_memory = set() - first_memory = self.memory_compress(input_text, 2.5) - # 延时防止访问超频 - # time.sleep(5) - #将记忆加入到图谱中 - for topic, memory in first_memory: - topics = segment_text(topic) - print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") - for split_topic in topics: - self.memory_graph.add_dot(split_topic,memory) - for split_topic in topics: - for other_split_topic in topics: - if split_topic != other_split_topic: - self.memory_graph.connect_dot(split_topic, other_split_topic) - - self.memory_graph.save_graph_to_db() + if input_text: + # 生成压缩后记忆 + first_memory = set() + first_memory = self.memory_compress(input_text, 2.5) + # 延时防止访问超频 + # time.sleep(5) + #将记忆加入到图谱中 + for topic, memory in first_memory: + topics = segment_text(topic) + print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") + for split_topic in topics: + self.memory_graph.add_dot(split_topic,memory) + for split_topic in topics: + for other_split_topic in topics: + if split_topic != other_split_topic: + self.memory_graph.connect_dot(split_topic, other_split_topic) + else: + print(f"空消息 跳过") + self.memory_graph.save_graph_to_db() def memory_compress(self, input_text, rate=1): information_content = calculate_information_content(input_text) @@ -260,16 +259,19 @@ def topic_what(text, topic): return prompt - +from nonebot import get_driver +driver = get_driver() +config = driver.config + start_time = time.time() Database.initialize( - host= os.getenv("MONGODB_HOST"), - port= int(os.getenv("MONGODB_PORT")), - db_name= os.getenv("DATABASE_NAME"), - username= os.getenv("MONGODB_USERNAME"), - password= os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source=config.mongodb_auth_source ) #创建记忆图 memory_graph = Memory_graph() diff --git a/src/plugins/memory_system/memory_make.py b/src/plugins/memory_system/memory_make.py index 0c1c0d219..7fb8af15a 100644 --- a/src/plugins/memory_system/memory_make.py +++ b/src/plugins/memory_system/memory_make.py @@ -13,7 +13,38 @@ import os sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 from src.common.database import Database # 使用正确的导入语法 from src.plugins.memory_system.llm_module import LLMModel - + +def calculate_information_content(text): + """计算文本的信息量(熵)""" + # 统计字符频率 + char_count = Counter(text) + total_chars = len(text) + + # 计算熵 + entropy = 0 + for count in char_count.values(): + probability = count / total_chars + entropy -= probability * math.log2(probability) + + return entropy + +def get_cloest_chat_from_db(db, length: int, timestamp: str): + """从数据库中获取最接近指定时间戳的聊天记录""" + chat_text = '' + closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) + + if closest_record: + closest_time = closest_record['time'] + group_id = closest_record['group_id'] # 获取groupid + # 获取该时间戳之后的length条消息,且groupid相同 + chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) + for record in chat_record: + time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) + chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' + return chat_text + + return '' + class Memory_graph: def __init__(self): self.G = nx.Graph() # 使用 networkx 的图结构 @@ -102,7 +133,8 @@ class Memory_graph: # 从数据库中根据时间戳获取离其最近的聊天记录 chat_text = '' closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 - print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") + + # print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") if closest_record: closest_time = closest_record['time'] @@ -110,8 +142,9 @@ class Memory_graph: # 获取该时间戳之后的length条消息,且groupid相同 chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) for record in chat_record: - time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) - chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 + if record: + time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) + chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 return chat_text return [] # 如果没有找到记录,返回空列表 @@ -187,155 +220,80 @@ class Memory_graph: for edge in edges: self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1)) -def calculate_information_content(text): - - """计算文本的信息量(熵)""" - # 统计字符频率 - char_count = Counter(text) - total_chars = len(text) - - # 计算熵 - entropy = 0 - for count in char_count.values(): - probability = count / total_chars - entropy -= probability * math.log2(probability) - - return entropy - - -# Database.initialize( -# global_config.MONGODB_HOST, -# global_config.MONGODB_PORT, -# global_config.DATABASE_NAME -# ) -# memory_graph = Memory_graph() - -# llm_model = LLMModel() -# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - -# memory_graph.load_graph_from_db() - - - -def main(): - # 初始化数据库 - Database.initialize( - host= os.getenv("MONGODB_HOST"), - port= int(os.getenv("MONGODB_PORT")), - db_name= os.getenv("DATABASE_NAME"), - username= os.getenv("MONGODB_USERNAME"), - password= os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") - ) - - memory_graph = Memory_graph() - # 创建LLM模型实例 - llm_model = LLMModel() - llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - - # 使用当前时间戳进行测试 - current_timestamp = datetime.datetime.now().timestamp() - chat_text = [] - - chat_size =25 - - for _ in range(30): # 循环10次 - random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间 - print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") - chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time) - chat_text.append(chat_) # 拼接所有text - # time.sleep(1) - - - - for i, input_text in enumerate(chat_text, 1): +# 海马体 +class Hippocampus: + def __init__(self,memory_graph:Memory_graph): + self.memory_graph = memory_graph + self.llm_model = LLMModel() + self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - progress = (i / len(chat_text)) * 100 - bar_length = 30 - filled_length = int(bar_length * i // len(chat_text)) - bar = '█' * filled_length + '-' * (bar_length - filled_length) - print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})") + def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}): + current_timestamp = datetime.datetime.now().timestamp() + chat_text = [] + #短期:1h 中期:4h 长期:24h + for _ in range(time_frequency.get('near')): # 循环10次 + random_time = current_timestamp - random.randint(1, 3600) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + for _ in range(time_frequency.get('mid')): # 循环10次 + random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + for _ in range(time_frequency.get('far')): # 循环10次 + random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + return chat_text + + def build_memory(self,chat_size=12): + #最近消息获取频率 + time_frequency = {'near':1,'mid':2,'far':2} + memory_sample = self.get_memory_sample(chat_size,time_frequency) - # print(input_text) - first_memory = set() - first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5) - # time.sleep(5) - - #将记忆加入到图谱中 - for topic, memory in first_memory: - topics = segment_text(topic) - print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") - for split_topic in topics: - memory_graph.add_dot(split_topic,memory) - for split_topic in topics: - for other_split_topic in topics: - if split_topic != other_split_topic: - memory_graph.connect_dot(split_topic, other_split_topic) - - # memory_graph.store_memory() - - # 展示两种不同的可视化方式 - print("\n按连接数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=False) - - print("\n按记忆数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=True) - - memory_graph.save_graph_to_db() - # memory_graph.load_graph_from_db() - - while True: - query = input("请输入新的查询概念(输入'退出'以结束):") - if query.lower() == '退出': - break - items_list = memory_graph.get_related_item(query) - if items_list: - # print(items_list) - for memory_item in items_list: - print(memory_item) - else: - print("未找到相关记忆。") + #加载进度可视化 + for i, input_text in enumerate(memory_sample, 1): + progress = (i / len(memory_sample)) * 100 + bar_length = 30 + filled_length = int(bar_length * i // len(memory_sample)) + bar = '█' * filled_length + '-' * (bar_length - filled_length) + print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") + # print(f"第{i}条消息: {input_text}") + if input_text: + # 生成压缩后记忆 + first_memory = set() + first_memory = self.memory_compress(input_text, 2.5) + #将记忆加入到图谱中 + for topic, memory in first_memory: + topics = segment_text(topic) + print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") + for split_topic in topics: + self.memory_graph.add_dot(split_topic,memory) + for split_topic in topics: + for other_split_topic in topics: + if split_topic != other_split_topic: + self.memory_graph.connect_dot(split_topic, other_split_topic) + else: + print(f"空消息 跳过") - while True: - query = input("请输入问题:") - - if query.lower() == '退出': - break - - topic_prompt = find_topic(query, 3) - topic_response = llm_model.generate_response(topic_prompt) + self.memory_graph.save_graph_to_db() + + def memory_compress(self, input_text, rate=1): + information_content = calculate_information_content(input_text) + print(f"文本的信息量(熵): {information_content:.4f} bits") + topic_num = max(1, min(5, int(information_content * rate / 4))) + topic_prompt = find_topic(input_text, topic_num) + topic_response = self.llm_model.generate_response(topic_prompt) # 检查 topic_response 是否为元组 if isinstance(topic_response, tuple): topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 else: topics = topic_response.split(",") - print(topics) - - for keyword in topics: - items_list = memory_graph.get_related_item(keyword) - if items_list: - print(items_list) - -def memory_compress(input_text, llm_model, llm_model_small, rate=1): - information_content = calculate_information_content(input_text) - print(f"文本的信息量(熵): {information_content:.4f} bits") - topic_num = max(1, min(5, int(information_content * rate / 4))) - print(topic_num) - topic_prompt = find_topic(input_text, topic_num) - topic_response = llm_model.generate_response(topic_prompt) - # 检查 topic_response 是否为元组 - if isinstance(topic_response, tuple): - topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 - else: - topics = topic_response.split(",") - print(topics) - compressed_memory = set() - for topic in topics: - topic_what_prompt = topic_what(input_text,topic) - topic_what_response = llm_model_small.generate_response(topic_what_prompt) - compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 - return compressed_memory - + compressed_memory = set() + for topic in topics: + topic_what_prompt = topic_what(input_text,topic) + topic_what_response = self.llm_model_small.generate_response(topic_what_prompt) + compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 + return compressed_memory def segment_text(text): seg_text = list(jieba.cut(text)) @@ -356,18 +314,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): G = memory_graph.G + # 创建一个新图用于可视化 + H = G.copy() + + # 移除只有一条记忆的节点和连接数少于3的节点 + nodes_to_remove = [] + for node in H.nodes(): + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + degree = H.degree(node) + if memory_count <= 1 or degree <= 2: + nodes_to_remove.append(node) + + H.remove_nodes_from(nodes_to_remove) + + # 如果过滤后没有节点,则返回 + if len(H.nodes()) == 0: + print("过滤后没有符合条件的节点可显示") + return + # 保存图到本地 - nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式 + nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式 # 根据连接条数或记忆数量设置节点颜色 node_colors = [] - nodes = list(G.nodes()) # 获取图中实际的节点列表 + nodes = list(H.nodes()) # 获取图中实际的节点列表 if color_by_memory: # 计算每个节点的记忆数量 memory_counts = [] for node in nodes: - memory_items = G.nodes[node].get('memory_items', []) + memory_items = H.nodes[node].get('memory_items', []) if isinstance(memory_items, list): count = len(memory_items) else: @@ -385,9 +362,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): node_colors.append(color) else: # 使用原来的连接数量着色方案 - max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 + max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1 for node in nodes: - degree = G.degree(node) + degree = H.degree(node) if max_degree > 0: red = min(1.0, degree / max_degree) blue = 1.0 - red @@ -398,8 +375,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): # 绘制图形 plt.figure(figsize=(12, 8)) - pos = nx.spring_layout(G, k=1, iterations=50) - nx.draw(G, pos, + pos = nx.spring_layout(H, k=1, iterations=50) + nx.draw(H, pos, with_labels=True, node_color=node_colors, node_size=2000, @@ -411,6 +388,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): plt.title(title, fontsize=16, fontfamily='SimHei') plt.show() +def main(): + # 初始化数据库 + Database.initialize( + host= os.getenv("MONGODB_HOST"), + port= int(os.getenv("MONGODB_PORT")), + db_name= os.getenv("DATABASE_NAME"), + username= os.getenv("MONGODB_USERNAME"), + password= os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE") + ) + + start_time = time.time() + + # 创建记忆图 + memory_graph = Memory_graph() + # 加载数据库中存储的记忆图 + memory_graph.load_graph_from_db() + # 创建海马体 + hippocampus = Hippocampus(memory_graph) + + end_time = time.time() + print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m") + + # 构建记忆 + hippocampus.build_memory(chat_size=25) + + # 展示两种不同的可视化方式 + print("\n按连接数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=False) + + print("\n按记忆数量着色的图谱:") + visualize_graph(memory_graph, color_by_memory=True) + + # 交互式查询 + while True: + query = input("请输入新的查询概念(输入'退出'以结束):") + if query.lower() == '退出': + break + items_list = memory_graph.get_related_item(query) + if items_list: + for memory_item in items_list: + print(memory_item) + else: + print("未找到相关记忆。") + + while True: + query = input("请输入问题:") + + if query.lower() == '退出': + break + + topic_prompt = find_topic(query, 3) + topic_response = hippocampus.llm_model.generate_response(topic_prompt) + # 检查 topic_response 是否为元组 + if isinstance(topic_response, tuple): + topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 + else: + topics = topic_response.split(",") + print(topics) + + for keyword in topics: + items_list = memory_graph.get_related_item(keyword) + if items_list: + print(items_list) + if __name__ == "__main__": main() diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index 097a89de8..a33c4b279 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -4,14 +4,19 @@ from typing import List, Dict from .schedule_llm_module import LLMModel from ...common.database import Database # 使用正确的导入语法 from ..chat.config import global_config +from nonebot import get_driver + +driver = get_driver() +config = driver.config + Database.initialize( - host= os.getenv("MONGODB_HOST"), - port= int(os.getenv("MONGODB_PORT")), - db_name= os.getenv("DATABASE_NAME"), - username= os.getenv("MONGODB_USERNAME"), - password= os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") + host= config.mongodb_host, + port= int(config.mongodb_port), + db_name= config.database_name, + username= config.mongodb_username, + password= config.mongodb_password, + auth_source=config.mongodb_auth_source ) class ScheduleGenerator: diff --git a/src/plugins/schedule/schedule_llm_module.py b/src/plugins/schedule/schedule_llm_module.py index ebf039c7e..408e7d546 100644 --- a/src/plugins/schedule/schedule_llm_module.py +++ b/src/plugins/schedule/schedule_llm_module.py @@ -1,20 +1,24 @@ import os import requests from typing import Tuple, Union +from nonebot import get_driver + +driver = get_driver() +config = driver.config class LLMModel: # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs): if api_using == "deepseek": - self.api_key = os.getenv("DEEP_SEEK_KEY") - self.base_url = os.getenv("DEEP_SEEK_BASE_URL") + self.api_key = config.deep_seek_key + self.base_url = config.deep_seek_base_url if model_name != "Pro/deepseek-ai/DeepSeek-R1": self.model_name = model_name else: self.model_name = "deepseek-reasoner" else: - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") + self.api_key = config.siliconflow_key + self.base_url = config.siliconflow_base_url self.model_name = model_name self.params = kwargs