Files
Mofox-Core/src/mais4u/s4u_config.py
雅诺狐 2d4745cd58 初始化
2025-08-11 19:34:18 +08:00

368 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tomlkit
import shutil
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table
from dataclasses import dataclass, fields, MISSING, field
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
from src.mais4u.constant_s4u import ENABLE_S4U
from src.common.logger import get_logger
logger = get_logger("s4u_config")
# 新增兼容dict和tomlkit Table
def is_dict_like(obj):
return isinstance(obj, (dict, Table))
# 新增递归将Table转为dict
def table_to_dict(obj):
if isinstance(obj, Table):
return {k: table_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, dict):
return {k: table_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [table_to_dict(i) for i in obj]
else:
return obj
# 获取mais4u模块目录
MAIS4U_ROOT = os.path.dirname(__file__)
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
# S4U配置版本
S4U_VERSION = "1.1.0"
T = TypeVar("T", bound="S4UConfigBase")
@dataclass
class S4UConfigBase:
"""S4U配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
data = table_to_dict(data) # 递归转dict兼容tomlkit Table
if not is_dict_like(data):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""转换字段值为指定类型"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
if not is_dict_like(value):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
if (
field_type_args
and isinstance(field_type_args[0], type)
and issubclass(field_type_args[0], S4UConfigBase)
):
return [field_type_args[0].from_dict(item) for item in value]
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
if field_origin_type is dict:
if not is_dict_like(value):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_origin_type is type(None) and value is None: # 处理Optional类型
return None
# 处理Literal类型
if field_origin_type is Literal or get_origin(field_type) is Literal:
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
@dataclass
class S4UModelConfig(S4UConfigBase):
"""S4U模型配置类"""
# 主要对话模型配置
chat: dict[str, Any] = field(default_factory=lambda: {})
"""主要对话模型配置"""
# 规划模型配置原model_motion
motion: dict[str, Any] = field(default_factory=lambda: {})
"""规划模型配置"""
# 情感分析模型配置
emotion: dict[str, Any] = field(default_factory=lambda: {})
"""情感分析模型配置"""
# 记忆模型配置
memory: dict[str, Any] = field(default_factory=lambda: {})
"""记忆模型配置"""
# 工具使用模型配置
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""工具使用模型配置"""
# 嵌入模型配置
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
# 视觉语言模型配置
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
# 知识库模型配置
knowledge: dict[str, Any] = field(default_factory=lambda: {})
"""知识库模型配置"""
# 实体提取模型配置
entity_extract: dict[str, Any] = field(default_factory=lambda: {})
"""实体提取模型配置"""
# 问答模型配置
qa: dict[str, Any] = field(default_factory=lambda: {})
"""问答模型配置"""
@dataclass
class S4UConfig(S4UConfigBase):
"""S4U聊天系统配置类"""
message_timeout_seconds: int = 120
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
at_bot_priority_bonus: float = 100.0
"""@机器人时的优先级加成分数"""
recent_message_keep_count: int = 6
"""保留最近N条消息超出范围的普通消息将被移除"""
typing_delay: float = 0.1
"""打字延迟时间(秒),模拟真实打字速度"""
chars_per_second: float = 15.0
"""每秒字符数,用于计算动态打字延迟"""
min_typing_delay: float = 0.2
"""最小打字延迟(秒)"""
max_typing_delay: float = 2.0
"""最大打字延迟(秒)"""
enable_dynamic_typing_delay: bool = False
"""是否启用基于文本长度的动态打字延迟"""
vip_queue_priority: bool = True
"""是否启用VIP队列优先级系统"""
enable_message_interruption: bool = True
"""是否允许高优先级消息中断当前回复"""
enable_old_message_cleanup: bool = True
"""是否自动清理过旧的普通消息"""
enable_streaming_output: bool = True
"""是否启用流式输出false时全部生成后一次性发送"""
max_context_message_length: int = 20
"""上下文消息最大长度"""
max_core_message_length: int = 30
"""核心消息最大长度"""
# 模型配置
models: S4UModelConfig = field(default_factory=S4UModelConfig)
"""S4U模型配置"""
# 兼容性字段,保持向后兼容
@dataclass
class S4UGlobalConfig(S4UConfigBase):
"""S4U总配置类"""
s4u: S4UConfig
S4U_VERSION: str = S4U_VERSION
def update_s4u_config():
"""更新S4U配置文件"""
# 创建配置目录(如果不存在)
os.makedirs(CONFIG_DIR, exist_ok=True)
# 检查模板文件是否存在
if not os.path.exists(TEMPLATE_PATH):
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
logger.error("请确保模板文件存在后重新运行")
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
# 检查配置文件是否存在
if not os.path.exists(CONFIG_PATH):
logger.info("S4U配置文件不存在从模板创建新配置")
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
return
# 读取旧配置文件和模板文件
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
with open(TEMPLATE_PATH, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("S4U配置文件未检测到版本号可能是旧版本。将进行更新")
# 创建备份目录
old_config_dir = os.path.join(CONFIG_DIR, "old")
os.makedirs(old_config_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(CONFIG_PATH, old_backup_path)
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并S4U新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("S4U配置文件更新完成")
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
"""
加载S4U配置文件
:param config_path: 配置文件路径
:return: S4UGlobalConfig对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建S4UGlobalConfig对象
try:
return S4UGlobalConfig.from_dict(config_data)
except Exception as e:
logger.critical("S4U配置文件解析失败")
raise e
if not ENABLE_S4U:
s4u_config = None
s4u_config_main = None
else:
# 初始化S4U配置
logger.info(f"S4U当前版本: {S4U_VERSION}")
update_s4u_config()
logger.info("正在加载S4U配置文件...")
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
logger.info("S4U配置文件加载完成")
s4u_config: S4UConfig = s4u_config_main.s4u