修复硬编码错误

This commit is contained in:
雅诺狐
2025-08-19 18:00:41 +08:00
parent d1efa3b5c1
commit 3dfb138d2c
3 changed files with 433 additions and 52 deletions

View File

@@ -88,10 +88,10 @@ class ModelInfo(ValidatedConfigBase):
class TaskConfig(ValidatedConfigBase):
"""任务配置类"""
model_list: List[str] = Field(default_factory=list, description="任务使用的模型列表")
max_tokens: int = Field(default=1024, ge=1, le=100000, description="任务最大输出token数")
temperature: float = Field(default=0.3, ge=0.0, le=2.0, description="模型温度")
concurrency_count: int = Field(default=1, ge=1, le=10, description="并发请求数量默认为1不并发")
model_list: List[str] = Field(..., description="任务使用的模型列表")
max_tokens: int = Field(default=None, ge=1, le=100000, description="任务最大输出token数")
temperature: float = Field(default=None, ge=0.0, le=2.0, description="模型温度")
concurrency_count: int = Field(default=None, ge=1, le=10, description="并发请求数量")
@field_validator('model_list')
@classmethod
@@ -103,22 +103,30 @@ class TaskConfig(ValidatedConfigBase):
raise ValueError("模型列表中不能有重复的模型")
return v
@field_validator('max_tokens')
@classmethod
def validate_max_tokens(cls, v):
"""验证最大token数"""
if v <= 0:
raise ValueError("最大token数必须大于0")
if v > 100000:
raise ValueError("最大token数不能超过100000")
return v
class ModelTaskConfig(ValidatedConfigBase):
"""模型配置类"""
# 必需配置项
utils: TaskConfig = Field(..., description="组件模型配置")
utils_small: TaskConfig = Field(..., description="组件小模型配置")
replyer_1: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置")
replyer_2: TaskConfig = Field(..., description="normal_chat次要回复模型配置")
maizone: TaskConfig = Field(..., description="maizone专用模型")
emotion: TaskConfig = Field(..., description="情绪模型配置")
vlm: TaskConfig = Field(..., description="视觉语言模型配置")
voice: TaskConfig = Field(..., description="语音识别模型配置")
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
planner: TaskConfig = Field(..., description="规划模型配置")
embedding: TaskConfig = Field(..., description="嵌入模型配置")
lpmm_entity_extract: TaskConfig = Field(..., description="LPMM实体提取模型配置")
lpmm_rdf_build: TaskConfig = Field(..., description="LPMM RDF构建模型配置")
lpmm_qa: TaskConfig = Field(..., description="LPMM问答模型配置")
schedule_generator: TaskConfig = Field(..., description="日程生成模型配置")
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
<<<<<<< Updated upstream
# 可选配置项(有默认值)
utils_small: TaskConfig = Field(
default_factory=lambda: TaskConfig(
@@ -240,37 +248,87 @@ class ModelTaskConfig(ValidatedConfigBase):
),
description="日程生成模型配置"
)
=======
# 处理配置文件中命名不一致的问题
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
>>>>>>> Stashed changes
# 可选配置项(有默认值)
video_analysis: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=1500,
temperature=0.3
),
description="视频分析模型配置"
)
emoji_vlm: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=800
),
description="表情包识别模型配置"
)
anti_injection: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=200,
temperature=0.1
),
description="反注入检测专用模型配置"
)
@property
def video_analysis(self) -> TaskConfig:
"""视频分析模型配置(提供向后兼容的属性访问)"""
return self.utils_video
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
# 处理向后兼容性如果请求video_analysis返回utils_video
if task_name == "video_analysis":
task_name = "utils_video"
if hasattr(self, task_name):
config = getattr(self, task_name)
if config is None:
raise ValueError(f"任务 '{task_name}' 未配置")
return config
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo] = Field(..., min_items=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表")
def __init__(self, **data):
super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
@field_validator('models')
@classmethod
def validate_models_list(cls, v):
"""验证模型列表"""
if not v:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
# 检查模型名称是否重复
model_names = [model.name for model in v]
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
# 检查模型标识符是否有效
for model in v:
if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
return v
@field_validator('api_providers')
@classmethod
def validate_api_providers_list(cls, v):
"""验证API提供商列表"""
if not v:
raise ValueError("API提供商列表不能为空请在配置中设置有效的API提供商列表。")
# 检查API提供商名称是否重复
provider_names = [provider.name for provider in v]
if len(provider_names) != len(set(provider_names)):
raise ValueError("API提供商名称存在重复请检查配置文件。")
return v
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
raise ValueError("模型名称不能为空")
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:
raise ValueError("API提供商名称不能为空")
if provider_name not in self.api_providers_dict:
raise KeyError(f"API提供商 '{provider_name}' 不存在")
return self.api_providers_dict[provider_name]

View File

@@ -487,10 +487,11 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建APIAdapterConfig对象各个配置类会自动进行 Pydantic 验证)
config_dict = dict(config_data)
try:
logger.info("正在解析和验证API适配器配置文件...")
config = APIAdapterConfig.from_dict(config_data)
config = APIAdapterConfig.from_dict(config_dict)
logger.info("API适配器配置文件解析和验证完成")
return config
except Exception as e: