Merge remote-tracking branch 'upstream/debug'
This commit is contained in:
26
.env
Normal file
26
.env
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
|
||||||
|
ENVIRONMENT=prod
|
||||||
|
# HOST=127.0.0.1
|
||||||
|
# PORT=8080
|
||||||
|
|
||||||
|
# COMMAND_START=["/"]
|
||||||
|
|
||||||
|
# # 插件配置
|
||||||
|
# PLUGINS=["src2.plugins.chat"]
|
||||||
|
|
||||||
|
# # 默认配置
|
||||||
|
# MONGODB_HOST=127.0.0.1
|
||||||
|
# MONGODB_PORT=27017
|
||||||
|
# DATABASE_NAME=MegBot
|
||||||
|
|
||||||
|
# MONGODB_USERNAME = "" # 默认空值
|
||||||
|
# MONGODB_PASSWORD = "" # 默认空值
|
||||||
|
# MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||||
|
|
||||||
|
# #key and url
|
||||||
|
# CHAT_ANY_WHERE_KEY=
|
||||||
|
# SILICONFLOW_KEY=
|
||||||
|
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||||
|
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
|
# DEEP_SEEK_KEY=
|
||||||
|
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -2,20 +2,15 @@ data/
|
|||||||
mongodb/
|
mongodb/
|
||||||
NapCat.Framework.Windows.Once/
|
NapCat.Framework.Windows.Once/
|
||||||
log/
|
log/
|
||||||
src/plugins/memory
|
|
||||||
config/bot_config.toml
|
|
||||||
/test
|
/test
|
||||||
message_queue_content.txt
|
message_queue_content.txt
|
||||||
message_queue_content.bat
|
message_queue_content.bat
|
||||||
message_queue_window.bat
|
message_queue_window.bat
|
||||||
message_queue_window.txt
|
message_queue_window.txt
|
||||||
reasoning_content.txt
|
|
||||||
reasoning_content.bat
|
|
||||||
reasoning_window.bat
|
|
||||||
queue_update.txt
|
queue_update.txt
|
||||||
memory_graph.gml
|
memory_graph.gml
|
||||||
.env.dev
|
.env.*
|
||||||
|
config/bot_config_dev.toml
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
@@ -147,7 +142,6 @@ celerybeat.pid
|
|||||||
*.sage.py
|
*.sage.py
|
||||||
|
|
||||||
# Environments
|
# Environments
|
||||||
.env
|
|
||||||
.venv
|
.venv
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ FROM nonebot/nb-cli:latest
|
|||||||
WORKDIR /
|
WORKDIR /
|
||||||
COPY . /MaiMBot/
|
COPY . /MaiMBot/
|
||||||
WORKDIR /MaiMBot
|
WORKDIR /MaiMBot
|
||||||
RUN mv config/env.example config/.env \
|
|
||||||
&& mv config/bot_config_toml config/bot_config.toml
|
|
||||||
RUN pip install --upgrade -r requirements.txt
|
RUN pip install --upgrade -r requirements.txt
|
||||||
VOLUME [ "/MaiMBot/config" ]
|
VOLUME [ "/MaiMBot/config" ]
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
ENTRYPOINT [ "nb","run" ]
|
ENTRYPOINT [ "nb","run" ]
|
||||||
|
|||||||
@@ -75,8 +75,9 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
|||||||
- 在Napcat的网络设置中添加ws反向代理:ws://localhost:8080/onebot/v11/ws
|
- 在Napcat的网络设置中添加ws反向代理:ws://localhost:8080/onebot/v11/ws
|
||||||
|
|
||||||
4. **配置文件设置**
|
4. **配置文件设置**
|
||||||
- 将.env文件打开,填上你的apikey(硅基流动或deepseekapi)
|
- 修改.env的 变量值为 prod
|
||||||
- 将bot_config.toml文件打开,并填写相关内容,不然无法正常运行
|
- 将.env.prod文件打开,填上你的apikey(硅基流动或deepseekapi)
|
||||||
|
- 将bot_config_toml改名为bot_config.toml,打开并填写相关内容,不然无法正常运行
|
||||||
|
|
||||||
#### .env 文件配置说明
|
#### .env 文件配置说明
|
||||||
```ini
|
```ini
|
||||||
|
|||||||
39
bot.py
39
bot.py
@@ -29,35 +29,25 @@ env_file = f".env.{env}"
|
|||||||
if env_file == ".env.dev" and os.path.exists(env_file):
|
if env_file == ".env.dev" and os.path.exists(env_file):
|
||||||
logger.success("加载开发环境变量配置")
|
logger.success("加载开发环境变量配置")
|
||||||
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
elif env_file == ".env.prod" and os.path.exists(env_file):
|
elif os.path.exists(".env.prod"):
|
||||||
logger.success("加载环境变量配置")
|
logger.success("加载环境变量配置")
|
||||||
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
else:
|
else:
|
||||||
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
nonebot.init(
|
# 获取所有环境变量
|
||||||
# 从环境变量中读取配置
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
websocket_port=os.getenv("PORT", 8080),
|
|
||||||
host=os.getenv("HOST", "127.0.0.1"),
|
# 设置基础配置
|
||||||
log_level="INFO",
|
base_config = {
|
||||||
# 添加自定义配置
|
"websocket_port": int(env_config.get("PORT", 8080)),
|
||||||
mongodb_host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
"host": env_config.get("HOST", "127.0.0.1"),
|
||||||
mongodb_port=os.getenv("MONGODB_PORT", 27017),
|
"log_level": "INFO",
|
||||||
database_name=os.getenv("DATABASE_NAME", "MegBot"),
|
}
|
||||||
mongodb_username=os.getenv("MONGODB_USERNAME", ""),
|
|
||||||
mongodb_password=os.getenv("MONGODB_PASSWORD", ""),
|
# 合并配置
|
||||||
mongodb_auth_source=os.getenv("MONGODB_AUTH_SOURCE", ""),
|
nonebot.init(**base_config, **env_config)
|
||||||
# API相关配置
|
|
||||||
chat_any_where_key=os.getenv("CHAT_ANY_WHERE_KEY", ""),
|
|
||||||
siliconflow_key=os.getenv("SILICONFLOW_KEY", ""),
|
|
||||||
chat_any_where_base_url=os.getenv("CHAT_ANY_WHERE_BASE_URL", "https://api.chatanywhere.tech/v1"),
|
|
||||||
siliconflow_base_url=os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1/"),
|
|
||||||
deep_seek_key=os.getenv("DEEP_SEEK_KEY", ""),
|
|
||||||
deep_seek_base_url=os.getenv("DEEP_SEEK_BASE_URL", "https://api.deepseek.com/v1"),
|
|
||||||
# 插件配置
|
|
||||||
plugins=os.getenv("PLUGINS", ["src2.plugins.chat"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# 注册适配器
|
# 注册适配器
|
||||||
driver = nonebot.get_driver()
|
driver = nonebot.get_driver()
|
||||||
@@ -67,4 +57,5 @@ driver.register_adapter(Adapter)
|
|||||||
nonebot.load_plugins("src/plugins")
|
nonebot.load_plugins("src/plugins")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
nonebot.run()
|
nonebot.run()
|
||||||
46
config/auto_format.py
Normal file
46
config/auto_format.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import tomli
|
||||||
|
import tomli_w
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
def sync_configs():
|
||||||
|
# 读取两个配置文件
|
||||||
|
try:
|
||||||
|
with open('bot_config_dev.toml', 'rb') as f: # tomli需要使用二进制模式读取
|
||||||
|
dev_config = tomli.load(f)
|
||||||
|
|
||||||
|
with open('bot_config.toml', 'rb') as f:
|
||||||
|
prod_config = tomli.load(f)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(f"错误:找不到配置文件 - {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except tomli.TOMLDecodeError as e:
|
||||||
|
print(f"错误:TOML格式解析失败 - {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 递归合并配置
|
||||||
|
def merge_configs(source, target):
|
||||||
|
for key, value in source.items():
|
||||||
|
if key not in target:
|
||||||
|
target[key] = value
|
||||||
|
elif isinstance(value, dict) and isinstance(target[key], dict):
|
||||||
|
merge_configs(value, target[key])
|
||||||
|
|
||||||
|
# 将dev配置的新属性合并到prod配置中
|
||||||
|
merge_configs(dev_config, prod_config)
|
||||||
|
|
||||||
|
# 保存更新后的配置
|
||||||
|
try:
|
||||||
|
with open('bot_config.toml', 'wb') as f: # tomli_w需要使用二进制模式写入
|
||||||
|
tomli_w.dump(prod_config, f)
|
||||||
|
print("配置文件同步完成!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误:保存配置文件失败 - {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 确保在正确的目录下运行
|
||||||
|
script_dir = Path(__file__).parent
|
||||||
|
os.chdir(script_dir)
|
||||||
|
sync_configs()
|
||||||
61
config/bot_config.toml
Normal file
61
config/bot_config.toml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
[bot]
|
||||||
|
qq = 123
|
||||||
|
nickname = "麦麦"
|
||||||
|
|
||||||
|
[message]
|
||||||
|
min_text_length = 2
|
||||||
|
max_context_size = 15
|
||||||
|
emoji_chance = 0.2
|
||||||
|
|
||||||
|
[emoji]
|
||||||
|
check_interval = 120
|
||||||
|
register_interval = 10
|
||||||
|
|
||||||
|
[cq_code]
|
||||||
|
enable_pic_translate = false
|
||||||
|
|
||||||
|
[response]
|
||||||
|
api_using = "siliconflow"
|
||||||
|
api_paid = true
|
||||||
|
model_r1_probability = 0.8
|
||||||
|
model_v3_probability = 0.1
|
||||||
|
model_r1_distill_probability = 0.1
|
||||||
|
|
||||||
|
[memory]
|
||||||
|
build_memory_interval = 300
|
||||||
|
|
||||||
|
[others]
|
||||||
|
enable_advance_output = true
|
||||||
|
|
||||||
|
[groups]
|
||||||
|
talk_allowed = [
|
||||||
|
123,
|
||||||
|
123,
|
||||||
|
]
|
||||||
|
talk_frequency_down = []
|
||||||
|
ban_user_id = []
|
||||||
|
|
||||||
|
[model.llm_reasoning]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-R1"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_reasoning_minor]
|
||||||
|
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal_minor]
|
||||||
|
name = "deepseek-ai/DeepSeek-V2.5"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.vlm]
|
||||||
|
name = "deepseek-ai/deepseek-vl2"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
[bot]
|
|
||||||
qq = 123456 #填入你的机器人QQ
|
|
||||||
nickname = "麦麦" #你希望bot被称呼的名字
|
|
||||||
|
|
||||||
[message]
|
|
||||||
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
|
|
||||||
max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃
|
|
||||||
emoji_chance = 0.2 # 麦麦使用表情包的概率
|
|
||||||
|
|
||||||
[emoji]
|
|
||||||
check_interval = 120
|
|
||||||
register_interval = 10
|
|
||||||
|
|
||||||
[cq_code]
|
|
||||||
enable_pic_translate = false
|
|
||||||
|
|
||||||
|
|
||||||
[response]
|
|
||||||
api_using = "siliconflow" # 选择大模型API,可选值为siliconflow,deepseek,建议使用siliconflow,因为识图api目前只支持siliconflow的deepseek-vl2模型
|
|
||||||
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
|
|
||||||
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
|
|
||||||
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
|
|
||||||
|
|
||||||
[memory]
|
|
||||||
build_memory_interval = 300 # 记忆构建间隔
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[others]
|
|
||||||
enable_advance_output = true # 开启后输出更多日志,false关闭true开启
|
|
||||||
|
|
||||||
|
|
||||||
[groups]
|
|
||||||
|
|
||||||
talk_allowed = [
|
|
||||||
123456,12345678
|
|
||||||
] #可以回复消息的群
|
|
||||||
|
|
||||||
talk_frequency_down = [
|
|
||||||
123456,12345678
|
|
||||||
] #降低回复频率的群
|
|
||||||
|
|
||||||
ban_user_id = [
|
|
||||||
123456,12345678
|
|
||||||
] #禁止回复消息的QQ号
|
|
||||||
@@ -27,7 +27,7 @@ services:
|
|||||||
- mongodb:/data/db
|
- mongodb:/data/db
|
||||||
- mongodbCONFIG:/data/configdb
|
- mongodbCONFIG:/data/configdb
|
||||||
image: mongo:latest
|
image: mongo:latest
|
||||||
|
|
||||||
maimbot:
|
maimbot:
|
||||||
container_name: maimbot
|
container_name: maimbot
|
||||||
environment:
|
environment:
|
||||||
@@ -41,8 +41,8 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- maimbotCONFIG:/MaiMBot/config
|
- maimbotCONFIG:/MaiMBot/config
|
||||||
- maimbotDATA:/MaiMBot/data
|
- maimbotDATA:/MaiMBot/data
|
||||||
|
- ./.env.prod:/MaiMBot/.env.prod
|
||||||
image: sengokucola/maimbot:latest
|
image: sengokucola/maimbot:latest
|
||||||
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
maimbotCONFIG:
|
maimbotCONFIG:
|
||||||
@@ -51,4 +51,5 @@ volumes:
|
|||||||
napcatCONFIG:
|
napcatCONFIG:
|
||||||
mongodb:
|
mongodb:
|
||||||
mongodbCONFIG:
|
mongodbCONFIG:
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
BIN
docs/qq.png
BIN
docs/qq.png
Binary file not shown.
|
Before Width: | Height: | Size: 191 KiB |
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
@@ -7,6 +7,23 @@ import threading
|
|||||||
import queue
|
import queue
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 获取当前文件的目录
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# 获取项目根目录
|
||||||
|
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
if os.path.exists(os.path.join(root_dir, '.env.dev')):
|
||||||
|
load_dotenv(os.path.join(root_dir, '.env.dev'))
|
||||||
|
print("成功加载开发环境配置")
|
||||||
|
elif os.path.exists(os.path.join(root_dir, '.env.prod')):
|
||||||
|
load_dotenv(os.path.join(root_dir, '.env.prod'))
|
||||||
|
print("成功加载生产环境配置")
|
||||||
|
else:
|
||||||
|
print("未找到环境配置文件")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -14,14 +31,23 @@ from typing import Optional
|
|||||||
class Database:
|
class Database:
|
||||||
_instance: Optional["Database"] = None
|
_instance: Optional["Database"] = None
|
||||||
|
|
||||||
def __init__(self, host: str, port: int, db_name: str):
|
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None):
|
||||||
self.client = MongoClient(host, port)
|
if username and password:
|
||||||
|
self.client = MongoClient(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
authSource=auth_source or 'admin'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.client = MongoClient(host, port)
|
||||||
self.db = self.client[db_name]
|
self.db = self.client[db_name]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, host: str, port: int, db_name: str) -> "Database":
|
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None) -> "Database":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = cls(host, port, db_name)
|
cls._instance = cls(host, port, db_name, username, password, auth_source)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ async def start_background_tasks():
|
|||||||
"""启动后台任务"""
|
"""启动后台任务"""
|
||||||
# 只启动表情包管理任务
|
# 只启动表情包管理任务
|
||||||
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
||||||
|
await bot_schedule.initialize()
|
||||||
bot_schedule.print_schedule()
|
bot_schedule.print_schedule()
|
||||||
|
|
||||||
@driver.on_startup
|
@driver.on_startup
|
||||||
@@ -90,7 +91,7 @@ async def monitor_relationships():
|
|||||||
async def build_memory_task():
|
async def build_memory_task():
|
||||||
"""每30秒执行一次记忆构建"""
|
"""每30秒执行一次记忆构建"""
|
||||||
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
|
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
|
||||||
await hippocampus.build_memory(chat_size=12)
|
await hippocampus.build_memory(chat_size=30)
|
||||||
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
|
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessa
|
|||||||
from .message import Message,MessageSet
|
from .message import Message,MessageSet
|
||||||
from .config import BotConfig, global_config
|
from .config import BotConfig, global_config
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from .llm_generator import LLMResponseGenerator
|
from .llm_generator import ResponseGenerator
|
||||||
from .message_stream import MessageStream, MessageStreamContainer
|
from .message_stream import MessageStream, MessageStreamContainer
|
||||||
from .topic_identifier import topic_identifier
|
from .topic_identifier import topic_identifier
|
||||||
from random import random, choice
|
from random import random, choice
|
||||||
@@ -20,7 +20,7 @@ from ..memory_system.memory import memory_graph
|
|||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
self.gpt = LLMResponseGenerator()
|
self.gpt = ResponseGenerator()
|
||||||
self.bot = None # bot 实例引用
|
self.bot = None # bot 实例引用
|
||||||
self._started = False
|
self._started = False
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Any, Optional, Set
|
from typing import Dict, Any, Optional, Set
|
||||||
import os
|
import os
|
||||||
from nonebot.log import logger, default_format
|
from nonebot.log import logger, default_format
|
||||||
@@ -32,7 +32,15 @@ class BotConfig:
|
|||||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
API_USING: str = "siliconflow" # 使用的API
|
API_USING: str = "siliconflow" # 使用的API
|
||||||
|
API_PAID: bool = False # 是否使用付费API
|
||||||
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
||||||
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
||||||
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
||||||
@@ -48,20 +56,19 @@ class BotConfig:
|
|||||||
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书"
|
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_config_path() -> str:
|
def get_config_dir() -> str:
|
||||||
"""获取默认配置文件路径"""
|
"""获取配置文件目录"""
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||||
config_dir = os.path.join(root_dir, 'config')
|
config_dir = os.path.join(root_dir, 'config')
|
||||||
return os.path.join(config_dir, 'bot_config.toml')
|
if not os.path.exists(config_dir):
|
||||||
|
os.makedirs(config_dir)
|
||||||
|
return config_dir
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(cls, config_path: str = None) -> "BotConfig":
|
def load_config(cls, config_path: str = None) -> "BotConfig":
|
||||||
"""从TOML配置文件加载配置"""
|
"""从TOML配置文件加载配置"""
|
||||||
if config_path is None:
|
|
||||||
config_path = cls.get_default_config_path()
|
|
||||||
logger.info(f"使用默认配置文件路径: {config_path}")
|
|
||||||
|
|
||||||
config = cls()
|
config = cls()
|
||||||
if os.path.exists(config_path):
|
if os.path.exists(config_path):
|
||||||
with open(config_path, "rb") as f:
|
with open(config_path, "rb") as f:
|
||||||
@@ -89,6 +96,26 @@ class BotConfig:
|
|||||||
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
||||||
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
||||||
config.API_USING = response_config.get("api_using", config.API_USING)
|
config.API_USING = response_config.get("api_using", config.API_USING)
|
||||||
|
config.API_PAID = response_config.get("api_paid", config.API_PAID)
|
||||||
|
|
||||||
|
# 加载模型配置
|
||||||
|
if "model" in toml_dict:
|
||||||
|
model_config = toml_dict["model"]
|
||||||
|
|
||||||
|
if "llm_reasoning" in model_config:
|
||||||
|
config.llm_reasoning = model_config["llm_reasoning"]
|
||||||
|
|
||||||
|
if "llm_reasoning_minor" in model_config:
|
||||||
|
config.llm_reasoning_minor = model_config["llm_reasoning_minor"]
|
||||||
|
|
||||||
|
if "llm_normal" in model_config:
|
||||||
|
config.llm_normal = model_config["llm_normal"]
|
||||||
|
|
||||||
|
if "llm_normal_minor" in model_config:
|
||||||
|
config.llm_normal_minor = model_config["llm_normal_minor"]
|
||||||
|
|
||||||
|
if "vlm" in model_config:
|
||||||
|
config.vlm = model_config["vlm"]
|
||||||
|
|
||||||
# 消息配置
|
# 消息配置
|
||||||
if "message" in toml_dict:
|
if "message" in toml_dict:
|
||||||
@@ -125,12 +152,21 @@ class BotConfig:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
# 获取配置文件路径
|
# 获取配置文件路径
|
||||||
bot_config_path = BotConfig.get_default_config_path()
|
|
||||||
config_dir = os.path.dirname(bot_config_path)
|
|
||||||
|
|
||||||
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
|
bot_config_floder_path = BotConfig.get_config_dir()
|
||||||
|
print(f"正在品鉴配置文件目录: {bot_config_floder_path}")
|
||||||
|
bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
|
||||||
|
if not os.path.exists(bot_config_path):
|
||||||
|
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||||||
|
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||||||
|
logger.info("使用默认配置文件")
|
||||||
|
else:
|
||||||
|
logger.info("已找到开发环境配置文件")
|
||||||
|
|
||||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMConfig:
|
class LLMConfig:
|
||||||
"""机器人配置类"""
|
"""机器人配置类"""
|
||||||
@@ -151,3 +187,4 @@ llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
|||||||
if not global_config.enable_advance_output:
|
if not global_config.enable_advance_output:
|
||||||
# logger.remove()
|
# logger.remove()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
from .utils_image import storage_image,storage_emoji
|
from .utils_image import storage_image,storage_emoji
|
||||||
from .utils_user import get_user_nickname
|
from .utils_user import get_user_nickname
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
#解析各种CQ码
|
#解析各种CQ码
|
||||||
#包含CQ码类
|
#包含CQ码类
|
||||||
import urllib3
|
import urllib3
|
||||||
@@ -57,6 +58,11 @@ class CQCode:
|
|||||||
translated_plain_text: Optional[str] = None
|
translated_plain_text: Optional[str] = None
|
||||||
reply_message: Dict = None # 存储回复消息
|
reply_message: Dict = None # 存储回复消息
|
||||||
image_base64: Optional[str] = None
|
image_base64: Optional[str] = None
|
||||||
|
_llm: Optional[LLM_request] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""初始化LLM实例"""
|
||||||
|
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
||||||
|
|
||||||
def translate(self):
|
def translate(self):
|
||||||
"""根据CQ码类型进行相应的翻译处理"""
|
"""根据CQ码类型进行相应的翻译处理"""
|
||||||
@@ -161,7 +167,7 @@ class CQCode:
|
|||||||
# 将 base64 字符串转换为字节类型
|
# 将 base64 字符串转换为字节类型
|
||||||
image_bytes = base64.b64decode(base64_str)
|
image_bytes = base64.b64decode(base64_str)
|
||||||
storage_emoji(image_bytes)
|
storage_emoji(image_bytes)
|
||||||
return self.get_image_description(base64_str)
|
return self.get_emoji_description(base64_str)
|
||||||
else:
|
else:
|
||||||
return '[表情包]'
|
return '[表情包]'
|
||||||
|
|
||||||
@@ -181,93 +187,23 @@ class CQCode:
|
|||||||
|
|
||||||
def get_emoji_description(self, image_base64: str) -> str:
|
def get_emoji_description(self, image_base64: str) -> str:
|
||||||
"""调用AI接口获取表情包描述"""
|
"""调用AI接口获取表情包描述"""
|
||||||
headers = {
|
try:
|
||||||
"Content-Type": "application/json",
|
prompt = "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
|
||||||
"Authorization": f"Bearer {config.siliconflow_key}"
|
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
||||||
}
|
return f"[表情包:{description}]"
|
||||||
|
except Exception as e:
|
||||||
payload = {
|
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
||||||
"model": "deepseek-ai/deepseek-vl2",
|
return "[表情包]"
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 50,
|
|
||||||
"temperature": 0.4
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"{config.siliconflow_base_url}chat/completions",
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result_json = response.json()
|
|
||||||
if "choices" in result_json and len(result_json["choices"]) > 0:
|
|
||||||
description = result_json["choices"][0]["message"]["content"]
|
|
||||||
return f"[表情包:{description}]"
|
|
||||||
|
|
||||||
raise ValueError(f"AI接口调用失败: {response.text}")
|
|
||||||
|
|
||||||
def get_image_description(self, image_base64: str) -> str:
|
def get_image_description(self, image_base64: str) -> str:
|
||||||
"""调用AI接口获取普通图片描述"""
|
"""调用AI接口获取普通图片描述"""
|
||||||
headers = {
|
try:
|
||||||
"Content-Type": "application/json",
|
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||||
"Authorization": f"Bearer {config.siliconflow_key}"
|
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
||||||
}
|
return f"[图片:{description}]"
|
||||||
|
except Exception as e:
|
||||||
payload = {
|
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
||||||
"model": "deepseek-ai/deepseek-vl2",
|
return "[图片]"
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 300,
|
|
||||||
"temperature": 0.6
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"{config.siliconflow_base_url}chat/completions",
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result_json = response.json()
|
|
||||||
if "choices" in result_json and len(result_json["choices"]) > 0:
|
|
||||||
description = result_json["choices"][0]["message"]["content"]
|
|
||||||
return f"[图片:{description}]"
|
|
||||||
|
|
||||||
raise ValueError(f"AI接口调用失败: {response.text}")
|
|
||||||
|
|
||||||
def translate_forward(self) -> str:
|
def translate_forward(self) -> str:
|
||||||
"""处理转发消息"""
|
"""处理转发消息"""
|
||||||
@@ -349,7 +285,7 @@ class CQCode:
|
|||||||
# 创建Message对象
|
# 创建Message对象
|
||||||
from .message import Message
|
from .message import Message
|
||||||
if self.reply_message == None:
|
if self.reply_message == None:
|
||||||
print(f"\033[1;31m[错误]\033[0m 回复消息为空")
|
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
|
||||||
return '[回复某人消息]'
|
return '[回复某人消息]'
|
||||||
|
|
||||||
if self.reply_message.sender.user_id:
|
if self.reply_message.sender.user_id:
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from ..chat.config import global_config
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -43,6 +45,7 @@ class EmojiManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
|
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50)
|
||||||
|
|
||||||
def _ensure_emoji_dir(self):
|
def _ensure_emoji_dir(self):
|
||||||
"""确保表情存储目录存在"""
|
"""确保表情存储目录存在"""
|
||||||
@@ -87,55 +90,23 @@ class EmojiManager:
|
|||||||
print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}")
|
||||||
|
|
||||||
async def _get_emotion_from_text(self, text: str) -> List[str]:
|
async def _get_emotion_from_text(self, text: str) -> List[str]:
|
||||||
"""从文本中识别情感关键词,使用DeepSeek API进行分析
|
"""从文本中识别情感关键词
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: 匹配到的情感标签列表
|
List[str]: 匹配到的情感标签列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 准备请求数据
|
prompt = f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签,不要输出其他任何内容。'
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {config.siliconflow_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
content, _ = await self.llm.generate_response(prompt)
|
||||||
"model": "deepseek-ai/DeepSeek-V3",
|
emotion = content.strip().lower()
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签,不要输出其他任何内容。'
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 50,
|
|
||||||
"temperature": 0.3
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
if emotion in self.EMOTION_KEYWORDS:
|
||||||
async with session.post(
|
print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}")
|
||||||
f"{config.siliconflow_base_url}chat/completions",
|
return [emotion]
|
||||||
headers=headers,
|
|
||||||
json=payload
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
print(f"\033[1;31m[错误]\033[0m API请求失败: {await response.text()}")
|
|
||||||
return ['neutral']
|
|
||||||
|
|
||||||
result = json.loads(await response.text())
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
emotion = result["choices"][0]["message"]["content"].strip().lower()
|
|
||||||
# 确保返回的标签是有效的
|
|
||||||
if emotion in self.EMOTION_KEYWORDS:
|
|
||||||
print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}")
|
|
||||||
return [emotion] # 返回单个情感标签的列表
|
|
||||||
|
|
||||||
return ['neutral'] # 如果无法识别情感,返回neutral
|
return ['neutral']
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}")
|
||||||
@@ -250,52 +221,20 @@ class EmojiManager:
|
|||||||
|
|
||||||
async def _get_emoji_tag(self, image_base64: str) -> str:
|
async def _get_emoji_tag(self, image_base64: str) -> str:
|
||||||
"""获取表情包的标签"""
|
"""获取表情包的标签"""
|
||||||
async with aiohttp.ClientSession() as session:
|
try:
|
||||||
headers = {
|
prompt = '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好'
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {config.siliconflow_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
content, _ = await self.llm.generate_response_for_image(prompt, image_base64)
|
||||||
"model": "deepseek-ai/deepseek-vl2",
|
tag_result = content.strip().lower()
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 60,
|
|
||||||
"temperature": 0.3
|
|
||||||
}
|
|
||||||
|
|
||||||
async with session.post(
|
valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"]
|
||||||
f"{config.siliconflow_base_url}chat/completions",
|
for tag_match in valid_tags:
|
||||||
headers=headers,
|
if tag_match in tag_result or tag_match == tag_result:
|
||||||
json=payload
|
return tag_match
|
||||||
) as response:
|
print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_result}, 跳过")
|
||||||
if response.status == 200:
|
|
||||||
result = await response.json()
|
except Exception as e:
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}")
|
||||||
tag_result = result["choices"][0]["message"]["content"].strip().lower()
|
|
||||||
|
|
||||||
valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"]
|
|
||||||
for tag_match in valid_tags:
|
|
||||||
if tag_match in tag_result or tag_match == tag_result:
|
|
||||||
return tag_match
|
|
||||||
print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_match}, 跳过")
|
|
||||||
else:
|
|
||||||
print(f"\033[1;31m[错误]\033[0m 获取标签失败, 状态码: {response.status}")
|
|
||||||
|
|
||||||
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
|
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
|
||||||
return "skip" # 默认标签
|
return "skip" # 默认标签
|
||||||
|
|||||||
@@ -13,274 +13,120 @@ from .prompt_builder import prompt_builder
|
|||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .utils import process_llm_response
|
from .utils import process_llm_response
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
class LLMResponseGenerator:
|
class ResponseGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if global_config.API_USING == "siliconflow":
|
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
|
||||||
self.client = OpenAI(
|
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
|
||||||
api_key=config.siliconflow_key,
|
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7)
|
||||||
base_url=config.siliconflow_base_url
|
|
||||||
)
|
|
||||||
elif global_config.API_USING == "deepseek":
|
|
||||||
self.client = OpenAI(
|
|
||||||
api_key=config.deep_seek_key,
|
|
||||||
base_url=config.deep_seek_base_url
|
|
||||||
)
|
|
||||||
|
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
# 当前使用的模型类型
|
|
||||||
self.current_model_type = 'r1' # 默认使用 R1
|
self.current_model_type = 'r1' # 默认使用 R1
|
||||||
|
|
||||||
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
|
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
# 从global_config中获取模型概率值
|
# 从global_config中获取模型概率值并选择模型
|
||||||
model_r1_probability = global_config.MODEL_R1_PROBABILITY
|
|
||||||
model_v3_probability = global_config.MODEL_V3_PROBABILITY
|
|
||||||
model_r1_distill_probability = global_config.MODEL_R1_DISTILL_PROBABILITY
|
|
||||||
|
|
||||||
# 生成随机数并根据概率选择模型
|
|
||||||
rand = random.random()
|
rand = random.random()
|
||||||
if rand < model_r1_probability:
|
if rand < global_config.MODEL_R1_PROBABILITY:
|
||||||
self.current_model_type = 'r1'
|
self.current_model_type = 'r1'
|
||||||
elif rand < model_r1_probability + model_v3_probability:
|
current_model = self.model_r1
|
||||||
|
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
|
||||||
self.current_model_type = 'v3'
|
self.current_model_type = 'v3'
|
||||||
|
current_model = self.model_v3
|
||||||
else:
|
else:
|
||||||
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
|
self.current_model_type = 'r1_distill'
|
||||||
|
current_model = self.model_r1_distill
|
||||||
|
|
||||||
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
||||||
if self.current_model_type == 'r1':
|
|
||||||
model_response = await self._generate_r1_response(message)
|
|
||||||
elif self.current_model_type == 'v3':
|
|
||||||
model_response = await self._generate_v3_response(message)
|
|
||||||
else:
|
|
||||||
model_response = await self._generate_r1_distill_response(message)
|
|
||||||
|
|
||||||
# 打印情感标签
|
model_response = await self._generate_response_with_model(message, current_model)
|
||||||
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
|
||||||
model_response, emotion = await self._process_response(model_response)
|
|
||||||
|
|
||||||
if model_response:
|
if model_response:
|
||||||
print(f"为 '{model_response}' 获取到的情感标签为:{emotion}")
|
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
||||||
valuedict={
|
model_response, emotion = await self._process_response(model_response)
|
||||||
|
if model_response:
|
||||||
|
print(f"为 '{model_response}' 获取到的情感标签为:{emotion}")
|
||||||
|
valuedict={
|
||||||
'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25
|
'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25
|
||||||
}
|
}
|
||||||
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
|
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
|
||||||
|
|
||||||
|
return model_response, emotion
|
||||||
|
return None, []
|
||||||
|
|
||||||
return model_response, emotion
|
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
|
||||||
|
"""使用指定的模型生成回复"""
|
||||||
async def _generate_base_response(
|
|
||||||
self,
|
|
||||||
message: Message,
|
|
||||||
model_name: str,
|
|
||||||
model_params: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Optional[str]:
|
|
||||||
sender_name = message.user_nickname or f"用户{message.user_id}"
|
sender_name = message.user_nickname or f"用户{message.user_id}"
|
||||||
if message.user_cardname:
|
if message.user_cardname:
|
||||||
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
|
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
|
||||||
|
|
||||||
# 获取关系值
|
# 获取关系值
|
||||||
if relationship_manager.get_relationship(message.user_id):
|
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
|
||||||
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value
|
if relationship_value != 0.0:
|
||||||
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
||||||
else:
|
|
||||||
relationship_value = 0.0
|
|
||||||
|
|
||||||
|
# 构建prompt
|
||||||
''' 构建prompt '''
|
prompt, prompt_check = prompt_builder._build_prompt(
|
||||||
prompt,prompt_check = prompt_builder._build_prompt(
|
|
||||||
message_txt=message.processed_plain_text,
|
message_txt=message.processed_plain_text,
|
||||||
sender_name=sender_name,
|
sender_name=sender_name,
|
||||||
relationship_value=relationship_value,
|
relationship_value=relationship_value,
|
||||||
group_id=message.group_id
|
group_id=message.group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 设置默认参数
|
|
||||||
default_params = {
|
|
||||||
"model": model_name,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"stream": False,
|
|
||||||
"max_tokens": 2048,
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
|
|
||||||
default_params_check = {
|
|
||||||
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
||||||
"messages": [{"role": "user", "content": prompt_check}],
|
|
||||||
"stream": False,
|
|
||||||
"max_tokens": 2048,
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
|
|
||||||
default_params_check = {
|
|
||||||
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
||||||
"messages": [{"role": "user", "content": prompt_check}],
|
|
||||||
"stream": False,
|
|
||||||
"max_tokens": 1024,
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
|
|
||||||
default_params_check = {
|
|
||||||
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
||||||
"messages": [{"role": "user", "content": prompt_check}],
|
|
||||||
"stream": False,
|
|
||||||
"max_tokens": 1024,
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
|
|
||||||
# 更新参数
|
|
||||||
if model_params:
|
|
||||||
default_params.update(model_params)
|
|
||||||
|
|
||||||
|
|
||||||
def create_completion():
|
|
||||||
return self.client.chat.completions.create(**default_params)
|
|
||||||
|
|
||||||
def create_completion_check():
|
|
||||||
return self.client.chat.completions.create(**default_params_check)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
# 读空气模块
|
# 读空气模块
|
||||||
air = 0
|
|
||||||
reasoning_content_check=''
|
|
||||||
content_check=''
|
|
||||||
if global_config.enable_kuuki_read:
|
if global_config.enable_kuuki_read:
|
||||||
response_check = await loop.run_in_executor(None, create_completion_check)
|
content_check, reasoning_content_check = await self.model_v3.generate_response(prompt_check)
|
||||||
if response_check:
|
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
|
||||||
reasoning_content_check = ""
|
if 'yes' not in content_check.lower() and random.random() < 0.3:
|
||||||
if hasattr(response_check.choices[0].message, "reasoning"):
|
self._save_to_db(
|
||||||
reasoning_content_check = response_check.choices[0].message.reasoning or reasoning_content_check
|
message=message,
|
||||||
elif hasattr(response_check.choices[0].message, "reasoning_content"):
|
sender_name=sender_name,
|
||||||
reasoning_content_check = response_check.choices[0].message.reasoning_content or reasoning_content_check
|
prompt=prompt,
|
||||||
content_check = response_check.choices[0].message.content
|
prompt_check=prompt_check,
|
||||||
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
|
content="",
|
||||||
if 'yes' not in content_check.lower():
|
content_check=content_check,
|
||||||
air = 1
|
reasoning_content="",
|
||||||
#稀释读空气的判定
|
reasoning_content_check=reasoning_content_check
|
||||||
if air == 1 and random.random() < 0.3:
|
)
|
||||||
self.db.db.reasoning_logs.insert_one({
|
return None
|
||||||
'time': time.time(),
|
|
||||||
'group_id': message.group_id,
|
|
||||||
'user': sender_name,
|
|
||||||
'message': message.processed_plain_text,
|
|
||||||
'model': model_name,
|
|
||||||
'reasoning_check': reasoning_content_check,
|
|
||||||
'response_check': content_check,
|
|
||||||
'reasoning': "",
|
|
||||||
'response': "",
|
|
||||||
'prompt': prompt,
|
|
||||||
'prompt_check': prompt_check,
|
|
||||||
'model_params': default_params
|
|
||||||
})
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
response = await loop.run_in_executor(None, create_completion)
|
# 生成回复
|
||||||
|
content, reasoning_content = await model.generate_response(prompt)
|
||||||
|
|
||||||
# 检查响应内容
|
|
||||||
if not response:
|
|
||||||
print("请求未返回任何内容")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not response.choices or not response.choices[0].message.content:
|
|
||||||
print("请求返回的内容无效:", response)
|
|
||||||
return None
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
|
|
||||||
# 获取推理内容
|
|
||||||
reasoning_content = ""
|
|
||||||
if hasattr(response.choices[0].message, "reasoning"):
|
|
||||||
reasoning_content = response.choices[0].message.reasoning or reasoning_content
|
|
||||||
elif hasattr(response.choices[0].message, "reasoning_content"):
|
|
||||||
reasoning_content = response.choices[0].message.reasoning_content or reasoning_content
|
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
|
self._save_to_db(
|
||||||
|
message=message,
|
||||||
|
sender_name=sender_name,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_check=prompt_check,
|
||||||
|
content=content,
|
||||||
|
content_check=content_check if global_config.enable_kuuki_read else "",
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
||||||
|
content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
|
||||||
|
"""保存对话记录到数据库"""
|
||||||
self.db.db.reasoning_logs.insert_one({
|
self.db.db.reasoning_logs.insert_one({
|
||||||
'time': time.time(),
|
'time': time.time(),
|
||||||
'group_id': message.group_id,
|
'group_id': message.group_id,
|
||||||
'user': sender_name,
|
'user': sender_name,
|
||||||
'message': message.processed_plain_text,
|
'message': message.processed_plain_text,
|
||||||
'model': model_name,
|
'model': self.current_model_type,
|
||||||
'reasoning_check': reasoning_content_check,
|
'reasoning_check': reasoning_content_check,
|
||||||
'response_check': content_check,
|
'response_check': content_check,
|
||||||
'reasoning': reasoning_content,
|
'reasoning': reasoning_content,
|
||||||
'response': content,
|
'response': content,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'prompt_check': prompt_check,
|
'prompt_check': prompt_check
|
||||||
'model_params': default_params
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def _generate_r1_response(self, message: Message) -> Optional[str]:
|
|
||||||
"""使用 DeepSeek-R1 模型生成回复"""
|
|
||||||
if global_config.API_USING == "deepseek":
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"deepseek-reasoner",
|
|
||||||
{"temperature": 0.7, "max_tokens": 2048}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"Pro/deepseek-ai/DeepSeek-R1",
|
|
||||||
{"temperature": 0.7, "max_tokens": 2048}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _generate_v3_response(self, message: Message) -> Optional[str]:
|
|
||||||
"""使用 DeepSeek-V3 模型生成回复"""
|
|
||||||
if global_config.API_USING == "deepseek":
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"deepseek-chat",
|
|
||||||
{"temperature": 0.8, "max_tokens": 2048}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"Pro/deepseek-ai/DeepSeek-V3",
|
|
||||||
{"temperature": 0.8, "max_tokens": 2048}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _generate_r1_distill_response(self, message: Message) -> Optional[str]:
|
|
||||||
"""使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复"""
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
||||||
{"temperature": 0.7, "max_tokens": 2048}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_group_chat_context(self, message: Message) -> str:
|
|
||||||
"""获取群聊上下文"""
|
|
||||||
recent_messages = self.db.db.messages.find(
|
|
||||||
{"group_id": message.group_id}
|
|
||||||
).sort("time", -1).limit(15)
|
|
||||||
|
|
||||||
messages_list = list(recent_messages)[::-1]
|
|
||||||
group_chat = ""
|
|
||||||
|
|
||||||
for msg_dict in messages_list:
|
|
||||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(msg_dict['time']))
|
|
||||||
display_name = msg_dict.get('user_nickname', f"用户{msg_dict['user_id']}")
|
|
||||||
cardname = msg_dict.get('user_cardname', '')
|
|
||||||
display_name = f"[({msg_dict['user_id']}){display_name}]{cardname}" if cardname!='' else display_name
|
|
||||||
content = msg_dict.get('processed_plain_text', msg_dict['plain_text'])
|
|
||||||
|
|
||||||
group_chat += f"[{time_str}] {display_name}: {content}\n"
|
|
||||||
|
|
||||||
return group_chat
|
|
||||||
|
|
||||||
async def _get_emotion_tags(self, content: str) -> List[str]:
|
async def _get_emotion_tags(self, content: str) -> List[str]:
|
||||||
"""提取情感标签"""
|
"""提取情感标签"""
|
||||||
@@ -291,33 +137,12 @@ class LLMResponseGenerator:
|
|||||||
输出:
|
输出:
|
||||||
'''
|
'''
|
||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
content, _ = await self.model_v3.generate_response(prompt)
|
||||||
|
return [content.strip()] if content else ["neutral"]
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if global_config.API_USING == "deepseek":
|
|
||||||
model = "deepseek-chat"
|
|
||||||
else:
|
|
||||||
model = "Pro/deepseek-ai/DeepSeek-V3"
|
|
||||||
create_completion = partial(
|
|
||||||
self.client.chat.completions.create,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=30,
|
|
||||||
temperature=0.6
|
|
||||||
)
|
|
||||||
response = await loop.run_in_executor(None, create_completion)
|
|
||||||
|
|
||||||
if response.choices[0].message.content:
|
|
||||||
# 确保返回的是列表格式
|
|
||||||
emotion_tag = response.choices[0].message.content.strip()
|
|
||||||
return [emotion_tag] # 将单个标签包装成列表返回
|
|
||||||
|
|
||||||
return ["neutral"] # 如果无法获取情感标签,返回默认值
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取情感标签时出错: {e}")
|
print(f"获取情感标签时出错: {e}")
|
||||||
return ["neutral"] # 发生错误时返回默认值
|
return ["neutral"]
|
||||||
|
|
||||||
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
||||||
"""处理响应内容,返回处理后的内容和情感标签"""
|
"""处理响应内容,返回处理后的内容和情感标签"""
|
||||||
@@ -325,10 +150,6 @@ class LLMResponseGenerator:
|
|||||||
return None, []
|
return None, []
|
||||||
|
|
||||||
emotion_tags = await self._get_emotion_tags(content)
|
emotion_tags = await self._get_emotion_tags(content)
|
||||||
|
|
||||||
processed_response = process_llm_response(content)
|
processed_response = process_llm_response(content)
|
||||||
|
|
||||||
return processed_response, emotion_tags
|
return processed_response, emotion_tags
|
||||||
|
|
||||||
# 创建全局实例
|
|
||||||
llm_response = LLMResponseGenerator()
|
|
||||||
@@ -72,12 +72,15 @@ class PromptBuilder:
|
|||||||
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
||||||
|
|
||||||
# 使用集合去重
|
# 使用集合去重
|
||||||
all_memories = list(set(all_first_layer_items) | set(overlapping_second_layer))
|
# 从每个来源随机选择2条记忆(如果有的话)
|
||||||
|
selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else []
|
||||||
|
selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else []
|
||||||
|
|
||||||
|
# 合并并去重
|
||||||
|
all_memories = list(set(selected_first_layer + selected_second_layer))
|
||||||
if all_memories:
|
if all_memories:
|
||||||
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
|
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
|
||||||
|
random_item = " ".join(all_memories)
|
||||||
if all_memories: # 只在列表非空时选择随机项
|
|
||||||
random_item = choice(all_memories)
|
|
||||||
memory_prompt = f"看到这些聊天,你想起来{random_item}\n"
|
memory_prompt = f"看到这些聊天,你想起来{random_item}\n"
|
||||||
else:
|
else:
|
||||||
memory_prompt = "" # 如果没有记忆,则返回空字符串
|
memory_prompt = "" # 如果没有记忆,则返回空字符串
|
||||||
@@ -150,7 +153,7 @@ class PromptBuilder:
|
|||||||
if personality_choice < 4/6: # 第一种人格
|
if personality_choice < 4/6: # 第一种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]},{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]},{promt_info_prompt},
|
||||||
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
||||||
请注意把握群里的聊天内容,不要回复的太有条理,可以有个性。'''
|
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
|
||||||
elif personality_choice < 1: # 第二种人格
|
elif personality_choice < 1: # 第二种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]},{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]},{promt_info_prompt},
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from openai import OpenAI
|
|||||||
from .message import Message
|
from .message import Message
|
||||||
import jieba
|
import jieba
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from .config import global_config
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -24,7 +25,7 @@ class TopicIdentifier:
|
|||||||
消息内容:{text}"""
|
消息内容:{text}"""
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model="Pro/deepseek-ai/DeepSeek-V3",
|
model=global_config.SILICONFLOW_MODEL_V3,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
max_tokens=10
|
max_tokens=10
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
from typing import Tuple, Union
|
|
||||||
import time
|
|
||||||
from nonebot import get_driver
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
class LLMModel:
|
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
|
||||||
self.model_name = model_name
|
|
||||||
self.params = kwargs
|
|
||||||
self.api_key = config.siliconflow_key
|
|
||||||
self.base_url = config.siliconflow_base_url
|
|
||||||
|
|
||||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
|
||||||
"""根据输入的提示生成模型的响应"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": 0.5,
|
|
||||||
**self.params
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
max_retries = 3
|
|
||||||
base_wait_time = 15 # 基础等待时间(秒)
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
|
||||||
|
|
||||||
if response.status_code == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
|
||||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
return f"请求失败: {str(e)}", ""
|
|
||||||
|
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
|
||||||
@@ -2,15 +2,17 @@ import os
|
|||||||
import requests
|
import requests
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
import time
|
import time
|
||||||
from ..chat.config import BotConfig
|
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from src.plugins.chat.config import BotConfig, global_config
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
class LLMModel:
|
class LLMModel:
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
self.api_key = config.siliconflow_key
|
self.api_key = config.siliconflow_key
|
||||||
@@ -21,7 +23,7 @@ class LLMModel:
|
|||||||
|
|
||||||
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
|
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
|
||||||
|
|
||||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||||
"""根据输入的提示生成模型的响应"""
|
"""根据输入的提示生成模型的响应"""
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
@@ -44,28 +46,28 @@ class LLMModel:
|
|||||||
|
|
||||||
for retry in range(max_retries):
|
for retry in range(max_retries):
|
||||||
try:
|
try:
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
if response.status_code == 429:
|
except Exception as e:
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
|
||||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
time.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
import os
|
||||||
import jieba
|
import jieba
|
||||||
from .llm_module import LLMModel
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import math
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
import sys
|
|
||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
|
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
@@ -169,8 +166,8 @@ class Memory_graph:
|
|||||||
class Hippocampus:
|
class Hippocampus:
|
||||||
def __init__(self,memory_graph:Memory_graph):
|
def __init__(self,memory_graph:Memory_graph):
|
||||||
self.memory_graph = memory_graph
|
self.memory_graph = memory_graph
|
||||||
self.llm_model = LLMModel()
|
self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5)
|
||||||
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
|
||||||
|
|
||||||
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
@@ -193,6 +190,24 @@ class Hippocampus:
|
|||||||
chat_text.append(chat_)
|
chat_text.append(chat_)
|
||||||
return chat_text
|
return chat_text
|
||||||
|
|
||||||
|
async def memory_compress(self, input_text, rate=1):
|
||||||
|
information_content = calculate_information_content(input_text)
|
||||||
|
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||||
|
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||||
|
topic_prompt = find_topic(input_text, topic_num)
|
||||||
|
topic_response = await self.llm_model.generate_response(topic_prompt)
|
||||||
|
# 检查 topic_response 是否为元组
|
||||||
|
if isinstance(topic_response, tuple):
|
||||||
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
|
else:
|
||||||
|
topics = topic_response.split(",")
|
||||||
|
compressed_memory = set()
|
||||||
|
for topic in topics:
|
||||||
|
topic_what_prompt = topic_what(input_text,topic)
|
||||||
|
topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt)
|
||||||
|
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||||
|
return compressed_memory
|
||||||
|
|
||||||
async def build_memory(self,chat_size=12):
|
async def build_memory(self,chat_size=12):
|
||||||
#最近消息获取频率
|
#最近消息获取频率
|
||||||
time_frequency = {'near':1,'mid':2,'far':2}
|
time_frequency = {'near':1,'mid':2,'far':2}
|
||||||
@@ -208,9 +223,7 @@ class Hippocampus:
|
|||||||
if input_text:
|
if input_text:
|
||||||
# 生成压缩后记忆
|
# 生成压缩后记忆
|
||||||
first_memory = set()
|
first_memory = set()
|
||||||
first_memory = self.memory_compress(input_text, 2.5)
|
first_memory = await self.memory_compress(input_text, 2.5)
|
||||||
# 延时防止访问超频
|
|
||||||
# time.sleep(5)
|
|
||||||
#将记忆加入到图谱中
|
#将记忆加入到图谱中
|
||||||
for topic, memory in first_memory:
|
for topic, memory in first_memory:
|
||||||
topics = segment_text(topic)
|
topics = segment_text(topic)
|
||||||
@@ -224,28 +237,6 @@ class Hippocampus:
|
|||||||
else:
|
else:
|
||||||
print(f"空消息 跳过")
|
print(f"空消息 跳过")
|
||||||
self.memory_graph.save_graph_to_db()
|
self.memory_graph.save_graph_to_db()
|
||||||
|
|
||||||
def memory_compress(self, input_text, rate=1):
|
|
||||||
information_content = calculate_information_content(input_text)
|
|
||||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
|
||||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
|
||||||
# print(topic_num)
|
|
||||||
topic_prompt = find_topic(input_text, topic_num)
|
|
||||||
topic_response = self.llm_model.generate_response(topic_prompt)
|
|
||||||
# 检查 topic_response 是否为元组
|
|
||||||
if isinstance(topic_response, tuple):
|
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
|
||||||
else:
|
|
||||||
topics = topic_response.split(",")
|
|
||||||
# print(topics)
|
|
||||||
compressed_memory = set()
|
|
||||||
for topic in topics:
|
|
||||||
if topic=='' or '[' in topic:
|
|
||||||
continue
|
|
||||||
topic_what_prompt = topic_what(input_text,topic)
|
|
||||||
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
|
|
||||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
|
||||||
return compressed_memory
|
|
||||||
|
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
|
|||||||
199
src/plugins/models/utils_model.py
Normal file
199
src/plugins/models/utils_model.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
from typing import Tuple, Union
|
||||||
|
from nonebot import get_driver
|
||||||
|
from ..chat.config import global_config
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
class LLM_request:
|
||||||
|
def __init__(self, model = global_config.llm_normal,**kwargs):
|
||||||
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
|
try:
|
||||||
|
self.api_key = getattr(config, model["key"])
|
||||||
|
self.base_url = getattr(config, model["base_url"])
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||||
|
self.model_name = model["name"]
|
||||||
|
self.params = kwargs
|
||||||
|
|
||||||
|
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||||
|
"""根据输入的提示生成模型的异步响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的chat/completions端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
||||||
|
"""根据输入的提示和图片生成模型的异步响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的chat/completions端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
||||||
|
"""同步方法:根据输入的提示和图片生成模型的响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的chat/completions端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from .schedule_llm_module import LLMModel
|
|
||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from ..chat.config import global_config
|
from src.plugins.chat.config import global_config
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -21,22 +21,27 @@ Database.initialize(
|
|||||||
|
|
||||||
class ScheduleGenerator:
|
class ScheduleGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if global_config.API_USING == "siliconflow":
|
#根据global_config.llm_normal这一字典配置指定模型
|
||||||
self.llm_scheduler = LLMModel(model_name="Pro/deepseek-ai/DeepSeek-V3")
|
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
|
||||||
elif global_config.API_USING == "deepseek":
|
self.llm_scheduler = LLM_request(model = global_config.llm_normal,temperature=0.9)
|
||||||
self.llm_scheduler = LLMModel(model_name="deepseek-chat",api_using="deepseek")
|
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
|
self.today_schedule_text = ""
|
||||||
|
self.today_schedule = {}
|
||||||
|
self.tomorrow_schedule_text = ""
|
||||||
|
self.tomorrow_schedule = {}
|
||||||
|
self.yesterday_schedule_text = ""
|
||||||
|
self.yesterday_schedule = {}
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
today = datetime.datetime.now()
|
today = datetime.datetime.now()
|
||||||
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
|
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
|
||||||
yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
|
yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
|
||||||
|
|
||||||
self.today_schedule_text, self.today_schedule = self.generate_daily_schedule(target_date=today)
|
self.today_schedule_text, self.today_schedule = await self.generate_daily_schedule(target_date=today)
|
||||||
|
self.tomorrow_schedule_text, self.tomorrow_schedule = await self.generate_daily_schedule(target_date=tomorrow,read_only=True)
|
||||||
self.tomorrow_schedule_text, self.tomorrow_schedule = self.generate_daily_schedule(target_date=tomorrow,read_only=True)
|
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(target_date=yesterday,read_only=True)
|
||||||
self.yesterday_schedule_text, self.yesterday_schedule = self.generate_daily_schedule(target_date=yesterday,read_only=True)
|
|
||||||
|
|
||||||
def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
|
async def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
|
||||||
if target_date is None:
|
if target_date is None:
|
||||||
target_date = datetime.datetime.now()
|
target_date = datetime.datetime.now()
|
||||||
|
|
||||||
@@ -60,7 +65,7 @@ class ScheduleGenerator:
|
|||||||
3. 晚上的计划和休息时间
|
3. 晚上的计划和休息时间
|
||||||
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床"。"""
|
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床"。"""
|
||||||
|
|
||||||
schedule_text, _ = self.llm_scheduler.generate_response(prompt)
|
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
|
||||||
# print(self.schedule_text)
|
# print(self.schedule_text)
|
||||||
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
from typing import Tuple, Union
|
|
||||||
from nonebot import get_driver
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
class LLMModel:
|
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs):
|
|
||||||
if api_using == "deepseek":
|
|
||||||
self.api_key = config.deep_seek_key
|
|
||||||
self.base_url = config.deep_seek_base_url
|
|
||||||
if model_name != "Pro/deepseek-ai/DeepSeek-R1":
|
|
||||||
self.model_name = model_name
|
|
||||||
else:
|
|
||||||
self.model_name = "deepseek-reasoner"
|
|
||||||
else:
|
|
||||||
self.api_key = config.siliconflow_key
|
|
||||||
self.base_url = config.siliconflow_base_url
|
|
||||||
self.model_name = model_name
|
|
||||||
self.params = kwargs
|
|
||||||
|
|
||||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
|
||||||
"""根据输入的提示生成模型的响应"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": 0.9,
|
|
||||||
**self.params
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
|
||||||
response.raise_for_status() # 检查响应状态
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
|
||||||
return content, reasoning_content # 返回内容和推理内容
|
|
||||||
return "没有返回结果", "" # 返回两个值
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
return f"请求失败: {str(e)}", "" # 返回错误信息和空字符串
|
|
||||||
|
|
||||||
# 示例用法
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model = LLMModel() # 默认使用 DeepSeek-V3 模型
|
|
||||||
prompt = "你好,你喜欢我吗?"
|
|
||||||
result, reasoning = model.generate_response(prompt)
|
|
||||||
print("回复内容:", result)
|
|
||||||
print("推理内容:", reasoning)
|
|
||||||
Reference in New Issue
Block a user