v0.3.2 更改了.env config的逻辑和memory优化
v0.3.2 更改了.env config的逻辑 memory优化 读空气优化
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
ENVIRONMENT=dev
|
|
||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
PORT=8080
|
PORT=8080
|
||||||
|
|
||||||
@@ -11,15 +10,15 @@ PLUGINS=["src2.plugins.chat"]
|
|||||||
MONGODB_HOST=127.0.0.1
|
MONGODB_HOST=127.0.0.1
|
||||||
MONGODB_PORT=27017
|
MONGODB_PORT=27017
|
||||||
DATABASE_NAME=MegBot
|
DATABASE_NAME=MegBot
|
||||||
|
|
||||||
MONGODB_USERNAME = "" # 默认空值
|
MONGODB_USERNAME = "" # 默认空值
|
||||||
MONGODB_PASSWORD = "" # 默认空值
|
MONGODB_PASSWORD = "" # 默认空值
|
||||||
MONGODB_AUTH_SOURCE = "" # 默认空值
|
MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||||
|
|
||||||
#api配置项
|
#key and url
|
||||||
|
CHAT_ANY_WHERE_KEY=
|
||||||
SILICONFLOW_KEY=
|
SILICONFLOW_KEY=
|
||||||
|
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
DEEP_SEEK_KEY=
|
DEEP_SEEK_KEY=
|
||||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -14,6 +14,7 @@ reasoning_content.bat
|
|||||||
reasoning_window.bat
|
reasoning_window.bat
|
||||||
queue_update.txt
|
queue_update.txt
|
||||||
memory_graph.gml
|
memory_graph.gml
|
||||||
|
.env.dev
|
||||||
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
|||||||
@@ -28,7 +28,7 @@
|
|||||||
> ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次
|
> ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次
|
||||||
> ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语(
|
> ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语(
|
||||||
|
|
||||||
关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言!
|
关于麦麦的开发和建议相关的讨论群:766798517(不建议发布无关消息)这里不会有麦麦发言!
|
||||||
|
|
||||||
## 开发计划TODO:LIST
|
## 开发计划TODO:LIST
|
||||||
|
|
||||||
@@ -41,16 +41,13 @@
|
|||||||
- config自动生成和检测
|
- config自动生成和检测
|
||||||
- log别用print
|
- log别用print
|
||||||
- 给发送消息写专门的类
|
- 给发送消息写专门的类
|
||||||
|
- 改进表情包发送逻辑l
|
||||||
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<img src="docs/qq.png" width="300" />
|
|
||||||
</div>
|
|
||||||
|
|
||||||
## 📚 详细文档
|
## 📚 详细文档
|
||||||
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
|
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
|
||||||
|
|
||||||
### 安装方法(还没测试好,现在部署可能遇到未知问题!!!!)
|
### 安装方法(还没测试好,随时outdated ,现在部署可能遇到未知问题!!!!)
|
||||||
|
|
||||||
#### Linux 使用 Docker Compose 部署
|
#### Linux 使用 Docker Compose 部署
|
||||||
获取项目根目录中的```docker-compose.yml```文件,运行以下命令
|
获取项目根目录中的```docker-compose.yml```文件,运行以下命令
|
||||||
|
|||||||
64
bot.py
64
bot.py
@@ -4,27 +4,59 @@ from nonebot.adapters.onebot.v11 import Adapter
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# 加载全局环境变量
|
'''彩蛋'''
|
||||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
from colorama import init, Fore
|
||||||
env_path=os.path.join(root_dir, "config",'.env')
|
init()
|
||||||
|
text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
|
||||||
|
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
|
||||||
|
rainbow_text = ""
|
||||||
|
for i, char in enumerate(text):
|
||||||
|
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
|
||||||
|
print(rainbow_text)
|
||||||
|
'''彩蛋'''
|
||||||
|
|
||||||
logger.info(f"尝试从 {env_path} 加载环境变量配置")
|
# 首先加载基础环境变量
|
||||||
if os.path.exists(env_path):
|
if os.path.exists(".env"):
|
||||||
load_dotenv(env_path)
|
load_dotenv(".env")
|
||||||
logger.success("成功加载环境变量配置")
|
logger.success("成功加载基础环境变量配置")
|
||||||
else:
|
else:
|
||||||
logger.error(f"环境变量配置文件不存在: {env_path}")
|
logger.error("基础环境变量配置文件 .env 不存在")
|
||||||
|
exit(1)
|
||||||
|
# 根据 ENVIRONMENT 加载对应的环境配置
|
||||||
|
env = os.getenv("ENVIRONMENT")
|
||||||
|
env_file = f".env.{env}"
|
||||||
|
|
||||||
|
if env_file == ".env.dev" and os.path.exists(env_file):
|
||||||
|
logger.success("加载开发环境变量配置")
|
||||||
|
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
|
elif env_file == ".env.prod" and os.path.exists(env_file):
|
||||||
|
logger.success("加载环境变量配置")
|
||||||
|
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
|
else:
|
||||||
|
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
# 初始化 NoneBot
|
|
||||||
nonebot.init(
|
nonebot.init(
|
||||||
# napcat 默认使用 8080 端口
|
# 从环境变量中读取配置
|
||||||
websocket_port=8080,
|
websocket_port=os.getenv("PORT", 8080),
|
||||||
# 设置日志级别
|
host=os.getenv("HOST", "127.0.0.1"),
|
||||||
log_level="INFO",
|
log_level="INFO",
|
||||||
# 设置超级用户
|
# 添加自定义配置
|
||||||
superusers={"你的QQ号"},
|
mongodb_host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||||
# TODO: 这样写会忽略环境变量需要优化 https://nonebot.dev/docs/appendices/config
|
mongodb_port=os.getenv("MONGODB_PORT", 27017),
|
||||||
_env_file=env_path
|
database_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||||
|
mongodb_username=os.getenv("MONGODB_USERNAME", ""),
|
||||||
|
mongodb_password=os.getenv("MONGODB_PASSWORD", ""),
|
||||||
|
mongodb_auth_source=os.getenv("MONGODB_AUTH_SOURCE", ""),
|
||||||
|
# API相关配置
|
||||||
|
chat_any_where_key=os.getenv("CHAT_ANY_WHERE_KEY", ""),
|
||||||
|
siliconflow_key=os.getenv("SILICONFLOW_KEY", ""),
|
||||||
|
chat_any_where_base_url=os.getenv("CHAT_ANY_WHERE_BASE_URL", "https://api.chatanywhere.tech/v1"),
|
||||||
|
siliconflow_base_url=os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1/"),
|
||||||
|
deep_seek_key=os.getenv("DEEP_SEEK_KEY", ""),
|
||||||
|
deep_seek_base_url=os.getenv("DEEP_SEEK_BASE_URL", "https://api.deepseek.com/v1"),
|
||||||
|
# 插件配置
|
||||||
|
plugins=os.getenv("PLUGINS", ["src2.plugins.chat"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册适配器
|
# 注册适配器
|
||||||
|
|||||||
@@ -3,3 +3,4 @@ cd .
|
|||||||
|
|
||||||
REM 执行nb run命令
|
REM 执行nb run命令
|
||||||
nb run
|
nb run
|
||||||
|
pause
|
||||||
@@ -11,16 +11,18 @@ from .relationship_manager import relationship_manager
|
|||||||
from ..schedule.schedule_generator import bot_schedule
|
from ..schedule.schedule_generator import bot_schedule
|
||||||
from .willing_manager import willing_manager
|
from .willing_manager import willing_manager
|
||||||
|
|
||||||
|
|
||||||
# 获取驱动器
|
# 获取驱动器
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source= config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||||
|
|
||||||
@@ -37,7 +39,7 @@ emoji_manager.initialize()
|
|||||||
|
|
||||||
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
||||||
# 创建机器人实例
|
# 创建机器人实例
|
||||||
chat_bot = ChatBot(global_config)
|
chat_bot = ChatBot()
|
||||||
# 注册消息处理器
|
# 注册消息处理器
|
||||||
group_msg = on_message()
|
group_msg = on_message()
|
||||||
# 创建定时任务
|
# 创建定时任务
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time
|
|||||||
from ..memory_system.memory import memory_graph
|
from ..memory_system.memory import memory_graph
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self, config: BotConfig):
|
def __init__(self):
|
||||||
self.config = config
|
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
self.gpt = LLMResponseGenerator(config)
|
self.gpt = LLMResponseGenerator()
|
||||||
self.bot = None # bot 实例引用
|
self.bot = None # bot 实例引用
|
||||||
self._started = False
|
self._started = False
|
||||||
|
|
||||||
@@ -39,11 +38,11 @@ class ChatBot:
|
|||||||
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
||||||
"""处理收到的群消息"""
|
"""处理收到的群消息"""
|
||||||
|
|
||||||
if event.group_id not in self.config.talk_allowed_groups:
|
if event.group_id not in global_config.talk_allowed_groups:
|
||||||
return
|
return
|
||||||
self.bot = bot # 更新 bot 实例
|
self.bot = bot # 更新 bot 实例
|
||||||
|
|
||||||
if event.user_id in self.config.ban_user_id:
|
if event.user_id in global_config.ban_user_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 打印原始消息内容
|
# 打印原始消息内容
|
||||||
@@ -120,7 +119,7 @@ class ChatBot:
|
|||||||
event.group_id,
|
event.group_id,
|
||||||
topic[0] if topic else None,
|
topic[0] if topic else None,
|
||||||
is_mentioned,
|
is_mentioned,
|
||||||
self.config,
|
global_config,
|
||||||
event.user_id,
|
event.user_id,
|
||||||
message.is_emoji,
|
message.is_emoji,
|
||||||
interested_rate
|
interested_rate
|
||||||
@@ -144,9 +143,13 @@ class ChatBot:
|
|||||||
|
|
||||||
# 如果生成了回复,发送并记录
|
# 如果生成了回复,发送并记录
|
||||||
|
|
||||||
|
'''
|
||||||
|
生成回复后的内容
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id)
|
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
|
||||||
accu_typing_time = 0
|
accu_typing_time = 0
|
||||||
for msg in response:
|
for msg in response:
|
||||||
print(f"当前消息: {msg}")
|
print(f"当前消息: {msg}")
|
||||||
@@ -157,7 +160,7 @@ class ChatBot:
|
|||||||
|
|
||||||
bot_message = Message(
|
bot_message = Message(
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=self.config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
message_based_id=event.message_id,
|
message_based_id=event.message_id,
|
||||||
raw_message=msg,
|
raw_message=msg,
|
||||||
@@ -174,7 +177,7 @@ class ChatBot:
|
|||||||
|
|
||||||
|
|
||||||
bot_response_time = tinking_time_point
|
bot_response_time = tinking_time_point
|
||||||
if random() < self.config.emoji_chance:
|
if random() < global_config.emoji_chance:
|
||||||
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
|
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
|
||||||
if emoji_path:
|
if emoji_path:
|
||||||
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
||||||
@@ -186,7 +189,7 @@ class ChatBot:
|
|||||||
|
|
||||||
bot_message = Message(
|
bot_message = Message(
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=self.config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
message_id=0,
|
message_id=0,
|
||||||
raw_message=emoji_cq,
|
raw_message=emoji_cq,
|
||||||
plain_text=emoji_cq,
|
plain_text=emoji_cq,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import configparser
|
|||||||
import tomli
|
import tomli
|
||||||
import sys
|
import sys
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -111,7 +112,6 @@ class BotConfig:
|
|||||||
# 获取配置文件路径
|
# 获取配置文件路径
|
||||||
bot_config_path = BotConfig.get_default_config_path()
|
bot_config_path = BotConfig.get_default_config_path()
|
||||||
config_dir = os.path.dirname(bot_config_path)
|
config_dir = os.path.dirname(bot_config_path)
|
||||||
env_path = os.path.join(config_dir, '.env')
|
|
||||||
|
|
||||||
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
|
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
|
||||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||||
@@ -126,10 +126,11 @@ class LLMConfig:
|
|||||||
DEEP_SEEK_BASE_URL: str = None
|
DEEP_SEEK_BASE_URL: str = None
|
||||||
|
|
||||||
llm_config = LLMConfig()
|
llm_config = LLMConfig()
|
||||||
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
|
config = get_driver().config
|
||||||
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
|
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
|
||||||
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
|
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
|
||||||
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
|
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
|
||||||
|
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
||||||
|
|
||||||
|
|
||||||
if not global_config.enable_advance_output:
|
if not global_config.enable_advance_output:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from PIL import Image
|
|||||||
import os
|
import os
|
||||||
from random import random
|
from random import random
|
||||||
from nonebot.adapters.onebot.v11 import Bot
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
from .config import global_config, llm_config
|
from .config import global_config
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from .utils_image import storage_image,storage_emoji
|
from .utils_image import storage_image,storage_emoji
|
||||||
@@ -16,6 +16,10 @@ from .utils_user import get_user_nickname
|
|||||||
#包含CQ码类
|
#包含CQ码类
|
||||||
import urllib3
|
import urllib3
|
||||||
from urllib3.util import create_urllib3_context
|
from urllib3.util import create_urllib3_context
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
|
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
|
||||||
ctx = create_urllib3_context()
|
ctx = create_urllib3_context()
|
||||||
@@ -179,7 +183,7 @@ class CQCode:
|
|||||||
"""调用AI接口获取表情包描述"""
|
"""调用AI接口获取表情包描述"""
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
"Authorization": f"Bearer {config.siliconflow_key}"
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -206,7 +210,7 @@ class CQCode:
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
f"{config.siliconflow_base_url}chat/completions",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=30
|
timeout=30
|
||||||
@@ -224,7 +228,7 @@ class CQCode:
|
|||||||
"""调用AI接口获取普通图片描述"""
|
"""调用AI接口获取普通图片描述"""
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
"Authorization": f"Bearer {config.siliconflow_key}"
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -251,7 +255,7 @@ class CQCode:
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
f"{config.siliconflow_base_url}chat/completions",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=30
|
timeout=30
|
||||||
|
|||||||
@@ -10,10 +10,14 @@ import hashlib
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import base64
|
import base64
|
||||||
import shutil
|
import shutil
|
||||||
from .config import global_config, llm_config
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
class EmojiManager:
|
class EmojiManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -93,7 +97,7 @@ class EmojiManager:
|
|||||||
# 准备请求数据
|
# 准备请求数据
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
"Authorization": f"Bearer {config.siliconflow_key}"
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -115,7 +119,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
f"{config.siliconflow_base_url}chat/completions",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
) as response:
|
) as response:
|
||||||
@@ -249,7 +253,7 @@ class EmojiManager:
|
|||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
"Authorization": f"Bearer {config.siliconflow_key}"
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -276,7 +280,7 @@ class EmojiManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
f"{config.siliconflow_base_url}chat/completions",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
) as response:
|
) as response:
|
||||||
|
|||||||
@@ -1,34 +1,34 @@
|
|||||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from .message import Message
|
from .message import Message
|
||||||
from .config import BotConfig, global_config
|
from .config import global_config
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import os
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from ..schedule.schedule_generator import bot_schedule
|
|
||||||
from .prompt_builder import prompt_builder
|
from .prompt_builder import prompt_builder
|
||||||
from .config import llm_config, global_config
|
from .config import global_config
|
||||||
from .utils import process_llm_response
|
from .utils import process_llm_response
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
class LLMResponseGenerator:
|
class LLMResponseGenerator:
|
||||||
def __init__(self, config: BotConfig):
|
def __init__(self):
|
||||||
self.config = config
|
if global_config.API_USING == "siliconflow":
|
||||||
if self.config.API_USING == "siliconflow":
|
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
api_key=config.siliconflow_key,
|
||||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
base_url=config.siliconflow_base_url
|
||||||
)
|
)
|
||||||
elif self.config.API_USING == "deepseek":
|
elif global_config.API_USING == "deepseek":
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=llm_config.DEEP_SEEK_API_KEY,
|
api_key=config.deep_seek_key,
|
||||||
base_url=llm_config.DEEP_SEEK_BASE_URL
|
base_url=config.deep_seek_base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
@@ -52,6 +52,7 @@ class LLMResponseGenerator:
|
|||||||
else:
|
else:
|
||||||
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
|
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
|
||||||
|
|
||||||
|
|
||||||
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
||||||
if self.current_model_type == 'r1':
|
if self.current_model_type == 'r1':
|
||||||
model_response = await self._generate_r1_response(message)
|
model_response = await self._generate_r1_response(message)
|
||||||
@@ -84,7 +85,8 @@ class LLMResponseGenerator:
|
|||||||
else:
|
else:
|
||||||
relationship_value = 0.0
|
relationship_value = 0.0
|
||||||
|
|
||||||
# 构建prompt
|
|
||||||
|
''' 构建prompt '''
|
||||||
prompt,prompt_check = prompt_builder._build_prompt(
|
prompt,prompt_check = prompt_builder._build_prompt(
|
||||||
message_txt=message.processed_plain_text,
|
message_txt=message.processed_plain_text,
|
||||||
sender_name=sender_name,
|
sender_name=sender_name,
|
||||||
@@ -92,6 +94,7 @@ class LLMResponseGenerator:
|
|||||||
group_id=message.group_id
|
group_id=message.group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 设置默认参数
|
# 设置默认参数
|
||||||
default_params = {
|
default_params = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
@@ -113,6 +116,7 @@ class LLMResponseGenerator:
|
|||||||
if model_params:
|
if model_params:
|
||||||
default_params.update(model_params)
|
default_params.update(model_params)
|
||||||
|
|
||||||
|
|
||||||
def create_completion():
|
def create_completion():
|
||||||
return self.client.chat.completions.create(**default_params)
|
return self.client.chat.completions.create(**default_params)
|
||||||
|
|
||||||
@@ -122,6 +126,7 @@ class LLMResponseGenerator:
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# 读空气模块
|
# 读空气模块
|
||||||
|
air = 0
|
||||||
reasoning_content_check=''
|
reasoning_content_check=''
|
||||||
content_check=''
|
content_check=''
|
||||||
if global_config.enable_kuuki_read:
|
if global_config.enable_kuuki_read:
|
||||||
@@ -135,21 +140,26 @@ class LLMResponseGenerator:
|
|||||||
content_check = response_check.choices[0].message.content
|
content_check = response_check.choices[0].message.content
|
||||||
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
|
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
|
||||||
if 'yes' not in content_check.lower():
|
if 'yes' not in content_check.lower():
|
||||||
self.db.db.reasoning_logs.insert_one({
|
air = 1
|
||||||
'time': time.time(),
|
#稀释读空气的判定
|
||||||
'group_id': message.group_id,
|
if air == 1 and random.random() < 0.3:
|
||||||
'user': sender_name,
|
self.db.db.reasoning_logs.insert_one({
|
||||||
'message': message.processed_plain_text,
|
'time': time.time(),
|
||||||
'model': model_name,
|
'group_id': message.group_id,
|
||||||
'reasoning_check': reasoning_content_check,
|
'user': sender_name,
|
||||||
'response_check': content_check,
|
'message': message.processed_plain_text,
|
||||||
'reasoning': "",
|
'model': model_name,
|
||||||
'response': "",
|
'reasoning_check': reasoning_content_check,
|
||||||
'prompt': prompt,
|
'response_check': content_check,
|
||||||
'prompt_check': prompt_check,
|
'reasoning': "",
|
||||||
'model_params': default_params
|
'response': "",
|
||||||
})
|
'prompt': prompt,
|
||||||
return None
|
'prompt_check': prompt_check,
|
||||||
|
'model_params': default_params
|
||||||
|
})
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -193,7 +203,7 @@ class LLMResponseGenerator:
|
|||||||
|
|
||||||
async def _generate_r1_response(self, message: Message) -> Optional[str]:
|
async def _generate_r1_response(self, message: Message) -> Optional[str]:
|
||||||
"""使用 DeepSeek-R1 模型生成回复"""
|
"""使用 DeepSeek-R1 模型生成回复"""
|
||||||
if self.config.API_USING == "deepseek":
|
if global_config.API_USING == "deepseek":
|
||||||
return await self._generate_base_response(
|
return await self._generate_base_response(
|
||||||
message,
|
message,
|
||||||
"deepseek-reasoner",
|
"deepseek-reasoner",
|
||||||
@@ -208,7 +218,7 @@ class LLMResponseGenerator:
|
|||||||
|
|
||||||
async def _generate_v3_response(self, message: Message) -> Optional[str]:
|
async def _generate_v3_response(self, message: Message) -> Optional[str]:
|
||||||
"""使用 DeepSeek-V3 模型生成回复"""
|
"""使用 DeepSeek-V3 模型生成回复"""
|
||||||
if self.config.API_USING == "deepseek":
|
if global_config.API_USING == "deepseek":
|
||||||
return await self._generate_base_response(
|
return await self._generate_base_response(
|
||||||
message,
|
message,
|
||||||
"deepseek-chat",
|
"deepseek-chat",
|
||||||
@@ -259,7 +269,7 @@ class LLMResponseGenerator:
|
|||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if self.config.API_USING == "deepseek":
|
if global_config.API_USING == "deepseek":
|
||||||
model = "deepseek-chat"
|
model = "deepseek-chat"
|
||||||
else:
|
else:
|
||||||
model = "Pro/deepseek-ai/DeepSeek-V3"
|
model = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
@@ -296,4 +306,4 @@ class LLMResponseGenerator:
|
|||||||
return processed_response, emotion_tags
|
return processed_response, emotion_tags
|
||||||
|
|
||||||
# 创建全局实例
|
# 创建全局实例
|
||||||
llm_response = LLMResponseGenerator(global_config)
|
llm_response = LLMResponseGenerator()
|
||||||
@@ -66,12 +66,15 @@ class PromptBuilder:
|
|||||||
overlapping_second_layer.update(overlap)
|
overlapping_second_layer.update(overlap)
|
||||||
|
|
||||||
# 合并所有需要的记忆
|
# 合并所有需要的记忆
|
||||||
if all_first_layer_items:
|
# if all_first_layer_items:
|
||||||
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
||||||
if overlapping_second_layer:
|
# if overlapping_second_layer:
|
||||||
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
||||||
|
|
||||||
all_memories = all_first_layer_items + list(overlapping_second_layer)
|
# 使用集合去重
|
||||||
|
all_memories = list(set(all_first_layer_items) | set(overlapping_second_layer))
|
||||||
|
if all_memories:
|
||||||
|
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
|
||||||
|
|
||||||
if all_memories: # 只在列表非空时选择随机项
|
if all_memories: # 只在列表非空时选择随机项
|
||||||
random_item = choice(all_memories)
|
random_item = choice(all_memories)
|
||||||
@@ -181,6 +184,10 @@ class PromptBuilder:
|
|||||||
prompt += f"{prompt_ger}\n"
|
prompt += f"{prompt_ger}\n"
|
||||||
prompt += f"{extra_info}\n"
|
prompt += f"{extra_info}\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''读空气prompt处理'''
|
||||||
|
|
||||||
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
|
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
|
||||||
prompt_personality_check = ''
|
prompt_personality_check = ''
|
||||||
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。"
|
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。"
|
||||||
|
|||||||
@@ -1,14 +1,17 @@
|
|||||||
from typing import Optional, Dict, List
|
from typing import Optional, Dict, List
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from .message import Message
|
from .message import Message
|
||||||
from .config import global_config, llm_config
|
|
||||||
import jieba
|
import jieba
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
class TopicIdentifier:
|
class TopicIdentifier:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
api_key=config.siliconflow_key,
|
||||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
base_url=config.siliconflow_base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
def identify_topic_llm(self, text: str) -> Optional[str]:
|
def identify_topic_llm(self, text: str) -> Optional[str]:
|
||||||
|
|||||||
@@ -4,11 +4,15 @@ from typing import List
|
|||||||
from .message import Message
|
from .message import Message
|
||||||
import requests
|
import requests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .config import llm_config, global_config
|
from .config import global_config
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import math
|
import math
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
def combine_messages(messages: List[Message]) -> str:
|
def combine_messages(messages: List[Message]) -> str:
|
||||||
@@ -64,7 +68,7 @@ def get_embedding(text):
|
|||||||
"encoding_format": "float"
|
"encoding_format": "float"
|
||||||
}
|
}
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
|
"Authorization": f"Bearer {config.siliconflow_key}",
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ from ...common.database import Database
|
|||||||
import zlib # 用于 CRC32
|
import zlib # 用于 CRC32
|
||||||
import base64
|
import base64
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
||||||
@@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
|||||||
|
|
||||||
# 连接数据库
|
# 连接数据库
|
||||||
db = Database(
|
db = Database(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否已存在相同哈希值的图片
|
# 检查是否已存在相同哈希值的图片
|
||||||
|
|||||||
@@ -58,8 +58,8 @@ class WillingManager:
|
|||||||
if group_id in config.talk_frequency_down_groups:
|
if group_id in config.talk_frequency_down_groups:
|
||||||
reply_probability = reply_probability / 3.5
|
reply_probability = reply_probability / 3.5
|
||||||
|
|
||||||
if is_mentioned_bot and user_id == int(964959351):
|
# if is_mentioned_bot and user_id == int(1026294844):
|
||||||
reply_probability = 1
|
# reply_probability = 1
|
||||||
|
|
||||||
return reply_probability
|
return reply_probability
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,10 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
# 添加项目根目录到 Python 路径
|
# 添加项目根目录到 Python 路径
|
||||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
@@ -13,12 +17,12 @@ from src.plugins.chat.config import llm_config
|
|||||||
|
|
||||||
# 直接配置数据库连接信息
|
# 直接配置数据库连接信息
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
|
|
||||||
class KnowledgeLibrary:
|
class KnowledgeLibrary:
|
||||||
|
|||||||
@@ -168,10 +168,12 @@ def main():
|
|||||||
memory_graph.load_graph_from_db()
|
memory_graph.load_graph_from_db()
|
||||||
# 展示两种不同的可视化方式
|
# 展示两种不同的可视化方式
|
||||||
print("\n按连接数量着色的图谱:")
|
print("\n按连接数量着色的图谱:")
|
||||||
visualize_graph(memory_graph, color_by_memory=False)
|
# visualize_graph(memory_graph, color_by_memory=False)
|
||||||
|
visualize_graph_lite(memory_graph, color_by_memory=False)
|
||||||
|
|
||||||
print("\n按记忆数量着色的图谱:")
|
print("\n按记忆数量着色的图谱:")
|
||||||
visualize_graph(memory_graph, color_by_memory=True)
|
# visualize_graph(memory_graph, color_by_memory=True)
|
||||||
|
visualize_graph_lite(memory_graph, color_by_memory=True)
|
||||||
|
|
||||||
# memory_graph.save_graph_to_db()
|
# memory_graph.save_graph_to_db()
|
||||||
|
|
||||||
@@ -262,7 +264,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||||
|
# 设置中文字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||||
|
|
||||||
|
G = memory_graph.G
|
||||||
|
|
||||||
|
# 创建一个新图用于可视化
|
||||||
|
H = G.copy()
|
||||||
|
|
||||||
|
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||||
|
nodes_to_remove = []
|
||||||
|
for node in H.nodes():
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
|
degree = H.degree(node)
|
||||||
|
if memory_count <= 2 or degree <= 2:
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
H.remove_nodes_from(nodes_to_remove)
|
||||||
|
|
||||||
|
# 如果过滤后没有节点,则返回
|
||||||
|
if len(H.nodes()) == 0:
|
||||||
|
print("过滤后没有符合条件的节点可显示")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 保存图到本地
|
||||||
|
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||||
|
|
||||||
|
# 根据连接条数或记忆数量设置节点颜色
|
||||||
|
node_colors = []
|
||||||
|
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||||
|
|
||||||
|
if color_by_memory:
|
||||||
|
# 计算每个节点的记忆数量
|
||||||
|
memory_counts = []
|
||||||
|
for node in nodes:
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
count = len(memory_items)
|
||||||
|
else:
|
||||||
|
count = 1 if memory_items else 0
|
||||||
|
memory_counts.append(count)
|
||||||
|
max_memories = max(memory_counts) if memory_counts else 1
|
||||||
|
|
||||||
|
for count in memory_counts:
|
||||||
|
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||||
|
if max_memories > 0:
|
||||||
|
intensity = min(1.0, count / max_memories)
|
||||||
|
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||||
|
node_colors.append(color)
|
||||||
|
else:
|
||||||
|
# 使用原来的连接数量着色方案
|
||||||
|
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||||
|
for node in nodes:
|
||||||
|
degree = H.degree(node)
|
||||||
|
if max_degree > 0:
|
||||||
|
red = min(1.0, degree / max_degree)
|
||||||
|
blue = 1.0 - red
|
||||||
|
color = (red, 0, blue)
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1)
|
||||||
|
node_colors.append(color)
|
||||||
|
|
||||||
|
# 绘制图形
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||||
|
nx.draw(H, pos,
|
||||||
|
with_labels=True,
|
||||||
|
node_color=node_colors,
|
||||||
|
node_size=2000,
|
||||||
|
font_size=10,
|
||||||
|
font_family='SimHei',
|
||||||
|
font_weight='bold')
|
||||||
|
|
||||||
|
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||||
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,18 @@ import os
|
|||||||
import requests
|
import requests
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
import time
|
import time
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
class LLMModel:
|
class LLMModel:
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
self.api_key = config.siliconflow_key
|
||||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
self.base_url = config.siliconflow_base_url
|
||||||
|
|
||||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||||
"""根据输入的提示生成模型的响应"""
|
"""根据输入的提示生成模型的响应"""
|
||||||
|
|||||||
@@ -3,14 +3,18 @@ import requests
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
import time
|
import time
|
||||||
from ..chat.config import BotConfig
|
from ..chat.config import BotConfig
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
class LLMModel:
|
class LLMModel:
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
self.api_key = config.siliconflow_key
|
||||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
self.base_url = config.siliconflow_base_url
|
||||||
|
|
||||||
if not self.api_key or not self.base_url:
|
if not self.api_key or not self.base_url:
|
||||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||||
|
|||||||
@@ -198,8 +198,6 @@ class Hippocampus:
|
|||||||
time_frequency = {'near':1,'mid':2,'far':2}
|
time_frequency = {'near':1,'mid':2,'far':2}
|
||||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||||
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
|
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
|
||||||
|
|
||||||
|
|
||||||
for i, input_text in enumerate(memory_sample, 1):
|
for i, input_text in enumerate(memory_sample, 1):
|
||||||
#加载进度可视化
|
#加载进度可视化
|
||||||
progress = (i / len(memory_sample)) * 100
|
progress = (i / len(memory_sample)) * 100
|
||||||
@@ -207,24 +205,25 @@ class Hippocampus:
|
|||||||
filled_length = int(bar_length * i // len(memory_sample))
|
filled_length = int(bar_length * i // len(memory_sample))
|
||||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||||
|
if input_text:
|
||||||
# 生成压缩后记忆
|
# 生成压缩后记忆
|
||||||
first_memory = set()
|
first_memory = set()
|
||||||
first_memory = self.memory_compress(input_text, 2.5)
|
first_memory = self.memory_compress(input_text, 2.5)
|
||||||
# 延时防止访问超频
|
# 延时防止访问超频
|
||||||
# time.sleep(5)
|
# time.sleep(5)
|
||||||
#将记忆加入到图谱中
|
#将记忆加入到图谱中
|
||||||
for topic, memory in first_memory:
|
for topic, memory in first_memory:
|
||||||
topics = segment_text(topic)
|
topics = segment_text(topic)
|
||||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||||
for split_topic in topics:
|
for split_topic in topics:
|
||||||
self.memory_graph.add_dot(split_topic,memory)
|
self.memory_graph.add_dot(split_topic,memory)
|
||||||
for split_topic in topics:
|
for split_topic in topics:
|
||||||
for other_split_topic in topics:
|
for other_split_topic in topics:
|
||||||
if split_topic != other_split_topic:
|
if split_topic != other_split_topic:
|
||||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||||
|
else:
|
||||||
self.memory_graph.save_graph_to_db()
|
print(f"空消息 跳过")
|
||||||
|
self.memory_graph.save_graph_to_db()
|
||||||
|
|
||||||
def memory_compress(self, input_text, rate=1):
|
def memory_compress(self, input_text, rate=1):
|
||||||
information_content = calculate_information_content(input_text)
|
information_content = calculate_information_content(input_text)
|
||||||
@@ -260,16 +259,19 @@ def topic_what(text, topic):
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
from nonebot import get_driver
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
#创建记忆图
|
#创建记忆图
|
||||||
memory_graph = Memory_graph()
|
memory_graph = Memory_graph()
|
||||||
|
|||||||
@@ -14,6 +14,37 @@ sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
|||||||
from src.common.database import Database # 使用正确的导入语法
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
from src.plugins.memory_system.llm_module import LLMModel
|
from src.plugins.memory_system.llm_module import LLMModel
|
||||||
|
|
||||||
|
def calculate_information_content(text):
|
||||||
|
"""计算文本的信息量(熵)"""
|
||||||
|
# 统计字符频率
|
||||||
|
char_count = Counter(text)
|
||||||
|
total_chars = len(text)
|
||||||
|
|
||||||
|
# 计算熵
|
||||||
|
entropy = 0
|
||||||
|
for count in char_count.values():
|
||||||
|
probability = count / total_chars
|
||||||
|
entropy -= probability * math.log2(probability)
|
||||||
|
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||||
|
"""从数据库中获取最接近指定时间戳的聊天记录"""
|
||||||
|
chat_text = ''
|
||||||
|
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||||
|
|
||||||
|
if closest_record:
|
||||||
|
closest_time = closest_record['time']
|
||||||
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
|
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||||
|
for record in chat_record:
|
||||||
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
|
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
|
||||||
|
return chat_text
|
||||||
|
|
||||||
|
return ''
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
@@ -102,7 +133,8 @@ class Memory_graph:
|
|||||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||||
chat_text = ''
|
chat_text = ''
|
||||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||||
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
|
||||||
|
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||||
|
|
||||||
if closest_record:
|
if closest_record:
|
||||||
closest_time = closest_record['time']
|
closest_time = closest_record['time']
|
||||||
@@ -110,8 +142,9 @@ class Memory_graph:
|
|||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||||
for record in chat_record:
|
for record in chat_record:
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
if record:
|
||||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
|
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||||
return chat_text
|
return chat_text
|
||||||
|
|
||||||
return [] # 如果没有找到记录,返回空列表
|
return [] # 如果没有找到记录,返回空列表
|
||||||
@@ -187,155 +220,80 @@ class Memory_graph:
|
|||||||
for edge in edges:
|
for edge in edges:
|
||||||
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
||||||
|
|
||||||
def calculate_information_content(text):
|
# 海马体
|
||||||
|
class Hippocampus:
|
||||||
|
def __init__(self,memory_graph:Memory_graph):
|
||||||
|
self.memory_graph = memory_graph
|
||||||
|
self.llm_model = LLMModel()
|
||||||
|
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||||
|
|
||||||
"""计算文本的信息量(熵)"""
|
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||||
# 统计字符频率
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
char_count = Counter(text)
|
chat_text = []
|
||||||
total_chars = len(text)
|
#短期:1h 中期:4h 长期:24h
|
||||||
|
for _ in range(time_frequency.get('near')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
for _ in range(time_frequency.get('far')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
return chat_text
|
||||||
|
|
||||||
# 计算熵
|
def build_memory(self,chat_size=12):
|
||||||
entropy = 0
|
#最近消息获取频率
|
||||||
for count in char_count.values():
|
time_frequency = {'near':1,'mid':2,'far':2}
|
||||||
probability = count / total_chars
|
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||||
entropy -= probability * math.log2(probability)
|
|
||||||
|
|
||||||
return entropy
|
#加载进度可视化
|
||||||
|
for i, input_text in enumerate(memory_sample, 1):
|
||||||
|
progress = (i / len(memory_sample)) * 100
|
||||||
|
bar_length = 30
|
||||||
|
filled_length = int(bar_length * i // len(memory_sample))
|
||||||
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
|
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||||
|
# print(f"第{i}条消息: {input_text}")
|
||||||
|
if input_text:
|
||||||
|
# 生成压缩后记忆
|
||||||
|
first_memory = set()
|
||||||
|
first_memory = self.memory_compress(input_text, 2.5)
|
||||||
|
#将记忆加入到图谱中
|
||||||
|
for topic, memory in first_memory:
|
||||||
|
topics = segment_text(topic)
|
||||||
|
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||||
|
for split_topic in topics:
|
||||||
|
self.memory_graph.add_dot(split_topic,memory)
|
||||||
|
for split_topic in topics:
|
||||||
|
for other_split_topic in topics:
|
||||||
|
if split_topic != other_split_topic:
|
||||||
|
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||||
|
else:
|
||||||
|
print(f"空消息 跳过")
|
||||||
|
|
||||||
|
self.memory_graph.save_graph_to_db()
|
||||||
|
|
||||||
# Database.initialize(
|
def memory_compress(self, input_text, rate=1):
|
||||||
# global_config.MONGODB_HOST,
|
information_content = calculate_information_content(input_text)
|
||||||
# global_config.MONGODB_PORT,
|
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||||
# global_config.DATABASE_NAME
|
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||||
# )
|
topic_prompt = find_topic(input_text, topic_num)
|
||||||
# memory_graph = Memory_graph()
|
topic_response = self.llm_model.generate_response(topic_prompt)
|
||||||
|
|
||||||
# llm_model = LLMModel()
|
|
||||||
# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
|
||||||
|
|
||||||
# memory_graph.load_graph_from_db()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 初始化数据库
|
|
||||||
Database.initialize(
|
|
||||||
host= os.getenv("MONGODB_HOST"),
|
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
|
||||||
)
|
|
||||||
|
|
||||||
memory_graph = Memory_graph()
|
|
||||||
# 创建LLM模型实例
|
|
||||||
llm_model = LLMModel()
|
|
||||||
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
|
||||||
|
|
||||||
# 使用当前时间戳进行测试
|
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
|
||||||
chat_text = []
|
|
||||||
|
|
||||||
chat_size =25
|
|
||||||
|
|
||||||
for _ in range(30): # 循环10次
|
|
||||||
random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
|
|
||||||
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
|
||||||
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
|
||||||
chat_text.append(chat_) # 拼接所有text
|
|
||||||
# time.sleep(1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for i, input_text in enumerate(chat_text, 1):
|
|
||||||
|
|
||||||
progress = (i / len(chat_text)) * 100
|
|
||||||
bar_length = 30
|
|
||||||
filled_length = int(bar_length * i // len(chat_text))
|
|
||||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
|
||||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})")
|
|
||||||
|
|
||||||
# print(input_text)
|
|
||||||
first_memory = set()
|
|
||||||
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
|
|
||||||
# time.sleep(5)
|
|
||||||
|
|
||||||
#将记忆加入到图谱中
|
|
||||||
for topic, memory in first_memory:
|
|
||||||
topics = segment_text(topic)
|
|
||||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
|
||||||
for split_topic in topics:
|
|
||||||
memory_graph.add_dot(split_topic,memory)
|
|
||||||
for split_topic in topics:
|
|
||||||
for other_split_topic in topics:
|
|
||||||
if split_topic != other_split_topic:
|
|
||||||
memory_graph.connect_dot(split_topic, other_split_topic)
|
|
||||||
|
|
||||||
# memory_graph.store_memory()
|
|
||||||
|
|
||||||
# 展示两种不同的可视化方式
|
|
||||||
print("\n按连接数量着色的图谱:")
|
|
||||||
visualize_graph(memory_graph, color_by_memory=False)
|
|
||||||
|
|
||||||
print("\n按记忆数量着色的图谱:")
|
|
||||||
visualize_graph(memory_graph, color_by_memory=True)
|
|
||||||
|
|
||||||
memory_graph.save_graph_to_db()
|
|
||||||
# memory_graph.load_graph_from_db()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
|
||||||
if query.lower() == '退出':
|
|
||||||
break
|
|
||||||
items_list = memory_graph.get_related_item(query)
|
|
||||||
if items_list:
|
|
||||||
# print(items_list)
|
|
||||||
for memory_item in items_list:
|
|
||||||
print(memory_item)
|
|
||||||
else:
|
|
||||||
print("未找到相关记忆。")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
query = input("请输入问题:")
|
|
||||||
|
|
||||||
if query.lower() == '退出':
|
|
||||||
break
|
|
||||||
|
|
||||||
topic_prompt = find_topic(query, 3)
|
|
||||||
topic_response = llm_model.generate_response(topic_prompt)
|
|
||||||
# 检查 topic_response 是否为元组
|
# 检查 topic_response 是否为元组
|
||||||
if isinstance(topic_response, tuple):
|
if isinstance(topic_response, tuple):
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
else:
|
else:
|
||||||
topics = topic_response.split(",")
|
topics = topic_response.split(",")
|
||||||
print(topics)
|
compressed_memory = set()
|
||||||
|
for topic in topics:
|
||||||
for keyword in topics:
|
topic_what_prompt = topic_what(input_text,topic)
|
||||||
items_list = memory_graph.get_related_item(keyword)
|
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
|
||||||
if items_list:
|
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||||
print(items_list)
|
return compressed_memory
|
||||||
|
|
||||||
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
|
|
||||||
information_content = calculate_information_content(input_text)
|
|
||||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
|
||||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
|
||||||
print(topic_num)
|
|
||||||
topic_prompt = find_topic(input_text, topic_num)
|
|
||||||
topic_response = llm_model.generate_response(topic_prompt)
|
|
||||||
# 检查 topic_response 是否为元组
|
|
||||||
if isinstance(topic_response, tuple):
|
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
|
||||||
else:
|
|
||||||
topics = topic_response.split(",")
|
|
||||||
print(topics)
|
|
||||||
compressed_memory = set()
|
|
||||||
for topic in topics:
|
|
||||||
topic_what_prompt = topic_what(input_text,topic)
|
|
||||||
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
|
|
||||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
|
||||||
return compressed_memory
|
|
||||||
|
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
seg_text = list(jieba.cut(text))
|
seg_text = list(jieba.cut(text))
|
||||||
@@ -356,18 +314,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
|
|
||||||
G = memory_graph.G
|
G = memory_graph.G
|
||||||
|
|
||||||
|
# 创建一个新图用于可视化
|
||||||
|
H = G.copy()
|
||||||
|
|
||||||
|
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||||
|
nodes_to_remove = []
|
||||||
|
for node in H.nodes():
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
|
degree = H.degree(node)
|
||||||
|
if memory_count <= 1 or degree <= 2:
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
H.remove_nodes_from(nodes_to_remove)
|
||||||
|
|
||||||
|
# 如果过滤后没有节点,则返回
|
||||||
|
if len(H.nodes()) == 0:
|
||||||
|
print("过滤后没有符合条件的节点可显示")
|
||||||
|
return
|
||||||
|
|
||||||
# 保存图到本地
|
# 保存图到本地
|
||||||
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||||
|
|
||||||
# 根据连接条数或记忆数量设置节点颜色
|
# 根据连接条数或记忆数量设置节点颜色
|
||||||
node_colors = []
|
node_colors = []
|
||||||
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||||
|
|
||||||
if color_by_memory:
|
if color_by_memory:
|
||||||
# 计算每个节点的记忆数量
|
# 计算每个节点的记忆数量
|
||||||
memory_counts = []
|
memory_counts = []
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
memory_items = G.nodes[node].get('memory_items', [])
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
if isinstance(memory_items, list):
|
if isinstance(memory_items, list):
|
||||||
count = len(memory_items)
|
count = len(memory_items)
|
||||||
else:
|
else:
|
||||||
@@ -385,9 +362,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
node_colors.append(color)
|
node_colors.append(color)
|
||||||
else:
|
else:
|
||||||
# 使用原来的连接数量着色方案
|
# 使用原来的连接数量着色方案
|
||||||
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
|
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
degree = G.degree(node)
|
degree = H.degree(node)
|
||||||
if max_degree > 0:
|
if max_degree > 0:
|
||||||
red = min(1.0, degree / max_degree)
|
red = min(1.0, degree / max_degree)
|
||||||
blue = 1.0 - red
|
blue = 1.0 - red
|
||||||
@@ -398,8 +375,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
|
|
||||||
# 绘制图形
|
# 绘制图形
|
||||||
plt.figure(figsize=(12, 8))
|
plt.figure(figsize=(12, 8))
|
||||||
pos = nx.spring_layout(G, k=1, iterations=50)
|
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||||
nx.draw(G, pos,
|
nx.draw(H, pos,
|
||||||
with_labels=True,
|
with_labels=True,
|
||||||
node_color=node_colors,
|
node_color=node_colors,
|
||||||
node_size=2000,
|
node_size=2000,
|
||||||
@@ -411,6 +388,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 初始化数据库
|
||||||
|
Database.initialize(
|
||||||
|
host= os.getenv("MONGODB_HOST"),
|
||||||
|
port= int(os.getenv("MONGODB_PORT")),
|
||||||
|
db_name= os.getenv("DATABASE_NAME"),
|
||||||
|
username= os.getenv("MONGODB_USERNAME"),
|
||||||
|
password= os.getenv("MONGODB_PASSWORD"),
|
||||||
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 创建记忆图
|
||||||
|
memory_graph = Memory_graph()
|
||||||
|
# 加载数据库中存储的记忆图
|
||||||
|
memory_graph.load_graph_from_db()
|
||||||
|
# 创建海马体
|
||||||
|
hippocampus = Hippocampus(memory_graph)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||||
|
|
||||||
|
# 构建记忆
|
||||||
|
hippocampus.build_memory(chat_size=25)
|
||||||
|
|
||||||
|
# 展示两种不同的可视化方式
|
||||||
|
print("\n按连接数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=False)
|
||||||
|
|
||||||
|
print("\n按记忆数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=True)
|
||||||
|
|
||||||
|
# 交互式查询
|
||||||
|
while True:
|
||||||
|
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||||
|
if query.lower() == '退出':
|
||||||
|
break
|
||||||
|
items_list = memory_graph.get_related_item(query)
|
||||||
|
if items_list:
|
||||||
|
for memory_item in items_list:
|
||||||
|
print(memory_item)
|
||||||
|
else:
|
||||||
|
print("未找到相关记忆。")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input("请输入问题:")
|
||||||
|
|
||||||
|
if query.lower() == '退出':
|
||||||
|
break
|
||||||
|
|
||||||
|
topic_prompt = find_topic(query, 3)
|
||||||
|
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
|
||||||
|
# 检查 topic_response 是否为元组
|
||||||
|
if isinstance(topic_response, tuple):
|
||||||
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
|
else:
|
||||||
|
topics = topic_response.split(",")
|
||||||
|
print(topics)
|
||||||
|
|
||||||
|
for keyword in topics:
|
||||||
|
items_list = memory_graph.get_related_item(keyword)
|
||||||
|
if items_list:
|
||||||
|
print(items_list)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -4,14 +4,19 @@ from typing import List, Dict
|
|||||||
from .schedule_llm_module import LLMModel
|
from .schedule_llm_module import LLMModel
|
||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
|
|
||||||
class ScheduleGenerator:
|
class ScheduleGenerator:
|
||||||
|
|||||||
@@ -1,20 +1,24 @@
|
|||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
class LLMModel:
|
class LLMModel:
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs):
|
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs):
|
||||||
if api_using == "deepseek":
|
if api_using == "deepseek":
|
||||||
self.api_key = os.getenv("DEEP_SEEK_KEY")
|
self.api_key = config.deep_seek_key
|
||||||
self.base_url = os.getenv("DEEP_SEEK_BASE_URL")
|
self.base_url = config.deep_seek_base_url
|
||||||
if model_name != "Pro/deepseek-ai/DeepSeek-R1":
|
if model_name != "Pro/deepseek-ai/DeepSeek-R1":
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
else:
|
else:
|
||||||
self.model_name = "deepseek-reasoner"
|
self.model_name = "deepseek-reasoner"
|
||||||
else:
|
else:
|
||||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
self.api_key = config.siliconflow_key
|
||||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
self.base_url = config.siliconflow_base_url
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user