v0.3.2 更改了.env config的逻辑和memory优化

v0.3.2
更改了.env config的逻辑
memory优化
读空气优化
This commit is contained in:
SengokuCola
2025-03-02 15:00:12 +08:00
parent 31659497f0
commit 1cd7f80937
24 changed files with 538 additions and 317 deletions

View File

@@ -1,4 +1,3 @@
ENVIRONMENT=dev
HOST=127.0.0.1 HOST=127.0.0.1
PORT=8080 PORT=8080
@@ -11,15 +10,15 @@ PLUGINS=["src2.plugins.chat"]
MONGODB_HOST=127.0.0.1 MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017 MONGODB_PORT=27017
DATABASE_NAME=MegBot DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 默认空值 MONGODB_USERNAME = "" # 默认空值
MONGODB_PASSWORD = "" # 默认空值 MONGODB_PASSWORD = "" # 默认空值
MONGODB_AUTH_SOURCE = "" # 默认空值 MONGODB_AUTH_SOURCE = "" # 默认空值
#api配置项 #key and url
CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY= SILICONFLOW_KEY=
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY= DEEP_SEEK_KEY=
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1

1
.gitignore vendored
View File

@@ -14,6 +14,7 @@ reasoning_content.bat
reasoning_window.bat reasoning_window.bat
queue_update.txt queue_update.txt
memory_graph.gml memory_graph.gml
.env.dev
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files

View File

@@ -28,7 +28,7 @@
> ⚠️ **警告**请自行了解qqbot的风险麦麦有时候一天被腾讯肘七八次 > ⚠️ **警告**请自行了解qqbot的风险麦麦有时候一天被腾讯肘七八次
> ⚠️ **警告**由于麦麦一直在迭代所以可能存在一些bug请自行测试包括胡言乱语 > ⚠️ **警告**由于麦麦一直在迭代所以可能存在一些bug请自行测试包括胡言乱语
关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言! 关于麦麦的开发和建议相关的讨论群:766798517(不建议发布无关消息)这里不会有麦麦发言!
## 开发计划TODOLIST ## 开发计划TODOLIST
@@ -41,16 +41,13 @@
- config自动生成和检测 - config自动生成和检测
- log别用print - log别用print
- 给发送消息写专门的类 - 给发送消息写专门的类
- 改进表情包发送逻辑l
<div align="center">
<img src="docs/qq.png" width="300" />
</div>
## 📚 详细文档 ## 📚 详细文档
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成) - [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
### 安装方法(还没测试好,现在部署可能遇到未知问题!!!!) ### 安装方法(还没测试好,随时outdated ,现在部署可能遇到未知问题!!!!)
#### Linux 使用 Docker Compose 部署 #### Linux 使用 Docker Compose 部署
获取项目根目录中的```docker-compose.yml```文件,运行以下命令 获取项目根目录中的```docker-compose.yml```文件,运行以下命令

64
bot.py
View File

@@ -4,27 +4,59 @@ from nonebot.adapters.onebot.v11 import Adapter
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from loguru import logger
# 加载全局环境变量 '''彩蛋'''
root_dir = os.path.dirname(os.path.abspath(__file__)) from colorama import init, Fore
env_path=os.path.join(root_dir, "config",'.env') init()
text = "多年以后面对行刑队张三将会回想起他2023年在会议上讨论人工智能的那个下午"
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
rainbow_text = ""
for i, char in enumerate(text):
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
print(rainbow_text)
'''彩蛋'''
logger.info(f"尝试从 {env_path} 加载环境变量配置") # 首先加载基础环境变量
if os.path.exists(env_path): if os.path.exists(".env"):
load_dotenv(env_path) load_dotenv(".env")
logger.success("成功加载环境变量配置") logger.success("成功加载基础环境变量配置")
else: else:
logger.error(f"环境变量配置文件不存在: {env_path}") logger.error("基础环境变量配置文件 .env 不存在")
exit(1)
# 根据 ENVIRONMENT 加载对应的环境配置
env = os.getenv("ENVIRONMENT")
env_file = f".env.{env}"
if env_file == ".env.dev" and os.path.exists(env_file):
logger.success("加载开发环境变量配置")
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
elif env_file == ".env.prod" and os.path.exists(env_file):
logger.success("加载环境变量配置")
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
else:
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
exit(1)
# 初始化 NoneBot
nonebot.init( nonebot.init(
# napcat 默认使用 8080 端口 # 从环境变量中读取配置
websocket_port=8080, websocket_port=os.getenv("PORT", 8080),
# 设置日志级别 host=os.getenv("HOST", "127.0.0.1"),
log_level="INFO", log_level="INFO",
# 设置超级用户 # 添加自定义配置
superusers={"你的QQ号"}, mongodb_host=os.getenv("MONGODB_HOST", "127.0.0.1"),
# TODO: 这样写会忽略环境变量需要优化 https://nonebot.dev/docs/appendices/config mongodb_port=os.getenv("MONGODB_PORT", 27017),
_env_file=env_path database_name=os.getenv("DATABASE_NAME", "MegBot"),
mongodb_username=os.getenv("MONGODB_USERNAME", ""),
mongodb_password=os.getenv("MONGODB_PASSWORD", ""),
mongodb_auth_source=os.getenv("MONGODB_AUTH_SOURCE", ""),
# API相关配置
chat_any_where_key=os.getenv("CHAT_ANY_WHERE_KEY", ""),
siliconflow_key=os.getenv("SILICONFLOW_KEY", ""),
chat_any_where_base_url=os.getenv("CHAT_ANY_WHERE_BASE_URL", "https://api.chatanywhere.tech/v1"),
siliconflow_base_url=os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1/"),
deep_seek_key=os.getenv("DEEP_SEEK_KEY", ""),
deep_seek_base_url=os.getenv("DEEP_SEEK_BASE_URL", "https://api.deepseek.com/v1"),
# 插件配置
plugins=os.getenv("PLUGINS", ["src2.plugins.chat"])
) )
# 注册适配器 # 注册适配器

View File

@@ -3,3 +3,4 @@ cd .
REM 执行nb run命令 REM 执行nb run命令
nb run nb run
pause

View File

@@ -11,16 +11,18 @@ from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from .willing_manager import willing_manager from .willing_manager import willing_manager
# 获取驱动器 # 获取驱动器
driver = get_driver() driver = get_driver()
config = driver.config
Database.initialize( Database.initialize(
host= os.getenv("MONGODB_HOST"), host= config.mongodb_host,
port= int(os.getenv("MONGODB_PORT")), port= int(config.mongodb_port),
db_name= os.getenv("DATABASE_NAME"), db_name= config.database_name,
username= os.getenv("MONGODB_USERNAME"), username= config.mongodb_username,
password= os.getenv("MONGODB_PASSWORD"), password= config.mongodb_password,
auth_source=os.getenv("MONGODB_AUTH_SOURCE") auth_source= config.mongodb_auth_source
) )
print("\033[1;32m[初始化数据库完成]\033[0m") print("\033[1;32m[初始化数据库完成]\033[0m")
@@ -37,7 +39,7 @@ emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m") print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
# 创建机器人实例 # 创建机器人实例
chat_bot = ChatBot(global_config) chat_bot = ChatBot()
# 注册消息处理器 # 注册消息处理器
group_msg = on_message() group_msg = on_message()
# 创建定时任务 # 创建定时任务

View File

@@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time
from ..memory_system.memory import memory_graph from ..memory_system.memory import memory_graph
class ChatBot: class ChatBot:
def __init__(self, config: BotConfig): def __init__(self):
self.config = config
self.storage = MessageStorage() self.storage = MessageStorage()
self.gpt = LLMResponseGenerator(config) self.gpt = LLMResponseGenerator()
self.bot = None # bot 实例引用 self.bot = None # bot 实例引用
self._started = False self._started = False
@@ -39,11 +38,11 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None: async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息""" """处理收到的群消息"""
if event.group_id not in self.config.talk_allowed_groups: if event.group_id not in global_config.talk_allowed_groups:
return return
self.bot = bot # 更新 bot 实例 self.bot = bot # 更新 bot 实例
if event.user_id in self.config.ban_user_id: if event.user_id in global_config.ban_user_id:
return return
# 打印原始消息内容 # 打印原始消息内容
@@ -120,7 +119,7 @@ class ChatBot:
event.group_id, event.group_id,
topic[0] if topic else None, topic[0] if topic else None,
is_mentioned, is_mentioned,
self.config, global_config,
event.user_id, event.user_id,
message.is_emoji, message.is_emoji,
interested_rate interested_rate
@@ -144,9 +143,13 @@ class ChatBot:
# 如果生成了回复,发送并记录 # 如果生成了回复,发送并记录
'''
生成回复后的内容
'''
if response: if response:
message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id) message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
accu_typing_time = 0 accu_typing_time = 0
for msg in response: for msg in response:
print(f"当前消息: {msg}") print(f"当前消息: {msg}")
@@ -157,7 +160,7 @@ class ChatBot:
bot_message = Message( bot_message = Message(
group_id=event.group_id, group_id=event.group_id,
user_id=self.config.BOT_QQ, user_id=global_config.BOT_QQ,
message_id=think_id, message_id=think_id,
message_based_id=event.message_id, message_based_id=event.message_id,
raw_message=msg, raw_message=msg,
@@ -174,7 +177,7 @@ class ChatBot:
bot_response_time = tinking_time_point bot_response_time = tinking_time_point
if random() < self.config.emoji_chance: if random() < global_config.emoji_chance:
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion) emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
if emoji_path: if emoji_path:
emoji_cq = CQCode.create_emoji_cq(emoji_path) emoji_cq = CQCode.create_emoji_cq(emoji_path)
@@ -186,7 +189,7 @@ class ChatBot:
bot_message = Message( bot_message = Message(
group_id=event.group_id, group_id=event.group_id,
user_id=self.config.BOT_QQ, user_id=global_config.BOT_QQ,
message_id=0, message_id=0,
raw_message=emoji_cq, raw_message=emoji_cq,
plain_text=emoji_cq, plain_text=emoji_cq,

View File

@@ -7,6 +7,7 @@ import configparser
import tomli import tomli
import sys import sys
from loguru import logger from loguru import logger
from nonebot import get_driver
@@ -111,7 +112,6 @@ class BotConfig:
# 获取配置文件路径 # 获取配置文件路径
bot_config_path = BotConfig.get_default_config_path() bot_config_path = BotConfig.get_default_config_path()
config_dir = os.path.dirname(bot_config_path) config_dir = os.path.dirname(bot_config_path)
env_path = os.path.join(config_dir, '.env')
logger.info(f"尝试从 {bot_config_path} 加载机器人配置") logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
global_config = BotConfig.load_config(config_path=bot_config_path) global_config = BotConfig.load_config(config_path=bot_config_path)
@@ -126,10 +126,11 @@ class LLMConfig:
DEEP_SEEK_BASE_URL: str = None DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig() llm_config = LLMConfig()
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY') config = get_driver().config
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL') llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY') llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL') llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output: if not global_config.enable_advance_output:

View File

@@ -7,7 +7,7 @@ from PIL import Image
import os import os
from random import random from random import random
from nonebot.adapters.onebot.v11 import Bot from nonebot.adapters.onebot.v11 import Bot
from .config import global_config, llm_config from .config import global_config
import time import time
import asyncio import asyncio
from .utils_image import storage_image,storage_emoji from .utils_image import storage_image,storage_emoji
@@ -16,6 +16,10 @@ from .utils_user import get_user_nickname
#包含CQ码类 #包含CQ码类
import urllib3 import urllib3
from urllib3.util import create_urllib3_context from urllib3.util import create_urllib3_context
from nonebot import get_driver
driver = get_driver()
config = driver.config
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616 # TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
ctx = create_urllib3_context() ctx = create_urllib3_context()
@@ -179,7 +183,7 @@ class CQCode:
"""调用AI接口获取表情包描述""" """调用AI接口获取表情包描述"""
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" "Authorization": f"Bearer {config.siliconflow_key}"
} }
payload = { payload = {
@@ -206,7 +210,7 @@ class CQCode:
} }
response = requests.post( response = requests.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", f"{config.siliconflow_base_url}chat/completions",
headers=headers, headers=headers,
json=payload, json=payload,
timeout=30 timeout=30
@@ -224,7 +228,7 @@ class CQCode:
"""调用AI接口获取普通图片描述""" """调用AI接口获取普通图片描述"""
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" "Authorization": f"Bearer {config.siliconflow_key}"
} }
payload = { payload = {
@@ -251,7 +255,7 @@ class CQCode:
} }
response = requests.post( response = requests.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", f"{config.siliconflow_base_url}chat/completions",
headers=headers, headers=headers,
json=payload, json=payload,
timeout=30 timeout=30

View File

@@ -10,10 +10,14 @@ import hashlib
from datetime import datetime from datetime import datetime
import base64 import base64
import shutil import shutil
from .config import global_config, llm_config
import asyncio import asyncio
import time import time
from nonebot import get_driver
driver = get_driver()
config = driver.config
class EmojiManager: class EmojiManager:
_instance = None _instance = None
@@ -93,7 +97,7 @@ class EmojiManager:
# 准备请求数据 # 准备请求数据
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" "Authorization": f"Bearer {config.siliconflow_key}"
} }
payload = { payload = {
@@ -115,7 +119,7 @@ class EmojiManager:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", f"{config.siliconflow_base_url}chat/completions",
headers=headers, headers=headers,
json=payload json=payload
) as response: ) as response:
@@ -249,7 +253,7 @@ class EmojiManager:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" "Authorization": f"Bearer {config.siliconflow_key}"
} }
payload = { payload = {
@@ -276,7 +280,7 @@ class EmojiManager:
} }
async with session.post( async with session.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", f"{config.siliconflow_base_url}chat/completions",
headers=headers, headers=headers,
json=payload json=payload
) as response: ) as response:

View File

@@ -1,34 +1,34 @@
from typing import Dict, Any, List, Optional, Union, Tuple from typing import Dict, Any, List, Optional, Union, Tuple
from openai import OpenAI from openai import OpenAI
import asyncio import asyncio
import requests
from functools import partial from functools import partial
from .message import Message from .message import Message
from .config import BotConfig, global_config from .config import global_config
from ...common.database import Database from ...common.database import Database
import random import random
import time import time
import os
import numpy as np import numpy as np
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .config import llm_config, global_config from .config import global_config
from .utils import process_llm_response from .utils import process_llm_response
from nonebot import get_driver
driver = get_driver()
config = driver.config
class LLMResponseGenerator: class LLMResponseGenerator:
def __init__(self, config: BotConfig): def __init__(self):
self.config = config if global_config.API_USING == "siliconflow":
if self.config.API_USING == "siliconflow":
self.client = OpenAI( self.client = OpenAI(
api_key=llm_config.SILICONFLOW_API_KEY, api_key=config.siliconflow_key,
base_url=llm_config.SILICONFLOW_BASE_URL base_url=config.siliconflow_base_url
) )
elif self.config.API_USING == "deepseek": elif global_config.API_USING == "deepseek":
self.client = OpenAI( self.client = OpenAI(
api_key=llm_config.DEEP_SEEK_API_KEY, api_key=config.deep_seek_key,
base_url=llm_config.DEEP_SEEK_BASE_URL base_url=config.deep_seek_base_url
) )
self.db = Database.get_instance() self.db = Database.get_instance()
@@ -52,6 +52,7 @@ class LLMResponseGenerator:
else: else:
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++") print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
if self.current_model_type == 'r1': if self.current_model_type == 'r1':
model_response = await self._generate_r1_response(message) model_response = await self._generate_r1_response(message)
@@ -84,7 +85,8 @@ class LLMResponseGenerator:
else: else:
relationship_value = 0.0 relationship_value = 0.0
# 构建prompt
''' 构建prompt '''
prompt,prompt_check = prompt_builder._build_prompt( prompt,prompt_check = prompt_builder._build_prompt(
message_txt=message.processed_plain_text, message_txt=message.processed_plain_text,
sender_name=sender_name, sender_name=sender_name,
@@ -92,6 +94,7 @@ class LLMResponseGenerator:
group_id=message.group_id group_id=message.group_id
) )
# 设置默认参数 # 设置默认参数
default_params = { default_params = {
"model": model_name, "model": model_name,
@@ -113,6 +116,7 @@ class LLMResponseGenerator:
if model_params: if model_params:
default_params.update(model_params) default_params.update(model_params)
def create_completion(): def create_completion():
return self.client.chat.completions.create(**default_params) return self.client.chat.completions.create(**default_params)
@@ -122,6 +126,7 @@ class LLMResponseGenerator:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# 读空气模块 # 读空气模块
air = 0
reasoning_content_check='' reasoning_content_check=''
content_check='' content_check=''
if global_config.enable_kuuki_read: if global_config.enable_kuuki_read:
@@ -135,21 +140,26 @@ class LLMResponseGenerator:
content_check = response_check.choices[0].message.content content_check = response_check.choices[0].message.content
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}") print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
if 'yes' not in content_check.lower(): if 'yes' not in content_check.lower():
self.db.db.reasoning_logs.insert_one({ air = 1
'time': time.time(), #稀释读空气的判定
'group_id': message.group_id, if air == 1 and random.random() < 0.3:
'user': sender_name, self.db.db.reasoning_logs.insert_one({
'message': message.processed_plain_text, 'time': time.time(),
'model': model_name, 'group_id': message.group_id,
'reasoning_check': reasoning_content_check, 'user': sender_name,
'response_check': content_check, 'message': message.processed_plain_text,
'reasoning': "", 'model': model_name,
'response': "", 'reasoning_check': reasoning_content_check,
'prompt': prompt, 'response_check': content_check,
'prompt_check': prompt_check, 'reasoning': "",
'model_params': default_params 'response': "",
}) 'prompt': prompt,
return None 'prompt_check': prompt_check,
'model_params': default_params
})
return None
@@ -193,7 +203,7 @@ class LLMResponseGenerator:
async def _generate_r1_response(self, message: Message) -> Optional[str]: async def _generate_r1_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-R1 模型生成回复""" """使用 DeepSeek-R1 模型生成回复"""
if self.config.API_USING == "deepseek": if global_config.API_USING == "deepseek":
return await self._generate_base_response( return await self._generate_base_response(
message, message,
"deepseek-reasoner", "deepseek-reasoner",
@@ -208,7 +218,7 @@ class LLMResponseGenerator:
async def _generate_v3_response(self, message: Message) -> Optional[str]: async def _generate_v3_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-V3 模型生成回复""" """使用 DeepSeek-V3 模型生成回复"""
if self.config.API_USING == "deepseek": if global_config.API_USING == "deepseek":
return await self._generate_base_response( return await self._generate_base_response(
message, message,
"deepseek-chat", "deepseek-chat",
@@ -259,7 +269,7 @@ class LLMResponseGenerator:
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if self.config.API_USING == "deepseek": if global_config.API_USING == "deepseek":
model = "deepseek-chat" model = "deepseek-chat"
else: else:
model = "Pro/deepseek-ai/DeepSeek-V3" model = "Pro/deepseek-ai/DeepSeek-V3"
@@ -296,4 +306,4 @@ class LLMResponseGenerator:
return processed_response, emotion_tags return processed_response, emotion_tags
# 创建全局实例 # 创建全局实例
llm_response = LLMResponseGenerator(global_config) llm_response = LLMResponseGenerator()

View File

@@ -66,12 +66,15 @@ class PromptBuilder:
overlapping_second_layer.update(overlap) overlapping_second_layer.update(overlap)
# 合并所有需要的记忆 # 合并所有需要的记忆
if all_first_layer_items: # if all_first_layer_items:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
if overlapping_second_layer: # if overlapping_second_layer:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") # print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
all_memories = all_first_layer_items + list(overlapping_second_layer) # 使用集合去重
all_memories = list(set(all_first_layer_items) | set(overlapping_second_layer))
if all_memories:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
if all_memories: # 只在列表非空时选择随机项 if all_memories: # 只在列表非空时选择随机项
random_item = choice(all_memories) random_item = choice(all_memories)
@@ -181,6 +184,10 @@ class PromptBuilder:
prompt += f"{prompt_ger}\n" prompt += f"{prompt_ger}\n"
prompt += f"{extra_info}\n" prompt += f"{extra_info}\n"
'''读空气prompt处理'''
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = '' prompt_personality_check = ''
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。" extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"

View File

@@ -1,14 +1,17 @@
from typing import Optional, Dict, List from typing import Optional, Dict, List
from openai import OpenAI from openai import OpenAI
from .message import Message from .message import Message
from .config import global_config, llm_config
import jieba import jieba
from nonebot import get_driver
driver = get_driver()
config = driver.config
class TopicIdentifier: class TopicIdentifier:
def __init__(self): def __init__(self):
self.client = OpenAI( self.client = OpenAI(
api_key=llm_config.SILICONFLOW_API_KEY, api_key=config.siliconflow_key,
base_url=llm_config.SILICONFLOW_BASE_URL base_url=config.siliconflow_base_url
) )
def identify_topic_llm(self, text: str) -> Optional[str]: def identify_topic_llm(self, text: str) -> Optional[str]:

View File

@@ -4,11 +4,15 @@ from typing import List
from .message import Message from .message import Message
import requests import requests
import numpy as np import numpy as np
from .config import llm_config, global_config from .config import global_config
import re import re
from typing import Dict from typing import Dict
from collections import Counter from collections import Counter
import math import math
from nonebot import get_driver
driver = get_driver()
config = driver.config
def combine_messages(messages: List[Message]) -> str: def combine_messages(messages: List[Message]) -> str:
@@ -64,7 +68,7 @@ def get_embedding(text):
"encoding_format": "float" "encoding_format": "float"
} }
headers = { headers = {
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}", "Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json" "Content-Type": "application/json"
} }

View File

@@ -7,6 +7,10 @@ from ...common.database import Database
import zlib # 用于 CRC32 import zlib # 用于 CRC32
import base64 import base64
from .config import global_config from .config import global_config
from nonebot import get_driver
driver = get_driver()
config = driver.config
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes: def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
@@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
# 连接数据库 # 连接数据库
db = Database( db = Database(
host= os.getenv("MONGODB_HOST"), host= config.mongodb_host,
port= int(os.getenv("MONGODB_PORT")), port= int(config.mongodb_port),
db_name= os.getenv("DATABASE_NAME"), db_name= config.database_name,
username= os.getenv("MONGODB_USERNAME"), username= config.mongodb_username,
password= os.getenv("MONGODB_PASSWORD"), password= config.mongodb_password,
auth_source=os.getenv("MONGODB_AUTH_SOURCE") auth_source=config.mongodb_auth_source
) )
# 检查是否已存在相同哈希值的图片 # 检查是否已存在相同哈希值的图片

View File

@@ -58,8 +58,8 @@ class WillingManager:
if group_id in config.talk_frequency_down_groups: if group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / 3.5 reply_probability = reply_probability / 3.5
if is_mentioned_bot and user_id == int(964959351): # if is_mentioned_bot and user_id == int(1026294844):
reply_probability = 1 # reply_probability = 1
return reply_probability return reply_probability

View File

@@ -3,6 +3,10 @@ import sys
import numpy as np import numpy as np
import requests import requests
import time import time
from nonebot import get_driver
driver = get_driver()
config = driver.config
# 添加项目根目录到 Python 路径 # 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -13,12 +17,12 @@ from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息 # 直接配置数据库连接信息
Database.initialize( Database.initialize(
host= os.getenv("MONGODB_HOST"), host= config.mongodb_host,
port= int(os.getenv("MONGODB_PORT")), port= int(config.mongodb_port),
db_name= os.getenv("DATABASE_NAME"), db_name= config.database_name,
username= os.getenv("MONGODB_USERNAME"), username= config.mongodb_username,
password= os.getenv("MONGODB_PASSWORD"), password= config.mongodb_password,
auth_source=os.getenv("MONGODB_AUTH_SOURCE") auth_source=config.mongodb_auth_source
) )
class KnowledgeLibrary: class KnowledgeLibrary:

View File

@@ -168,10 +168,12 @@ def main():
memory_graph.load_graph_from_db() memory_graph.load_graph_from_db()
# 展示两种不同的可视化方式 # 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:") print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False) # visualize_graph(memory_graph, color_by_memory=False)
visualize_graph_lite(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:") print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True) # visualize_graph(memory_graph, color_by_memory=True)
visualize_graph_lite(memory_graph, color_by_memory=True)
# memory_graph.save_graph_to_db() # memory_graph.save_graph_to_db()
@@ -262,7 +264,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
plt.title(title, fontsize=16, fontfamily='SimHei') plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show() plt.show()
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 2 or degree <= 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50)
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
font_size=10,
font_family='SimHei',
font_weight='bold')
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,14 +2,18 @@ import os
import requests import requests
from typing import Tuple, Union from typing import Tuple, Union
import time import time
from nonebot import get_driver
driver = get_driver()
config = driver.config
class LLMModel: class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY") self.api_key = config.siliconflow_key
self.base_url = os.getenv("SILICONFLOW_BASE_URL") self.base_url = config.siliconflow_base_url
def generate_response(self, prompt: str) -> Tuple[str, str]: def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应""" """根据输入的提示生成模型的响应"""

View File

@@ -3,14 +3,18 @@ import requests
from typing import Tuple, Union from typing import Tuple, Union
import time import time
from ..chat.config import BotConfig from ..chat.config import BotConfig
from nonebot import get_driver
driver = get_driver()
config = driver.config
class LLMModel: class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY") self.api_key = config.siliconflow_key
self.base_url = os.getenv("SILICONFLOW_BASE_URL") self.base_url = config.siliconflow_base_url
if not self.api_key or not self.base_url: if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")

View File

@@ -198,8 +198,6 @@ class Hippocampus:
time_frequency = {'near':1,'mid':2,'far':2} time_frequency = {'near':1,'mid':2,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency) memory_sample = self.get_memory_sample(chat_size,time_frequency)
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}") # print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
for i, input_text in enumerate(memory_sample, 1): for i, input_text in enumerate(memory_sample, 1):
#加载进度可视化 #加载进度可视化
progress = (i / len(memory_sample)) * 100 progress = (i / len(memory_sample)) * 100
@@ -207,24 +205,25 @@ class Hippocampus:
filled_length = int(bar_length * i // len(memory_sample)) filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length) bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
if input_text:
# 生成压缩后记忆 # 生成压缩后记忆
first_memory = set() first_memory = set()
first_memory = self.memory_compress(input_text, 2.5) first_memory = self.memory_compress(input_text, 2.5)
# 延时防止访问超频 # 延时防止访问超频
# time.sleep(5) # time.sleep(5)
#将记忆加入到图谱中 #将记忆加入到图谱中
for topic, memory in first_memory: for topic, memory in first_memory:
topics = segment_text(topic) topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics: for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory) self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics: for split_topic in topics:
for other_split_topic in topics: for other_split_topic in topics:
if split_topic != other_split_topic: if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic) self.memory_graph.connect_dot(split_topic, other_split_topic)
else:
self.memory_graph.save_graph_to_db() print(f"空消息 跳过")
self.memory_graph.save_graph_to_db()
def memory_compress(self, input_text, rate=1): def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text) information_content = calculate_information_content(input_text)
@@ -260,16 +259,19 @@ def topic_what(text, topic):
return prompt return prompt
from nonebot import get_driver
driver = get_driver()
config = driver.config
start_time = time.time() start_time = time.time()
Database.initialize( Database.initialize(
host= os.getenv("MONGODB_HOST"), host= config.mongodb_host,
port= int(os.getenv("MONGODB_PORT")), port= int(config.mongodb_port),
db_name= os.getenv("DATABASE_NAME"), db_name= config.database_name,
username= os.getenv("MONGODB_USERNAME"), username= config.mongodb_username,
password= os.getenv("MONGODB_PASSWORD"), password= config.mongodb_password,
auth_source=os.getenv("MONGODB_AUTH_SOURCE") auth_source=config.mongodb_auth_source
) )
#创建记忆图 #创建记忆图
memory_graph = Memory_graph() memory_graph = Memory_graph()

View File

@@ -14,6 +14,37 @@ sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法 from src.common.database import Database # 使用正确的导入语法
from src.plugins.memory_system.llm_module import LLMModel from src.plugins.memory_system.llm_module import LLMModel
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
return chat_text
return ''
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
@@ -102,7 +133,8 @@ class Memory_graph:
# 从数据库中根据时间戳获取离其最近的聊天记录 # 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = '' chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record: if closest_record:
closest_time = closest_record['time'] closest_time = closest_record['time']
@@ -110,8 +142,9 @@ class Memory_graph:
# 获取该时间戳之后的length条消息且groupid相同 # 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record: for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) if record:
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text return chat_text
return [] # 如果没有找到记录,返回空列表 return [] # 如果没有找到记录,返回空列表
@@ -187,155 +220,80 @@ class Memory_graph:
for edge in edges: for edge in edges:
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1)) self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
def calculate_information_content(text): # 海马体
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
"""计算文本的信息量(熵)""" def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
# 统计字符频率 current_timestamp = datetime.datetime.now().timestamp()
char_count = Counter(text) chat_text = []
total_chars = len(text) #短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return chat_text
# 计算熵 def build_memory(self,chat_size=12):
entropy = 0 #最近消息获取频率
for count in char_count.values(): time_frequency = {'near':1,'mid':2,'far':2}
probability = count / total_chars memory_sample = self.get_memory_sample(chat_size,time_frequency)
entropy -= probability * math.log2(probability)
return entropy #加载进度可视化
for i, input_text in enumerate(memory_sample, 1):
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# print(f"第{i}条消息: {input_text}")
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = self.memory_compress(input_text, 2.5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
else:
print(f"空消息 跳过")
self.memory_graph.save_graph_to_db()
# Database.initialize( def memory_compress(self, input_text, rate=1):
# global_config.MONGODB_HOST, information_content = calculate_information_content(input_text)
# global_config.MONGODB_PORT, print(f"文本的信息量(熵): {information_content:.4f} bits")
# global_config.DATABASE_NAME topic_num = max(1, min(5, int(information_content * rate / 4)))
# ) topic_prompt = find_topic(input_text, topic_num)
# memory_graph = Memory_graph() topic_response = self.llm_model.generate_response(topic_prompt)
# llm_model = LLMModel()
# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
# memory_graph.load_graph_from_db()
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
memory_graph = Memory_graph()
# 创建LLM模型实例
llm_model = LLMModel()
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
# 使用当前时间戳进行测试
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
chat_size =25
for _ in range(30): # 循环10次
random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
chat_text.append(chat_) # 拼接所有text
# time.sleep(1)
for i, input_text in enumerate(chat_text, 1):
progress = (i / len(chat_text)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(chat_text))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})")
# print(input_text)
first_memory = set()
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
# time.sleep(5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
memory_graph.connect_dot(split_topic, other_split_topic)
# memory_graph.store_memory()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
memory_graph.save_graph_to_db()
# memory_graph.load_graph_from_db()
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
while True:
query = input("请输入问题:")
if query.lower() == '退出':
break
topic_prompt = find_topic(query, 3)
topic_response = llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组 # 检查 topic_response 是否为元组
if isinstance(topic_response, tuple): if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else: else:
topics = topic_response.split(",") topics = topic_response.split(",")
print(topics) compressed_memory = set()
for topic in topics:
for keyword in topics: topic_what_prompt = topic_what(input_text,topic)
items_list = memory_graph.get_related_item(keyword) topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
if items_list: compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
print(items_list) return compressed_memory
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
print(topic_num)
topic_prompt = find_topic(input_text, topic_num)
topic_response = llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
def segment_text(text): def segment_text(text):
seg_text = list(jieba.cut(text)) seg_text = list(jieba.cut(text))
@@ -356,18 +314,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
G = memory_graph.G G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 1 or degree <= 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
return
# 保存图到本地 # 保存图到本地
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式 nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色 # 根据连接条数或记忆数量设置节点颜色
node_colors = [] node_colors = []
nodes = list(G.nodes()) # 获取图中实际的节点列表 nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory: if color_by_memory:
# 计算每个节点的记忆数量 # 计算每个节点的记忆数量
memory_counts = [] memory_counts = []
for node in nodes: for node in nodes:
memory_items = G.nodes[node].get('memory_items', []) memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list): if isinstance(memory_items, list):
count = len(memory_items) count = len(memory_items)
else: else:
@@ -385,9 +362,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
node_colors.append(color) node_colors.append(color)
else: else:
# 使用原来的连接数量着色方案 # 使用原来的连接数量着色方案
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes: for node in nodes:
degree = G.degree(node) degree = H.degree(node)
if max_degree > 0: if max_degree > 0:
red = min(1.0, degree / max_degree) red = min(1.0, degree / max_degree)
blue = 1.0 - red blue = 1.0 - red
@@ -398,8 +375,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 绘制图形 # 绘制图形
plt.figure(figsize=(12, 8)) plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, k=1, iterations=50) pos = nx.spring_layout(H, k=1, iterations=50)
nx.draw(G, pos, nx.draw(H, pos,
with_labels=True, with_labels=True,
node_color=node_colors, node_color=node_colors,
node_size=2000, node_size=2000,
@@ -411,6 +388,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
plt.title(title, fontsize=16, fontfamily='SimHei') plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show() plt.show()
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
start_time = time.time()
# 创建记忆图
memory_graph = Memory_graph()
# 加载数据库中存储的记忆图
memory_graph.load_graph_from_db()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
hippocampus.build_memory(chat_size=25)
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
# 交互式查询
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
while True:
query = input("请输入问题:")
if query.lower() == '退出':
break
topic_prompt = find_topic(query, 3)
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
for keyword in topics:
items_list = memory_graph.get_related_item(keyword)
if items_list:
print(items_list)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -4,14 +4,19 @@ from typing import List, Dict
from .schedule_llm_module import LLMModel from .schedule_llm_module import LLMModel
from ...common.database import Database # 使用正确的导入语法 from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config from ..chat.config import global_config
from nonebot import get_driver
driver = get_driver()
config = driver.config
Database.initialize( Database.initialize(
host= os.getenv("MONGODB_HOST"), host= config.mongodb_host,
port= int(os.getenv("MONGODB_PORT")), port= int(config.mongodb_port),
db_name= os.getenv("DATABASE_NAME"), db_name= config.database_name,
username= os.getenv("MONGODB_USERNAME"), username= config.mongodb_username,
password= os.getenv("MONGODB_PASSWORD"), password= config.mongodb_password,
auth_source=os.getenv("MONGODB_AUTH_SOURCE") auth_source=config.mongodb_auth_source
) )
class ScheduleGenerator: class ScheduleGenerator:

View File

@@ -1,20 +1,24 @@
import os import os
import requests import requests
from typing import Tuple, Union from typing import Tuple, Union
from nonebot import get_driver
driver = get_driver()
config = driver.config
class LLMModel: class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs):
if api_using == "deepseek": if api_using == "deepseek":
self.api_key = os.getenv("DEEP_SEEK_KEY") self.api_key = config.deep_seek_key
self.base_url = os.getenv("DEEP_SEEK_BASE_URL") self.base_url = config.deep_seek_base_url
if model_name != "Pro/deepseek-ai/DeepSeek-R1": if model_name != "Pro/deepseek-ai/DeepSeek-R1":
self.model_name = model_name self.model_name = model_name
else: else:
self.model_name = "deepseek-reasoner" self.model_name = "deepseek-reasoner"
else: else:
self.api_key = os.getenv("SILICONFLOW_KEY") self.api_key = config.siliconflow_key
self.base_url = os.getenv("SILICONFLOW_BASE_URL") self.base_url = config.siliconflow_base_url
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs