from dataclasses import dataclass, field from .config_base import ConfigBase @dataclass class APIProvider(ConfigBase): """API提供商配置类""" name: str """API提供商名称""" base_url: str """API基础URL""" api_key: str = field(default_factory=str, repr=False) """API密钥列表""" client_type: str = field(default="openai") """客户端类型(如openai/google等,默认为openai)""" max_retry: int = 2 """最大重试次数(单个模型API调用失败,最多重试的次数)""" timeout: int = 10 """API调用的超时时长(超过这个时长,本次请求将被视为"请求超时",单位:秒)""" retry_interval: int = 10 """重试间隔(如果API调用失败,重试的间隔时间,单位:秒)""" enable_content_obfuscation: bool = field(default=False) """是否启用内容混淆(用于特定场景下的内容处理)""" obfuscation_intensity: int = field(default=1) """混淆强度(1-3级,数值越高混淆程度越强)""" def get_api_key(self) -> str: return self.api_key def __post_init__(self): """确保api_key在repr中不被显示""" if not self.api_key: raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。") if not self.base_url and self.client_type != "gemini": raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。") if not self.name: raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。") @dataclass class ModelInfo(ConfigBase): """单个模型信息配置类""" model_identifier: str """模型标识符(用于URL调用)""" name: str """模型名称(用于模块调用)""" api_provider: str """API提供商(如OpenAI、Azure等)""" price_in: float = field(default=0.0) """每M token输入价格""" price_out: float = field(default=0.0) """每M token输出价格""" force_stream_mode: bool = field(default=False) """是否强制使用流式输出模式""" extra_params: dict = field(default_factory=dict) """额外参数(用于API调用时的额外配置)""" def __post_init__(self): if not self.model_identifier: raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。") if not self.name: raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。") if not self.api_provider: raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。") @dataclass class TaskConfig(ConfigBase): """任务配置类""" model_list: list[str] = field(default_factory=list) """任务使用的模型列表""" max_tokens: int = 1024 """任务最大输出token数""" temperature: float = 0.3 """模型温度""" @dataclass class ModelTaskConfig(ConfigBase): """模型配置类""" utils: TaskConfig """组件模型配置""" utils_small: TaskConfig """组件小模型配置""" replyer_1: TaskConfig """normal_chat首要回复模型模型配置""" replyer_2: TaskConfig """normal_chat次要回复模型配置""" emotion: TaskConfig """情绪模型配置""" vlm: TaskConfig """视觉语言模型配置""" voice: TaskConfig """语音识别模型配置""" tool_use: TaskConfig """专注工具使用模型配置""" planner: TaskConfig """规划模型配置""" embedding: TaskConfig """嵌入模型配置""" lpmm_entity_extract: TaskConfig """LPMM实体提取模型配置""" lpmm_rdf_build: TaskConfig """LPMM RDF构建模型配置""" lpmm_qa: TaskConfig """LPMM问答模型配置""" schedule_generator: TaskConfig """日程生成模型配置""" utils_video: TaskConfig = field(default_factory=lambda: TaskConfig( model_list=["qwen2.5-vl-72b"], max_tokens=1500, temperature=0.3 )) """视频分析模型配置""" def get_task(self, task_name: str) -> TaskConfig: """获取指定任务的配置""" if hasattr(self, task_name): return getattr(self, task_name) raise ValueError(f"任务 '{task_name}' 未找到对应的配置")