初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
parent ff7d1177fa
commit 2d4745cd58
257 changed files with 69069 additions and 0 deletions

View File

@@ -0,0 +1,139 @@
from dataclasses import dataclass, field
from .config_base import ConfigBase
@dataclass
class APIProvider(ConfigBase):
"""API提供商配置类"""
name: str
"""API提供商名称"""
base_url: str
"""API基础URL"""
api_key: str = field(default_factory=str, repr=False)
"""API密钥列表"""
client_type: str = field(default="openai")
"""客户端类型如openai/google等默认为openai"""
max_retry: int = 2
"""最大重试次数单个模型API调用失败最多重试的次数"""
timeout: int = 10
"""API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位"""
retry_interval: int = 10
"""重试间隔如果API调用失败重试的间隔时间单位"""
def get_api_key(self) -> str:
return self.api_key
def __post_init__(self):
"""确保api_key在repr中不被显示"""
if not self.api_key:
raise ValueError("API密钥不能为空请在配置中设置有效的API密钥。")
if not self.base_url and self.client_type != "gemini":
raise ValueError("API基础URL不能为空请在配置中设置有效的基础URL。")
if not self.name:
raise ValueError("API提供商名称不能为空请在配置中设置有效的名称。")
@dataclass
class ModelInfo(ConfigBase):
"""单个模型信息配置类"""
model_identifier: str
"""模型标识符用于URL调用"""
name: str
"""模型名称(用于模块调用)"""
api_provider: str
"""API提供商如OpenAI、Azure等"""
price_in: float = field(default=0.0)
"""每M token输入价格"""
price_out: float = field(default=0.0)
"""每M token输出价格"""
force_stream_mode: bool = field(default=False)
"""是否强制使用流式输出模式"""
extra_params: dict = field(default_factory=dict)
"""额外参数用于API调用时的额外配置"""
def __post_init__(self):
if not self.model_identifier:
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
if not self.name:
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
if not self.api_provider:
raise ValueError("API提供商不能为空请在配置中设置有效的API提供商。")
@dataclass
class TaskConfig(ConfigBase):
"""任务配置类"""
model_list: list[str] = field(default_factory=list)
"""任务使用的模型列表"""
max_tokens: int = 1024
"""任务最大输出token数"""
temperature: float = 0.3
"""模型温度"""
@dataclass
class ModelTaskConfig(ConfigBase):
"""模型配置类"""
utils: TaskConfig
"""组件模型配置"""
utils_small: TaskConfig
"""组件小模型配置"""
replyer_1: TaskConfig
"""normal_chat首要回复模型模型配置"""
replyer_2: TaskConfig
"""normal_chat次要回复模型配置"""
emotion: TaskConfig
"""情绪模型配置"""
vlm: TaskConfig
"""视觉语言模型配置"""
voice: TaskConfig
"""语音识别模型配置"""
tool_use: TaskConfig
"""专注工具使用模型配置"""
planner: TaskConfig
"""规划模型配置"""
embedding: TaskConfig
"""嵌入模型配置"""
lpmm_entity_extract: TaskConfig
"""LPMM实体提取模型配置"""
lpmm_rdf_build: TaskConfig
"""LPMM RDF构建模型配置"""
lpmm_qa: TaskConfig
"""LPMM问答模型配置"""
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
if hasattr(self, task_name):
return getattr(self, task_name)
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")

479
src/config/config.py Normal file
View File

@@ -0,0 +1,479 @@
import os
import tomlkit
import shutil
import sys
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass
from rich.traceback import install
from typing import List, Optional
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
from src.config.official_configs import (
DatabaseConfig,
BotConfig,
PersonalityConfig,
ExpressionConfig,
ChatConfig,
NormalChatConfig,
EmojiConfig,
MemoryConfig,
MoodConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
MessageReceiveConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
RelationshipConfig,
ToolConfig,
VoiceConfig,
DebugConfig,
CustomPromptConfig,
)
from .api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
install(extra_lines=3)
# 配置主程序日志格式
logger = get_logger("config")
# 获取当前文件所在目录的父目录的父目录即MaiBot项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
CONFIG_DIR = os.path.join(PROJECT_ROOT, "config")
TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.10.0-snapshot.5"
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment
return None
def compare_dicts(new, old, path=None, logs=None):
# 递归比较两个dict找出新增和删减项收集注释
if path is None:
path = []
if logs is None:
logs = []
# 新增项
for key in new:
if key == "version":
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项
for key in old:
if key == "version":
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs
def get_value_by_path(d, path):
for k in path:
if isinstance(d, dict) and k in d:
d = d[k]
else:
return None
return d
def set_value_by_path(d, path, value):
for k in path[:-1]:
if k not in d or not isinstance(d[k], dict):
d[k] = {}
d = d[k]
d[path[-1]] = value
def compare_default_values(new, old, path=None, logs=None, changes=None):
# 递归比较两个dict找出默认值变化项
if path is None:
path = []
if logs is None:
logs = []
if changes is None:
changes = []
for key in new:
if key == "version":
continue
if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
elif new[key] != old[key]:
logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes
def _get_version_from_toml(toml_path) -> Optional[str]:
"""从TOML文件中获取版本号"""
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
def _version_tuple(v):
"""将版本字符串转换为元组以便比较"""
if v is None:
return (0,)
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
_update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
def _update_config_generic(config_name: str, template_name: str):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
# 定义文件路径
template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
compare_path = os.path.join(compare_dir, f"{template_name}.toml")
# 创建compare目录如果不存在
os.makedirs(compare_dir, exist_ok=True)
template_version = _get_version_from_toml(template_path)
compare_version = _get_version_from_toml(compare_path)
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info(f"{config_name}.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 新创建配置文件,退出
sys.exit(0)
compare_config = None
new_config = None
old_config = None
# 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f:
compare_config = tomlkit.load(f)
# 读取当前模板
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做)
if compare_config:
# 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
logs, changes = compare_default_values(new_config, compare_config)
if logs:
logger.info(f"检测到{config_name}模板默认值变动如下:")
for log in logs:
logger.info(log)
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
for path, old_default, new_default in changes:
old_value = get_value_by_path(old_config, path)
if old_value == old_default:
set_value_by_path(old_config, path, new_default)
logger.info(
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
else:
logger.info(f"未检测到{config_name}模板默认值变动")
# 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path)
logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
elif _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else:
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None:
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# new_config 已经读取
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(
f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
)
else:
logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录如果不存在
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
# 输出新增和删减项及注释
if old_config:
logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
if logs := compare_dicts(new_config, old_config):
for log in logs:
logger.info(log)
else:
logger.info("无新增或删减项")
# 将旧配置的值更新到新配置中
logger.info(f"开始合并{config_name}新旧配置...")
_update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template")
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template")
@dataclass
class Config(ConfigBase):
"""总配置类"""
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
database: DatabaseConfig
bot: BotConfig
personality: PersonalityConfig
relationship: RelationshipConfig
chat: ChatConfig
message_receive: MessageReceiveConfig
normal_chat: NormalChatConfig
emoji: EmojiConfig
expression: ExpressionConfig
memory: MemoryConfig
mood: MoodConfig
keyword_reaction: KeywordReactionConfig
chinese_typo: ChineseTypoConfig
response_post_process: ResponsePostProcessConfig
response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig
experimental: ExperimentalConfig
maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig
debug: DebugConfig
custom_prompt: CustomPromptConfig
voice: VoiceConfig
@dataclass
class APIAdapterConfig(ConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo]
"""模型列表"""
model_task_config: ModelTaskConfig
"""模型任务配置"""
api_providers: List[APIProvider] = field(default_factory=list)
"""API提供商列表"""
def __post_init__(self):
if not self.models:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
if not self.api_providers:
raise ValueError("API提供商列表不能为空请在配置中设置有效的API提供商列表。")
# 检查API提供商名称是否重复
provider_names = [provider.name for provider in self.api_providers]
if len(provider_names) != len(set(provider_names)):
raise ValueError("API提供商名称存在重复请检查配置文件。")
# 检查模型名称是否重复
model_names = [model.name for model in self.models]
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
for model in self.models:
if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
if not model.api_provider or model.api_provider not in self.api_providers_dict:
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
raise ValueError("模型名称不能为空")
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:
raise ValueError("API提供商名称不能为空")
if provider_name not in self.api_providers_dict:
raise KeyError(f"API提供商 '{provider_name}' 不存在")
return self.api_providers_dict[provider_name]
def load_config(config_path: str) -> Config:
"""
加载配置文件
Args:
config_path: 配置文件路径
Returns:
Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象
try:
return Config.from_dict(config_data)
except Exception as e:
logger.critical("配置文件解析失败")
raise e
def api_ada_load_config(config_path: str) -> APIAdapterConfig:
"""
加载API适配器配置文件
Args:
config_path: 配置文件路径
Returns:
APIAdapterConfig对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建APIAdapterConfig对象
try:
return APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.critical("API适配器配置文件解析失败")
raise e
# 获取配置文件路径
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config()
update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
# 初始化数据库连接
logger.info("正在初始化数据库连接...")
from src.common.database.database import initialize_sql_database
try:
initialize_sql_database(global_config.database)
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
except Exception as e:
logger.error(f"数据库连接初始化失败: {e}")
raise e
# 初始化数据库表结构
logger.info("正在初始化数据库表结构...")
from src.common.database.sqlalchemy_models import initialize_database as init_db
try:
init_db()
logger.info("数据库表结构初始化完成")
except Exception as e:
logger.error(f"数据库表结构初始化失败: {e}")
raise e
logger.info("非常的新鲜,非常的美味!")

135
src/config/config_base.py Normal file
View File

@@ -0,0 +1,135 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
# 如果列表元素类型是ConfigBase的子类则对每个元素调用from_dict
if (
field_type_args
and isinstance(field_type_args[0], type)
and issubclass(field_type_args[0], ConfigBase)
):
return [field_type_args[0].from_dict(item) for item in value]
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_origin_type is type(None) and value is None: # 处理Optional类型
return None
# 处理Literal类型
if field_origin_type is Literal or get_origin(field_type) is Literal:
# 获取Literal的允许值
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"

View File

@@ -0,0 +1,814 @@
import re
from dataclasses import dataclass, field
from typing import Literal, Optional
from src.config.config_base import ConfigBase
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
@dataclass
class DatabaseConfig(ConfigBase):
"""数据库配置类"""
database_type: Literal["sqlite", "mysql"] = "sqlite"
"""数据库类型,支持 sqlite 或 mysql"""
# SQLite 配置
sqlite_path: str = "data/MaiBot.db"
"""SQLite数据库文件路径"""
# MySQL 配置
mysql_host: str = "localhost"
"""MySQL服务器地址"""
mysql_port: int = 3306
"""MySQL服务器端口"""
mysql_database: str = "maibot"
"""MySQL数据库名"""
mysql_user: str = "root"
"""MySQL用户名"""
mysql_password: str = ""
"""MySQL密码"""
mysql_charset: str = "utf8mb4"
"""MySQL字符集"""
mysql_unix_socket: str = ""
"""MySQL Unix套接字路径可选用于本地连接优先于host/port"""
# MySQL SSL 配置
mysql_ssl_mode: str = "DISABLED"
"""SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY"""
mysql_ssl_ca: str = ""
"""SSL CA证书路径"""
mysql_ssl_cert: str = ""
"""SSL客户端证书路径"""
mysql_ssl_key: str = ""
"""SSL客户端密钥路径"""
# MySQL 高级配置
mysql_autocommit: bool = True
"""自动提交事务"""
mysql_sql_mode: str = "TRADITIONAL"
"""SQL模式"""
# 连接池配置
connection_pool_size: int = 10
"""连接池大小仅MySQL有效"""
connection_timeout: int = 10
"""连接超时时间(秒)"""
@dataclass
class BotConfig(ConfigBase):
"""QQ机器人配置类"""
platform: str
"""平台"""
qq_account: str
"""QQ账号"""
nickname: str
"""昵称"""
alias_names: list[str] = field(default_factory=lambda: [])
"""别名列表"""
@dataclass
class PersonalityConfig(ConfigBase):
"""人格配置类"""
personality_core: str
"""核心人格"""
personality_side: str
"""人格侧写"""
identity: str = ""
"""身份特征"""
reply_style: str = ""
"""表达风格"""
compress_personality: bool = True
"""是否压缩人格压缩后会精简人格信息节省token消耗并提高回复性能但是会丢失一些信息如果人设不长可以关闭"""
compress_identity: bool = True
"""是否压缩身份压缩后会精简身份信息节省token消耗并提高回复性能但是会丢失一些信息如果不长可以关闭"""
@dataclass
class RelationshipConfig(ConfigBase):
"""关系配置类"""
enable_relationship: bool = True
"""是否启用关系系统"""
relation_frequency: int = 1
"""关系频率,麦麦构建关系的速度"""
@dataclass
class ChatConfig(ConfigBase):
"""聊天配置类"""
max_context_size: int = 18
"""上下文长度"""
replyer_random_probability: float = 0.5
"""
发言时选择推理模型的概率0-1之间
选择普通模型的概率为 1 - reasoning_normal_model_probability
"""
thinking_timeout: int = 40
"""麦麦最长思考规划时间超过这个时间的思考会放弃往往是api反应太慢"""
talk_frequency: float = 1
"""回复频率阈值"""
mentioned_bot_inevitable_reply: bool = False
"""提及 bot 必然回复"""
at_bot_inevitable_reply: bool = False
"""@bot 必然回复"""
# 合并后的时段频率配置
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
"""
统一的时段频率配置
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
全局配置示例:
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
特定聊天流配置示例:
[
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
]
说明:
- 当第一个元素为空字符串""时,表示全局默认配置
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
- 后续元素是"时间,频率"格式,表示从该时间开始使用该频率,直到下一个时间点
- 优先级:特定聊天流配置 > 全局配置 > 默认 talk_frequency
"""
focus_value: float = 1.0
"""麦麦的专注思考能力越低越容易专注消耗token也越多"""
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 talk_frequency
Args:
chat_stream_id: 聊天流ID格式为 "platform:chat_id:type"
Returns:
float: 对应的频率值
"""
if not self.talk_frequency_adjust:
return self.talk_frequency
# 优先检查聊天流特定的配置
if chat_stream_id:
stream_frequency = self._get_stream_specific_frequency(chat_stream_id)
if stream_frequency is not None:
return stream_frequency
# 检查全局时段配置(第一个元素为空字符串的配置)
global_frequency = self._get_global_frequency()
if global_frequency is not None:
return global_frequency
# 如果都没有匹配,返回默认值
return self.talk_frequency
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的频率
Args:
time_freq_list: 时间频率配置列表,格式为 ["HH:MM,frequency", ...]
Returns:
float: 频率值,如果没有配置则返回 None
"""
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
current_hour, current_minute = map(int, current_time.split(":"))
current_minutes = current_hour * 60 + current_minute
# 解析时间频率配置
time_freq_pairs = []
for time_freq_str in time_freq_list:
try:
time_str, freq_str = time_freq_str.split(",")
hour, minute = map(int, time_str.split(":"))
frequency = float(freq_str)
minutes = hour * 60 + minute
time_freq_pairs.append((minutes, frequency))
except (ValueError, IndexError):
continue
if not time_freq_pairs:
return None
# 按时间排序
time_freq_pairs.sort(key=lambda x: x[0])
# 查找当前时间对应的频率
current_frequency = None
for minutes, frequency in time_freq_pairs:
if current_minutes >= minutes:
current_frequency = frequency
else:
break
# 如果当前时间在所有配置时间之前,使用最后一个时间段的频率(跨天逻辑)
if current_frequency is None and time_freq_pairs:
current_frequency = time_freq_pairs[-1][1]
return current_frequency
def _get_stream_specific_frequency(self, chat_stream_id: str):
"""
获取特定聊天流在当前时间的频率
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
float: 频率值,如果没有配置则返回 None
"""
# 查找匹配的聊天流配置
for config_item in self.talk_frequency_adjust:
if not config_item or len(config_item) < 2:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 解析配置字符串并生成对应的 chat_id
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_stream_id:
continue
# 使用通用的时间频率解析方法
return self._get_time_based_frequency(config_item[1:])
return None
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""
解析流配置字符串并生成对应的 chat_id
Args:
stream_config_str: 格式为 "platform:id:type" 的字符串
Returns:
str: 生成的 chat_id如果解析失败则返回 None
"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
# 判断是否为群聊
is_group = stream_type == "group"
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
import hashlib
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except (ValueError, IndexError):
return None
def _get_global_frequency(self) -> Optional[float]:
"""
获取全局默认频率配置
Returns:
float: 频率值,如果没有配置则返回 None
"""
for config_item in self.talk_frequency_adjust:
if not config_item or len(config_item) < 2:
continue
# 检查是否为全局默认配置(第一个元素为空字符串)
if config_item[0] == "":
return self._get_time_based_frequency(config_item[1:])
return None
@dataclass
class MessageReceiveConfig(ConfigBase):
"""消息接收配置类"""
ban_words: set[str] = field(default_factory=lambda: set())
"""过滤词列表"""
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class NormalChatConfig(ConfigBase):
"""普通聊天配置类"""
willing_mode: str = "classical"
"""意愿模式"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
expression_learning: list[list] = field(default_factory=lambda: [])
"""
表达学习配置列表,支持按聊天流配置
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
示例:
[
["", "enable", "enable", 1.0], # 全局配置使用表达启用学习学习强度1.0
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置使用表达启用学习学习强度1.5
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置使用表达禁用学习学习强度0.5
]
说明:
- 第一位: chat_stream_id空字符串表示全局配置
- 第二位: 是否使用学到的表达 ("enable"/"disable")
- 第三位: 是否学习表达 ("enable"/"disable")
- 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒)
"""
expression_groups: list[list[str]] = field(default_factory=list)
"""
表达学习互通组
格式: [["qq:12345:group", "qq:67890:private"]]
"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""
解析流配置字符串并生成对应的 chat_id
Args:
stream_config_str: 格式为 "platform:id:type" 的字符串
Returns:
str: 生成的 chat_id如果解析失败则返回 None
"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
# 判断是否为群聊
is_group = stream_type == "group"
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
import hashlib
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except (ValueError, IndexError):
return None
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
"""
根据聊天流ID获取表达配置
Args:
chat_stream_id: 聊天流ID格式为哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)
"""
if not self.expression_learning:
# 如果没有配置使用默认值启用表达启用学习300秒间隔
return True, True, 300
# 优先检查聊天流特定的配置
if chat_stream_id:
specific_config = self._get_stream_specific_config(chat_stream_id)
if specific_config is not None:
return specific_config
# 检查全局配置(第一个元素为空字符串的配置)
global_config = self._get_global_config()
if global_config is not None:
return global_config
# 如果都没有匹配,返回默认值
return True, True, 300
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
"""
获取特定聊天流的表达配置
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
"""
for config_item in self.expression_learning:
if not config_item or len(config_item) < 4:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 如果是空字符串,跳过(这是全局配置)
if stream_config_str == "":
continue
# 解析配置字符串并生成对应的 chat_id
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_stream_id:
continue
# 解析配置
try:
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
return None
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
"""
获取全局表达配置
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
"""
for config_item in self.expression_learning:
if not config_item or len(config_item) < 4:
continue
# 检查是否为全局配置(第一个元素为空字符串)
if config_item[0] == "":
try:
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
return None
@dataclass
class ToolConfig(ConfigBase):
"""工具配置类"""
enable_tool: bool = False
"""是否在聊天中启用工具"""
@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
enable_asr: bool = False
"""是否启用语音识别"""
@dataclass
class EmojiConfig(ConfigBase):
"""表情包配置类"""
emoji_chance: float = 0.6
"""发送表情包的基础概率"""
emoji_activate_type: str = "random"
"""表情包激活类型可选randomllmrandom下表情包动作随机启用llm下表情包动作根据llm判断是否启用"""
max_reg_num: int = 200
"""表情包最大注册数量"""
do_replace: bool = True
"""达到最大注册数量时替换旧表情包"""
check_interval: int = 120
"""表情包检查间隔(分钟)"""
steal_emoji: bool = True
"""是否偷取表情包,让麦麦可以发送她保存的这些表情包"""
content_filtration: bool = False
"""是否开启表情包过滤"""
filtration_prompt: str = "符合公序良俗"
"""表情包过滤要求"""
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
enable_memory: bool = True
memory_build_interval: int = 600
"""记忆构建间隔(秒)"""
memory_build_distribution: tuple[
float,
float,
float,
float,
float,
float,
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
"""记忆构建分布参数分布1均值标准差权重分布2均值标准差权重"""
memory_build_sample_num: int = 8
"""记忆构建采样数量"""
memory_build_sample_length: int = 40
"""记忆构建采样长度"""
memory_compress_rate: float = 0.1
"""记忆压缩率"""
forget_memory_interval: int = 1000
"""记忆遗忘间隔(秒)"""
memory_forget_time: int = 24
"""记忆遗忘时间(小时)"""
memory_forget_percentage: float = 0.01
"""记忆遗忘比例"""
consolidate_memory_interval: int = 1000
"""记忆整合间隔(秒)"""
consolidation_similarity_threshold: float = 0.7
"""整合相似度阈值"""
consolidate_memory_percentage: float = 0.01
"""整合检查节点比例"""
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
"""不允许记忆的词列表"""
enable_instant_memory: bool = True
"""是否启用即时记忆"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
enable_mood: bool = False
"""是否启用情绪系统"""
mood_update_threshold: float = 1.0
"""情绪更新阈值,越高,更新越慢"""
@dataclass
class KeywordRuleConfig(ConfigBase):
"""关键词规则配置类"""
keywords: list[str] = field(default_factory=lambda: [])
"""关键词列表"""
regex: list[str] = field(default_factory=lambda: [])
"""正则表达式列表"""
reaction: str = ""
"""关键词触发的反应"""
def __post_init__(self):
"""验证配置"""
if not self.keywords and not self.regex:
raise ValueError("关键词规则必须至少包含keywords或regex中的一个")
if not self.reaction:
raise ValueError("关键词规则必须包含reaction")
# 验证正则表达式
for pattern in self.regex:
try:
re.compile(pattern)
except re.error as e:
raise ValueError(f"无效的正则表达式 '{pattern}': {str(e)}") from e
@dataclass
class KeywordReactionConfig(ConfigBase):
"""关键词配置类"""
keyword_rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
"""关键词规则列表"""
regex_rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
"""正则表达式规则列表"""
def __post_init__(self):
"""验证配置"""
# 验证所有规则
for rule in self.keyword_rules + self.regex_rules:
if not isinstance(rule, KeywordRuleConfig):
raise ValueError(f"规则必须是KeywordRuleConfig类型而不是{type(rule).__name__}")
@dataclass
class CustomPromptConfig(ConfigBase):
"""自定义提示词配置类"""
image_prompt: str = ""
"""图片提示词"""
@dataclass
class ResponsePostProcessConfig(ConfigBase):
"""回复后处理配置类"""
enable_response_post_process: bool = True
"""是否启用回复后处理,包括错别字生成器,回复分割器"""
@dataclass
class ChineseTypoConfig(ConfigBase):
"""中文错别字配置类"""
enable: bool = True
"""是否启用中文错别字生成器"""
error_rate: float = 0.01
"""单字替换概率"""
min_freq: int = 9
"""最小字频阈值"""
tone_error_rate: float = 0.1
"""声调错误概率"""
word_replace_rate: float = 0.006
"""整词替换概率"""
@dataclass
class ResponseSplitterConfig(ConfigBase):
"""回复分割器配置类"""
enable: bool = True
"""是否启用回复分割器"""
max_length: int = 256
"""回复允许的最大长度"""
max_sentence_num: int = 3
"""回复允许的最大句子数"""
enable_kaomoji_protection: bool = False
"""是否启用颜文字保护"""
@dataclass
class TelemetryConfig(ConfigBase):
"""遥测配置类"""
enable: bool = True
"""是否启用遥测"""
@dataclass
class DebugConfig(ConfigBase):
"""调试配置类"""
show_prompt: bool = False
"""是否显示prompt"""
@dataclass
class ExperimentalConfig(ConfigBase):
"""实验功能配置类"""
enable_friend_chat: bool = False
"""是否启用好友聊天"""
pfc_chatting: bool = False
"""是否启用PFC"""
@dataclass
class MaimMessageConfig(ConfigBase):
"""maim_message配置类"""
use_custom: bool = False
"""是否使用自定义的maim_message配置"""
host: str = "127.0.0.1"
"""主机地址"""
port: int = 8090
""""端口号"""
mode: Literal["ws", "tcp"] = "ws"
"""连接模式支持ws和tcp"""
use_wss: bool = False
"""是否使用WSS安全连接"""
cert_file: str = ""
"""SSL证书文件路径仅在use_wss=True时有效"""
key_file: str = ""
"""SSL密钥文件路径仅在use_wss=True时有效"""
auth_token: list[str] = field(default_factory=lambda: [])
"""认证令牌用于API验证为空则不启用验证"""
@dataclass
class LPMMKnowledgeConfig(ConfigBase):
"""LPMM知识库配置类"""
enable: bool = True
"""是否启用LPMM知识库"""
rag_synonym_search_top_k: int = 10
"""RAG同义词搜索的Top K数量"""
rag_synonym_threshold: float = 0.8
"""RAG同义词搜索的相似度阈值"""
info_extraction_workers: int = 3
"""信息提取工作线程数"""
qa_relation_search_top_k: int = 10
"""QA关系搜索的Top K数量"""
qa_relation_threshold: float = 0.75
"""QA关系搜索的相似度阈值"""
qa_paragraph_search_top_k: int = 1000
"""QA段落搜索的Top K数量"""
qa_paragraph_node_weight: float = 0.05
"""QA段落节点权重"""
qa_ent_filter_top_k: int = 10
"""QA实体过滤的Top K数量"""
qa_ppr_damping: float = 0.8
"""QA PageRank阻尼系数"""
qa_res_top_k: int = 10
"""QA最终结果的Top K数量"""
embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致"""