本次更新彻底重构了动作规划器(Planner)的核心架构,废弃了原有的“大脑/小脑”并行决策模型,转而采用一个更简洁、高效的统一决策模型。 主要变更: - **统一决策**: 单个LLM调用现在可以一次性决策出所有需要执行的动作,并以JSON列表的形式返回。 - **架构简化**: 完全移除了 `sub_plan`(小脑)逻辑、`planner_small` 模型以及相关的并行处理和结果合并代码,大幅降低了复杂性。 - **配置精简**: 从配置文件中删除了与小脑相关的 `planner_size` 和 `include_personality` 选项,简化了用户配置。 - **提示词更新**: 更新了规划器的Prompt,明确指示LLM返回一个动作列表,即使只有一个动作或没有动作。 带来的好处: - **性能提升**: 减少了LLM API的调用次数,显著降低了单次规划的延迟和成本。 - **可维护性**: 代码逻辑更清晰、线性,易于理解和后续维护。 - **稳定性**: 减少了多路并发带来的不确定性和潜在的竞态问题。 BREAKING CHANGE: 移除了大脑/小脑规划器架构。 用户需要从 `model_config.toml` 中移除 `[model_task_config.planner_small]` 配置节,并从 `bot_config.toml` 中移除 `planner_size` 和 `include_personality` 配置项。
229 lines
10 KiB
Python
229 lines
10 KiB
Python
from typing import List, Dict, Any, Literal, Union
|
||
from pydantic import Field, field_validator
|
||
from threading import Lock
|
||
|
||
from src.config.config_base import ValidatedConfigBase
|
||
|
||
|
||
class APIProvider(ValidatedConfigBase):
|
||
"""API提供商配置类"""
|
||
|
||
name: str = Field(..., min_length=1, description="API提供商名称")
|
||
base_url: str = Field(..., description="API基础URL")
|
||
api_key: Union[str, List[str]] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询")
|
||
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
|
||
default="openai", description="客户端类型(如openai/google等,默认为openai)"
|
||
)
|
||
max_retry: int = Field(default=2, ge=0, description="最大重试次数(单个模型API调用失败,最多重试的次数)")
|
||
timeout: int = Field(
|
||
default=10, ge=1, description="API调用的超时时长(超过这个时长,本次请求将被视为'请求超时',单位:秒)"
|
||
)
|
||
retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)")
|
||
enable_content_obfuscation: bool = Field(default=False, description="是否启用内容混淆(用于特定场景下的内容处理)")
|
||
obfuscation_intensity: int = Field(default=1, ge=1, le=3, description="混淆强度(1-3级,数值越高混淆程度越强)")
|
||
|
||
@field_validator("base_url")
|
||
@classmethod
|
||
def validate_base_url(cls, v):
|
||
"""验证base_url,确保URL格式正确"""
|
||
if v and not (v.startswith("http://") or v.startswith("https://")):
|
||
raise ValueError("base_url必须以http://或https://开头")
|
||
return v
|
||
|
||
@field_validator("api_key")
|
||
@classmethod
|
||
def validate_api_key(cls, v):
|
||
"""验证API密钥不能为空"""
|
||
if isinstance(v, str):
|
||
if not v.strip():
|
||
raise ValueError("API密钥不能为空")
|
||
elif isinstance(v, list):
|
||
if not v:
|
||
raise ValueError("API密钥列表不能为空")
|
||
for key in v:
|
||
if not isinstance(key, str) or not key.strip():
|
||
raise ValueError("API密钥列表中的密钥不能为空")
|
||
else:
|
||
raise ValueError("API密钥必须是字符串或字符串列表")
|
||
return v
|
||
|
||
def __init__(self, **data):
|
||
super().__init__(**data)
|
||
self._api_key_lock = Lock()
|
||
self._api_key_index = 0
|
||
|
||
def get_api_key(self) -> str:
|
||
with self._api_key_lock:
|
||
if isinstance(self.api_key, str):
|
||
return self.api_key
|
||
if not self.api_key:
|
||
raise ValueError("API密钥列表为空")
|
||
key = self.api_key[self._api_key_index]
|
||
self._api_key_index = (self._api_key_index + 1) % len(self.api_key)
|
||
return key
|
||
|
||
|
||
class ModelInfo(ValidatedConfigBase):
|
||
"""单个模型信息配置类"""
|
||
|
||
model_identifier: str = Field(..., min_length=1, description="模型标识符(用于URL调用)")
|
||
name: str = Field(..., min_length=1, description="模型名称(用于模块调用)")
|
||
api_provider: str = Field(..., min_length=1, description="API提供商(如OpenAI、Azure等)")
|
||
price_in: float = Field(default=0.0, ge=0, description="每M token输入价格")
|
||
price_out: float = Field(default=0.0, ge=0, description="每M token输出价格")
|
||
force_stream_mode: bool = Field(default=False, description="是否强制使用流式输出模式")
|
||
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)")
|
||
anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断")
|
||
|
||
@field_validator("price_in", "price_out")
|
||
@classmethod
|
||
def validate_prices(cls, v):
|
||
"""验证价格必须为非负数"""
|
||
if v < 0:
|
||
raise ValueError("价格不能为负数")
|
||
return v
|
||
|
||
@field_validator("model_identifier")
|
||
@classmethod
|
||
def validate_model_identifier(cls, v):
|
||
"""验证模型标识符不能为空且不能包含特殊字符"""
|
||
if not v or not v.strip():
|
||
raise ValueError("模型标识符不能为空")
|
||
# 检查是否包含危险字符
|
||
if any(char in v for char in [" ", "\n", "\t", "\r"]):
|
||
raise ValueError("模型标识符不能包含空格或换行符")
|
||
return v
|
||
|
||
@field_validator("name")
|
||
@classmethod
|
||
def validate_name(cls, v):
|
||
"""验证模型名称不能为空"""
|
||
if not v or not v.strip():
|
||
raise ValueError("模型名称不能为空")
|
||
return v
|
||
|
||
|
||
class TaskConfig(ValidatedConfigBase):
|
||
"""任务配置类"""
|
||
|
||
model_list: List[str] = Field(..., description="任务使用的模型列表")
|
||
max_tokens: int = Field(default=800, description="任务最大输出token数")
|
||
temperature: float = Field(default=0.7, description="模型温度")
|
||
concurrency_count: int = Field(default=1, description="并发请求数量")
|
||
|
||
@field_validator("model_list")
|
||
@classmethod
|
||
def validate_model_list(cls, v):
|
||
"""验证模型列表不能为空"""
|
||
if not v:
|
||
raise ValueError("模型列表不能为空")
|
||
if len(v) != len(set(v)):
|
||
raise ValueError("模型列表中不能有重复的模型")
|
||
return v
|
||
|
||
|
||
class ModelTaskConfig(ValidatedConfigBase):
|
||
"""模型配置类"""
|
||
|
||
# 必需配置项
|
||
utils: TaskConfig = Field(..., description="组件模型配置")
|
||
utils_small: TaskConfig = Field(..., description="组件小模型配置")
|
||
replyer: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置")
|
||
maizone: TaskConfig = Field(..., description="maizone专用模型")
|
||
emotion: TaskConfig = Field(..., description="情绪模型配置")
|
||
vlm: TaskConfig = Field(..., description="视觉语言模型配置")
|
||
voice: TaskConfig = Field(..., description="语音识别模型配置")
|
||
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
|
||
planner: TaskConfig = Field(..., description="规划模型配置")
|
||
embedding: TaskConfig = Field(..., description="嵌入模型配置")
|
||
lpmm_entity_extract: TaskConfig = Field(..., description="LPMM实体提取模型配置")
|
||
lpmm_rdf_build: TaskConfig = Field(..., description="LPMM RDF构建模型配置")
|
||
lpmm_qa: TaskConfig = Field(..., description="LPMM问答模型配置")
|
||
schedule_generator: TaskConfig = Field(..., description="日程生成模型配置")
|
||
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
|
||
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
|
||
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
|
||
|
||
# 处理配置文件中命名不一致的问题
|
||
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
|
||
|
||
@property
|
||
def video_analysis(self) -> TaskConfig:
|
||
"""视频分析模型配置(提供向后兼容的属性访问)"""
|
||
return self.utils_video
|
||
|
||
def get_task(self, task_name: str) -> TaskConfig:
|
||
"""获取指定任务的配置"""
|
||
# 处理向后兼容性:如果请求video_analysis,返回utils_video
|
||
if task_name == "video_analysis":
|
||
task_name = "utils_video"
|
||
|
||
if hasattr(self, task_name):
|
||
config = getattr(self, task_name)
|
||
if config is None:
|
||
raise ValueError(f"任务 '{task_name}' 未配置")
|
||
return config
|
||
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
|
||
|
||
|
||
class APIAdapterConfig(ValidatedConfigBase):
|
||
"""API Adapter配置类"""
|
||
|
||
models: List[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||
api_providers: List[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||
|
||
def __init__(self, **data):
|
||
super().__init__(**data)
|
||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||
self.models_dict = {model.name: model for model in self.models}
|
||
|
||
@field_validator("models")
|
||
@classmethod
|
||
def validate_models_list(cls, v):
|
||
"""验证模型列表"""
|
||
if not v:
|
||
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||
|
||
# 检查模型名称是否重复
|
||
model_names = [model.name for model in v]
|
||
if len(model_names) != len(set(model_names)):
|
||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||
|
||
# 检查模型标识符是否有效
|
||
for model in v:
|
||
if not model.model_identifier:
|
||
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
|
||
|
||
return v
|
||
|
||
@field_validator("api_providers")
|
||
@classmethod
|
||
def validate_api_providers_list(cls, v):
|
||
"""验证API提供商列表"""
|
||
if not v:
|
||
raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
|
||
|
||
# 检查API提供商名称是否重复
|
||
provider_names = [provider.name for provider in v]
|
||
if len(provider_names) != len(set(provider_names)):
|
||
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||
|
||
return v
|
||
|
||
def get_model_info(self, model_name: str) -> ModelInfo:
|
||
"""根据模型名称获取模型信息"""
|
||
if not model_name:
|
||
raise ValueError("模型名称不能为空")
|
||
if model_name not in self.models_dict:
|
||
raise KeyError(f"模型 '{model_name}' 不存在")
|
||
return self.models_dict[model_name]
|
||
|
||
def get_provider(self, provider_name: str) -> APIProvider:
|
||
"""根据提供商名称获取API提供商信息"""
|
||
if not provider_name:
|
||
raise ValueError("API提供商名称不能为空")
|
||
if provider_name not in self.api_providers_dict:
|
||
raise KeyError(f"API提供商 '{provider_name}' 不存在")
|
||
return self.api_providers_dict[provider_name]
|