196 lines
8.9 KiB
Python
196 lines
8.9 KiB
Python
from dataclasses import dataclass, field
|
||
from typing import Dict, Any, Optional, Set
|
||
import os
|
||
from nonebot.log import logger, default_format
|
||
import logging
|
||
import configparser
|
||
import tomli
|
||
import sys
|
||
from loguru import logger
|
||
from nonebot import get_driver
|
||
|
||
|
||
|
||
@dataclass
|
||
class BotConfig:
|
||
"""机器人配置类"""
|
||
BOT_QQ: Optional[int] = 1
|
||
BOT_NICKNAME: Optional[str] = None
|
||
|
||
# 消息处理相关配置
|
||
MIN_TEXT_LENGTH: int = 2 # 最小处理文本长度
|
||
MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数
|
||
emoji_chance: float = 0.2 # 发送表情包的基础概率
|
||
|
||
ENABLE_PIC_TRANSLATE: bool = True # 是否启用图片翻译
|
||
|
||
talk_allowed_groups = set()
|
||
talk_frequency_down_groups = set()
|
||
ban_user_id = set()
|
||
|
||
build_memory_interval: int = 60 # 记忆构建间隔(秒)
|
||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||
|
||
ban_words = set()
|
||
|
||
# 模型配置
|
||
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
|
||
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
||
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||
embedding: Dict[str, str] = field(default_factory=lambda: {})
|
||
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
||
|
||
# 主题提取配置
|
||
topic_extract: str = 'snownlp' # 只支持jieba,snownlp,llm
|
||
llm_topic_extract: Dict[str, str] = field(default_factory=lambda: {})
|
||
|
||
API_USING: str = "siliconflow" # 使用的API
|
||
API_PAID: bool = False # 是否使用付费API
|
||
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
||
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
||
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
||
|
||
enable_advance_output: bool = False # 是否启用高级输出
|
||
enable_kuuki_read: bool = True # 是否启用读空气功能
|
||
|
||
# 默认人设
|
||
PROMPT_PERSONALITY=[
|
||
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
|
||
"是一个女大学生,你有黑色头发,你会刷小红书"
|
||
]
|
||
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书"
|
||
|
||
@staticmethod
|
||
def get_config_dir() -> str:
|
||
"""获取配置文件目录"""
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||
config_dir = os.path.join(root_dir, 'config')
|
||
if not os.path.exists(config_dir):
|
||
os.makedirs(config_dir)
|
||
return config_dir
|
||
|
||
|
||
@classmethod
|
||
def load_config(cls, config_path: str = None) -> "BotConfig":
|
||
"""从TOML配置文件加载配置"""
|
||
config = cls()
|
||
if os.path.exists(config_path):
|
||
with open(config_path, "rb") as f:
|
||
toml_dict = tomli.load(f)
|
||
|
||
if 'personality' in toml_dict:
|
||
personality_config=toml_dict['personality']
|
||
personality=personality_config.get('prompt_personality')
|
||
if len(personality) >= 2:
|
||
print(f"载入自定义人格:{personality}")
|
||
config.PROMPT_PERSONALITY=personality_config.get('prompt_personality',config.PROMPT_PERSONALITY)
|
||
print(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}")
|
||
config.PROMPT_SCHEDULE_GEN=personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)
|
||
|
||
if "emoji" in toml_dict:
|
||
emoji_config = toml_dict["emoji"]
|
||
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
|
||
config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL)
|
||
|
||
if "cq_code" in toml_dict:
|
||
cq_code_config = toml_dict["cq_code"]
|
||
config.ENABLE_PIC_TRANSLATE = cq_code_config.get("enable_pic_translate", config.ENABLE_PIC_TRANSLATE)
|
||
|
||
# 机器人基础配置
|
||
if "bot" in toml_dict:
|
||
bot_config = toml_dict["bot"]
|
||
bot_qq = bot_config.get("qq")
|
||
config.BOT_QQ = int(bot_qq)
|
||
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
|
||
|
||
if "response" in toml_dict:
|
||
response_config = toml_dict["response"]
|
||
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
|
||
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
||
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
||
config.API_USING = response_config.get("api_using", config.API_USING)
|
||
config.API_PAID = response_config.get("api_paid", config.API_PAID)
|
||
|
||
# 加载模型配置
|
||
if "model" in toml_dict:
|
||
model_config = toml_dict["model"]
|
||
|
||
if "llm_reasoning" in model_config:
|
||
config.llm_reasoning = model_config["llm_reasoning"]
|
||
|
||
if "llm_reasoning_minor" in model_config:
|
||
config.llm_reasoning_minor = model_config["llm_reasoning_minor"]
|
||
|
||
if "llm_normal" in model_config:
|
||
config.llm_normal = model_config["llm_normal"]
|
||
config.llm_topic_extract = config.llm_normal
|
||
|
||
if "llm_normal_minor" in model_config:
|
||
config.llm_normal_minor = model_config["llm_normal_minor"]
|
||
|
||
if "vlm" in model_config:
|
||
config.vlm = model_config["vlm"]
|
||
|
||
if "embedding" in model_config:
|
||
config.embedding = model_config["embedding"]
|
||
|
||
if 'topic' in toml_dict:
|
||
topic_config=toml_dict['topic']
|
||
if 'topic_extract' in topic_config:
|
||
config.topic_extract=topic_config.get('topic_extract',config.topic_extract)
|
||
print(f"载入自定义主题提取为{config.topic_extract}")
|
||
if config.topic_extract=='llm' and 'llm_topic' in topic_config:
|
||
config.llm_topic_extract=topic_config['llm_topic']
|
||
print(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}")
|
||
|
||
# 消息配置
|
||
if "message" in toml_dict:
|
||
msg_config = toml_dict["message"]
|
||
config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH)
|
||
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
|
||
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
|
||
config.ban_words=msg_config.get("ban_words",config.ban_words)
|
||
|
||
if "memory" in toml_dict:
|
||
memory_config = toml_dict["memory"]
|
||
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
|
||
|
||
# 群组配置
|
||
if "groups" in toml_dict:
|
||
groups_config = toml_dict["groups"]
|
||
config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
|
||
config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
|
||
config.ban_user_id = set(groups_config.get("ban_user_id", []))
|
||
|
||
if "others" in toml_dict:
|
||
others_config = toml_dict["others"]
|
||
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
|
||
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
|
||
|
||
logger.success(f"成功加载配置文件: {config_path}")
|
||
|
||
return config
|
||
|
||
# 获取配置文件路径
|
||
|
||
bot_config_floder_path = BotConfig.get_config_dir()
|
||
print(f"正在品鉴配置文件目录: {bot_config_floder_path}")
|
||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
|
||
if not os.path.exists(bot_config_path):
|
||
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||
logger.info("使用bot配置文件")
|
||
else:
|
||
logger.info("已找到开发bot配置文件")
|
||
|
||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||
|
||
|
||
if not global_config.enable_advance_output:
|
||
# logger.remove()
|
||
pass
|
||
|