diff --git a/config/env.example b/.env.prod
similarity index 73%
rename from config/env.example
rename to .env.prod
index 9988d58f3..3d795978c 100644
--- a/config/env.example
+++ b/.env.prod
@@ -1,4 +1,3 @@
-ENVIRONMENT=dev
HOST=127.0.0.1
PORT=8080
@@ -11,15 +10,15 @@ PLUGINS=["src2.plugins.chat"]
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
+
MONGODB_USERNAME = "" # 默认空值
MONGODB_PASSWORD = "" # 默认空值
MONGODB_AUTH_SOURCE = "" # 默认空值
-#api配置项
+#key and url
+CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY=
+CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=
-DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
-
-
-
+DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 265108181..c19b9ce33 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,7 @@ reasoning_content.bat
reasoning_window.bat
queue_update.txt
memory_graph.gml
+.env.dev
# Byte-compiled / optimized / DLL files
diff --git a/README.md b/README.md
index a85fcc4e8..1310d4879 100644
--- a/README.md
+++ b/README.md
@@ -28,7 +28,7 @@
> ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次
> ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语(
-关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言!
+关于麦麦的开发和建议相关的讨论群:766798517(不建议发布无关消息)这里不会有麦麦发言!
## 开发计划TODO:LIST
@@ -41,16 +41,13 @@
- config自动生成和检测
- log别用print
- 给发送消息写专门的类
+- 改进表情包发送逻辑l
-
-

-
-
## 📚 详细文档
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
-### 安装方法(还没测试好,现在部署可能遇到未知问题!!!!)
+### 安装方法(还没测试好,随时outdated ,现在部署可能遇到未知问题!!!!)
#### Linux 使用 Docker Compose 部署
获取项目根目录中的```docker-compose.yml```文件,运行以下命令
diff --git a/bot.py b/bot.py
index 8abdbbbe9..8741eca7f 100644
--- a/bot.py
+++ b/bot.py
@@ -4,27 +4,59 @@ from nonebot.adapters.onebot.v11 import Adapter
from dotenv import load_dotenv
from loguru import logger
- # 加载全局环境变量
-root_dir = os.path.dirname(os.path.abspath(__file__))
-env_path=os.path.join(root_dir, "config",'.env')
+'''彩蛋'''
+from colorama import init, Fore
+init()
+text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
+rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
+rainbow_text = ""
+for i, char in enumerate(text):
+ rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
+print(rainbow_text)
+'''彩蛋'''
-logger.info(f"尝试从 {env_path} 加载环境变量配置")
-if os.path.exists(env_path):
- load_dotenv(env_path)
- logger.success("成功加载环境变量配置")
+# 首先加载基础环境变量
+if os.path.exists(".env"):
+ load_dotenv(".env")
+ logger.success("成功加载基础环境变量配置")
else:
- logger.error(f"环境变量配置文件不存在: {env_path}")
+ logger.error("基础环境变量配置文件 .env 不存在")
+ exit(1)
+# 根据 ENVIRONMENT 加载对应的环境配置
+env = os.getenv("ENVIRONMENT")
+env_file = f".env.{env}"
+
+if env_file == ".env.dev" and os.path.exists(env_file):
+ logger.success("加载开发环境变量配置")
+ load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
+elif env_file == ".env.prod" and os.path.exists(env_file):
+ logger.success("加载环境变量配置")
+ load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
+else:
+ logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
+ exit(1)
-# 初始化 NoneBot
nonebot.init(
- # napcat 默认使用 8080 端口
- websocket_port=8080,
- # 设置日志级别
+ # 从环境变量中读取配置
+ websocket_port=os.getenv("PORT", 8080),
+ host=os.getenv("HOST", "127.0.0.1"),
log_level="INFO",
- # 设置超级用户
- superusers={"你的QQ号"},
- # TODO: 这样写会忽略环境变量需要优化 https://nonebot.dev/docs/appendices/config
- _env_file=env_path
+ # 添加自定义配置
+ mongodb_host=os.getenv("MONGODB_HOST", "127.0.0.1"),
+ mongodb_port=os.getenv("MONGODB_PORT", 27017),
+ database_name=os.getenv("DATABASE_NAME", "MegBot"),
+ mongodb_username=os.getenv("MONGODB_USERNAME", ""),
+ mongodb_password=os.getenv("MONGODB_PASSWORD", ""),
+ mongodb_auth_source=os.getenv("MONGODB_AUTH_SOURCE", ""),
+ # API相关配置
+ chat_any_where_key=os.getenv("CHAT_ANY_WHERE_KEY", ""),
+ siliconflow_key=os.getenv("SILICONFLOW_KEY", ""),
+ chat_any_where_base_url=os.getenv("CHAT_ANY_WHERE_BASE_URL", "https://api.chatanywhere.tech/v1"),
+ siliconflow_base_url=os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1/"),
+ deep_seek_key=os.getenv("DEEP_SEEK_KEY", ""),
+ deep_seek_base_url=os.getenv("DEEP_SEEK_BASE_URL", "https://api.deepseek.com/v1"),
+ # 插件配置
+ plugins=os.getenv("PLUGINS", ["src2.plugins.chat"])
)
# 注册适配器
diff --git a/run_maimai.bat b/run_maimai.bat
index 0e1bd7eb6..702d39edc 100644
--- a/run_maimai.bat
+++ b/run_maimai.bat
@@ -2,4 +2,5 @@ call conda activate niuniu
cd .
REM 执行nb run命令
-nb run
\ No newline at end of file
+nb run
+pause
\ No newline at end of file
diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py
index 6fc896ac8..0f81d30d6 100644
--- a/src/plugins/chat/__init__.py
+++ b/src/plugins/chat/__init__.py
@@ -11,16 +11,18 @@ from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule
from .willing_manager import willing_manager
+
# 获取驱动器
driver = get_driver()
+config = driver.config
Database.initialize(
- host= os.getenv("MONGODB_HOST"),
- port= int(os.getenv("MONGODB_PORT")),
- db_name= os.getenv("DATABASE_NAME"),
- username= os.getenv("MONGODB_USERNAME"),
- password= os.getenv("MONGODB_PASSWORD"),
- auth_source=os.getenv("MONGODB_AUTH_SOURCE")
+ host= config.mongodb_host,
+ port= int(config.mongodb_port),
+ db_name= config.database_name,
+ username= config.mongodb_username,
+ password= config.mongodb_password,
+ auth_source= config.mongodb_auth_source
)
print("\033[1;32m[初始化数据库完成]\033[0m")
@@ -37,7 +39,7 @@ emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
# 创建机器人实例
-chat_bot = ChatBot(global_config)
+chat_bot = ChatBot()
# 注册消息处理器
group_msg = on_message()
# 创建定时任务
diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py
index 1b5201645..e68ae93f3 100644
--- a/src/plugins/chat/bot.py
+++ b/src/plugins/chat/bot.py
@@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time
from ..memory_system.memory import memory_graph
class ChatBot:
- def __init__(self, config: BotConfig):
- self.config = config
+ def __init__(self):
self.storage = MessageStorage()
- self.gpt = LLMResponseGenerator(config)
+ self.gpt = LLMResponseGenerator()
self.bot = None # bot 实例引用
self._started = False
@@ -39,11 +38,11 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息"""
- if event.group_id not in self.config.talk_allowed_groups:
+ if event.group_id not in global_config.talk_allowed_groups:
return
self.bot = bot # 更新 bot 实例
- if event.user_id in self.config.ban_user_id:
+ if event.user_id in global_config.ban_user_id:
return
# 打印原始消息内容
@@ -120,7 +119,7 @@ class ChatBot:
event.group_id,
topic[0] if topic else None,
is_mentioned,
- self.config,
+ global_config,
event.user_id,
message.is_emoji,
interested_rate
@@ -143,10 +142,14 @@ class ChatBot:
response, emotion = await self.gpt.generate_response(message)
# 如果生成了回复,发送并记录
-
+
+ '''
+ 生成回复后的内容
+
+ '''
if response:
- message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id)
+ message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
accu_typing_time = 0
for msg in response:
print(f"当前消息: {msg}")
@@ -157,7 +160,7 @@ class ChatBot:
bot_message = Message(
group_id=event.group_id,
- user_id=self.config.BOT_QQ,
+ user_id=global_config.BOT_QQ,
message_id=think_id,
message_based_id=event.message_id,
raw_message=msg,
@@ -174,7 +177,7 @@ class ChatBot:
bot_response_time = tinking_time_point
- if random() < self.config.emoji_chance:
+ if random() < global_config.emoji_chance:
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
if emoji_path:
emoji_cq = CQCode.create_emoji_cq(emoji_path)
@@ -186,7 +189,7 @@ class ChatBot:
bot_message = Message(
group_id=event.group_id,
- user_id=self.config.BOT_QQ,
+ user_id=global_config.BOT_QQ,
message_id=0,
raw_message=emoji_cq,
plain_text=emoji_cq,
diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py
index 8c04e1126..0232219f8 100644
--- a/src/plugins/chat/config.py
+++ b/src/plugins/chat/config.py
@@ -7,6 +7,7 @@ import configparser
import tomli
import sys
from loguru import logger
+from nonebot import get_driver
@@ -111,7 +112,6 @@ class BotConfig:
# 获取配置文件路径
bot_config_path = BotConfig.get_default_config_path()
config_dir = os.path.dirname(bot_config_path)
-env_path = os.path.join(config_dir, '.env')
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
global_config = BotConfig.load_config(config_path=bot_config_path)
@@ -126,10 +126,11 @@ class LLMConfig:
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
-llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
-llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
-llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
-llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
+config = get_driver().config
+llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
+llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
+llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
+llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output:
diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py
index 92ca20bd8..ae5d8a257 100644
--- a/src/plugins/chat/cq_code.py
+++ b/src/plugins/chat/cq_code.py
@@ -7,7 +7,7 @@ from PIL import Image
import os
from random import random
from nonebot.adapters.onebot.v11 import Bot
-from .config import global_config, llm_config
+from .config import global_config
import time
import asyncio
from .utils_image import storage_image,storage_emoji
@@ -16,6 +16,10 @@ from .utils_user import get_user_nickname
#包含CQ码类
import urllib3
from urllib3.util import create_urllib3_context
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
ctx = create_urllib3_context()
@@ -179,7 +183,7 @@ class CQCode:
"""调用AI接口获取表情包描述"""
headers = {
"Content-Type": "application/json",
- "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
+ "Authorization": f"Bearer {config.siliconflow_key}"
}
payload = {
@@ -206,7 +210,7 @@ class CQCode:
}
response = requests.post(
- f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
+ f"{config.siliconflow_base_url}chat/completions",
headers=headers,
json=payload,
timeout=30
@@ -224,7 +228,7 @@ class CQCode:
"""调用AI接口获取普通图片描述"""
headers = {
"Content-Type": "application/json",
- "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
+ "Authorization": f"Bearer {config.siliconflow_key}"
}
payload = {
@@ -251,7 +255,7 @@ class CQCode:
}
response = requests.post(
- f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
+ f"{config.siliconflow_base_url}chat/completions",
headers=headers,
json=payload,
timeout=30
diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py
index a4352758d..c8c9dc814 100644
--- a/src/plugins/chat/emoji_manager.py
+++ b/src/plugins/chat/emoji_manager.py
@@ -10,10 +10,14 @@ import hashlib
from datetime import datetime
import base64
import shutil
-from .config import global_config, llm_config
import asyncio
import time
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
+
class EmojiManager:
_instance = None
@@ -93,7 +97,7 @@ class EmojiManager:
# 准备请求数据
headers = {
"Content-Type": "application/json",
- "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
+ "Authorization": f"Bearer {config.siliconflow_key}"
}
payload = {
@@ -115,7 +119,7 @@ class EmojiManager:
async with aiohttp.ClientSession() as session:
async with session.post(
- f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
+ f"{config.siliconflow_base_url}chat/completions",
headers=headers,
json=payload
) as response:
@@ -249,7 +253,7 @@ class EmojiManager:
async with aiohttp.ClientSession() as session:
headers = {
"Content-Type": "application/json",
- "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
+ "Authorization": f"Bearer {config.siliconflow_key}"
}
payload = {
@@ -276,7 +280,7 @@ class EmojiManager:
}
async with session.post(
- f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
+ f"{config.siliconflow_base_url}chat/completions",
headers=headers,
json=payload
) as response:
diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py
index 5ed12c9b6..bfff1d474 100644
--- a/src/plugins/chat/llm_generator.py
+++ b/src/plugins/chat/llm_generator.py
@@ -1,34 +1,34 @@
from typing import Dict, Any, List, Optional, Union, Tuple
from openai import OpenAI
import asyncio
-import requests
from functools import partial
from .message import Message
-from .config import BotConfig, global_config
+from .config import global_config
from ...common.database import Database
import random
import time
-import os
import numpy as np
from .relationship_manager import relationship_manager
-from ..schedule.schedule_generator import bot_schedule
from .prompt_builder import prompt_builder
-from .config import llm_config, global_config
+from .config import global_config
from .utils import process_llm_response
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
class LLMResponseGenerator:
- def __init__(self, config: BotConfig):
- self.config = config
- if self.config.API_USING == "siliconflow":
+ def __init__(self):
+ if global_config.API_USING == "siliconflow":
self.client = OpenAI(
- api_key=llm_config.SILICONFLOW_API_KEY,
- base_url=llm_config.SILICONFLOW_BASE_URL
+ api_key=config.siliconflow_key,
+ base_url=config.siliconflow_base_url
)
- elif self.config.API_USING == "deepseek":
+ elif global_config.API_USING == "deepseek":
self.client = OpenAI(
- api_key=llm_config.DEEP_SEEK_API_KEY,
- base_url=llm_config.DEEP_SEEK_BASE_URL
+ api_key=config.deep_seek_key,
+ base_url=config.deep_seek_base_url
)
self.db = Database.get_instance()
@@ -52,6 +52,7 @@ class LLMResponseGenerator:
else:
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
+
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
if self.current_model_type == 'r1':
model_response = await self._generate_r1_response(message)
@@ -83,8 +84,9 @@ class LLMResponseGenerator:
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
else:
relationship_value = 0.0
+
- # 构建prompt
+ ''' 构建prompt '''
prompt,prompt_check = prompt_builder._build_prompt(
message_txt=message.processed_plain_text,
sender_name=sender_name,
@@ -92,6 +94,7 @@ class LLMResponseGenerator:
group_id=message.group_id
)
+
# 设置默认参数
default_params = {
"model": model_name,
@@ -113,6 +116,7 @@ class LLMResponseGenerator:
if model_params:
default_params.update(model_params)
+
def create_completion():
return self.client.chat.completions.create(**default_params)
@@ -122,6 +126,7 @@ class LLMResponseGenerator:
loop = asyncio.get_event_loop()
# 读空气模块
+ air = 0
reasoning_content_check=''
content_check=''
if global_config.enable_kuuki_read:
@@ -135,21 +140,26 @@ class LLMResponseGenerator:
content_check = response_check.choices[0].message.content
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
if 'yes' not in content_check.lower():
- self.db.db.reasoning_logs.insert_one({
- 'time': time.time(),
- 'group_id': message.group_id,
- 'user': sender_name,
- 'message': message.processed_plain_text,
- 'model': model_name,
- 'reasoning_check': reasoning_content_check,
- 'response_check': content_check,
- 'reasoning': "",
- 'response': "",
- 'prompt': prompt,
- 'prompt_check': prompt_check,
- 'model_params': default_params
- })
- return None
+ air = 1
+ #稀释读空气的判定
+ if air == 1 and random.random() < 0.3:
+ self.db.db.reasoning_logs.insert_one({
+ 'time': time.time(),
+ 'group_id': message.group_id,
+ 'user': sender_name,
+ 'message': message.processed_plain_text,
+ 'model': model_name,
+ 'reasoning_check': reasoning_content_check,
+ 'response_check': content_check,
+ 'reasoning': "",
+ 'response': "",
+ 'prompt': prompt,
+ 'prompt_check': prompt_check,
+ 'model_params': default_params
+ })
+ return None
+
+
@@ -193,7 +203,7 @@ class LLMResponseGenerator:
async def _generate_r1_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-R1 模型生成回复"""
- if self.config.API_USING == "deepseek":
+ if global_config.API_USING == "deepseek":
return await self._generate_base_response(
message,
"deepseek-reasoner",
@@ -208,7 +218,7 @@ class LLMResponseGenerator:
async def _generate_v3_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-V3 模型生成回复"""
- if self.config.API_USING == "deepseek":
+ if global_config.API_USING == "deepseek":
return await self._generate_base_response(
message,
"deepseek-chat",
@@ -259,7 +269,7 @@ class LLMResponseGenerator:
messages = [{"role": "user", "content": prompt}]
loop = asyncio.get_event_loop()
- if self.config.API_USING == "deepseek":
+ if global_config.API_USING == "deepseek":
model = "deepseek-chat"
else:
model = "Pro/deepseek-ai/DeepSeek-V3"
@@ -296,4 +306,4 @@ class LLMResponseGenerator:
return processed_response, emotion_tags
# 创建全局实例
-llm_response = LLMResponseGenerator(global_config)
\ No newline at end of file
+llm_response = LLMResponseGenerator()
\ No newline at end of file
diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py
index 49402dbfc..ba0e9b4cc 100644
--- a/src/plugins/chat/prompt_builder.py
+++ b/src/plugins/chat/prompt_builder.py
@@ -66,12 +66,15 @@ class PromptBuilder:
overlapping_second_layer.update(overlap)
# 合并所有需要的记忆
- if all_first_layer_items:
- print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
- if overlapping_second_layer:
- print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
+ # if all_first_layer_items:
+ # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
+ # if overlapping_second_layer:
+ # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
- all_memories = all_first_layer_items + list(overlapping_second_layer)
+ # 使用集合去重
+ all_memories = list(set(all_first_layer_items) | set(overlapping_second_layer))
+ if all_memories:
+ print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
if all_memories: # 只在列表非空时选择随机项
random_item = choice(all_memories)
@@ -179,7 +182,11 @@ class PromptBuilder:
# prompt += f"{activate_prompt}\n"
prompt += f"{prompt_personality}\n"
prompt += f"{prompt_ger}\n"
- prompt += f"{extra_info}\n"
+ prompt += f"{extra_info}\n"
+
+
+
+ '''读空气prompt处理'''
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = ''
diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py
index 75593667b..81956ddc1 100644
--- a/src/plugins/chat/topic_identifier.py
+++ b/src/plugins/chat/topic_identifier.py
@@ -1,14 +1,17 @@
from typing import Optional, Dict, List
from openai import OpenAI
from .message import Message
-from .config import global_config, llm_config
import jieba
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
class TopicIdentifier:
def __init__(self):
self.client = OpenAI(
- api_key=llm_config.SILICONFLOW_API_KEY,
- base_url=llm_config.SILICONFLOW_BASE_URL
+ api_key=config.siliconflow_key,
+ base_url=config.siliconflow_base_url
)
def identify_topic_llm(self, text: str) -> Optional[str]:
diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py
index 4e2235805..91018ee08 100644
--- a/src/plugins/chat/utils.py
+++ b/src/plugins/chat/utils.py
@@ -4,11 +4,15 @@ from typing import List
from .message import Message
import requests
import numpy as np
-from .config import llm_config, global_config
+from .config import global_config
import re
from typing import Dict
from collections import Counter
import math
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
def combine_messages(messages: List[Message]) -> str:
@@ -64,7 +68,7 @@ def get_embedding(text):
"encoding_format": "float"
}
headers = {
- "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
+ "Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json"
}
diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py
index 097294010..68b2fa7f0 100644
--- a/src/plugins/chat/utils_image.py
+++ b/src/plugins/chat/utils_image.py
@@ -7,6 +7,10 @@ from ...common.database import Database
import zlib # 用于 CRC32
import base64
from .config import global_config
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
@@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
# 连接数据库
db = Database(
- host= os.getenv("MONGODB_HOST"),
- port= int(os.getenv("MONGODB_PORT")),
- db_name= os.getenv("DATABASE_NAME"),
- username= os.getenv("MONGODB_USERNAME"),
- password= os.getenv("MONGODB_PASSWORD"),
- auth_source=os.getenv("MONGODB_AUTH_SOURCE")
+ host= config.mongodb_host,
+ port= int(config.mongodb_port),
+ db_name= config.database_name,
+ username= config.mongodb_username,
+ password= config.mongodb_password,
+ auth_source=config.mongodb_auth_source
)
# 检查是否已存在相同哈希值的图片
diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py
index 8f3734a4f..e35743577 100644
--- a/src/plugins/chat/willing_manager.py
+++ b/src/plugins/chat/willing_manager.py
@@ -58,8 +58,8 @@ class WillingManager:
if group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / 3.5
- if is_mentioned_bot and user_id == int(964959351):
- reply_probability = 1
+ # if is_mentioned_bot and user_id == int(1026294844):
+ # reply_probability = 1
return reply_probability
diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py
index cdb591dee..d8c2e1482 100644
--- a/src/plugins/knowledege/knowledge_library.py
+++ b/src/plugins/knowledege/knowledge_library.py
@@ -3,6 +3,10 @@ import sys
import numpy as np
import requests
import time
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -13,12 +17,12 @@ from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息
Database.initialize(
- host= os.getenv("MONGODB_HOST"),
- port= int(os.getenv("MONGODB_PORT")),
- db_name= os.getenv("DATABASE_NAME"),
- username= os.getenv("MONGODB_USERNAME"),
- password= os.getenv("MONGODB_PASSWORD"),
- auth_source=os.getenv("MONGODB_AUTH_SOURCE")
+ host= config.mongodb_host,
+ port= int(config.mongodb_port),
+ db_name= config.database_name,
+ username= config.mongodb_username,
+ password= config.mongodb_password,
+ auth_source=config.mongodb_auth_source
)
class KnowledgeLibrary:
diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py
index a24f95a76..6b5dcd716 100644
--- a/src/plugins/memory_system/draw_memory.py
+++ b/src/plugins/memory_system/draw_memory.py
@@ -168,10 +168,12 @@ def main():
memory_graph.load_graph_from_db()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
- visualize_graph(memory_graph, color_by_memory=False)
+ # visualize_graph(memory_graph, color_by_memory=False)
+ visualize_graph_lite(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
- visualize_graph(memory_graph, color_by_memory=True)
+ # visualize_graph(memory_graph, color_by_memory=True)
+ visualize_graph_lite(memory_graph, color_by_memory=True)
# memory_graph.save_graph_to_db()
@@ -262,7 +264,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
-if __name__ == "__main__":
- main()
+def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
+ # 设置中文字体
+ plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
+ plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
+ G = memory_graph.G
+
+ # 创建一个新图用于可视化
+ H = G.copy()
+
+ # 移除只有一条记忆的节点和连接数少于3的节点
+ nodes_to_remove = []
+ for node in H.nodes():
+ memory_items = H.nodes[node].get('memory_items', [])
+ memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
+ degree = H.degree(node)
+ if memory_count <= 2 or degree <= 2:
+ nodes_to_remove.append(node)
+
+ H.remove_nodes_from(nodes_to_remove)
+
+ # 如果过滤后没有节点,则返回
+ if len(H.nodes()) == 0:
+ print("过滤后没有符合条件的节点可显示")
+ return
+
+ # 保存图到本地
+ nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
+
+ # 根据连接条数或记忆数量设置节点颜色
+ node_colors = []
+ nodes = list(H.nodes()) # 获取图中实际的节点列表
+
+ if color_by_memory:
+ # 计算每个节点的记忆数量
+ memory_counts = []
+ for node in nodes:
+ memory_items = H.nodes[node].get('memory_items', [])
+ if isinstance(memory_items, list):
+ count = len(memory_items)
+ else:
+ count = 1 if memory_items else 0
+ memory_counts.append(count)
+ max_memories = max(memory_counts) if memory_counts else 1
+
+ for count in memory_counts:
+ # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
+ if max_memories > 0:
+ intensity = min(1.0, count / max_memories)
+ color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
+ else:
+ color = (0, 0, 1) # 如果没有记忆,则为蓝色
+ node_colors.append(color)
+ else:
+ # 使用原来的连接数量着色方案
+ max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
+ for node in nodes:
+ degree = H.degree(node)
+ if max_degree > 0:
+ red = min(1.0, degree / max_degree)
+ blue = 1.0 - red
+ color = (red, 0, blue)
+ else:
+ color = (0, 0, 1)
+ node_colors.append(color)
+
+ # 绘制图形
+ plt.figure(figsize=(12, 8))
+ pos = nx.spring_layout(H, k=1, iterations=50)
+ nx.draw(H, pos,
+ with_labels=True,
+ node_color=node_colors,
+ node_size=2000,
+ font_size=10,
+ font_family='SimHei',
+ font_weight='bold')
+
+ title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
+ plt.title(title, fontsize=16, fontfamily='SimHei')
+ plt.show()
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/plugins/memory_system/llm_module.py b/src/plugins/memory_system/llm_module.py
index 4c9174dac..bd7f60dc3 100644
--- a/src/plugins/memory_system/llm_module.py
+++ b/src/plugins/memory_system/llm_module.py
@@ -2,14 +2,18 @@ import os
import requests
from typing import Tuple, Union
import time
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
- self.api_key = os.getenv("SILICONFLOW_KEY")
- self.base_url = os.getenv("SILICONFLOW_BASE_URL")
+ self.api_key = config.siliconflow_key
+ self.base_url = config.siliconflow_base_url
def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
diff --git a/src/plugins/memory_system/llm_module_memory_make.py b/src/plugins/memory_system/llm_module_memory_make.py
index 07cecae9d..04ab6dbc6 100644
--- a/src/plugins/memory_system/llm_module_memory_make.py
+++ b/src/plugins/memory_system/llm_module_memory_make.py
@@ -3,14 +3,18 @@ import requests
from typing import Tuple, Union
import time
from ..chat.config import BotConfig
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
- self.api_key = os.getenv("SILICONFLOW_KEY")
- self.base_url = os.getenv("SILICONFLOW_BASE_URL")
+ self.api_key = config.siliconflow_key
+ self.base_url = config.siliconflow_base_url
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py
index 4deb28d63..f2b162afb 100644
--- a/src/plugins/memory_system/memory.py
+++ b/src/plugins/memory_system/memory.py
@@ -198,8 +198,6 @@ class Hippocampus:
time_frequency = {'near':1,'mid':2,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
-
-
for i, input_text in enumerate(memory_sample, 1):
#加载进度可视化
progress = (i / len(memory_sample)) * 100
@@ -207,24 +205,25 @@ class Hippocampus:
filled_length = int(bar_length * i // len(memory_sample))
bar = '█' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
-
- # 生成压缩后记忆
- first_memory = set()
- first_memory = self.memory_compress(input_text, 2.5)
- # 延时防止访问超频
- # time.sleep(5)
- #将记忆加入到图谱中
- for topic, memory in first_memory:
- topics = segment_text(topic)
- print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
- for split_topic in topics:
- self.memory_graph.add_dot(split_topic,memory)
- for split_topic in topics:
- for other_split_topic in topics:
- if split_topic != other_split_topic:
- self.memory_graph.connect_dot(split_topic, other_split_topic)
-
- self.memory_graph.save_graph_to_db()
+ if input_text:
+ # 生成压缩后记忆
+ first_memory = set()
+ first_memory = self.memory_compress(input_text, 2.5)
+ # 延时防止访问超频
+ # time.sleep(5)
+ #将记忆加入到图谱中
+ for topic, memory in first_memory:
+ topics = segment_text(topic)
+ print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
+ for split_topic in topics:
+ self.memory_graph.add_dot(split_topic,memory)
+ for split_topic in topics:
+ for other_split_topic in topics:
+ if split_topic != other_split_topic:
+ self.memory_graph.connect_dot(split_topic, other_split_topic)
+ else:
+ print(f"空消息 跳过")
+ self.memory_graph.save_graph_to_db()
def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
@@ -260,16 +259,19 @@ def topic_what(text, topic):
return prompt
-
+from nonebot import get_driver
+driver = get_driver()
+config = driver.config
+
start_time = time.time()
Database.initialize(
- host= os.getenv("MONGODB_HOST"),
- port= int(os.getenv("MONGODB_PORT")),
- db_name= os.getenv("DATABASE_NAME"),
- username= os.getenv("MONGODB_USERNAME"),
- password= os.getenv("MONGODB_PASSWORD"),
- auth_source=os.getenv("MONGODB_AUTH_SOURCE")
+ host= config.mongodb_host,
+ port= int(config.mongodb_port),
+ db_name= config.database_name,
+ username= config.mongodb_username,
+ password= config.mongodb_password,
+ auth_source=config.mongodb_auth_source
)
#创建记忆图
memory_graph = Memory_graph()
diff --git a/src/plugins/memory_system/memory_make.py b/src/plugins/memory_system/memory_make.py
index 0c1c0d219..7fb8af15a 100644
--- a/src/plugins/memory_system/memory_make.py
+++ b/src/plugins/memory_system/memory_make.py
@@ -13,7 +13,38 @@ import os
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
from src.plugins.memory_system.llm_module import LLMModel
-
+
+def calculate_information_content(text):
+ """计算文本的信息量(熵)"""
+ # 统计字符频率
+ char_count = Counter(text)
+ total_chars = len(text)
+
+ # 计算熵
+ entropy = 0
+ for count in char_count.values():
+ probability = count / total_chars
+ entropy -= probability * math.log2(probability)
+
+ return entropy
+
+def get_cloest_chat_from_db(db, length: int, timestamp: str):
+ """从数据库中获取最接近指定时间戳的聊天记录"""
+ chat_text = ''
+ closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
+
+ if closest_record:
+ closest_time = closest_record['time']
+ group_id = closest_record['group_id'] # 获取groupid
+ # 获取该时间戳之后的length条消息,且groupid相同
+ chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
+ for record in chat_record:
+ time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
+ chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
+ return chat_text
+
+ return ''
+
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
@@ -102,7 +133,8 @@ class Memory_graph:
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
- print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
+
+ # print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
@@ -110,8 +142,9 @@ class Memory_graph:
# 获取该时间戳之后的length条消息,且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
- time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
- chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
+ if record:
+ time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
+ chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
@@ -187,155 +220,80 @@ class Memory_graph:
for edge in edges:
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
-def calculate_information_content(text):
-
- """计算文本的信息量(熵)"""
- # 统计字符频率
- char_count = Counter(text)
- total_chars = len(text)
-
- # 计算熵
- entropy = 0
- for count in char_count.values():
- probability = count / total_chars
- entropy -= probability * math.log2(probability)
-
- return entropy
-
-
-# Database.initialize(
-# global_config.MONGODB_HOST,
-# global_config.MONGODB_PORT,
-# global_config.DATABASE_NAME
-# )
-# memory_graph = Memory_graph()
-
-# llm_model = LLMModel()
-# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
-
-# memory_graph.load_graph_from_db()
-
-
-
-def main():
- # 初始化数据库
- Database.initialize(
- host= os.getenv("MONGODB_HOST"),
- port= int(os.getenv("MONGODB_PORT")),
- db_name= os.getenv("DATABASE_NAME"),
- username= os.getenv("MONGODB_USERNAME"),
- password= os.getenv("MONGODB_PASSWORD"),
- auth_source=os.getenv("MONGODB_AUTH_SOURCE")
- )
-
- memory_graph = Memory_graph()
- # 创建LLM模型实例
- llm_model = LLMModel()
- llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
-
- # 使用当前时间戳进行测试
- current_timestamp = datetime.datetime.now().timestamp()
- chat_text = []
-
- chat_size =25
-
- for _ in range(30): # 循环10次
- random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
- print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
- chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
- chat_text.append(chat_) # 拼接所有text
- # time.sleep(1)
-
-
-
- for i, input_text in enumerate(chat_text, 1):
+# 海马体
+class Hippocampus:
+ def __init__(self,memory_graph:Memory_graph):
+ self.memory_graph = memory_graph
+ self.llm_model = LLMModel()
+ self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
- progress = (i / len(chat_text)) * 100
- bar_length = 30
- filled_length = int(bar_length * i // len(chat_text))
- bar = '█' * filled_length + '-' * (bar_length - filled_length)
- print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})")
+ def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
+ current_timestamp = datetime.datetime.now().timestamp()
+ chat_text = []
+ #短期:1h 中期:4h 长期:24h
+ for _ in range(time_frequency.get('near')): # 循环10次
+ random_time = current_timestamp - random.randint(1, 3600) # 随机时间
+ chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
+ chat_text.append(chat_)
+ for _ in range(time_frequency.get('mid')): # 循环10次
+ random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
+ chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
+ chat_text.append(chat_)
+ for _ in range(time_frequency.get('far')): # 循环10次
+ random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
+ chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
+ chat_text.append(chat_)
+ return chat_text
+
+ def build_memory(self,chat_size=12):
+ #最近消息获取频率
+ time_frequency = {'near':1,'mid':2,'far':2}
+ memory_sample = self.get_memory_sample(chat_size,time_frequency)
- # print(input_text)
- first_memory = set()
- first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
- # time.sleep(5)
-
- #将记忆加入到图谱中
- for topic, memory in first_memory:
- topics = segment_text(topic)
- print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
- for split_topic in topics:
- memory_graph.add_dot(split_topic,memory)
- for split_topic in topics:
- for other_split_topic in topics:
- if split_topic != other_split_topic:
- memory_graph.connect_dot(split_topic, other_split_topic)
-
- # memory_graph.store_memory()
-
- # 展示两种不同的可视化方式
- print("\n按连接数量着色的图谱:")
- visualize_graph(memory_graph, color_by_memory=False)
-
- print("\n按记忆数量着色的图谱:")
- visualize_graph(memory_graph, color_by_memory=True)
-
- memory_graph.save_graph_to_db()
- # memory_graph.load_graph_from_db()
-
- while True:
- query = input("请输入新的查询概念(输入'退出'以结束):")
- if query.lower() == '退出':
- break
- items_list = memory_graph.get_related_item(query)
- if items_list:
- # print(items_list)
- for memory_item in items_list:
- print(memory_item)
- else:
- print("未找到相关记忆。")
+ #加载进度可视化
+ for i, input_text in enumerate(memory_sample, 1):
+ progress = (i / len(memory_sample)) * 100
+ bar_length = 30
+ filled_length = int(bar_length * i // len(memory_sample))
+ bar = '█' * filled_length + '-' * (bar_length - filled_length)
+ print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
+ # print(f"第{i}条消息: {input_text}")
+ if input_text:
+ # 生成压缩后记忆
+ first_memory = set()
+ first_memory = self.memory_compress(input_text, 2.5)
+ #将记忆加入到图谱中
+ for topic, memory in first_memory:
+ topics = segment_text(topic)
+ print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
+ for split_topic in topics:
+ self.memory_graph.add_dot(split_topic,memory)
+ for split_topic in topics:
+ for other_split_topic in topics:
+ if split_topic != other_split_topic:
+ self.memory_graph.connect_dot(split_topic, other_split_topic)
+ else:
+ print(f"空消息 跳过")
- while True:
- query = input("请输入问题:")
-
- if query.lower() == '退出':
- break
-
- topic_prompt = find_topic(query, 3)
- topic_response = llm_model.generate_response(topic_prompt)
+ self.memory_graph.save_graph_to_db()
+
+ def memory_compress(self, input_text, rate=1):
+ information_content = calculate_information_content(input_text)
+ print(f"文本的信息量(熵): {information_content:.4f} bits")
+ topic_num = max(1, min(5, int(information_content * rate / 4)))
+ topic_prompt = find_topic(input_text, topic_num)
+ topic_response = self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
- print(topics)
-
- for keyword in topics:
- items_list = memory_graph.get_related_item(keyword)
- if items_list:
- print(items_list)
-
-def memory_compress(input_text, llm_model, llm_model_small, rate=1):
- information_content = calculate_information_content(input_text)
- print(f"文本的信息量(熵): {information_content:.4f} bits")
- topic_num = max(1, min(5, int(information_content * rate / 4)))
- print(topic_num)
- topic_prompt = find_topic(input_text, topic_num)
- topic_response = llm_model.generate_response(topic_prompt)
- # 检查 topic_response 是否为元组
- if isinstance(topic_response, tuple):
- topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
- else:
- topics = topic_response.split(",")
- print(topics)
- compressed_memory = set()
- for topic in topics:
- topic_what_prompt = topic_what(input_text,topic)
- topic_what_response = llm_model_small.generate_response(topic_what_prompt)
- compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
- return compressed_memory
-
+ compressed_memory = set()
+ for topic in topics:
+ topic_what_prompt = topic_what(input_text,topic)
+ topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
+ compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
+ return compressed_memory
def segment_text(text):
seg_text = list(jieba.cut(text))
@@ -356,18 +314,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
G = memory_graph.G
+ # 创建一个新图用于可视化
+ H = G.copy()
+
+ # 移除只有一条记忆的节点和连接数少于3的节点
+ nodes_to_remove = []
+ for node in H.nodes():
+ memory_items = H.nodes[node].get('memory_items', [])
+ memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
+ degree = H.degree(node)
+ if memory_count <= 1 or degree <= 2:
+ nodes_to_remove.append(node)
+
+ H.remove_nodes_from(nodes_to_remove)
+
+ # 如果过滤后没有节点,则返回
+ if len(H.nodes()) == 0:
+ print("过滤后没有符合条件的节点可显示")
+ return
+
# 保存图到本地
- nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
+ nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
- nodes = list(G.nodes()) # 获取图中实际的节点列表
+ nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
- memory_items = G.nodes[node].get('memory_items', [])
+ memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
@@ -385,9 +362,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
- max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
+ max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
- degree = G.degree(node)
+ degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
@@ -398,8 +375,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 绘制图形
plt.figure(figsize=(12, 8))
- pos = nx.spring_layout(G, k=1, iterations=50)
- nx.draw(G, pos,
+ pos = nx.spring_layout(H, k=1, iterations=50)
+ nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
@@ -411,6 +388,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
+def main():
+ # 初始化数据库
+ Database.initialize(
+ host= os.getenv("MONGODB_HOST"),
+ port= int(os.getenv("MONGODB_PORT")),
+ db_name= os.getenv("DATABASE_NAME"),
+ username= os.getenv("MONGODB_USERNAME"),
+ password= os.getenv("MONGODB_PASSWORD"),
+ auth_source=os.getenv("MONGODB_AUTH_SOURCE")
+ )
+
+ start_time = time.time()
+
+ # 创建记忆图
+ memory_graph = Memory_graph()
+ # 加载数据库中存储的记忆图
+ memory_graph.load_graph_from_db()
+ # 创建海马体
+ hippocampus = Hippocampus(memory_graph)
+
+ end_time = time.time()
+ print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
+
+ # 构建记忆
+ hippocampus.build_memory(chat_size=25)
+
+ # 展示两种不同的可视化方式
+ print("\n按连接数量着色的图谱:")
+ visualize_graph(memory_graph, color_by_memory=False)
+
+ print("\n按记忆数量着色的图谱:")
+ visualize_graph(memory_graph, color_by_memory=True)
+
+ # 交互式查询
+ while True:
+ query = input("请输入新的查询概念(输入'退出'以结束):")
+ if query.lower() == '退出':
+ break
+ items_list = memory_graph.get_related_item(query)
+ if items_list:
+ for memory_item in items_list:
+ print(memory_item)
+ else:
+ print("未找到相关记忆。")
+
+ while True:
+ query = input("请输入问题:")
+
+ if query.lower() == '退出':
+ break
+
+ topic_prompt = find_topic(query, 3)
+ topic_response = hippocampus.llm_model.generate_response(topic_prompt)
+ # 检查 topic_response 是否为元组
+ if isinstance(topic_response, tuple):
+ topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
+ else:
+ topics = topic_response.split(",")
+ print(topics)
+
+ for keyword in topics:
+ items_list = memory_graph.get_related_item(keyword)
+ if items_list:
+ print(items_list)
+
if __name__ == "__main__":
main()
diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py
index 097a89de8..a33c4b279 100644
--- a/src/plugins/schedule/schedule_generator.py
+++ b/src/plugins/schedule/schedule_generator.py
@@ -4,14 +4,19 @@ from typing import List, Dict
from .schedule_llm_module import LLMModel
from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
+
Database.initialize(
- host= os.getenv("MONGODB_HOST"),
- port= int(os.getenv("MONGODB_PORT")),
- db_name= os.getenv("DATABASE_NAME"),
- username= os.getenv("MONGODB_USERNAME"),
- password= os.getenv("MONGODB_PASSWORD"),
- auth_source=os.getenv("MONGODB_AUTH_SOURCE")
+ host= config.mongodb_host,
+ port= int(config.mongodb_port),
+ db_name= config.database_name,
+ username= config.mongodb_username,
+ password= config.mongodb_password,
+ auth_source=config.mongodb_auth_source
)
class ScheduleGenerator:
diff --git a/src/plugins/schedule/schedule_llm_module.py b/src/plugins/schedule/schedule_llm_module.py
index ebf039c7e..408e7d546 100644
--- a/src/plugins/schedule/schedule_llm_module.py
+++ b/src/plugins/schedule/schedule_llm_module.py
@@ -1,20 +1,24 @@
import os
import requests
from typing import Tuple, Union
+from nonebot import get_driver
+
+driver = get_driver()
+config = driver.config
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs):
if api_using == "deepseek":
- self.api_key = os.getenv("DEEP_SEEK_KEY")
- self.base_url = os.getenv("DEEP_SEEK_BASE_URL")
+ self.api_key = config.deep_seek_key
+ self.base_url = config.deep_seek_base_url
if model_name != "Pro/deepseek-ai/DeepSeek-R1":
self.model_name = model_name
else:
self.model_name = "deepseek-reasoner"
else:
- self.api_key = os.getenv("SILICONFLOW_KEY")
- self.base_url = os.getenv("SILICONFLOW_BASE_URL")
+ self.api_key = config.siliconflow_key
+ self.base_url = config.siliconflow_base_url
self.model_name = model_name
self.params = kwargs