368 lines
13 KiB
Python
368 lines
13 KiB
Python
import os
|
||
import tomlkit
|
||
import shutil
|
||
from datetime import datetime
|
||
from tomlkit import TOMLDocument
|
||
from tomlkit.items import Table
|
||
from dataclasses import dataclass, fields, MISSING, field
|
||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("s4u_config")
|
||
|
||
# 新增:兼容dict和tomlkit Table
|
||
def is_dict_like(obj):
|
||
return isinstance(obj, (dict, Table))
|
||
|
||
# 新增:递归将Table转为dict
|
||
def table_to_dict(obj):
|
||
if isinstance(obj, Table):
|
||
return {k: table_to_dict(v) for k, v in obj.items()}
|
||
elif isinstance(obj, dict):
|
||
return {k: table_to_dict(v) for k, v in obj.items()}
|
||
elif isinstance(obj, list):
|
||
return [table_to_dict(i) for i in obj]
|
||
else:
|
||
return obj
|
||
|
||
# 获取mais4u模块目录
|
||
MAIS4U_ROOT = os.path.dirname(__file__)
|
||
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
|
||
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
|
||
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
|
||
|
||
# S4U配置版本
|
||
S4U_VERSION = "1.1.0"
|
||
|
||
T = TypeVar("T", bound="S4UConfigBase")
|
||
|
||
|
||
@dataclass
|
||
class S4UConfigBase:
|
||
"""S4U配置类的基类"""
|
||
|
||
@classmethod
|
||
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
|
||
"""从字典加载配置字段"""
|
||
data = table_to_dict(data) # 递归转dict,兼容tomlkit Table
|
||
if not is_dict_like(data):
|
||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||
|
||
init_args: dict[str, Any] = {}
|
||
|
||
for f in fields(cls):
|
||
field_name = f.name
|
||
|
||
if field_name.startswith("_"):
|
||
# 跳过以 _ 开头的字段
|
||
continue
|
||
|
||
if field_name not in data:
|
||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||
# 跳过未提供且有默认值/默认构造方法的字段
|
||
continue
|
||
else:
|
||
raise ValueError(f"Missing required field: '{field_name}'")
|
||
|
||
value = data[field_name]
|
||
field_type = f.type
|
||
|
||
try:
|
||
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
|
||
except TypeError as e:
|
||
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||
except Exception as e:
|
||
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
|
||
|
||
return cls(**init_args)
|
||
|
||
@classmethod
|
||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||
"""转换字段值为指定类型"""
|
||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
|
||
if not is_dict_like(value):
|
||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||
return field_type.from_dict(value)
|
||
|
||
# 处理泛型集合类型(list, set, tuple)
|
||
field_origin_type = get_origin(field_type)
|
||
field_type_args = get_args(field_type)
|
||
|
||
if field_origin_type in {list, set, tuple}:
|
||
if not isinstance(value, list):
|
||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||
|
||
if field_origin_type is list:
|
||
if (
|
||
field_type_args
|
||
and isinstance(field_type_args[0], type)
|
||
and issubclass(field_type_args[0], S4UConfigBase)
|
||
):
|
||
return [field_type_args[0].from_dict(item) for item in value]
|
||
return [cls._convert_field(item, field_type_args[0]) for item in value]
|
||
elif field_origin_type is set:
|
||
return {cls._convert_field(item, field_type_args[0]) for item in value}
|
||
elif field_origin_type is tuple:
|
||
if len(value) != len(field_type_args):
|
||
raise TypeError(
|
||
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||
)
|
||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
|
||
|
||
if field_origin_type is dict:
|
||
if not is_dict_like(value):
|
||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||
|
||
if len(field_type_args) != 2:
|
||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||
key_type, value_type = field_type_args
|
||
|
||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||
|
||
# 处理基础类型,例如 int, str 等
|
||
if field_origin_type is type(None) and value is None: # 处理Optional类型
|
||
return None
|
||
|
||
# 处理Literal类型
|
||
if field_origin_type is Literal or get_origin(field_type) is Literal:
|
||
allowed_values = get_args(field_type)
|
||
if value in allowed_values:
|
||
return value
|
||
else:
|
||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||
|
||
if field_type is Any or isinstance(value, field_type):
|
||
return value
|
||
|
||
# 其他类型,尝试直接转换
|
||
try:
|
||
return field_type(value)
|
||
except (ValueError, TypeError) as e:
|
||
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
|
||
|
||
|
||
@dataclass
|
||
class S4UModelConfig(S4UConfigBase):
|
||
"""S4U模型配置类"""
|
||
|
||
# 主要对话模型配置
|
||
chat: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""主要对话模型配置"""
|
||
|
||
# 规划模型配置(原model_motion)
|
||
motion: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""规划模型配置"""
|
||
|
||
# 情感分析模型配置
|
||
emotion: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""情感分析模型配置"""
|
||
|
||
# 记忆模型配置
|
||
memory: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""记忆模型配置"""
|
||
|
||
# 工具使用模型配置
|
||
tool_use: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""工具使用模型配置"""
|
||
|
||
# 嵌入模型配置
|
||
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""嵌入模型配置"""
|
||
|
||
# 视觉语言模型配置
|
||
vlm: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""视觉语言模型配置"""
|
||
|
||
# 知识库模型配置
|
||
knowledge: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""知识库模型配置"""
|
||
|
||
# 实体提取模型配置
|
||
entity_extract: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""实体提取模型配置"""
|
||
|
||
# 问答模型配置
|
||
qa: dict[str, Any] = field(default_factory=lambda: {})
|
||
"""问答模型配置"""
|
||
|
||
|
||
@dataclass
|
||
class S4UConfig(S4UConfigBase):
|
||
"""S4U聊天系统配置类"""
|
||
|
||
message_timeout_seconds: int = 120
|
||
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
|
||
|
||
at_bot_priority_bonus: float = 100.0
|
||
"""@机器人时的优先级加成分数"""
|
||
|
||
recent_message_keep_count: int = 6
|
||
"""保留最近N条消息,超出范围的普通消息将被移除"""
|
||
|
||
typing_delay: float = 0.1
|
||
"""打字延迟时间(秒),模拟真实打字速度"""
|
||
|
||
chars_per_second: float = 15.0
|
||
"""每秒字符数,用于计算动态打字延迟"""
|
||
|
||
min_typing_delay: float = 0.2
|
||
"""最小打字延迟(秒)"""
|
||
|
||
max_typing_delay: float = 2.0
|
||
"""最大打字延迟(秒)"""
|
||
|
||
enable_dynamic_typing_delay: bool = False
|
||
"""是否启用基于文本长度的动态打字延迟"""
|
||
|
||
vip_queue_priority: bool = True
|
||
"""是否启用VIP队列优先级系统"""
|
||
|
||
enable_message_interruption: bool = True
|
||
"""是否允许高优先级消息中断当前回复"""
|
||
|
||
enable_old_message_cleanup: bool = True
|
||
"""是否自动清理过旧的普通消息"""
|
||
|
||
enable_streaming_output: bool = True
|
||
"""是否启用流式输出,false时全部生成后一次性发送"""
|
||
|
||
max_context_message_length: int = 20
|
||
"""上下文消息最大长度"""
|
||
|
||
max_core_message_length: int = 30
|
||
"""核心消息最大长度"""
|
||
|
||
# 模型配置
|
||
models: S4UModelConfig = field(default_factory=S4UModelConfig)
|
||
"""S4U模型配置"""
|
||
|
||
# 兼容性字段,保持向后兼容
|
||
|
||
|
||
|
||
@dataclass
|
||
class S4UGlobalConfig(S4UConfigBase):
|
||
"""S4U总配置类"""
|
||
|
||
s4u: S4UConfig
|
||
S4U_VERSION: str = S4U_VERSION
|
||
|
||
|
||
def update_s4u_config():
|
||
"""更新S4U配置文件"""
|
||
# 创建配置目录(如果不存在)
|
||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||
|
||
# 检查模板文件是否存在
|
||
if not os.path.exists(TEMPLATE_PATH):
|
||
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||
logger.error("请确保模板文件存在后重新运行")
|
||
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||
|
||
# 检查配置文件是否存在
|
||
if not os.path.exists(CONFIG_PATH):
|
||
logger.info("S4U配置文件不存在,从模板创建新配置")
|
||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
|
||
return
|
||
|
||
# 读取旧配置文件和模板文件
|
||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||
old_config = tomlkit.load(f)
|
||
with open(TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
||
new_config = tomlkit.load(f)
|
||
|
||
# 检查version是否相同
|
||
if old_config and "inner" in old_config and "inner" in new_config:
|
||
old_version = old_config["inner"].get("version") # type: ignore
|
||
new_version = new_config["inner"].get("version") # type: ignore
|
||
if old_version and new_version and old_version == new_version:
|
||
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
|
||
return
|
||
else:
|
||
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||
else:
|
||
logger.info("S4U配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||
|
||
# 创建备份目录
|
||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||
os.makedirs(old_config_dir, exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
|
||
|
||
# 移动旧配置文件到old目录
|
||
shutil.move(CONFIG_PATH, old_backup_path)
|
||
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
|
||
|
||
# 复制模板文件到配置目录
|
||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
|
||
|
||
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||
"""
|
||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||
"""
|
||
for key, value in source.items():
|
||
# 跳过version字段的更新
|
||
if key == "version":
|
||
continue
|
||
if key in target:
|
||
target_value = target[key]
|
||
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
||
update_dict(target_value, value)
|
||
else:
|
||
try:
|
||
# 对数组类型进行特殊处理
|
||
if isinstance(value, list):
|
||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||
else:
|
||
# 其他类型使用item方法创建新值
|
||
target[key] = tomlkit.item(value)
|
||
except (TypeError, ValueError):
|
||
# 如果转换失败,直接赋值
|
||
target[key] = value
|
||
|
||
# 将旧配置的值更新到新配置中
|
||
logger.info("开始合并S4U新旧配置...")
|
||
update_dict(new_config, old_config)
|
||
|
||
# 保存更新后的配置(保留注释和格式)
|
||
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
||
f.write(tomlkit.dumps(new_config))
|
||
|
||
logger.info("S4U配置文件更新完成")
|
||
|
||
|
||
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
||
"""
|
||
加载S4U配置文件
|
||
:param config_path: 配置文件路径
|
||
:return: S4UGlobalConfig对象
|
||
"""
|
||
# 读取配置文件
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
config_data = tomlkit.load(f)
|
||
|
||
# 创建S4UGlobalConfig对象
|
||
try:
|
||
return S4UGlobalConfig.from_dict(config_data)
|
||
except Exception as e:
|
||
logger.critical("S4U配置文件解析失败")
|
||
raise e
|
||
|
||
|
||
if not ENABLE_S4U:
|
||
s4u_config = None
|
||
s4u_config_main = None
|
||
else:
|
||
# 初始化S4U配置
|
||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||
update_s4u_config()
|
||
|
||
logger.info("正在加载S4U配置文件...")
|
||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||
logger.info("S4U配置文件加载完成!")
|
||
|
||
s4u_config: S4UConfig = s4u_config_main.s4u |