初始化
This commit is contained in:
139
src/config/api_ada_configs.py
Normal file
139
src/config/api_ada_configs.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .config_base import ConfigBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIProvider(ConfigBase):
|
||||
"""API提供商配置类"""
|
||||
|
||||
name: str
|
||||
"""API提供商名称"""
|
||||
|
||||
base_url: str
|
||||
"""API基础URL"""
|
||||
|
||||
api_key: str = field(default_factory=str, repr=False)
|
||||
"""API密钥列表"""
|
||||
|
||||
client_type: str = field(default="openai")
|
||||
"""客户端类型(如openai/google等,默认为openai)"""
|
||||
|
||||
max_retry: int = 2
|
||||
"""最大重试次数(单个模型API调用失败,最多重试的次数)"""
|
||||
|
||||
timeout: int = 10
|
||||
"""API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)"""
|
||||
|
||||
retry_interval: int = 10
|
||||
"""重试间隔(如果API调用失败,重试的间隔时间,单位:秒)"""
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.api_key
|
||||
|
||||
def __post_init__(self):
|
||||
"""确保api_key在repr中不被显示"""
|
||||
if not self.api_key:
|
||||
raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
|
||||
if not self.base_url and self.client_type != "gemini":
|
||||
raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
|
||||
if not self.name:
|
||||
raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo(ConfigBase):
|
||||
"""单个模型信息配置类"""
|
||||
|
||||
model_identifier: str
|
||||
"""模型标识符(用于URL调用)"""
|
||||
|
||||
name: str
|
||||
"""模型名称(用于模块调用)"""
|
||||
|
||||
api_provider: str
|
||||
"""API提供商(如OpenAI、Azure等)"""
|
||||
|
||||
price_in: float = field(default=0.0)
|
||||
"""每M token输入价格"""
|
||||
|
||||
price_out: float = field(default=0.0)
|
||||
"""每M token输出价格"""
|
||||
|
||||
force_stream_mode: bool = field(default=False)
|
||||
"""是否强制使用流式输出模式"""
|
||||
|
||||
extra_params: dict = field(default_factory=dict)
|
||||
"""额外参数(用于API调用时的额外配置)"""
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.model_identifier:
|
||||
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
|
||||
if not self.name:
|
||||
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
|
||||
if not self.api_provider:
|
||||
raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskConfig(ConfigBase):
|
||||
"""任务配置类"""
|
||||
|
||||
model_list: list[str] = field(default_factory=list)
|
||||
"""任务使用的模型列表"""
|
||||
|
||||
max_tokens: int = 1024
|
||||
"""任务最大输出token数"""
|
||||
|
||||
temperature: float = 0.3
|
||||
"""模型温度"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelTaskConfig(ConfigBase):
|
||||
"""模型配置类"""
|
||||
|
||||
utils: TaskConfig
|
||||
"""组件模型配置"""
|
||||
|
||||
utils_small: TaskConfig
|
||||
"""组件小模型配置"""
|
||||
|
||||
replyer_1: TaskConfig
|
||||
"""normal_chat首要回复模型模型配置"""
|
||||
|
||||
replyer_2: TaskConfig
|
||||
"""normal_chat次要回复模型配置"""
|
||||
|
||||
emotion: TaskConfig
|
||||
"""情绪模型配置"""
|
||||
|
||||
vlm: TaskConfig
|
||||
"""视觉语言模型配置"""
|
||||
|
||||
voice: TaskConfig
|
||||
"""语音识别模型配置"""
|
||||
|
||||
tool_use: TaskConfig
|
||||
"""专注工具使用模型配置"""
|
||||
|
||||
planner: TaskConfig
|
||||
"""规划模型配置"""
|
||||
|
||||
embedding: TaskConfig
|
||||
"""嵌入模型配置"""
|
||||
|
||||
lpmm_entity_extract: TaskConfig
|
||||
"""LPMM实体提取模型配置"""
|
||||
|
||||
lpmm_rdf_build: TaskConfig
|
||||
"""LPMM RDF构建模型配置"""
|
||||
|
||||
lpmm_qa: TaskConfig
|
||||
"""LPMM问答模型配置"""
|
||||
|
||||
def get_task(self, task_name: str) -> TaskConfig:
|
||||
"""获取指定任务的配置"""
|
||||
if hasattr(self, task_name):
|
||||
return getattr(self, task_name)
|
||||
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
|
||||
479
src/config/config.py
Normal file
479
src/config/config.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import os
|
||||
import tomlkit
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from datetime import datetime
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table, KeyType
|
||||
from dataclasses import field, dataclass
|
||||
from rich.traceback import install
|
||||
from typing import List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
DatabaseConfig,
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
ExpressionConfig,
|
||||
ChatConfig,
|
||||
NormalChatConfig,
|
||||
EmojiConfig,
|
||||
MemoryConfig,
|
||||
MoodConfig,
|
||||
KeywordReactionConfig,
|
||||
ChineseTypoConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
TelemetryConfig,
|
||||
ExperimentalConfig,
|
||||
MessageReceiveConfig,
|
||||
MaimMessageConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
RelationshipConfig,
|
||||
ToolConfig,
|
||||
VoiceConfig,
|
||||
DebugConfig,
|
||||
CustomPromptConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
ModelInfo,
|
||||
APIProvider,
|
||||
)
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
# 配置主程序日志格式
|
||||
logger = get_logger("config")
|
||||
|
||||
# 获取当前文件所在目录的父目录的父目录(即MaiBot项目根目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
CONFIG_DIR = os.path.join(PROJECT_ROOT, "config")
|
||||
TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.10.0-snapshot.5"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||
return toml_table.trivia.comment
|
||||
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
||||
item = toml_table.value.get(key)
|
||||
if item is not None and hasattr(item, "trivia"):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, "keys"):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key:
|
||||
return k.trivia.comment
|
||||
return None
|
||||
|
||||
|
||||
def compare_dicts(new, old, path=None, logs=None):
|
||||
# 递归比较两个dict,找出新增和删减项,收集注释
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
# 新增项
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in old:
|
||||
comment = get_key_comment(new, key)
|
||||
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
|
||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||
compare_dicts(new[key], old[key], path + [str(key)], logs)
|
||||
# 删减项
|
||||
for key in old:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in new:
|
||||
comment = get_key_comment(old, key)
|
||||
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
|
||||
return logs
|
||||
|
||||
|
||||
def get_value_by_path(d, path):
|
||||
for k in path:
|
||||
if isinstance(d, dict) and k in d:
|
||||
d = d[k]
|
||||
else:
|
||||
return None
|
||||
return d
|
||||
|
||||
|
||||
def set_value_by_path(d, path, value):
|
||||
for k in path[:-1]:
|
||||
if k not in d or not isinstance(d[k], dict):
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
d[path[-1]] = value
|
||||
|
||||
|
||||
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||
# 递归比较两个dict,找出默认值变化项
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
if changes is None:
|
||||
changes = []
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key in old:
|
||||
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
||||
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
|
||||
elif new[key] != old[key]:
|
||||
logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
|
||||
changes.append((path + [str(key)], old[key], new[key]))
|
||||
return logs, changes
|
||||
|
||||
|
||||
def _get_version_from_toml(toml_path) -> Optional[str]:
|
||||
"""从TOML文件中获取版本号"""
|
||||
if not os.path.exists(toml_path):
|
||||
return None
|
||||
with open(toml_path, "r", encoding="utf-8") as f:
|
||||
doc = tomlkit.load(f)
|
||||
if "inner" in doc and "version" in doc["inner"]: # type: ignore
|
||||
return doc["inner"]["version"] # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def _version_tuple(v):
|
||||
"""将版本字符串转换为元组以便比较"""
|
||||
if v is None:
|
||||
return (0,)
|
||||
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
|
||||
|
||||
|
||||
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
||||
_update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
|
||||
def _update_config_generic(config_name: str, template_name: str):
|
||||
"""
|
||||
通用的配置文件更新函数
|
||||
|
||||
Args:
|
||||
config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config'
|
||||
template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template'
|
||||
"""
|
||||
# 获取根目录路径
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
|
||||
|
||||
# 定义文件路径
|
||||
template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
|
||||
old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
|
||||
new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
|
||||
compare_path = os.path.join(compare_dir, f"{template_name}.toml")
|
||||
|
||||
# 创建compare目录(如果不存在)
|
||||
os.makedirs(compare_dir, exist_ok=True)
|
||||
|
||||
template_version = _get_version_from_toml(template_path)
|
||||
compare_version = _get_version_from_toml(compare_path)
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置")
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
|
||||
shutil.copy2(template_path, old_config_path) # 复制模板文件
|
||||
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
|
||||
# 新创建配置文件,退出
|
||||
sys.exit(0)
|
||||
|
||||
compare_config = None
|
||||
new_config = None
|
||||
old_config = None
|
||||
|
||||
# 先读取 compare 下的模板(如果有),用于默认值变动检测
|
||||
if os.path.exists(compare_path):
|
||||
with open(compare_path, "r", encoding="utf-8") as f:
|
||||
compare_config = tomlkit.load(f)
|
||||
|
||||
# 读取当前模板
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查默认值变化并处理(只有 compare_config 存在时才做)
|
||||
if compare_config:
|
||||
# 读取旧配置
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
logs, changes = compare_default_values(new_config, compare_config)
|
||||
if logs:
|
||||
logger.info(f"检测到{config_name}模板默认值变动如下:")
|
||||
for log in logs:
|
||||
logger.info(log)
|
||||
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
|
||||
for path, old_default, new_default in changes:
|
||||
old_value = get_value_by_path(old_config, path)
|
||||
if old_value == old_default:
|
||||
set_value_by_path(old_config, path, new_default)
|
||||
logger.info(
|
||||
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"未检测到{config_name}模板默认值变动")
|
||||
|
||||
# 检查 compare 下没有模板,或新模板版本更高,则复制
|
||||
if not os.path.exists(compare_path):
|
||||
shutil.copy2(template_path, compare_path)
|
||||
logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
|
||||
elif _version_tuple(template_version) > _version_tuple(compare_version):
|
||||
shutil.copy2(template_path, compare_path)
|
||||
logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}")
|
||||
else:
|
||||
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
|
||||
|
||||
# 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
|
||||
if old_config is None:
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
# new_config 已经读取
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version") # type: ignore
|
||||
new_version = new_config["inner"].get("version") # type: ignore
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(
|
||||
f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
|
||||
)
|
||||
else:
|
||||
logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建old目录(如果不存在)
|
||||
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
|
||||
|
||||
# 移动旧配置文件到old目录
|
||||
shutil.move(old_config_path, old_backup_path)
|
||||
logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path, new_config_path)
|
||||
logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
|
||||
|
||||
# 输出新增和删减项及注释
|
||||
if old_config:
|
||||
logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
|
||||
if logs := compare_dicts(new_config, old_config):
|
||||
for log in logs:
|
||||
logger.info(log)
|
||||
else:
|
||||
logger.info("无新增或删减项")
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info(f"开始合并{config_name}新旧配置...")
|
||||
_update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||
|
||||
|
||||
def update_config():
|
||||
"""更新bot_config.toml配置文件"""
|
||||
_update_config_generic("bot_config", "bot_config_template")
|
||||
|
||||
|
||||
def update_model_config():
|
||||
"""更新model_config.toml配置文件"""
|
||||
_update_config_generic("model_config", "model_config_template")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(ConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
|
||||
|
||||
database: DatabaseConfig
|
||||
bot: BotConfig
|
||||
personality: PersonalityConfig
|
||||
relationship: RelationshipConfig
|
||||
chat: ChatConfig
|
||||
message_receive: MessageReceiveConfig
|
||||
normal_chat: NormalChatConfig
|
||||
emoji: EmojiConfig
|
||||
expression: ExpressionConfig
|
||||
memory: MemoryConfig
|
||||
mood: MoodConfig
|
||||
keyword_reaction: KeywordReactionConfig
|
||||
chinese_typo: ChineseTypoConfig
|
||||
response_post_process: ResponsePostProcessConfig
|
||||
response_splitter: ResponseSplitterConfig
|
||||
telemetry: TelemetryConfig
|
||||
experimental: ExperimentalConfig
|
||||
maim_message: MaimMessageConfig
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
tool: ToolConfig
|
||||
debug: DebugConfig
|
||||
custom_prompt: CustomPromptConfig
|
||||
voice: VoiceConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIAdapterConfig(ConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
models: List[ModelInfo]
|
||||
"""模型列表"""
|
||||
|
||||
model_task_config: ModelTaskConfig
|
||||
"""模型任务配置"""
|
||||
|
||||
api_providers: List[APIProvider] = field(default_factory=list)
|
||||
"""API提供商列表"""
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.models:
|
||||
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||||
if not self.api_providers:
|
||||
raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
|
||||
|
||||
# 检查API提供商名称是否重复
|
||||
provider_names = [provider.name for provider in self.api_providers]
|
||||
if len(provider_names) != len(set(provider_names)):
|
||||
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||||
|
||||
# 检查模型名称是否重复
|
||||
model_names = [model.name for model in self.models]
|
||||
if len(model_names) != len(set(model_names)):
|
||||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||||
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
|
||||
for model in self.models:
|
||||
if not model.model_identifier:
|
||||
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
|
||||
if not model.api_provider or model.api_provider not in self.api_providers_dict:
|
||||
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
|
||||
|
||||
def get_model_info(self, model_name: str) -> ModelInfo:
|
||||
"""根据模型名称获取模型信息"""
|
||||
if not model_name:
|
||||
raise ValueError("模型名称不能为空")
|
||||
if model_name not in self.models_dict:
|
||||
raise KeyError(f"模型 '{model_name}' 不存在")
|
||||
return self.models_dict[model_name]
|
||||
|
||||
def get_provider(self, provider_name: str) -> APIProvider:
|
||||
"""根据提供商名称获取API提供商信息"""
|
||||
if not provider_name:
|
||||
raise ValueError("API提供商名称不能为空")
|
||||
if provider_name not in self.api_providers_dict:
|
||||
raise KeyError(f"API提供商 '{provider_name}' 不存在")
|
||||
return self.api_providers_dict[provider_name]
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""
|
||||
加载配置文件
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
Returns:
|
||||
Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象
|
||||
try:
|
||||
return Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
def api_ada_load_config(config_path: str) -> APIAdapterConfig:
|
||||
"""
|
||||
加载API适配器配置文件
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
Returns:
|
||||
APIAdapterConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建APIAdapterConfig对象
|
||||
try:
|
||||
return APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("API适配器配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
# 获取配置文件路径
|
||||
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
||||
update_config()
|
||||
update_model_config()
|
||||
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
|
||||
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
|
||||
|
||||
# 初始化数据库连接
|
||||
logger.info("正在初始化数据库连接...")
|
||||
from src.common.database.database import initialize_sql_database
|
||||
try:
|
||||
initialize_sql_database(global_config.database)
|
||||
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接初始化失败: {e}")
|
||||
raise e
|
||||
|
||||
# 初始化数据库表结构
|
||||
logger.info("正在初始化数据库表结构...")
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
||||
try:
|
||||
init_db()
|
||||
logger.info("数据库表结构初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库表结构初始化失败: {e}")
|
||||
raise e
|
||||
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
135
src/config/config_base.py
Normal file
135
src/config/config_base.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
TOML_DICT_TYPE = {
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
bool,
|
||||
list,
|
||||
dict,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigBase:
|
||||
"""配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
|
||||
"""从字典加载配置字段"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
field_type = f.type
|
||||
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
|
||||
except TypeError as e:
|
||||
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
|
||||
|
||||
return cls(**init_args)
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
"""
|
||||
转换字段值为指定类型
|
||||
|
||||
1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
|
||||
2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
|
||||
3. 对于基础类型(int, str, float, bool),直接转换
|
||||
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
|
||||
"""
|
||||
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
return field_type.from_dict(value)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_type_args = get_args(field_type)
|
||||
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
# 检查提供的value是否为list
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
# 如果列表元素类型是ConfigBase的子类,则对每个元素调用from_dict
|
||||
if (
|
||||
field_type_args
|
||||
and isinstance(field_type_args[0], type)
|
||||
and issubclass(field_type_args[0], ConfigBase)
|
||||
):
|
||||
return [field_type_args[0].from_dict(item) for item in value]
|
||||
return [cls._convert_field(item, field_type_args[0]) for item in value]
|
||||
elif field_origin_type is set:
|
||||
return {cls._convert_field(item, field_type_args[0]) for item in value}
|
||||
elif field_origin_type is tuple:
|
||||
# 检查提供的value长度是否与类型参数一致
|
||||
if len(value) != len(field_type_args):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
|
||||
|
||||
if field_origin_type is dict:
|
||||
# 检查提供的value是否为dict
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 检查字典的键值类型
|
||||
if len(field_type_args) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_type_args
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理基础类型,例如 int, str 等
|
||||
if field_origin_type is type(None) and value is None: # 处理Optional类型
|
||||
return None
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal or get_origin(field_type) is Literal:
|
||||
# 获取Literal的允许值
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
if field_type is Any or isinstance(value, field_type):
|
||||
return value
|
||||
|
||||
# 其他类型,尝试直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
|
||||
|
||||
def __str__(self):
|
||||
"""返回配置类的字符串表示"""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|
||||
814
src/config/official_configs.py
Normal file
814
src/config/official_configs.py
Normal file
@@ -0,0 +1,814 @@
|
||||
import re
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
2. 所有新增的class都需要继承自ConfigBase
|
||||
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig(ConfigBase):
|
||||
"""数据库配置类"""
|
||||
|
||||
database_type: Literal["sqlite", "mysql"] = "sqlite"
|
||||
"""数据库类型,支持 sqlite 或 mysql"""
|
||||
|
||||
# SQLite 配置
|
||||
sqlite_path: str = "data/MaiBot.db"
|
||||
"""SQLite数据库文件路径"""
|
||||
|
||||
# MySQL 配置
|
||||
mysql_host: str = "localhost"
|
||||
"""MySQL服务器地址"""
|
||||
|
||||
mysql_port: int = 3306
|
||||
"""MySQL服务器端口"""
|
||||
|
||||
mysql_database: str = "maibot"
|
||||
"""MySQL数据库名"""
|
||||
|
||||
mysql_user: str = "root"
|
||||
"""MySQL用户名"""
|
||||
|
||||
mysql_password: str = ""
|
||||
"""MySQL密码"""
|
||||
|
||||
mysql_charset: str = "utf8mb4"
|
||||
"""MySQL字符集"""
|
||||
|
||||
mysql_unix_socket: str = ""
|
||||
"""MySQL Unix套接字路径(可选,用于本地连接,优先于host/port)"""
|
||||
|
||||
# MySQL SSL 配置
|
||||
mysql_ssl_mode: str = "DISABLED"
|
||||
"""SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY"""
|
||||
|
||||
mysql_ssl_ca: str = ""
|
||||
"""SSL CA证书路径"""
|
||||
|
||||
mysql_ssl_cert: str = ""
|
||||
"""SSL客户端证书路径"""
|
||||
|
||||
mysql_ssl_key: str = ""
|
||||
"""SSL客户端密钥路径"""
|
||||
|
||||
# MySQL 高级配置
|
||||
mysql_autocommit: bool = True
|
||||
"""自动提交事务"""
|
||||
|
||||
mysql_sql_mode: str = "TRADITIONAL"
|
||||
"""SQL模式"""
|
||||
|
||||
# 连接池配置
|
||||
connection_pool_size: int = 10
|
||||
"""连接池大小(仅MySQL有效)"""
|
||||
|
||||
connection_timeout: int = 10
|
||||
"""连接超时时间(秒)"""
|
||||
|
||||
@dataclass
|
||||
class BotConfig(ConfigBase):
|
||||
"""QQ机器人配置类"""
|
||||
|
||||
platform: str
|
||||
"""平台"""
|
||||
|
||||
qq_account: str
|
||||
"""QQ账号"""
|
||||
|
||||
nickname: str
|
||||
"""昵称"""
|
||||
|
||||
alias_names: list[str] = field(default_factory=lambda: [])
|
||||
"""别名列表"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PersonalityConfig(ConfigBase):
|
||||
"""人格配置类"""
|
||||
|
||||
personality_core: str
|
||||
"""核心人格"""
|
||||
|
||||
personality_side: str
|
||||
"""人格侧写"""
|
||||
|
||||
identity: str = ""
|
||||
"""身份特征"""
|
||||
|
||||
reply_style: str = ""
|
||||
"""表达风格"""
|
||||
|
||||
compress_personality: bool = True
|
||||
"""是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭"""
|
||||
|
||||
compress_identity: bool = True
|
||||
"""是否压缩身份,压缩后会精简身份信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果不长,可以关闭"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationshipConfig(ConfigBase):
|
||||
"""关系配置类"""
|
||||
|
||||
enable_relationship: bool = True
|
||||
"""是否启用关系系统"""
|
||||
|
||||
relation_frequency: int = 1
|
||||
"""关系频率,麦麦构建关系的速度"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatConfig(ConfigBase):
|
||||
"""聊天配置类"""
|
||||
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
|
||||
replyer_random_probability: float = 0.5
|
||||
"""
|
||||
发言时选择推理模型的概率(0-1之间)
|
||||
选择普通模型的概率为 1 - reasoning_normal_model_probability
|
||||
"""
|
||||
|
||||
thinking_timeout: int = 40
|
||||
"""麦麦最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢)"""
|
||||
|
||||
talk_frequency: float = 1
|
||||
"""回复频率阈值"""
|
||||
|
||||
mentioned_bot_inevitable_reply: bool = False
|
||||
"""提及 bot 必然回复"""
|
||||
|
||||
at_bot_inevitable_reply: bool = False
|
||||
"""@bot 必然回复"""
|
||||
|
||||
# 合并后的时段频率配置
|
||||
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||
"""
|
||||
统一的时段频率配置
|
||||
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
|
||||
|
||||
全局配置示例:
|
||||
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
|
||||
|
||||
特定聊天流配置示例:
|
||||
[
|
||||
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
|
||||
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
|
||||
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
|
||||
]
|
||||
|
||||
说明:
|
||||
- 当第一个元素为空字符串""时,表示全局默认配置
|
||||
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
|
||||
- 后续元素是"时间,频率"格式,表示从该时间开始使用该频率,直到下一个时间点
|
||||
- 优先级:特定聊天流配置 > 全局配置 > 默认 talk_frequency
|
||||
"""
|
||||
|
||||
focus_value: float = 1.0
|
||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||
|
||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 talk_frequency
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID,格式为 "platform:chat_id:type"
|
||||
|
||||
Returns:
|
||||
float: 对应的频率值
|
||||
"""
|
||||
if not self.talk_frequency_adjust:
|
||||
return self.talk_frequency
|
||||
|
||||
# 优先检查聊天流特定的配置
|
||||
if chat_stream_id:
|
||||
stream_frequency = self._get_stream_specific_frequency(chat_stream_id)
|
||||
if stream_frequency is not None:
|
||||
return stream_frequency
|
||||
|
||||
# 检查全局时段配置(第一个元素为空字符串的配置)
|
||||
global_frequency = self._get_global_frequency()
|
||||
if global_frequency is not None:
|
||||
return global_frequency
|
||||
|
||||
# 如果都没有匹配,返回默认值
|
||||
return self.talk_frequency
|
||||
|
||||
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
根据时间配置列表获取当前时段的频率
|
||||
|
||||
Args:
|
||||
time_freq_list: 时间频率配置列表,格式为 ["HH:MM,frequency", ...]
|
||||
|
||||
Returns:
|
||||
float: 频率值,如果没有配置则返回 None
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
current_time = datetime.now().strftime("%H:%M")
|
||||
current_hour, current_minute = map(int, current_time.split(":"))
|
||||
current_minutes = current_hour * 60 + current_minute
|
||||
|
||||
# 解析时间频率配置
|
||||
time_freq_pairs = []
|
||||
for time_freq_str in time_freq_list:
|
||||
try:
|
||||
time_str, freq_str = time_freq_str.split(",")
|
||||
hour, minute = map(int, time_str.split(":"))
|
||||
frequency = float(freq_str)
|
||||
minutes = hour * 60 + minute
|
||||
time_freq_pairs.append((minutes, frequency))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
if not time_freq_pairs:
|
||||
return None
|
||||
|
||||
# 按时间排序
|
||||
time_freq_pairs.sort(key=lambda x: x[0])
|
||||
|
||||
# 查找当前时间对应的频率
|
||||
current_frequency = None
|
||||
for minutes, frequency in time_freq_pairs:
|
||||
if current_minutes >= minutes:
|
||||
current_frequency = frequency
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果当前时间在所有配置时间之前,使用最后一个时间段的频率(跨天逻辑)
|
||||
if current_frequency is None and time_freq_pairs:
|
||||
current_frequency = time_freq_pairs[-1][1]
|
||||
|
||||
return current_frequency
|
||||
|
||||
def _get_stream_specific_frequency(self, chat_stream_id: str):
|
||||
"""
|
||||
获取特定聊天流在当前时间的频率
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
float: 频率值,如果没有配置则返回 None
|
||||
"""
|
||||
# 查找匹配的聊天流配置
|
||||
for config_item in self.talk_frequency_adjust:
|
||||
if not config_item or len(config_item) < 2:
|
||||
continue
|
||||
|
||||
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||
|
||||
# 解析配置字符串并生成对应的 chat_id
|
||||
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
|
||||
if config_chat_id is None:
|
||||
continue
|
||||
|
||||
# 比较生成的 chat_id
|
||||
if config_chat_id != chat_stream_id:
|
||||
continue
|
||||
|
||||
# 使用通用的时间频率解析方法
|
||||
return self._get_time_based_frequency(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""
|
||||
解析流配置字符串并生成对应的 chat_id
|
||||
|
||||
Args:
|
||||
stream_config_str: 格式为 "platform:id:type" 的字符串
|
||||
|
||||
Returns:
|
||||
str: 生成的 chat_id,如果解析失败则返回 None
|
||||
"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def _get_global_frequency(self) -> Optional[float]:
|
||||
"""
|
||||
获取全局默认频率配置
|
||||
|
||||
Returns:
|
||||
float: 频率值,如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.talk_frequency_adjust:
|
||||
if not config_item or len(config_item) < 2:
|
||||
continue
|
||||
|
||||
# 检查是否为全局默认配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
return self._get_time_based_frequency(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageReceiveConfig(ConfigBase):
|
||||
"""消息接收配置类"""
|
||||
|
||||
ban_words: set[str] = field(default_factory=lambda: set())
|
||||
"""过滤词列表"""
|
||||
|
||||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||
"""过滤正则表达式列表"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalChatConfig(ConfigBase):
|
||||
"""普通聊天配置类"""
|
||||
|
||||
willing_mode: str = "classical"
|
||||
"""意愿模式"""
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
expression_learning: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5
|
||||
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||
]
|
||||
|
||||
说明:
|
||||
- 第一位: chat_stream_id,空字符串表示全局配置
|
||||
- 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||
- 第三位: 是否学习表达 ("enable"/"disable")
|
||||
- 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒)
|
||||
"""
|
||||
|
||||
expression_groups: list[list[str]] = field(default_factory=list)
|
||||
"""
|
||||
表达学习互通组
|
||||
格式: [["qq:12345:group", "qq:67890:private"]]
|
||||
"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""
|
||||
解析流配置字符串并生成对应的 chat_id
|
||||
|
||||
Args:
|
||||
stream_config_str: 格式为 "platform:id:type" 的字符串
|
||||
|
||||
Returns:
|
||||
str: 生成的 chat_id,如果解析失败则返回 None
|
||||
"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔)
|
||||
"""
|
||||
if not self.expression_learning:
|
||||
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
|
||||
return True, True, 300
|
||||
|
||||
# 优先检查聊天流特定的配置
|
||||
if chat_stream_id:
|
||||
specific_config = self._get_stream_specific_config(chat_stream_id)
|
||||
if specific_config is not None:
|
||||
return specific_config
|
||||
|
||||
# 检查全局配置(第一个元素为空字符串的配置)
|
||||
global_config = self._get_global_config()
|
||||
if global_config is not None:
|
||||
return global_config
|
||||
|
||||
# 如果都没有匹配,返回默认值
|
||||
return True, True, 300
|
||||
|
||||
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
|
||||
"""
|
||||
获取特定聊天流的表达配置
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.expression_learning:
|
||||
if not config_item or len(config_item) < 4:
|
||||
continue
|
||||
|
||||
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||
|
||||
# 如果是空字符串,跳过(这是全局配置)
|
||||
if stream_config_str == "":
|
||||
continue
|
||||
|
||||
# 解析配置字符串并生成对应的 chat_id
|
||||
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
|
||||
if config_chat_id is None:
|
||||
continue
|
||||
|
||||
# 比较生成的 chat_id
|
||||
if config_chat_id != chat_stream_id:
|
||||
continue
|
||||
|
||||
# 解析配置
|
||||
try:
|
||||
use_expression = config_item[1].lower() == "enable"
|
||||
enable_learning = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
|
||||
"""
|
||||
获取全局表达配置
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.expression_learning:
|
||||
if not config_item or len(config_item) < 4:
|
||||
continue
|
||||
|
||||
# 检查是否为全局配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
try:
|
||||
use_expression = config_item[1].lower() == "enable"
|
||||
enable_learning = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolConfig(ConfigBase):
|
||||
"""工具配置类"""
|
||||
|
||||
enable_tool: bool = False
|
||||
"""是否在聊天中启用工具"""
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
|
||||
enable_asr: bool = False
|
||||
"""是否启用语音识别"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmojiConfig(ConfigBase):
|
||||
"""表情包配置类"""
|
||||
|
||||
emoji_chance: float = 0.6
|
||||
"""发送表情包的基础概率"""
|
||||
|
||||
emoji_activate_type: str = "random"
|
||||
"""表情包激活类型,可选:random,llm,random下,表情包动作随机启用,llm下,表情包动作根据llm判断是否启用"""
|
||||
|
||||
max_reg_num: int = 200
|
||||
"""表情包最大注册数量"""
|
||||
|
||||
do_replace: bool = True
|
||||
"""达到最大注册数量时替换旧表情包"""
|
||||
|
||||
check_interval: int = 120
|
||||
"""表情包检查间隔(分钟)"""
|
||||
|
||||
steal_emoji: bool = True
|
||||
"""是否偷取表情包,让麦麦可以发送她保存的这些表情包"""
|
||||
|
||||
content_filtration: bool = False
|
||||
"""是否开启表情包过滤"""
|
||||
|
||||
filtration_prompt: str = "符合公序良俗"
|
||||
"""表情包过滤要求"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig(ConfigBase):
|
||||
"""记忆配置类"""
|
||||
|
||||
enable_memory: bool = True
|
||||
|
||||
memory_build_interval: int = 600
|
||||
"""记忆构建间隔(秒)"""
|
||||
|
||||
memory_build_distribution: tuple[
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
|
||||
"""记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重"""
|
||||
|
||||
memory_build_sample_num: int = 8
|
||||
"""记忆构建采样数量"""
|
||||
|
||||
memory_build_sample_length: int = 40
|
||||
"""记忆构建采样长度"""
|
||||
|
||||
memory_compress_rate: float = 0.1
|
||||
"""记忆压缩率"""
|
||||
|
||||
forget_memory_interval: int = 1000
|
||||
"""记忆遗忘间隔(秒)"""
|
||||
|
||||
memory_forget_time: int = 24
|
||||
"""记忆遗忘时间(小时)"""
|
||||
|
||||
memory_forget_percentage: float = 0.01
|
||||
"""记忆遗忘比例"""
|
||||
|
||||
consolidate_memory_interval: int = 1000
|
||||
"""记忆整合间隔(秒)"""
|
||||
|
||||
consolidation_similarity_threshold: float = 0.7
|
||||
"""整合相似度阈值"""
|
||||
|
||||
consolidate_memory_percentage: float = 0.01
|
||||
"""整合检查节点比例"""
|
||||
|
||||
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
||||
"""不允许记忆的词列表"""
|
||||
|
||||
enable_instant_memory: bool = True
|
||||
"""是否启用即时记忆"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoodConfig(ConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
enable_mood: bool = False
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
mood_update_threshold: float = 1.0
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordRuleConfig(ConfigBase):
|
||||
"""关键词规则配置类"""
|
||||
|
||||
keywords: list[str] = field(default_factory=lambda: [])
|
||||
"""关键词列表"""
|
||||
|
||||
regex: list[str] = field(default_factory=lambda: [])
|
||||
"""正则表达式列表"""
|
||||
|
||||
reaction: str = ""
|
||||
"""关键词触发的反应"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
if not self.keywords and not self.regex:
|
||||
raise ValueError("关键词规则必须至少包含keywords或regex中的一个")
|
||||
|
||||
if not self.reaction:
|
||||
raise ValueError("关键词规则必须包含reaction")
|
||||
|
||||
# 验证正则表达式
|
||||
for pattern in self.regex:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
raise ValueError(f"无效的正则表达式 '{pattern}': {str(e)}") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordReactionConfig(ConfigBase):
|
||||
"""关键词配置类"""
|
||||
|
||||
keyword_rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
|
||||
"""关键词规则列表"""
|
||||
|
||||
regex_rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
|
||||
"""正则表达式规则列表"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置"""
|
||||
# 验证所有规则
|
||||
for rule in self.keyword_rules + self.regex_rules:
|
||||
if not isinstance(rule, KeywordRuleConfig):
|
||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||
|
||||
@dataclass
|
||||
class CustomPromptConfig(ConfigBase):
|
||||
"""自定义提示词配置类"""
|
||||
|
||||
image_prompt: str = ""
|
||||
"""图片提示词"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponsePostProcessConfig(ConfigBase):
|
||||
"""回复后处理配置类"""
|
||||
|
||||
enable_response_post_process: bool = True
|
||||
"""是否启用回复后处理,包括错别字生成器,回复分割器"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChineseTypoConfig(ConfigBase):
|
||||
"""中文错别字配置类"""
|
||||
|
||||
enable: bool = True
|
||||
"""是否启用中文错别字生成器"""
|
||||
|
||||
error_rate: float = 0.01
|
||||
"""单字替换概率"""
|
||||
|
||||
min_freq: int = 9
|
||||
"""最小字频阈值"""
|
||||
|
||||
tone_error_rate: float = 0.1
|
||||
"""声调错误概率"""
|
||||
|
||||
word_replace_rate: float = 0.006
|
||||
"""整词替换概率"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseSplitterConfig(ConfigBase):
|
||||
"""回复分割器配置类"""
|
||||
|
||||
enable: bool = True
|
||||
"""是否启用回复分割器"""
|
||||
|
||||
max_length: int = 256
|
||||
"""回复允许的最大长度"""
|
||||
|
||||
max_sentence_num: int = 3
|
||||
"""回复允许的最大句子数"""
|
||||
|
||||
enable_kaomoji_protection: bool = False
|
||||
"""是否启用颜文字保护"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TelemetryConfig(ConfigBase):
|
||||
"""遥测配置类"""
|
||||
|
||||
enable: bool = True
|
||||
"""是否启用遥测"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugConfig(ConfigBase):
|
||||
"""调试配置类"""
|
||||
|
||||
show_prompt: bool = False
|
||||
"""是否显示prompt"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentalConfig(ConfigBase):
|
||||
"""实验功能配置类"""
|
||||
|
||||
enable_friend_chat: bool = False
|
||||
"""是否启用好友聊天"""
|
||||
|
||||
pfc_chatting: bool = False
|
||||
"""是否启用PFC"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaimMessageConfig(ConfigBase):
|
||||
"""maim_message配置类"""
|
||||
|
||||
use_custom: bool = False
|
||||
"""是否使用自定义的maim_message配置"""
|
||||
|
||||
host: str = "127.0.0.1"
|
||||
"""主机地址"""
|
||||
|
||||
port: int = 8090
|
||||
""""端口号"""
|
||||
|
||||
mode: Literal["ws", "tcp"] = "ws"
|
||||
"""连接模式,支持ws和tcp"""
|
||||
|
||||
use_wss: bool = False
|
||||
"""是否使用WSS安全连接"""
|
||||
|
||||
cert_file: str = ""
|
||||
"""SSL证书文件路径,仅在use_wss=True时有效"""
|
||||
|
||||
key_file: str = ""
|
||||
"""SSL密钥文件路径,仅在use_wss=True时有效"""
|
||||
|
||||
auth_token: list[str] = field(default_factory=lambda: [])
|
||||
"""认证令牌,用于API验证,为空则不启用验证"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LPMMKnowledgeConfig(ConfigBase):
|
||||
"""LPMM知识库配置类"""
|
||||
|
||||
enable: bool = True
|
||||
"""是否启用LPMM知识库"""
|
||||
|
||||
rag_synonym_search_top_k: int = 10
|
||||
"""RAG同义词搜索的Top K数量"""
|
||||
|
||||
rag_synonym_threshold: float = 0.8
|
||||
"""RAG同义词搜索的相似度阈值"""
|
||||
|
||||
info_extraction_workers: int = 3
|
||||
"""信息提取工作线程数"""
|
||||
|
||||
qa_relation_search_top_k: int = 10
|
||||
"""QA关系搜索的Top K数量"""
|
||||
|
||||
qa_relation_threshold: float = 0.75
|
||||
"""QA关系搜索的相似度阈值"""
|
||||
|
||||
qa_paragraph_search_top_k: int = 1000
|
||||
"""QA段落搜索的Top K数量"""
|
||||
|
||||
qa_paragraph_node_weight: float = 0.05
|
||||
"""QA段落节点权重"""
|
||||
|
||||
qa_ent_filter_top_k: int = 10
|
||||
"""QA实体过滤的Top K数量"""
|
||||
|
||||
qa_ppr_damping: float = 0.8
|
||||
"""QA PageRank阻尼系数"""
|
||||
|
||||
qa_res_top_k: int = 10
|
||||
"""QA最终结果的Top K数量"""
|
||||
|
||||
embedding_dimension: int = 1024
|
||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||
|
||||
Reference in New Issue
Block a user