feat: 添加任务类型和能力字段至模型配置,增强模型初始化逻辑

This commit is contained in:
墨梓柒
2025-07-29 09:57:20 +08:00
parent 254958fe85
commit 7313529dcb
4 changed files with 201 additions and 20 deletions

View File

@@ -85,7 +85,7 @@ class APIProvider:
# 如果所有key都不可用返回当前key让上层处理 # 如果所有key都不可用返回当前key让上层处理
return api_key return api_key
def reset_key_failures(self, api_key: str = None): def reset_key_failures(self, api_key: str | None = None):
"""重置失败计数(成功调用后调用)""" """重置失败计数(成功调用后调用)"""
with self._lock: with self._lock:
if api_key and api_key in self.api_keys: if api_key and api_key in self.api_keys:
@@ -125,6 +125,10 @@ class ModelInfo:
force_stream_mode: bool = False # 是否强制使用流式输出模式 force_stream_mode: bool = False # 是否强制使用流式输出模式
# 新增:任务类型和能力字段
task_type: str = "" # 任务类型llm_normal, llm_reasoning, vision, embedding, speech
capabilities: List[str] = field(default_factory=list) # 模型能力text, vision, embedding, speech, tool_calling, reasoning
@dataclass @dataclass
class RequestConfig: class RequestConfig:

View File

@@ -162,6 +162,8 @@ def _models(parent: Dict, config: ModuleConfig):
price_in = model.get("price_in", 0.0) price_in = model.get("price_in", 0.0)
price_out = model.get("price_out", 0.0) price_out = model.get("price_out", 0.0)
force_stream_mode = model.get("force_stream_mode", False) force_stream_mode = model.get("force_stream_mode", False)
task_type = model.get("task_type", "")
capabilities = model.get("capabilities", [])
if name in config.models: # 查重 if name in config.models: # 查重
logger.error(f"重复的模型名称: {name},请检查配置文件。") logger.error(f"重复的模型名称: {name},请检查配置文件。")
@@ -181,6 +183,8 @@ def _models(parent: Dict, config: ModuleConfig):
price_in=price_in, price_in=price_in,
price_out=price_out, price_out=price_out,
force_stream_mode=force_stream_mode, force_stream_mode=force_stream_mode,
task_type=task_type,
capabilities=capabilities,
) )
else: else:
logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。") logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")

View File

@@ -131,12 +131,24 @@ class LLMRequest:
**kwargs: 额外参数 **kwargs: 额外参数
""" """
logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('model_name', model.get('name', 'Unknown'))}") logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('model_name', model.get('name', 'Unknown'))}")
logger.debug(f"🔍 [模型初始化] 模型配置: {model}") logger.debug(f"🔍 [模型初始化] 输入的模型配置: {model}")
logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}") logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}")
# 兼容新旧模型配置格式 # 兼容新旧模型配置格式
# 新格式使用 model_name旧格式使用 name # 新格式使用 model_name旧格式使用 name
self.model_name: str = model.get("model_name", model.get("name", "")) self.model_name: str = model.get("model_name", model.get("name", ""))
# 如果传入的配置不完整,自动从全局配置中获取完整配置
if not all(key in model for key in ["task_type", "capabilities"]):
logger.debug("🔍 [模型初始化] 检测到不完整的模型配置,尝试获取完整配置")
if (full_model_config := self._get_full_model_config(self.model_name)):
logger.debug("🔍 [模型初始化] 成功获取完整模型配置,合并配置信息")
# 合并配置:运行时参数优先,但添加缺失的配置字段
model = {**full_model_config, **model}
logger.debug(f"🔍 [模型初始化] 合并后的模型配置: {model}")
else:
logger.warning(f"⚠️ [模型初始化] 无法获取模型 {self.model_name} 的完整配置,使用原始配置")
# 在新架构中provider信息从model_config.toml自动获取不需要在这里设置 # 在新架构中provider信息从model_config.toml自动获取不需要在这里设置
self.provider = model.get("provider", "") # 保留兼容性,但在新架构中不使用 self.provider = model.get("provider", "") # 保留兼容性,但在新架构中不使用
@@ -235,6 +247,13 @@ class LLMRequest:
Returns: Returns:
任务名称 任务名称
""" """
# 调试信息:打印模型配置字典的所有键
logger.debug(f"🔍 [任务确定] 模型配置字典的所有键: {list(model.keys())}")
logger.debug(f"🔍 [任务确定] 模型配置字典内容: {model}")
# 获取模型名称
model_name = model.get("model_name", model.get("name", ""))
# 方法1: 优先使用配置文件中明确定义的 task_type 字段 # 方法1: 优先使用配置文件中明确定义的 task_type 字段
if "task_type" in model: if "task_type" in model:
task_type = model["task_type"] task_type = model["task_type"]
@@ -262,7 +281,6 @@ class LLMRequest:
return task return task
# 方法3: 向后兼容 - 基于模型名称的关键字推断(不推荐但保留兼容性) # 方法3: 向后兼容 - 基于模型名称的关键字推断(不推荐但保留兼容性)
model_name = model.get("model_name", model.get("name", ""))
logger.warning(f"⚠️ [任务确定] 配置中未找到 task_type 或 capabilities回退到基于模型名称的推断: {model_name}") logger.warning(f"⚠️ [任务确定] 配置中未找到 task_type 或 capabilities回退到基于模型名称的推断: {model_name}")
logger.warning("⚠️ [建议] 请在 model_config.toml 中为模型添加明确的 task_type 或 capabilities 字段") logger.warning("⚠️ [建议] 请在 model_config.toml 中为模型添加明确的 task_type 或 capabilities 字段")
@@ -282,6 +300,76 @@ class LLMRequest:
logger.debug(f"🎯 [任务确定] 从 request_type {self.request_type} 推断为: {task}") logger.debug(f"🎯 [任务确定] 从 request_type {self.request_type} 推断为: {task}")
return task return task
def _get_full_model_config(self, model_name: str) -> dict | None:
"""
根据模型名称从全局配置中获取完整的模型配置
现在直接使用已解析的ModelInfo对象不再读取TOML文件
Args:
model_name: 模型名称
Returns:
完整的模型配置字典如果找不到则返回None
"""
try:
from src.config.config import model_config
return self._get_model_config_from_parsed(model_name, model_config)
except Exception as e:
logger.warning(f"⚠️ [配置查找] 获取模型配置时出错: {str(e)}")
return None
def _get_model_config_from_parsed(self, model_name: str, model_config) -> dict | None:
"""
从已解析的配置对象中获取模型配置
使用扩展后的ModelInfo类包含task_type和capabilities字段
"""
try:
# 直接通过模型名称查找
if model_name in model_config.models:
model_info = model_config.models[model_name]
logger.debug(f"🔍 [配置查找] 找到模型 {model_name} 的配置对象: {model_info}")
# 将ModelInfo对象转换为字典
model_dict = {
"model_identifier": model_info.model_identifier,
"name": model_info.name,
"api_provider": model_info.api_provider,
"price_in": model_info.price_in,
"price_out": model_info.price_out,
"force_stream_mode": model_info.force_stream_mode,
"task_type": model_info.task_type,
"capabilities": model_info.capabilities,
}
logger.debug(f"🔍 [配置查找] 转换后的模型配置字典: {model_dict}")
return model_dict
# 如果直接查找失败尝试通过model_identifier查找
for name, model_info in model_config.models.items():
if (model_info.model_identifier == model_name or
hasattr(model_info, 'model_name') and model_info.model_name == model_name):
logger.debug(f"🔍 [配置查找] 通过标识符找到模型 {model_name} (配置名称: {name})")
# 同样转换为字典
model_dict = {
"model_identifier": model_info.model_identifier,
"name": model_info.name,
"api_provider": model_info.api_provider,
"price_in": model_info.price_in,
"price_out": model_info.price_out,
"force_stream_mode": model_info.force_stream_mode,
"task_type": model_info.task_type,
"capabilities": model_info.capabilities,
}
return model_dict
return None
except Exception as e:
logger.warning(f"⚠️ [配置查找] 从已解析配置获取模型配置时出错: {str(e)}")
return None
@staticmethod @staticmethod
def _init_database(): def _init_database():
"""初始化数据库集合""" """初始化数据库集合"""

View File

@@ -1,7 +1,45 @@
[inner] [inner]
version = "0.1.1" version = "0.2.1"
# 配置文件版本号迭代规则同bot_config.toml # 配置文件版本号迭代规则同bot_config.toml
#
# === 多API Key支持 ===
# 本配置文件支持为每个API服务商配置多个API Key实现以下功能
# 1. 错误自动切换当某个API Key失败时自动切换到下一个可用的Key
# 2. 负载均衡在多个可用的API Key之间循环使用避免单个Key的频率限制
# 3. 向后兼容仍然支持单个key字段的配置方式
#
# 配置方式:
# - 多Key配置使用 api_keys = ["key1", "key2", "key3"] 数组格式
# - 单Key配置使用 key = "your-key" 字符串格式(向后兼容)
#
# 错误处理机制:
# - 401/403认证错误立即切换到下一个API Key
# - 429频率限制等待后重试如果持续失败则切换Key
# - 网络错误短暂等待后重试失败则切换Key
# - 其他错误:按照正常重试机制处理
#
# === 任务类型和模型能力配置 ===
# 为了提高任务分配的准确性和可维护性,现在支持明确配置模型的任务类型和能力:
#
# task_type推荐配置:
# - 明确指定模型主要用于什么任务
# - 可选值llm_normal, llm_reasoning, vision, embedding, speech
# - 如果不配置系统会根据capabilities或模型名称自动推断不推荐
#
# capabilities推荐配置:
# - 描述模型支持的所有能力
# - 可选值text, vision, embedding, speech, tool_calling, reasoning
# - 支持多个能力的组合,如:["text", "vision"]
#
# 配置优先级:
# 1. task_type最高优先级直接指定任务类型
# 2. capabilities中等优先级根据能力推断任务类型
# 3. 模型名称关键字(最低优先级,不推荐依赖)
#
# 向后兼容:
# - 仍然支持 model_flags 字段,但建议迁移到 capabilities
# - 未配置新字段时会自动回退到基于模型名称的推断
[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释) [request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释)
#max_retry = 2 # 最大重试次数单个模型API调用失败最多重试的次数 #max_retry = 2 # 最大重试次数单个模型API调用失败最多重试的次数
@@ -13,20 +51,32 @@ version = "0.1.1"
[[api_providers]] # API服务提供商可以配置多个 [[api_providers]] # API服务提供商可以配置多个
name = "DeepSeek" # API服务商名称可随意命名在models的api-provider中需使用这个命名 name = "DeepSeek" # API服务商名称可随意命名在models的api-provider中需使用这个命名
base_url = "https://api.deepseek.cn" # API服务商的BaseURL base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL
key = "******" # API Key 可选默认为None # 支持多个API Key实现自动切换和负载均衡
client_type = "openai" # 请求客户端(可选,默认值为"openai"使用gimini等Google系模型时请配置为"google" api_keys = [ # API Key列表多个key支持错误自动切换和负载均衡
"sk-your-first-key-here",
"sk-your-second-key-here",
"sk-your-third-key-here"
]
# 向后兼容如果只有一个key也可以使用单个key字段
#key = "******" # API Key 可选默认为None
client_type = "openai" # 请求客户端(可选,默认值为"openai"使用gimini等Google系模型时请配置为"gemini"
#[[api_providers]] # 特殊Google的Gimini使用特殊API与OpenAI格式不兼容需要配置client为"google" [[api_providers]] # 特殊Google的Gimini使用特殊API与OpenAI格式不兼容需要配置client为"gemini"
#name = "Google" name = "Google"
#base_url = "https://api.google.com" base_url = "https://api.google.com/v1"
#key = "******" # Google API同样支持多key配置
#client_type = "google" api_keys = [
# "your-google-api-key-1",
#[[api_providers]] "your-google-api-key-2"
#name = "SiliconFlow" ]
#base_url = "https://api.siliconflow.cn" client_type = "gemini"
#key = "******"
[[api_providers]]
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
# 单个key的示例向后兼容
key = "******"
# #
#[[api_providers]] #[[api_providers]]
#name = "LocalHost" #name = "LocalHost"
@@ -42,6 +92,13 @@ model_identifier = "deepseek-chat"
name = "deepseek-v3" name = "deepseek-v3"
# API服务商名称对应在api_providers中配置的服务商名称 # API服务商名称对应在api_providers中配置的服务商名称
api_provider = "DeepSeek" api_provider = "DeepSeek"
# 任务类型(推荐配置,明确指定模型主要用于什么任务)
# 可选值llm_normal, llm_reasoning, vision, embedding, speech
# 如果不配置系统会根据capabilities或模型名称自动推断
task_type = "llm_normal"
# 模型能力列表(推荐配置,描述模型支持的能力)
# 可选值text, vision, embedding, speech, tool_calling, reasoning
capabilities = ["text", "tool_calling"]
# 输入价格用于API调用统计单位元/兆token可选若无该字段默认值为0 # 输入价格用于API调用统计单位元/兆token可选若无该字段默认值为0
price_in = 2.0 price_in = 2.0
# 输出价格用于API调用统计单位元/兆token可选若无该字段默认值为0 # 输出价格用于API调用统计单位元/兆token可选若无该字段默认值为0
@@ -54,6 +111,10 @@ price_out = 8.0
model_identifier = "deepseek-reasoner" model_identifier = "deepseek-reasoner"
name = "deepseek-r1" name = "deepseek-r1"
api_provider = "DeepSeek" api_provider = "DeepSeek"
# 推理模型的配置示例
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "text", "tool_calling", "reasoning",] model_flags = [ "text", "tool_calling", "reasoning",]
price_in = 4.0 price_in = 4.0
price_out = 16.0 price_out = 16.0
@@ -62,6 +123,8 @@ price_out = 16.0
model_identifier = "Pro/deepseek-ai/DeepSeek-V3" model_identifier = "Pro/deepseek-ai/DeepSeek-V3"
name = "siliconflow-deepseek-v3" name = "siliconflow-deepseek-v3"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 2.0 price_in = 2.0
price_out = 8.0 price_out = 8.0
@@ -69,6 +132,8 @@ price_out = 8.0
model_identifier = "Pro/deepseek-ai/DeepSeek-R1" model_identifier = "Pro/deepseek-ai/DeepSeek-R1"
name = "siliconflow-deepseek-r1" name = "siliconflow-deepseek-r1"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
price_in = 4.0 price_in = 4.0
price_out = 16.0 price_out = 16.0
@@ -76,6 +141,8 @@ price_out = 16.0
model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
name = "deepseek-r1-distill-qwen-32b" name = "deepseek-r1-distill-qwen-32b"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
price_in = 4.0 price_in = 4.0
price_out = 16.0 price_out = 16.0
@@ -83,6 +150,8 @@ price_out = 16.0
model_identifier = "Qwen/Qwen3-8B" model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b" name = "qwen3-8b"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text"]
price_in = 0 price_in = 0
price_out = 0 price_out = 0
@@ -90,6 +159,8 @@ price_out = 0
model_identifier = "Qwen/Qwen3-14B" model_identifier = "Qwen/Qwen3-14B"
name = "qwen3-14b" name = "qwen3-14b"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 0.5 price_in = 0.5
price_out = 2.0 price_out = 2.0
@@ -97,6 +168,8 @@ price_out = 2.0
model_identifier = "Qwen/Qwen3-30B-A3B" model_identifier = "Qwen/Qwen3-30B-A3B"
name = "qwen3-30b" name = "qwen3-30b"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 0.7 price_in = 0.7
price_out = 2.8 price_out = 2.8
@@ -104,6 +177,10 @@ price_out = 2.8
model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct" model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct"
name = "qwen2.5-vl-72b" name = "qwen2.5-vl-72b"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
# 视觉模型的配置示例
task_type = "vision"
capabilities = ["vision", "text"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "vision", "text",] model_flags = [ "vision", "text",]
price_in = 4.13 price_in = 4.13
price_out = 4.13 price_out = 4.13
@@ -112,6 +189,10 @@ price_out = 4.13
model_identifier = "FunAudioLLM/SenseVoiceSmall" model_identifier = "FunAudioLLM/SenseVoiceSmall"
name = "sensevoice-small" name = "sensevoice-small"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
# 语音模型的配置示例
task_type = "speech"
capabilities = ["speech"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "audio",] model_flags = [ "audio",]
price_in = 0 price_in = 0
price_out = 0 price_out = 0
@@ -120,15 +201,19 @@ price_out = 0
model_identifier = "BAAI/bge-m3" model_identifier = "BAAI/bge-m3"
name = "bge-m3" name = "bge-m3"
api_provider = "SiliconFlow" api_provider = "SiliconFlow"
# 嵌入模型的配置示例
task_type = "embedding"
capabilities = ["text", "embedding"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "text", "embedding",] model_flags = [ "text", "embedding",]
price_in = 0 price_in = 0
price_out = 0 price_out = 0
[task_model_usage] [task_model_usage]
#llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0} llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0}
#llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0} llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0}
#embedding = "siliconflow-bge-m3" embedding = "siliconflow-bge-m3"
#schedule = [ #schedule = [
# "deepseek-v3", # "deepseek-v3",
# "deepseek-r1", # "deepseek-r1",