refactor: 重构配置模块

This commit is contained in:
Oct-autumn
2025-05-16 16:50:53 +08:00
parent 5d5033452d
commit 021e7f1a97
52 changed files with 902 additions and 1102 deletions

View File

@@ -1,64 +1,68 @@
import os
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from dataclasses import field, dataclass
import tomli
import tomlkit
import shutil
from datetime import datetime
from pathlib import Path
from packaging import version
from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet, InvalidSpecifier
from tomlkit import TOMLDocument
from tomlkit.items import Table
from src.common.logger_manager import get_logger
from rich.traceback import install
from src.config.config_base import ConfigBase
from src.config.official_configs import (
BotConfig,
ChatTargetConfig,
PersonalityConfig,
IdentityConfig,
PlatformsConfig,
ChatConfig,
NormalChatConfig,
FocusChatConfig,
EmojiConfig,
MemoryConfig,
MoodConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
ModelConfig,
)
install(extra_lines=3)
# 配置主程序日志格式
logger = get_logger("config")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
is_test = True
mai_version_main = "0.6.4"
mai_version_fix = "snapshot-1"
CONFIG_DIR = "config"
TEMPLATE_DIR = "template"
if mai_version_fix:
if is_test:
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
else:
mai_version = f"{mai_version_main}-{mai_version_fix}"
else:
if is_test:
mai_version = f"test-{mai_version_main}"
else:
mai_version = mai_version_main
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.7.0-snapshot.1"
def update_config():
# 获取根目录路径
root_dir = Path(__file__).parent.parent.parent
template_dir = root_dir / "template"
config_dir = root_dir / "config"
old_config_dir = config_dir / "old"
old_config_dir = f"{CONFIG_DIR}/old"
# 定义文件路径
template_path = template_dir / "bot_config_template.toml"
old_config_path = config_dir / "bot_config.toml"
new_config_path = config_dir / "bot_config.toml"
template_path = f"{TEMPLATE_DIR}/bot_config_template.toml"
old_config_path = f"{CONFIG_DIR}/bot_config.toml"
new_config_path = f"{CONFIG_DIR}/bot_config.toml"
# 检查配置文件是否存在
if not old_config_path.exists():
if not os.path.exists(old_config_path):
logger.info("配置文件不存在,从模板创建新配置")
# 创建文件夹
old_config_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(template_path, old_config_path)
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回
return quit()
quit()
# 读取旧配置文件和模板文件
with open(old_config_path, "r", encoding="utf-8") as f:
@@ -75,13 +79,15 @@ def update_config():
return
else:
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录如果不存在
old_config_dir.mkdir(exist_ok=True)
os.makedirs(old_config_dir, exist_ok=True)
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
old_backup_path = f"{old_config_dir}/bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
@@ -91,24 +97,23 @@ def update_config():
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}")
# 递归更新配置
def update_dict(target, source):
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
if not value:
target[key] = tomlkit.array()
else:
target[key] = tomlkit.array(value)
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
@@ -123,619 +128,57 @@ def update_config():
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成")
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
quit()
@dataclass
class BotConfig:
"""机器人配置类"""
INNER_VERSION: Version = None
MAI_VERSION: str = mai_version # 硬编码的版本信息
# bot
BOT_QQ: Optional[str] = "114514"
BOT_NICKNAME: Optional[str] = None
BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它
# group
talk_allowed_groups = set()
talk_frequency_down_groups = set()
ban_user_id = set()
# personality
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋
personality_sides: List[str] = field(
default_factory=lambda: [
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面",
]
)
expression_style = "描述麦麦说话的表达风格,表达习惯"
# identity
identity_detail: List[str] = field(
default_factory=lambda: [
"身份特点",
"身份特点",
]
)
height: int = 170 # 身高 单位厘米
weight: int = 50 # 体重 单位千克
age: int = 20 # 年龄 单位岁
gender: str = "" # 性别
appearance: str = "用几句话描述外貌特征" # 外貌特征
# chat
allow_focus_mode: bool = True # 是否允许专注聊天状态
base_normal_chat_num: int = 3 # 最多允许多少个群进行普通聊天
base_focused_chat_num: int = 2 # 最多允许多少个群进行专注聊天
observation_context_size: int = 12 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
message_buffer: bool = True # 消息缓冲器
ban_words = set()
ban_msgs_regex = set()
# focus_chat
reply_trigger_threshold: float = 3.0 # 心流聊天触发阈值,越低越容易触发
default_decay_rate_per_second: float = 0.98 # 默认衰减率,越大衰减越慢
consecutive_no_reply_threshold = 3
compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5
compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除
# normal_chat
model_reasoning_probability: float = 0.7 # 麦麦回答时选择推理模型(主要)模型概率
model_normal_probability: float = 0.3 # 麦麦回答时选择一般模型(次要)模型概率
emoji_chance: float = 0.2 # 发送表情包的基础概率
thinking_timeout: int = 120 # 思考时间
willing_mode: str = "classical" # 意愿模式
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
at_bot_inevitable_reply: bool = False # @bot 必然回复
# emoji
max_emoji_num: int = 200 # 表情包最大数量
max_reach_deletion: bool = True # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
save_pic: bool = False # 是否保存图片
save_emoji: bool = False # 是否保存表情包
steal_emoji: bool = True # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
EMOJI_CHECK: bool = False # 是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
# memory
build_memory_interval: int = 600 # 记忆构建间隔(秒)
memory_build_distribution: list = field(
default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
) # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
build_memory_sample_num: int = 10 # 记忆构建采样数量
build_memory_sample_length: int = 20 # 记忆构建采样长度
memory_compress_rate: float = 0.1 # 记忆压缩率
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
consolidate_memory_interval: int = 1000 # 记忆整合间隔(秒)
consolidation_similarity_threshold: float = 0.7 # 相似度阈值
consolidate_memory_percentage: float = 0.01 # 检查节点比例
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
# mood
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
# keywords
keywords_reaction_rules = [] # 关键词回复规则
# chinese_typo
chinese_typo_enable = True # 是否启用中文错别字生成器
chinese_typo_error_rate = 0.03 # 单字替换概率
chinese_typo_min_freq = 7 # 最小字频阈值
chinese_typo_tone_error_rate = 0.2 # 声调错误概率
chinese_typo_word_replace_rate = 0.02 # 整词替换概率
# response_splitter
enable_kaomoji_protection = False # 是否启用颜文字保护
enable_response_splitter = True # 是否启用回复分割器
response_max_length = 100 # 回复允许的最大长度
response_max_sentence_num = 3 # 回复允许的最大句子数
model_max_output_length: int = 800 # 最大回复长度
# remote
remote_enable: bool = True # 是否启用远程控制
# experimental
enable_friend_chat: bool = False # 是否启用好友聊天
# enable_think_flow: bool = False # 是否启用思考流程
talk_allowed_private = set()
enable_pfc_chatting: bool = False # 是否启用PFC聊天
# 模型配置
llm_reasoning: dict[str, str] = field(default_factory=lambda: {})
# llm_reasoning_minor: dict[str, str] = field(default_factory=lambda: {})
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {})
llm_summary: Dict[str, str] = field(default_factory=lambda: {})
embedding: Dict[str, str] = field(default_factory=lambda: {})
vlm: Dict[str, str] = field(default_factory=lambda: {})
moderation: Dict[str, str] = field(default_factory=lambda: {})
llm_observation: Dict[str, str] = field(default_factory=lambda: {})
llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {})
llm_heartflow: Dict[str, str] = field(default_factory=lambda: {})
llm_tool_use: Dict[str, str] = field(default_factory=lambda: {})
llm_plan: Dict[str, str] = field(default_factory=lambda: {})
api_urls: Dict[str, str] = field(default_factory=lambda: {})
@staticmethod
def get_config_dir() -> str:
"""获取配置文件目录"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", ".."))
config_dir = os.path.join(root_dir, "config")
if not os.path.exists(config_dir):
os.makedirs(config_dir)
return config_dir
@classmethod
def convert_to_specifierset(cls, value: str) -> SpecifierSet:
"""将 字符串 版本表达式转换成 SpecifierSet
Args:
value[str]: 版本表达式(字符串)
Returns:
SpecifierSet
"""
try:
converted = SpecifierSet(value)
except InvalidSpecifier:
logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
exit(1)
return converted
@classmethod
def get_config_version(cls, toml: dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
"""
if "inner" in toml:
try:
config_version: str = toml["inner"]["version"]
except KeyError as e:
logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
else:
toml["inner"] = {"version": "0.0.0"}
config_version = toml["inner"]["version"]
try:
ver = version.parse(config_version)
except InvalidVersion as e:
logger.error(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
"请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
"本项目在不同的版本下有不同的模板,请注意识别"
)
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
return ver
@classmethod
def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置"""
config = cls()
def personality(parent: dict):
personality_config = parent["personality"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
config.personality_core = personality_config.get("personality_core", config.personality_core)
config.personality_sides = personality_config.get("personality_sides", config.personality_sides)
if config.INNER_VERSION in SpecifierSet(">=1.7.0"):
config.expression_style = personality_config.get("expression_style", config.expression_style)
def identity(parent: dict):
identity_config = parent["identity"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
config.identity_detail = identity_config.get("identity_detail", config.identity_detail)
config.height = identity_config.get("height", config.height)
config.weight = identity_config.get("weight", config.weight)
config.age = identity_config.get("age", config.age)
config.gender = identity_config.get("gender", config.gender)
config.appearance = identity_config.get("appearance", config.appearance)
def emoji(parent: dict):
emoji_config = parent["emoji"]
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
if config.INNER_VERSION in SpecifierSet(">=1.1.1"):
config.max_emoji_num = emoji_config.get("max_emoji_num", config.max_emoji_num)
config.max_reach_deletion = emoji_config.get("max_reach_deletion", config.max_reach_deletion)
if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
config.save_pic = emoji_config.get("save_pic", config.save_pic)
config.save_emoji = emoji_config.get("save_emoji", config.save_emoji)
config.steal_emoji = emoji_config.get("steal_emoji", config.steal_emoji)
def bot(parent: dict):
# 机器人基础配置
bot_config = parent["bot"]
bot_qq = bot_config.get("qq")
config.BOT_QQ = str(bot_qq)
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
def chat(parent: dict):
chat_config = parent["chat"]
config.allow_focus_mode = chat_config.get("allow_focus_mode", config.allow_focus_mode)
config.base_normal_chat_num = chat_config.get("base_normal_chat_num", config.base_normal_chat_num)
config.base_focused_chat_num = chat_config.get("base_focused_chat_num", config.base_focused_chat_num)
config.observation_context_size = chat_config.get(
"observation_context_size", config.observation_context_size
)
config.message_buffer = chat_config.get("message_buffer", config.message_buffer)
config.ban_words = chat_config.get("ban_words", config.ban_words)
for r in chat_config.get("ban_msgs_regex", config.ban_msgs_regex):
config.ban_msgs_regex.add(re.compile(r))
def normal_chat(parent: dict):
normal_chat_config = parent["normal_chat"]
config.model_reasoning_probability = normal_chat_config.get(
"model_reasoning_probability", config.model_reasoning_probability
)
config.model_normal_probability = normal_chat_config.get(
"model_normal_probability", config.model_normal_probability
)
config.emoji_chance = normal_chat_config.get("emoji_chance", config.emoji_chance)
config.thinking_timeout = normal_chat_config.get("thinking_timeout", config.thinking_timeout)
config.willing_mode = normal_chat_config.get("willing_mode", config.willing_mode)
config.response_willing_amplifier = normal_chat_config.get(
"response_willing_amplifier", config.response_willing_amplifier
)
config.response_interested_rate_amplifier = normal_chat_config.get(
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
)
config.down_frequency_rate = normal_chat_config.get("down_frequency_rate", config.down_frequency_rate)
config.emoji_response_penalty = normal_chat_config.get(
"emoji_response_penalty", config.emoji_response_penalty
)
config.mentioned_bot_inevitable_reply = normal_chat_config.get(
"mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
)
config.at_bot_inevitable_reply = normal_chat_config.get(
"at_bot_inevitable_reply", config.at_bot_inevitable_reply
)
def focus_chat(parent: dict):
focus_chat_config = parent["focus_chat"]
config.compressed_length = focus_chat_config.get("compressed_length", config.compressed_length)
config.compress_length_limit = focus_chat_config.get("compress_length_limit", config.compress_length_limit)
config.reply_trigger_threshold = focus_chat_config.get(
"reply_trigger_threshold", config.reply_trigger_threshold
)
config.default_decay_rate_per_second = focus_chat_config.get(
"default_decay_rate_per_second", config.default_decay_rate_per_second
)
config.consecutive_no_reply_threshold = focus_chat_config.get(
"consecutive_no_reply_threshold", config.consecutive_no_reply_threshold
)
def model(parent: dict):
# 加载模型配置
model_config: dict = parent["model"]
config_list = [
"llm_reasoning",
# "llm_reasoning_minor",
"llm_normal",
"llm_topic_judge",
"llm_summary",
"vlm",
"embedding",
"llm_tool_use",
"llm_observation",
"llm_sub_heartflow",
"llm_plan",
"llm_heartflow",
"llm_PFC_action_planner",
"llm_PFC_chat",
"llm_PFC_reply_checker",
]
for item in config_list:
if item in model_config:
cfg_item: dict = model_config[item]
# base_url 的例子: SILICONFLOW_BASE_URL
# key 的例子: SILICONFLOW_KEY
cfg_target = {
"name": "",
"base_url": "",
"key": "",
"stream": False,
"pri_in": 0,
"pri_out": 0,
"temp": 0.7,
}
if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
cfg_target = cfg_item
elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
stable_item = ["name", "pri_in", "pri_out"]
stream_item = ["stream"]
if config.INNER_VERSION in SpecifierSet(">=1.0.1"):
stable_item.append("stream")
pricing_item = ["pri_in", "pri_out"]
# 从配置中原始拷贝稳定字段
for i in stable_item:
# 如果 字段 属于计费项 且获取不到,那默认值是 0
if i in pricing_item and i not in cfg_item:
cfg_target[i] = 0
if i in stream_item and i not in cfg_item:
cfg_target[i] = False
else:
# 没有特殊情况则原样复制
try:
cfg_target[i] = cfg_item[i]
except KeyError as e:
logger.error(f"{item} 中的必要字段不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
# 如果配置中有temp参数就使用配置中的值
if "temp" in cfg_item:
cfg_target["temp"] = cfg_item["temp"]
else:
# 如果没有temp参数就删除默认值
cfg_target.pop("temp", None)
provider = cfg_item.get("provider")
if provider is None:
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
cfg_target["base_url"] = f"{provider}_BASE_URL"
cfg_target["key"] = f"{provider}_KEY"
# 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
setattr(config, item, cfg_target)
else:
logger.error(f"模型 {item} 在config中不存在请检查或尝试更新配置文件")
raise KeyError(f"模型 {item} 在config中不存在请检查或尝试更新配置文件")
def memory(parent: dict):
memory_config = parent["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
config.memory_forget_percentage = memory_config.get(
"memory_forget_percentage", config.memory_forget_percentage
)
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
config.memory_build_distribution = memory_config.get(
"memory_build_distribution", config.memory_build_distribution
)
config.build_memory_sample_num = memory_config.get(
"build_memory_sample_num", config.build_memory_sample_num
)
config.build_memory_sample_length = memory_config.get(
"build_memory_sample_length", config.build_memory_sample_length
)
if config.INNER_VERSION in SpecifierSet(">=1.5.1"):
config.consolidate_memory_interval = memory_config.get(
"consolidate_memory_interval", config.consolidate_memory_interval
)
config.consolidation_similarity_threshold = memory_config.get(
"consolidation_similarity_threshold", config.consolidation_similarity_threshold
)
config.consolidate_memory_percentage = memory_config.get(
"consolidate_memory_percentage", config.consolidate_memory_percentage
)
def remote(parent: dict):
remote_config = parent["remote"]
config.remote_enable = remote_config.get("enable", config.remote_enable)
def mood(parent: dict):
mood_config = parent["mood"]
config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate)
config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor)
def keywords_reaction(parent: dict):
keywords_reaction_config = parent["keywords_reaction"]
if keywords_reaction_config.get("enable", False):
config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules)
for rule in config.keywords_reaction_rules:
if rule.get("enable", False) and "regex" in rule:
rule["regex"] = [re.compile(r) for r in rule.get("regex", [])]
def chinese_typo(parent: dict):
chinese_typo_config = parent["chinese_typo"]
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
config.chinese_typo_tone_error_rate = chinese_typo_config.get(
"tone_error_rate", config.chinese_typo_tone_error_rate
)
config.chinese_typo_word_replace_rate = chinese_typo_config.get(
"word_replace_rate", config.chinese_typo_word_replace_rate
)
def response_splitter(parent: dict):
response_splitter_config = parent["response_splitter"]
config.enable_response_splitter = response_splitter_config.get(
"enable_response_splitter", config.enable_response_splitter
)
config.response_max_length = response_splitter_config.get("response_max_length", config.response_max_length)
config.response_max_sentence_num = response_splitter_config.get(
"response_max_sentence_num", config.response_max_sentence_num
)
if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
config.enable_kaomoji_protection = response_splitter_config.get(
"enable_kaomoji_protection", config.enable_kaomoji_protection
)
if config.INNER_VERSION in SpecifierSet(">=1.6.0"):
config.model_max_output_length = response_splitter_config.get(
"model_max_output_length", config.model_max_output_length
)
def groups(parent: dict):
groups_config = parent["groups"]
# config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
config.talk_allowed_groups = set(str(group) for group in groups_config.get("talk_allowed", []))
# config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
config.talk_frequency_down_groups = set(
str(group) for group in groups_config.get("talk_frequency_down", [])
)
# config.ban_user_id = set(groups_config.get("ban_user_id", []))
config.ban_user_id = set(str(user) for user in groups_config.get("ban_user_id", []))
def experimental(parent: dict):
experimental_config = parent["experimental"]
config.enable_friend_chat = experimental_config.get("enable_friend_chat", config.enable_friend_chat)
# config.enable_think_flow = experimental_config.get("enable_think_flow", config.enable_think_flow)
config.talk_allowed_private = set(str(user) for user in experimental_config.get("talk_allowed_private", []))
if config.INNER_VERSION in SpecifierSet(">=1.1.0"):
config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting)
# 版本表达式:>=1.0.0,<2.0.0
# 允许字段func: method, support: str, notice: str, necessary: bool
# 如果使用 notice 字段,在该组配置加载时,会展示该字段对用户的警示
# 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
# 正常执行程序,但是会看到这条自定义提示
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
# 主版本号:当你做了不兼容的 API 修改,
# 次版本号:当你做了向下兼容的功能性新增,
# 修订号:当你做了向下兼容的问题修正。
# 先行版本号及版本编译信息可以加到"主版本号.次版本号.修订号"的后面,作为延伸。
# 如果你做了break的修改就应该改动主版本号
# 如果做了一个兼容修改,就不应该要求这个选项是必须的!
include_configs = {
"bot": {"func": bot, "support": ">=0.0.0"},
"groups": {"func": groups, "support": ">=0.0.0"},
"personality": {"func": personality, "support": ">=0.0.0"},
"identity": {"func": identity, "support": ">=1.2.4"},
"emoji": {"func": emoji, "support": ">=0.0.0"},
"model": {"func": model, "support": ">=0.0.0"},
"memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
"mood": {"func": mood, "support": ">=0.0.0"},
"remote": {"func": remote, "support": ">=0.0.10", "necessary": False},
"keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
"chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
"response_splitter": {"func": response_splitter, "support": ">=0.0.11", "necessary": False},
"experimental": {"func": experimental, "support": ">=0.0.11", "necessary": False},
"chat": {"func": chat, "support": ">=1.6.0", "necessary": False},
"normal_chat": {"func": normal_chat, "support": ">=1.6.0", "necessary": False},
"focus_chat": {"func": focus_chat, "support": ">=1.6.0", "necessary": False},
}
# 原地修改,将 字符串版本表达式 转换成 版本对象
for key in include_configs:
item_support = include_configs[key]["support"]
include_configs[key]["support"] = cls.convert_to_specifierset(item_support)
if os.path.exists(config_path):
with open(config_path, "rb") as f:
try:
toml_dict = tomli.load(f)
except tomli.TOMLDecodeError as e:
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
exit(1)
# 获取配置文件版本
config.INNER_VERSION = cls.get_config_version(toml_dict)
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
for key in include_configs:
if key in toml_dict:
group_specifierset: SpecifierSet = include_configs[key]["support"]
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifierset:
# 如果版本在支持范围内,检查是否存在通知
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
include_configs[key]["func"](toml_dict)
else:
# 如果版本不在支持范围内,崩溃并提示用户
logger.error(
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
f"当前程序仅支持以下版本范围: {group_specifierset}"
)
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
# identity_detail字段非空检查
if not config.identity_detail:
logger.error("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
raise ValueError("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
logger.success(f"成功加载配置文件: {config_path}")
return config
class Config(ConfigBase):
"""配置类"""
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
bot: BotConfig
chat_target: ChatTargetConfig
personality: PersonalityConfig
identity: IdentityConfig
platforms: PlatformsConfig
chat: ChatConfig
normal_chat: NormalChatConfig
focus_chat: FocusChatConfig
emoji: EmojiConfig
memory: MemoryConfig
mood: MoodConfig
keyword_reaction: KeywordReactionConfig
chinese_typo: ChineseTypoConfig
response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig
experimental: ExperimentalConfig
model: ModelConfig
def load_config(config_path: str) -> Config:
"""
加载配置文件
:param config_path: 配置文件路径
:return: Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象
try:
return Config.from_dict(config_data)
except Exception as e:
logger.critical("配置文件解析失败")
raise e
# 获取配置文件路径
logger.info(f"MaiCore当前版本: {mai_version}")
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config()
bot_config_floder_path = BotConfig.get_config_dir()
logger.info(f"正在品鉴配置文件目录: {bot_config_floder_path}")
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
if os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
logger.info(f"异常的新鲜,异常的美味: {bot_config_path}")
else:
# 配置文件不存在
logger.error("配置文件不存在,请检查路径: {bot_config_path}")
raise FileNotFoundError(f"配置文件不存在: {bot_config_path}")
global_config = BotConfig.load_config(config_path=bot_config_path)
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=f"{CONFIG_DIR}/bot_config.toml")
logger.info("非常的新鲜,非常的美味!")

116
src/config/config_base.py Normal file
View File

@@ -0,0 +1,116 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type)
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"

View File

@@ -0,0 +1,399 @@
from dataclasses import dataclass, field
from typing import Any
from src.config.config_base import ConfigBase
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
@dataclass
class BotConfig(ConfigBase):
"""QQ机器人配置类"""
qq_account: str
"""QQ账号"""
nickname: str
"""昵称"""
alias_names: list[str] = field(default_factory=lambda: [])
"""别名列表"""
@dataclass
class ChatTargetConfig(ConfigBase):
"""
聊天目标配置类
此类中有聊天的群组和用户配置
"""
talk_allowed_groups: set[str] = field(default_factory=lambda: set())
"""允许聊天的群组列表"""
talk_frequency_down_groups: set[str] = field(default_factory=lambda: set())
"""降低聊天频率的群组列表"""
ban_user_id: set[str] = field(default_factory=lambda: set())
"""禁止聊天的用户列表"""
@dataclass
class PersonalityConfig(ConfigBase):
"""人格配置类"""
personality_core: str
"""核心人格"""
expression_style: str
"""表达风格"""
personality_sides: list[str] = field(default_factory=lambda: [])
"""人格侧写"""
@dataclass
class IdentityConfig(ConfigBase):
"""个体特征配置类"""
height: int = 170
"""身高(单位:厘米)"""
weight: float = 50
"""体重(单位:千克)"""
age: int = 18
"""年龄(单位:岁)"""
gender: str = ""
"""性别(男/女)"""
appearance: str = "可爱"
"""外貌描述"""
identity_detail: list[str] = field(default_factory=lambda: [])
"""身份特征"""
@dataclass
class PlatformsConfig(ConfigBase):
"""平台配置类"""
qq: str
"""QQ适配器连接URL配置"""
@dataclass
class ChatConfig(ConfigBase):
"""聊天配置类"""
allow_focus_mode: bool = True
"""是否允许专注聊天状态"""
base_normal_chat_num: int = 3
"""最多允许多少个群进行普通聊天"""
base_focused_chat_num: int = 2
"""最多允许多少个群进行专注聊天"""
observation_context_size: int = 12
"""可观察到的最长上下文大小,超过这个值的上下文会被压缩"""
message_buffer: bool = True
"""消息缓冲器"""
ban_words: set[str] = field(default_factory=lambda: set())
"""过滤词列表"""
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class NormalChatConfig(ConfigBase):
"""普通聊天配置类"""
reasoning_model_probability: float = 0.3
"""
发言时选择推理模型的概率0-1之间
选择普通模型的概率为 1 - reasoning_normal_model_probability
"""
emoji_chance: float = 0.2
"""发送表情包的基础概率"""
thinking_timeout: int = 120
"""最长思考时间"""
willing_mode: str = "classical"
"""意愿模式"""
response_willing_amplifier: float = 1.0
"""回复意愿放大系数"""
response_interested_rate_amplifier: float = 1.0
"""回复兴趣度放大系数"""
down_frequency_rate: float = 3.0
"""降低回复频率的群组回复意愿降低系数"""
emoji_response_penalty: float = 0.0
"""表情包回复惩罚系数"""
mentioned_bot_inevitable_reply: bool = False
"""提及 bot 必然回复"""
at_bot_inevitable_reply: bool = False
"""@bot 必然回复"""
@dataclass
class FocusChatConfig(ConfigBase):
"""专注聊天配置类"""
reply_trigger_threshold: float = 3.0
"""心流聊天触发阈值,越低越容易触发"""
default_decay_rate_per_second: float = 0.98
"""默认衰减率,越大衰减越快"""
consecutive_no_reply_threshold: int = 3
"""连续不回复的次数阈值"""
compressed_length: int = 5
"""心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5"""
compress_length_limit: int = 5
"""最多压缩份数,超过该数值的压缩上下文会被删除"""
@dataclass
class EmojiConfig(ConfigBase):
"""表情包配置类"""
max_reg_num: int = 200
"""表情包最大注册数量"""
do_replace: bool = True
"""达到最大注册数量时替换旧表情包"""
check_interval: int = 120
"""表情包检查间隔(分钟)"""
save_pic: bool = False
"""是否保存图片"""
cache_emoji: bool = True
"""是否缓存表情包"""
steal_emoji: bool = True
"""是否偷取表情包,让麦麦可以发送她保存的这些表情包"""
content_filtration: bool = False
"""是否开启表情包过滤"""
filtration_prompt: str = "符合公序良俗"
"""表情包过滤要求"""
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
memory_build_interval: int = 600
"""记忆构建间隔(秒)"""
memory_build_distribution: tuple[
float,
float,
float,
float,
float,
float,
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
"""记忆构建分布参数分布1均值标准差权重分布2均值标准差权重"""
memory_build_sample_num: int = 8
"""记忆构建采样数量"""
memory_build_sample_length: int = 40
"""记忆构建采样长度"""
memory_compress_rate: float = 0.1
"""记忆压缩率"""
forget_memory_interval: int = 1000
"""记忆遗忘间隔(秒)"""
memory_forget_time: int = 24
"""记忆遗忘时间(小时)"""
memory_forget_percentage: float = 0.01
"""记忆遗忘比例"""
consolidate_memory_interval: int = 1000
"""记忆整合间隔(秒)"""
consolidation_similarity_threshold: float = 0.7
"""整合相似度阈值"""
consolidate_memory_percentage: float = 0.01
"""整合检查节点比例"""
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
"""不允许记忆的词列表"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
mood_update_interval: int = 1
"""情绪更新间隔(秒)"""
mood_decay_rate: float = 0.95
"""情绪衰减率"""
mood_intensity_factor: float = 0.7
"""情绪强度因子"""
@dataclass
class KeywordRuleConfig(ConfigBase):
"""关键词规则配置类"""
enable: bool = True
"""是否启用关键词规则"""
keywords: list[str] = field(default_factory=lambda: [])
"""关键词列表"""
regex: list[str] = field(default_factory=lambda: [])
"""正则表达式列表"""
reaction: str = ""
"""关键词触发的反应"""
@dataclass
class KeywordReactionConfig(ConfigBase):
"""关键词配置类"""
enable: bool = True
"""是否启用关键词反应"""
rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
"""关键词反应规则列表"""
@dataclass
class ChineseTypoConfig(ConfigBase):
"""中文错别字配置类"""
enable: bool = True
"""是否启用中文错别字生成器"""
error_rate: float = 0.01
"""单字替换概率"""
min_freq: int = 9
"""最小字频阈值"""
tone_error_rate: float = 0.1
"""声调错误概率"""
word_replace_rate: float = 0.006
"""整词替换概率"""
@dataclass
class ResponseSplitterConfig(ConfigBase):
"""回复分割器配置类"""
enable: bool = True
"""是否启用回复分割器"""
max_length: int = 256
"""回复允许的最大长度"""
max_sentence_num: int = 3
"""回复允许的最大句子数"""
enable_kaomoji_protection: bool = False
"""是否启用颜文字保护"""
@dataclass
class TelemetryConfig(ConfigBase):
"""遥测配置类"""
enable: bool = True
"""是否启用遥测"""
@dataclass
class ExperimentalConfig(ConfigBase):
"""实验功能配置类"""
enable_friend_chat: bool = False
"""是否启用好友聊天"""
talk_allowed_private: set[str] = field(default_factory=lambda: set())
"""允许聊天的私聊列表"""
pfc_chatting: bool = False
"""是否启用PFC"""
@dataclass
class ModelConfig(ConfigBase):
"""模型配置类"""
model_max_output_length: int = 800 # 最大回复长度
reasoning: dict[str, Any] = field(default_factory=lambda: {})
"""推理模型配置"""
normal: dict[str, Any] = field(default_factory=lambda: {})
"""普通模型配置"""
topic_judge: dict[str, Any] = field(default_factory=lambda: {})
"""主题判断模型配置"""
summary: dict[str, Any] = field(default_factory=lambda: {})
"""摘要模型配置"""
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
heartflow: dict[str, Any] = field(default_factory=lambda: {})
"""心流模型配置"""
observation: dict[str, Any] = field(default_factory=lambda: {})
"""观察模型配置"""
sub_heartflow: dict[str, Any] = field(default_factory=lambda: {})
"""子心流模型配置"""
plan: dict[str, Any] = field(default_factory=lambda: {})
"""计划模型配置"""
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
pfc_action_planner: dict[str, Any] = field(default_factory=lambda: {})
"""PFC动作规划模型配置"""
pfc_chat: dict[str, Any] = field(default_factory=lambda: {})
"""PFC聊天模型配置"""
pfc_reply_checker: dict[str, Any] = field(default_factory=lambda: {})
"""PFC回复检查模型配置"""
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""工具使用模型配置"""