Files
Mofox-Core/src/config/api_ada_configs.py

156 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass, field
from .config_base import ConfigBase
@dataclass
class APIProvider(ConfigBase):
"""API提供商配置类"""
name: str
"""API提供商名称"""
base_url: str
"""API基础URL"""
api_key: str = field(default_factory=str, repr=False)
"""API密钥列表"""
client_type: str = field(default="openai")
"""客户端类型如openai/google等默认为openai"""
max_retry: int = 2
"""最大重试次数单个模型API调用失败最多重试的次数"""
timeout: int = 10
"""API调用的超时时长超过这个时长本次请求将被视为"请求超时",单位:秒)"""
retry_interval: int = 10
"""重试间隔如果API调用失败重试的间隔时间单位"""
enable_content_obfuscation: bool = field(default=False)
"""是否启用内容混淆(用于特定场景下的内容处理)"""
obfuscation_intensity: int = field(default=1)
"""混淆强度1-3级数值越高混淆程度越强"""
def get_api_key(self) -> str:
return self.api_key
def __post_init__(self):
"""确保api_key在repr中不被显示"""
if not self.api_key:
raise ValueError("API密钥不能为空请在配置中设置有效的API密钥。")
if not self.base_url and self.client_type != "gemini":
raise ValueError("API基础URL不能为空请在配置中设置有效的基础URL。")
if not self.name:
raise ValueError("API提供商名称不能为空请在配置中设置有效的名称。")
@dataclass
class ModelInfo(ConfigBase):
"""单个模型信息配置类"""
model_identifier: str
"""模型标识符用于URL调用"""
name: str
"""模型名称(用于模块调用)"""
api_provider: str
"""API提供商如OpenAI、Azure等"""
price_in: float = field(default=0.0)
"""每M token输入价格"""
price_out: float = field(default=0.0)
"""每M token输出价格"""
force_stream_mode: bool = field(default=False)
"""是否强制使用流式输出模式"""
extra_params: dict = field(default_factory=dict)
"""额外参数用于API调用时的额外配置"""
def __post_init__(self):
if not self.model_identifier:
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
if not self.name:
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
if not self.api_provider:
raise ValueError("API提供商不能为空请在配置中设置有效的API提供商。")
@dataclass
class TaskConfig(ConfigBase):
"""任务配置类"""
model_list: list[str] = field(default_factory=list)
"""任务使用的模型列表"""
max_tokens: int = 1024
"""任务最大输出token数"""
temperature: float = 0.3
"""模型温度"""
@dataclass
class ModelTaskConfig(ConfigBase):
"""模型配置类"""
utils: TaskConfig
"""组件模型配置"""
utils_small: TaskConfig
"""组件小模型配置"""
replyer_1: TaskConfig
"""normal_chat首要回复模型模型配置"""
replyer_2: TaskConfig
"""normal_chat次要回复模型配置"""
emotion: TaskConfig
"""情绪模型配置"""
vlm: TaskConfig
"""视觉语言模型配置"""
voice: TaskConfig
"""语音识别模型配置"""
tool_use: TaskConfig
"""专注工具使用模型配置"""
planner: TaskConfig
"""规划模型配置"""
embedding: TaskConfig
"""嵌入模型配置"""
lpmm_entity_extract: TaskConfig
"""LPMM实体提取模型配置"""
lpmm_rdf_build: TaskConfig
"""LPMM RDF构建模型配置"""
lpmm_qa: TaskConfig
"""LPMM问答模型配置"""
schedule_generator: TaskConfig
"""日程生成模型配置"""
utils_video: TaskConfig = field(default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=1500,
temperature=0.3
))
"""视频分析模型配置"""
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
if hasattr(self, task_name):
return getattr(self, task_name)
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")