大修LLMReq

This commit is contained in:
UnCLAS-Prommer
2025-07-30 09:45:13 +08:00
parent 94db64c118
commit 3c40ceda4c
15 changed files with 2290 additions and 1995 deletions

View File

@@ -1,180 +1,128 @@
from dataclasses import dataclass, field
from typing import List, Dict, Union
import threading
import time
from packaging.version import Version
NEWEST_VER = "0.2.1" # 当前支持的最新版本
@dataclass
class APIProvider:
name: str = "" # API提供商名称
base_url: str = "" # API基础URL
api_key: str = field(repr=False, default="") # API密钥向后兼容
api_keys: List[str] = field(repr=False, default_factory=list) # API密钥列表新格式
client_type: str = "openai" # 客户端类型如openai/google等默认为openai
# 多API Key管理相关属性
_current_key_index: int = field(default=0, init=False, repr=False) # 当前使用的key索引
_key_failure_count: Dict[int, int] = field(default_factory=dict, init=False, repr=False) # 每个key的失败次数
_key_last_failure_time: Dict[int, float] = field(default_factory=dict, init=False, repr=False) # 每个key最后失败时间
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) # 线程锁
def __post_init__(self):
"""初始化后处理确保API keys列表正确"""
# 向后兼容如果只设置了api_key将其添加到api_keys列表
if self.api_key and not self.api_keys:
self.api_keys = [self.api_key]
# 如果api_keys不为空但api_key为空设置api_key为第一个
elif self.api_keys and not self.api_key:
self.api_key = self.api_keys[0]
# 初始化失败计数器
for i in range(len(self.api_keys)):
self._key_failure_count[i] = 0
self._key_last_failure_time[i] = 0
def get_current_api_key(self) -> str:
"""获取当前应该使用的API Key"""
with self._lock:
if not self.api_keys:
return ""
# 确保索引在有效范围内
if self._current_key_index >= len(self.api_keys):
self._current_key_index = 0
return self.api_keys[self._current_key_index]
def get_next_api_key(self) -> Union[str, None]:
"""获取下一个可用的API Key负载均衡"""
with self._lock:
if not self.api_keys:
return None
# 如果只有一个key直接返回
if len(self.api_keys) == 1:
return self.api_keys[0]
# 轮询到下一个key
self._current_key_index = (self._current_key_index + 1) % len(self.api_keys)
return self.api_keys[self._current_key_index]
def mark_key_failed(self, api_key: str) -> Union[str, None]:
"""标记某个API Key失败返回下一个可用的key"""
with self._lock:
if not self.api_keys or api_key not in self.api_keys:
return None
key_index = self.api_keys.index(api_key)
self._key_failure_count[key_index] += 1
self._key_last_failure_time[key_index] = time.time()
# 寻找下一个可用的key
current_time = time.time()
for _ in range(len(self.api_keys)):
self._current_key_index = (self._current_key_index + 1) % len(self.api_keys)
next_key_index = self._current_key_index
# 检查该key是否最近失败过5分钟内失败超过3次则暂时跳过
if (self._key_failure_count[next_key_index] <= 3 or
current_time - self._key_last_failure_time[next_key_index] > 300): # 5分钟后重试
return self.api_keys[next_key_index]
# 如果所有key都不可用返回当前key让上层处理
return api_key
def reset_key_failures(self, api_key: str | None = None):
"""重置失败计数(成功调用后调用)"""
with self._lock:
if api_key and api_key in self.api_keys:
key_index = self.api_keys.index(api_key)
self._key_failure_count[key_index] = 0
self._key_last_failure_time[key_index] = 0
else:
# 重置所有key的失败计数
for i in range(len(self.api_keys)):
self._key_failure_count[i] = 0
self._key_last_failure_time[i] = 0
def get_api_key_stats(self) -> Dict[str, Dict[str, Union[int, float]]]:
"""获取API Key使用统计"""
with self._lock:
stats = {}
for i, key in enumerate(self.api_keys):
# 只显示key的前8位和后4位中间用*代替
masked_key = f"{key[:8]}***{key[-4:]}" if len(key) > 12 else "***"
stats[masked_key] = {
"failure_count": self._key_failure_count.get(i, 0),
"last_failure_time": self._key_last_failure_time.get(i, 0),
"is_current": i == self._current_key_index
}
return stats
from .config_base import ConfigBase
@dataclass
class ModelInfo:
model_identifier: str = "" # 模型标识符用于URL调用
name: str = "" # 模型名称(用于模块调用)
api_provider: str = "" # API提供商如OpenAI、Azure等
class APIProvider(ConfigBase):
"""API提供商配置类"""
# 以下用于模型计费
price_in: float = 0.0 # 每M token输入价格
price_out: float = 0.0 # 每M token输出价格
name: str
"""API提供商名称"""
force_stream_mode: bool = False # 是否强制使用流式输出模式
# 新增:任务类型和能力字段
task_type: str = "" # 任务类型llm_normal, llm_reasoning, vision, embedding, speech
capabilities: List[str] = field(default_factory=list) # 模型能力text, vision, embedding, speech, tool_calling, reasoning
base_url: str
"""API基础URL"""
api_key: str = field(default_factory=str, repr=False)
"""API密钥列表"""
client_type: str = field(default="openai")
"""客户端类型如openai/google等默认为openai"""
max_retry: int = 2
"""最大重试次数单个模型API调用失败最多重试的次数"""
timeout: int = 10
"""API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位"""
retry_interval: int = 10
"""重试间隔如果API调用失败重试的间隔时间单位"""
def get_api_key(self) -> str:
return self.api_key
@dataclass
class RequestConfig:
max_retry: int = 2 # 最大重试次数单个模型API调用失败最多重试的次数
timeout: int = (
10 # API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位
)
retry_interval: int = 10 # 重试间隔如果API调用失败重试的间隔时间单位
default_temperature: float = 0.7 # 默认的温度如果bot_config.toml中没有设置temperature参数默认使用这个值
default_max_tokens: int = 1024 # 默认的最大输出token数如果bot_config.toml中没有设置max_tokens参数默认使用这个值
class ModelInfo(ConfigBase):
"""单个模型信息配置类"""
model_identifier: str
"""模型标识符用于URL调用"""
name: str
"""模型名称(用于模块调用)"""
api_provider: str
"""API提供商如OpenAI、Azure等"""
price_in: float = field(default=0.0)
"""每M token输入价格"""
price_out: float = field(default=0.0)
"""每M token输出价格"""
force_stream_mode: bool = field(default=False)
"""是否强制使用流式输出模式"""
has_thinking: bool = field(default=False)
"""是否有思考参数"""
enable_thinking: bool = field(default=False)
"""是否启用思考"""
@dataclass
class ModelUsageArgConfigItem:
"""模型使用的配置类
该类用于加载和存储子任务模型使用的配置
"""
class TaskConfig(ConfigBase):
"""任务配置类"""
name: str = "" # 模型名称
temperature: float | None = None # 温度
max_tokens: int | None = None # 最大token数
max_retry: int | None = None # 调用失败时的最大重试次数
model_list: list[str] = field(default_factory=list)
"""任务使用的模型列表"""
max_tokens: int = 1024
"""任务最大输出token数"""
temperature: float = 0.3
"""模型温度"""
@dataclass
class ModelUsageArgConfig:
"""子任务使用模型配置类
该类用于加载和存储子任务使用的模型配置
"""
class ModelTaskConfig(ConfigBase):
"""模型配置类"""
name: str = "" # 任务名称
usage: List[ModelUsageArgConfigItem] = field(
default_factory=lambda: []
) # 任务使用的模型列表
utils: TaskConfig
"""组件模型配置"""
utils_small: TaskConfig
"""组件小模型配置"""
replyer_1: TaskConfig
"""normal_chat首要回复模型模型配置"""
@dataclass
class ModuleConfig:
INNER_VERSION: Version | None = None # 配置文件版本
replyer_2: TaskConfig
"""normal_chat次要回复模型配置"""
req_conf: RequestConfig = field(default_factory=lambda: RequestConfig()) # 请求配置
api_providers: Dict[str, APIProvider] = field(
default_factory=lambda: {}
) # API提供商列表
models: Dict[str, ModelInfo] = field(default_factory=lambda: {}) # 模型列表
task_model_arg_map: Dict[str, ModelUsageArgConfig] = field(
default_factory=lambda: {}
)
memory: TaskConfig
"""记忆模型配置"""
emotion: TaskConfig
"""情绪模型配置"""
vlm: TaskConfig
"""视觉语言模型配置"""
voice: TaskConfig
"""语音识别模型配置"""
tool_use: TaskConfig
"""专注工具使用模型配置"""
planner: TaskConfig
"""规划模型配置"""
embedding: TaskConfig
"""嵌入模型配置"""
lpmm_entity_extract: TaskConfig
"""LPMM实体提取模型配置"""
lpmm_rdf_build: TaskConfig
"""LPMM RDF构建模型配置"""
lpmm_qa: TaskConfig
"""LPMM问答模型配置"""
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
if hasattr(self, task_name):
return getattr(self, task_name)
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")

View File

@@ -1,16 +1,14 @@
import os
import tomlkit
import shutil
import sys
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass
from rich.traceback import install
from packaging import version
from packaging.specifiers import SpecifierSet
from packaging.version import Version, InvalidVersion
from typing import Any, Dict, List
from typing import List, Optional
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
@@ -29,7 +27,6 @@ from src.config.official_configs import (
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
ModelConfig,
MessageReceiveConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
@@ -41,16 +38,12 @@ from src.config.official_configs import (
)
from .api_ada_configs import (
ModelUsageArgConfigItem,
ModelUsageArgConfig,
APIProvider,
ModelTaskConfig,
ModelInfo,
NEWEST_VER,
ModuleConfig,
APIProvider,
)
install(extra_lines=3)
@@ -64,275 +57,270 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.10.0-snapshot1"
MMC_VERSION = "0.10.0-snapshot.2"
# def _get_config_version(toml: Dict) -> Version:
# """提取配置文件的 SpecifierSet 版本数据
# Args:
# toml[dict]: 输入的配置文件字典
# Returns:
# Version
# """
# if "inner" in toml and "version" in toml["inner"]:
# config_version: str = toml["inner"]["version"]
# else:
# raise InvalidVersion("配置文件缺少版本信息,请检查配置文件。")
# try:
# return version.parse(config_version)
# except InvalidVersion as e:
# logger.error(
# "配置文件中 inner段 的 version 键是错误的版本描述\n"
# f"请检查配置文件,当前 version 键: {config_version}\n"
# f"错误信息: {e}"
# )
# raise e
def _get_config_version(toml: Dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
"""
if "inner" in toml and "version" in toml["inner"]:
config_version: str = toml["inner"]["version"]
else:
config_version = "0.0.0" # 默认版本
try:
ver = version.parse(config_version)
except InvalidVersion as e:
logger.error(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
f"请检查配置文件,当前 version 键: {config_version}\n"
f"错误信息: {e}"
)
raise InvalidVersion(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
) from e
return ver
# def _request_conf(parent: Dict, config: ModuleConfig):
# request_conf_config = parent.get("request_conf")
# config.req_conf.max_retry = request_conf_config.get(
# "max_retry", config.req_conf.max_retry
# )
# config.req_conf.timeout = request_conf_config.get(
# "timeout", config.req_conf.timeout
# )
# config.req_conf.retry_interval = request_conf_config.get(
# "retry_interval", config.req_conf.retry_interval
# )
# config.req_conf.default_temperature = request_conf_config.get(
# "default_temperature", config.req_conf.default_temperature
# )
# config.req_conf.default_max_tokens = request_conf_config.get(
# "default_max_tokens", config.req_conf.default_max_tokens
# )
def _request_conf(parent: Dict, config: ModuleConfig):
request_conf_config = parent.get("request_conf")
config.req_conf.max_retry = request_conf_config.get(
"max_retry", config.req_conf.max_retry
)
config.req_conf.timeout = request_conf_config.get(
"timeout", config.req_conf.timeout
)
config.req_conf.retry_interval = request_conf_config.get(
"retry_interval", config.req_conf.retry_interval
)
config.req_conf.default_temperature = request_conf_config.get(
"default_temperature", config.req_conf.default_temperature
)
config.req_conf.default_max_tokens = request_conf_config.get(
"default_max_tokens", config.req_conf.default_max_tokens
)
# def _api_providers(parent: Dict, config: ModuleConfig):
# api_providers_config = parent.get("api_providers")
# for provider in api_providers_config:
# name = provider.get("name", None)
# base_url = provider.get("base_url", None)
# api_key = provider.get("api_key", None)
# api_keys = provider.get("api_keys", []) # 新增支持多个API Key
# client_type = provider.get("client_type", "openai")
# if name in config.api_providers: # 查重
# logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
# raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
# if name and base_url:
# # 处理API Key配置支持单个api_key或多个api_keys
# if api_keys:
# # 使用新格式api_keys列表
# logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key")
# elif api_key:
# # 向后兼容使用单个api_key
# api_keys = [api_key]
# logger.debug(f"API提供商 '{name}' 使用单个API Key向后兼容模式")
# else:
# logger.warning(f"API提供商 '{name}' 没有配置API Key某些功能可能不可用")
# config.api_providers[name] = APIProvider(
# name=name,
# base_url=base_url,
# api_key=api_key, # 保留向后兼容
# api_keys=api_keys, # 新格式
# client_type=client_type,
# )
# else:
# logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
# raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
def _api_providers(parent: Dict, config: ModuleConfig):
api_providers_config = parent.get("api_providers")
for provider in api_providers_config:
name = provider.get("name", None)
base_url = provider.get("base_url", None)
api_key = provider.get("api_key", None)
api_keys = provider.get("api_keys", []) # 新增支持多个API Key
client_type = provider.get("client_type", "openai")
# def _models(parent: Dict, config: ModuleConfig):
# models_config = parent.get("models")
# for model in models_config:
# model_identifier = model.get("model_identifier", None)
# name = model.get("name", model_identifier)
# api_provider = model.get("api_provider", None)
# price_in = model.get("price_in", 0.0)
# price_out = model.get("price_out", 0.0)
# force_stream_mode = model.get("force_stream_mode", False)
# task_type = model.get("task_type", "")
# capabilities = model.get("capabilities", [])
if name in config.api_providers: # 查重
logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
# if name in config.models: # 查重
# logger.error(f"重复的模型名称: {name},请检查配置文件。")
# raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
if name and base_url:
# 处理API Key配置支持单个api_key或多个api_keys
if api_keys:
# 使用新格式api_keys列表
logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key")
elif api_key:
# 向后兼容使用单个api_key
api_keys = [api_key]
logger.debug(f"API提供商 '{name}' 使用单个API Key向后兼容模式")
else:
logger.warning(f"API提供商 '{name}' 没有配置API Key某些功能可能不可用")
config.api_providers[name] = APIProvider(
name=name,
base_url=base_url,
api_key=api_key, # 保留向后兼容
api_keys=api_keys, # 新格式
client_type=client_type,
)
else:
logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
# if model_identifier and api_provider:
# # 检查API提供商是否存在
# if api_provider not in config.api_providers:
# logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
# raise ValueError(
# f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
# )
# config.models[name] = ModelInfo(
# name=name,
# model_identifier=model_identifier,
# api_provider=api_provider,
# price_in=price_in,
# price_out=price_out,
# force_stream_mode=force_stream_mode,
# task_type=task_type,
# capabilities=capabilities,
# )
# else:
# logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
# raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
def _models(parent: Dict, config: ModuleConfig):
models_config = parent.get("models")
for model in models_config:
model_identifier = model.get("model_identifier", None)
name = model.get("name", model_identifier)
api_provider = model.get("api_provider", None)
price_in = model.get("price_in", 0.0)
price_out = model.get("price_out", 0.0)
force_stream_mode = model.get("force_stream_mode", False)
task_type = model.get("task_type", "")
capabilities = model.get("capabilities", [])
# def _task_model_usage(parent: Dict, config: ModuleConfig):
# model_usage_configs = parent.get("task_model_usage")
# config.task_model_arg_map = {}
# for task_name, item in model_usage_configs.items():
# if task_name in config.task_model_arg_map:
# logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
# raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
if name in config.models: # 查重
logger.error(f"重复的模型名称: {name},请检查配置文件。")
raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
# usage = []
# if isinstance(item, Dict):
# if "model" in item:
# usage.append(
# ModelUsageArgConfigItem(
# name=item["model"],
# temperature=item.get("temperature", None),
# max_tokens=item.get("max_tokens", None),
# max_retry=item.get("max_retry", None),
# )
# )
# else:
# logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
# raise ValueError(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# elif isinstance(item, List):
# for model in item:
# if isinstance(model, Dict):
# usage.append(
# ModelUsageArgConfigItem(
# name=model["model"],
# temperature=model.get("temperature", None),
# max_tokens=model.get("max_tokens", None),
# max_retry=model.get("max_retry", None),
# )
# )
# elif isinstance(model, str):
# usage.append(
# ModelUsageArgConfigItem(
# name=model,
# temperature=None,
# max_tokens=None,
# max_retry=None,
# )
# )
# else:
# logger.error(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# raise ValueError(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# elif isinstance(item, str):
# usage.append(
# ModelUsageArgConfigItem(
# name=item,
# temperature=None,
# max_tokens=None,
# max_retry=None,
# )
# )
if model_identifier and api_provider:
# 检查API提供商是否存在
if api_provider not in config.api_providers:
logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
raise ValueError(
f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
)
config.models[name] = ModelInfo(
name=name,
model_identifier=model_identifier,
api_provider=api_provider,
price_in=price_in,
price_out=price_out,
force_stream_mode=force_stream_mode,
task_type=task_type,
capabilities=capabilities,
)
else:
logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
# config.task_model_arg_map[task_name] = ModelUsageArgConfig(
# name=task_name,
# usage=usage,
# )
def _task_model_usage(parent: Dict, config: ModuleConfig):
model_usage_configs = parent.get("task_model_usage")
config.task_model_arg_map = {}
for task_name, item in model_usage_configs.items():
if task_name in config.task_model_arg_map:
logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
# def api_ada_load_config(config_path: str) -> ModuleConfig:
# """从TOML配置文件加载配置"""
# config = ModuleConfig()
usage = []
if isinstance(item, Dict):
if "model" in item:
usage.append(
ModelUsageArgConfigItem(
name=item["model"],
temperature=item.get("temperature", None),
max_tokens=item.get("max_tokens", None),
max_retry=item.get("max_retry", None),
)
)
else:
logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
raise ValueError(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
elif isinstance(item, List):
for model in item:
if isinstance(model, Dict):
usage.append(
ModelUsageArgConfigItem(
name=model["model"],
temperature=model.get("temperature", None),
max_tokens=model.get("max_tokens", None),
max_retry=model.get("max_retry", None),
)
)
elif isinstance(model, str):
usage.append(
ModelUsageArgConfigItem(
name=model,
temperature=None,
max_tokens=None,
max_retry=None,
)
)
else:
logger.error(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
raise ValueError(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
elif isinstance(item, str):
usage.append(
ModelUsageArgConfigItem(
name=item,
temperature=None,
max_tokens=None,
max_retry=None,
)
)
# include_configs: Dict[str, Dict[str, Any]] = {
# "request_conf": {
# "func": _request_conf,
# "support": ">=0.0.0",
# "necessary": False,
# },
# "api_providers": {"func": _api_providers, "support": ">=0.0.0"},
# "models": {"func": _models, "support": ">=0.0.0"},
# "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
# }
config.task_model_arg_map[task_name] = ModelUsageArgConfig(
name=task_name,
usage=usage,
)
# if os.path.exists(config_path):
# with open(config_path, "rb") as f:
# try:
# toml_dict = tomlkit.load(f)
# except tomlkit.TOMLDecodeError as e:
# logger.critical(
# f"配置文件model_list.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}"
# )
# exit(1)
# # 获取配置文件版本
# config.INNER_VERSION = _get_config_version(toml_dict)
def api_ada_load_config(config_path: str) -> ModuleConfig:
"""从TOML配置文件加载配置"""
config = ModuleConfig()
# # 检查版本
# if config.INNER_VERSION > Version(NEWEST_VER):
# logger.warning(
# f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
# )
include_configs: Dict[str, Dict[str, Any]] = {
"request_conf": {
"func": _request_conf,
"support": ">=0.0.0",
"necessary": False,
},
"api_providers": {"func": _api_providers, "support": ">=0.0.0"},
"models": {"func": _models, "support": ">=0.0.0"},
"task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
}
# # 解析配置文件
# # 如果在配置中找到了需要的项,调用对应项的闭包函数处理
# for key in include_configs:
# if key in toml_dict:
# group_specifier_set: SpecifierSet = SpecifierSet(
# include_configs[key]["support"]
# )
if os.path.exists(config_path):
with open(config_path, "rb") as f:
try:
toml_dict = tomlkit.load(f)
except tomlkit.TOMLDecodeError as e:
logger.critical(
f"配置文件model_list.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}"
)
exit(1)
# # 检查配置文件版本是否在支持范围内
# if config.INNER_VERSION in group_specifier_set:
# # 如果版本在支持范围内,检查是否存在通知
# if "notice" in include_configs[key]:
# logger.warning(include_configs[key]["notice"])
# # 调用闭包函数处理配置
# (include_configs[key]["func"])(toml_dict, config)
# else:
# # 如果版本不在支持范围内,崩溃并提示用户
# logger.error(
# f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
# f"当前程序仅支持以下版本范围: {group_specifier_set}"
# )
# raise InvalidVersion(
# f"当前程序仅支持以下版本范围: {group_specifier_set}"
# )
# 获取配置文件版本
config.INNER_VERSION = _get_config_version(toml_dict)
# # 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
# elif (
# "necessary" in include_configs[key]
# and include_configs[key].get("necessary") is False
# ):
# # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
# if key == "keywords_reaction":
# pass
# else:
# # 如果用户根本没有需要的配置项,提示缺少配置
# logger.error(f"配置文件中缺少必需的字段: '{key}'")
# raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
# 检查版本
if config.INNER_VERSION > Version(NEWEST_VER):
logger.warning(
f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
)
# logger.info(f"成功加载配置文件: {config_path}")
# 解析配置文件
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
for key in include_configs:
if key in toml_dict:
group_specifier_set: SpecifierSet = SpecifierSet(
include_configs[key]["support"]
)
# return config
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifier_set:
# 如果版本在支持范围内,检查是否存在通知
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
# 调用闭包函数处理配置
(include_configs[key]["func"])(toml_dict, config)
else:
# 如果版本不在支持范围内,崩溃并提示用户
logger.error(
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
f"当前程序仅支持以下版本范围: {group_specifier_set}"
)
raise InvalidVersion(
f"当前程序仅支持以下版本范围: {group_specifier_set}"
)
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif (
"necessary" in include_configs[key]
and include_configs[key].get("necessary") is False
):
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
logger.info(f"成功加载配置文件: {config_path}")
return config
def get_key_comment(toml_table, key):
# 获取key的注释如果有
@@ -361,7 +349,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项
@@ -370,7 +358,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs
@@ -405,17 +393,13 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
else:
# 只要值发生变化就记录
if new[key] != old[key]:
logs.append(
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
)
changes.append((path + [str(key)], old[key], new[key]))
elif new[key] != old[key]:
logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes
def _get_version_from_toml(toml_path):
def _get_version_from_toml(toml_path) -> Optional[str]:
"""从TOML文件中获取版本号"""
if not os.path.exists(toml_path):
return None
@@ -459,14 +443,13 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
target[key] = value
def _update_config_generic(config_name: str, template_name: str, should_quit_on_new: bool = True):
def _update_config_generic(config_name: str, template_name: str):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
should_quit_on_new: 创建新配置文件后是否退出程序
"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
@@ -484,19 +467,30 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_
template_version = _get_version_from_toml(template_path)
compare_version = _get_version_from_toml(compare_path)
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info(f"{config_name}.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 新创建配置文件,退出
sys.exit(0)
compare_config = None
new_config = None
old_config = None
# 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f:
compare_config = tomlkit.load(f)
else:
compare_config = None
# 读取当前模板
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做)
if compare_config is not None:
if compare_config:
# 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
@@ -515,32 +509,16 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_
)
else:
logger.info(f"未检测到{config_name}模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config
else:
old_config = None
# 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path)
logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
elif _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else:
if _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else:
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info(f"{config_name}.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,根据参数决定是否退出
if should_quit_on_new:
quit()
else:
return
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None:
@@ -578,8 +556,7 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_
# 输出新增和删减项及注释
if old_config:
logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
logs = compare_dicts(new_config, old_config)
if logs:
if logs := compare_dicts(new_config, old_config):
for log in logs:
logger.info(log)
else:
@@ -597,12 +574,12 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template", should_quit_on_new=True)
_update_config_generic("bot_config", "bot_config_template")
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template", should_quit_on_new=False)
_update_config_generic("model_config", "model_config_template")
@dataclass
@@ -627,7 +604,6 @@ class Config(ConfigBase):
response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig
experimental: ExperimentalConfig
model: ModelConfig
maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig
@@ -635,11 +611,48 @@ class Config(ConfigBase):
custom_prompt: CustomPromptConfig
voice: VoiceConfig
@dataclass
class APIAdapterConfig(ConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo]
"""模型列表"""
model_task_config: ModelTaskConfig
"""模型任务配置"""
api_providers: List[APIProvider] = field(default_factory=list)
"""API提供商列表"""
def __post_init__(self):
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
raise ValueError("模型名称不能为空")
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:
raise ValueError("API提供商名称不能为空")
if provider_name not in self.api_providers_dict:
raise KeyError(f"API提供商 '{provider_name}' 不存在")
return self.api_providers_dict[provider_name]
def load_config(config_path: str) -> Config:
"""
加载配置文件
:param config_path: 配置文件路径
:return: Config对象
Args:
config_path: 配置文件路径
Returns:
Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
@@ -653,12 +666,24 @@ def load_config(config_path: str) -> Config:
raise e
def get_config_dir() -> str:
def api_ada_load_config(config_path: str) -> APIAdapterConfig:
"""
获取配置目录
:return: 配置目录路径
加载API适配器配置文件
Args:
config_path: 配置文件路径
Returns:
APIAdapterConfig对象
"""
return CONFIG_DIR
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建APIAdapterConfig对象
try:
return APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.critical("API适配器配置文件解析失败")
raise e
# 获取配置文件路径
@@ -669,4 +694,4 @@ update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
logger.info("非常的新鲜,非常的美味!")
logger.info("非常的新鲜,非常的美味!")

View File

@@ -1,10 +1,9 @@
import re
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from typing import Literal, Optional
from src.config.config_base import ConfigBase
from packaging.version import Version
"""
须知:
@@ -599,50 +598,3 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致"""
@dataclass
class ModelConfig(ConfigBase):
"""模型配置类"""
model_max_output_length: int = 800 # 最大回复长度
utils: dict[str, Any] = field(default_factory=lambda: {})
"""组件模型配置"""
utils_small: dict[str, Any] = field(default_factory=lambda: {})
"""组件小模型配置"""
replyer_1: dict[str, Any] = field(default_factory=lambda: {})
"""normal_chat首要回复模型模型配置"""
replyer_2: dict[str, Any] = field(default_factory=lambda: {})
"""normal_chat次要回复模型配置"""
memory: dict[str, Any] = field(default_factory=lambda: {})
"""记忆模型配置"""
emotion: dict[str, Any] = field(default_factory=lambda: {})
"""情绪模型配置"""
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
voice: dict[str, Any] = field(default_factory=lambda: {})
"""语音识别模型配置"""
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""专注工具使用模型配置"""
planner: dict[str, Any] = field(default_factory=lambda: {})
"""规划模型配置"""
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM实体提取模型配置"""
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM RDF构建模型配置"""
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM问答模型配置"""

View File

@@ -62,8 +62,37 @@ class RespParseException(Exception):
self.message = message
def __str__(self):
return (
self.message
if self.message
else "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
)
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return "请求体过大,请尝试压缩图片或减少输入内容。"
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message

View File

@@ -1,380 +0,0 @@
import asyncio
from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from .base_client import BaseClient, APIResponse
from src.config.api_ada_configs import (
ModelInfo,
ModelUsageArgConfigItem,
RequestConfig,
ModuleConfig,
)
from ..exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption
from ..utils import compress_messages
from src.common.logger import get_logger
logger = get_logger("模型客户端")
def _check_retry(
remain_try: int,
retry_interval: int,
can_retry_msg: str,
cannot_retry_msg: str,
can_retry_callable: Callable | None = None,
**kwargs,
) -> tuple[int, Any | None]:
"""
辅助函数:检查是否可以重试
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param can_retry_msg: 可以重试时的提示信息
:param cannot_retry_msg: 不可以重试时的提示信息
:return: (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if remain_try > 0:
# 还有重试机会
logger.warning(f"{can_retry_msg}")
if can_retry_callable is not None:
return retry_interval, can_retry_callable(**kwargs)
else:
return retry_interval, None
else:
# 达到最大重试次数
logger.warning(f"{cannot_retry_msg}")
return -1, None # 不再重试请求该模型
def _handle_resp_not_ok(
e: RespNotOkException,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
):
"""
处理响应错误异常
:param e: 异常对象
:param task_name: 任务名称
:param model_name: 模型名称
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param messages: (消息列表, 是否已压缩过)
:return: (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
# 响应错误
if e.status_code in [401, 403]:
# API Key认证错误 - 让多API Key机制处理给一次重试机会
if remain_try > 0:
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"API Key认证失败错误代码-{e.status_code}多API Key机制会自动切换"
)
return 0, None # 立即重试让底层客户端切换API Key
else:
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"所有API Key都认证失败错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code in [400, 402, 404]:
# 其他客户端错误(不应该重试)
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"请求失败,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code == 413:
if messages and not messages[1]:
# 消息列表不为空且未压缩,尝试压缩消息
return _check_retry(
remain_try,
0,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,尝试压缩消息后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,压缩消息后仍然过大,放弃请求"
),
can_retry_callable=compress_messages,
messages=messages[0],
)
# 没有消息可压缩
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,无法压缩消息,放弃请求。"
)
return -1, None
elif e.status_code == 429:
# 请求过于频繁 - 让多API Key机制处理适当延迟后重试
return _check_retry(
remain_try,
min(retry_interval, 5), # 限制最大延迟为5秒让API Key切换更快生效
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"请求过于频繁多API Key机制会自动切换{min(retry_interval, 5)}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求过于频繁所有API Key都被限制放弃请求"
),
)
elif e.status_code >= 500:
# 服务器错误
return _check_retry(
remain_try,
retry_interval,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"服务器错误,将于{retry_interval}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"服务器错误,超过最大重试次数,请稍后再试"
),
)
else:
# 未知错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"未知错误,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None
def default_exception_handler(
e: Exception,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
) -> tuple[int, list[Message] | None]:
"""
默认异常处理函数
:param e: 异常对象
:param task_name: 任务名称
:param model_name: 模型名称
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param messages: (消息列表, 是否已压缩过)
:return (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if isinstance(e, NetworkConnectionError): # 网络连接错误
# 网络错误可能是某个API Key的端点问题给多API Key机制一次快速重试机会
return _check_retry(
remain_try,
min(retry_interval, 3), # 网络错误时减少等待时间让API Key切换更快
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"连接异常多API Key机制会尝试其他Key{min(retry_interval, 3)}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"连接异常超过最大重试次数请检查网络连接状态或URL是否正确"
),
)
elif isinstance(e, ReqAbortException):
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n请求被中断,详细信息-{str(e.message)}"
)
return -1, None # 不再重试请求该模型
elif isinstance(e, RespNotOkException):
return _handle_resp_not_ok(
e,
task_name,
model_name,
remain_try,
retry_interval,
messages,
)
elif isinstance(e, RespParseException):
# 响应解析错误
logger.error(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"响应解析错误,错误信息-{e.message}\n"
)
logger.debug(f"附加内容:\n{str(e.ext_info)}")
return -1, None # 不再重试请求该模型
else:
logger.error(
f"任务-'{task_name}' 模型-'{model_name}'\n未知异常,错误信息-{str(e)}"
)
return -1, None # 不再重试请求该模型
class ModelRequestHandler:
"""
模型请求处理器
"""
def __init__(
self,
task_name: str,
config: ModuleConfig,
api_client_map: dict[str, BaseClient],
):
self.task_name: str = task_name
"""任务名称"""
self.client_map: dict[str, BaseClient] = {}
"""API客户端列表"""
self.configs: list[tuple[ModelInfo, ModelUsageArgConfigItem]] = []
"""模型参数配置"""
self.req_conf: RequestConfig = config.req_conf
"""请求配置"""
# 获取模型与使用配置
for model_usage in config.task_model_arg_map[task_name].usage:
if model_usage.name not in config.models:
logger.error(f"Model '{model_usage.name}' not found in ModelManager")
raise KeyError(f"Model '{model_usage.name}' not found in ModelManager")
model_info = config.models[model_usage.name]
if model_info.api_provider not in self.client_map:
# 缓存API客户端
self.client_map[model_info.api_provider] = api_client_map[
model_info.api_provider
]
self.configs.append((model_info, model_usage)) # 添加模型与使用配置
async def get_response(
self,
messages: list[Message],
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None, # 暂不启用
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[ChatCompletion], APIResponse] | None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param messages: 消息列表
:param tool_options: 工具选项列表
:param response_format: 响应格式
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: APIResponse
"""
# 遍历可用模型,若获取响应失败,则使用下一个模型继续请求
for config_item in self.configs:
client = self.client_map[config_item[0].api_provider]
model_info: ModelInfo = config_item[0]
model_usage_config: ModelUsageArgConfigItem = config_item[1]
remain_try = (
model_usage_config.max_retry or self.req_conf.max_retry
) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1
compressed_messages = None
retry_interval = self.req_conf.retry_interval
while remain_try > 0:
try:
return await client.get_response(
model_info,
message_list=(compressed_messages or messages),
tool_options=tool_options,
max_tokens=model_usage_config.max_tokens
or self.req_conf.default_max_tokens,
temperature=model_usage_config.temperature
or self.req_conf.default_temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
interrupt_flag=interrupt_flag,
)
except Exception as e:
logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
handle_res = default_exception_handler(
e,
self.task_name,
model_info.name,
remain_try,
retry_interval=self.req_conf.retry_interval,
messages=(messages, compressed_messages is not None),
)
if handle_res[0] == -1:
# 等待间隔为-1表示不再请求该模型
remain_try = 0
elif handle_res[0] != 0:
# 等待间隔不为0表示需要等待
await asyncio.sleep(handle_res[0])
retry_interval *= 2
if handle_res[1] is not None:
# 压缩消息
compressed_messages = handle_res[1]
logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用")
raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败
async def get_embedding(
self,
embedding_input: str,
) -> APIResponse:
"""
获取嵌入向量
:param embedding_input: 嵌入输入
:return: APIResponse
"""
for config in self.configs:
client = self.client_map[config[0].api_provider]
model_info: ModelInfo = config[0]
model_usage_config: ModelUsageArgConfigItem = config[1]
remain_try = (
model_usage_config.max_retry or self.req_conf.max_retry
) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1
while remain_try:
try:
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
)
except Exception as e:
logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
handle_res = default_exception_handler(
e,
self.task_name,
model_info.name,
remain_try,
retry_interval=self.req_conf.retry_interval,
)
if handle_res[0] == -1:
# 等待间隔为-1表示不再请求该模型
remain_try = 0
elif handle_res[0] != 0:
# 等待间隔不为0表示需要等待
await asyncio.sleep(handle_res[0])
logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用")
raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败

View File

@@ -0,0 +1,380 @@
import asyncio
from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from .base_client import BaseClient, APIResponse
from src.config.api_ada_configs import (
ModelInfo,
ModelUsageArgConfigItem,
RequestConfig,
ModuleConfig,
)
from ..exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption
from ..utils import compress_messages
from src.common.logger import get_logger
logger = get_logger("模型客户端")
def _check_retry(
remain_try: int,
retry_interval: int,
can_retry_msg: str,
cannot_retry_msg: str,
can_retry_callable: Callable | None = None,
**kwargs,
) -> tuple[int, Any | None]:
"""
辅助函数:检查是否可以重试
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param can_retry_msg: 可以重试时的提示信息
:param cannot_retry_msg: 不可以重试时的提示信息
:return: (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if remain_try > 0:
# 还有重试机会
logger.warning(f"{can_retry_msg}")
if can_retry_callable is not None:
return retry_interval, can_retry_callable(**kwargs)
else:
return retry_interval, None
else:
# 达到最大重试次数
logger.warning(f"{cannot_retry_msg}")
return -1, None # 不再重试请求该模型
def _handle_resp_not_ok(
e: RespNotOkException,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
):
"""
处理响应错误异常
:param e: 异常对象
:param task_name: 任务名称
:param model_name: 模型名称
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param messages: (消息列表, 是否已压缩过)
:return: (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
# 响应错误
if e.status_code in [401, 403]:
# API Key认证错误 - 让多API Key机制处理给一次重试机会
if remain_try > 0:
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"API Key认证失败错误代码-{e.status_code}多API Key机制会自动切换"
)
return 0, None # 立即重试让底层客户端切换API Key
else:
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"所有API Key都认证失败错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code in [400, 402, 404]:
# 其他客户端错误(不应该重试)
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"请求失败,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code == 413:
if messages and not messages[1]:
# 消息列表不为空且未压缩,尝试压缩消息
return _check_retry(
remain_try,
0,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,尝试压缩消息后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,压缩消息后仍然过大,放弃请求"
),
can_retry_callable=compress_messages,
messages=messages[0],
)
# 没有消息可压缩
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,无法压缩消息,放弃请求。"
)
return -1, None
elif e.status_code == 429:
# 请求过于频繁 - 让多API Key机制处理适当延迟后重试
return _check_retry(
remain_try,
min(retry_interval, 5), # 限制最大延迟为5秒让API Key切换更快生效
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"请求过于频繁多API Key机制会自动切换{min(retry_interval, 5)}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求过于频繁所有API Key都被限制放弃请求"
),
)
elif e.status_code >= 500:
# 服务器错误
return _check_retry(
remain_try,
retry_interval,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"服务器错误,将于{retry_interval}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"服务器错误,超过最大重试次数,请稍后再试"
),
)
else:
# 未知错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"未知错误,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None
def default_exception_handler(
e: Exception,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
) -> tuple[int, list[Message] | None]:
"""
默认异常处理函数
:param e: 异常对象
:param task_name: 任务名称
:param model_name: 模型名称
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param messages: (消息列表, 是否已压缩过)
:return (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if isinstance(e, NetworkConnectionError): # 网络连接错误
# 网络错误可能是某个API Key的端点问题给多API Key机制一次快速重试机会
return _check_retry(
remain_try,
min(retry_interval, 3), # 网络错误时减少等待时间让API Key切换更快
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"连接异常多API Key机制会尝试其他Key{min(retry_interval, 3)}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"连接异常超过最大重试次数请检查网络连接状态或URL是否正确"
),
)
elif isinstance(e, ReqAbortException):
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n请求被中断,详细信息-{str(e.message)}"
)
return -1, None # 不再重试请求该模型
elif isinstance(e, RespNotOkException):
return _handle_resp_not_ok(
e,
task_name,
model_name,
remain_try,
retry_interval,
messages,
)
elif isinstance(e, RespParseException):
# 响应解析错误
logger.error(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"响应解析错误,错误信息-{e.message}\n"
)
logger.debug(f"附加内容:\n{str(e.ext_info)}")
return -1, None # 不再重试请求该模型
else:
logger.error(
f"任务-'{task_name}' 模型-'{model_name}'\n未知异常,错误信息-{str(e)}"
)
return -1, None # 不再重试请求该模型
class ModelRequestHandler:
"""
模型请求处理器
"""
def __init__(
self,
task_name: str,
config: ModuleConfig,
api_client_map: dict[str, BaseClient],
):
self.task_name: str = task_name
"""任务名称"""
self.client_map: dict[str, BaseClient] = {}
"""API客户端列表"""
self.configs: list[tuple[ModelInfo, ModelUsageArgConfigItem]] = []
"""模型参数配置"""
self.req_conf: RequestConfig = config.req_conf
"""请求配置"""
# 获取模型与使用配置
for model_usage in config.task_model_arg_map[task_name].usage:
if model_usage.name not in config.models:
logger.error(f"Model '{model_usage.name}' not found in ModelManager")
raise KeyError(f"Model '{model_usage.name}' not found in ModelManager")
model_info = config.models[model_usage.name]
if model_info.api_provider not in self.client_map:
# 缓存API客户端
self.client_map[model_info.api_provider] = api_client_map[
model_info.api_provider
]
self.configs.append((model_info, model_usage)) # 添加模型与使用配置
async def get_response(
self,
messages: list[Message],
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None, # 暂不启用
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[ChatCompletion], APIResponse] | None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param messages: 消息列表
:param tool_options: 工具选项列表
:param response_format: 响应格式
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: APIResponse
"""
# 遍历可用模型,若获取响应失败,则使用下一个模型继续请求
for config_item in self.configs:
client = self.client_map[config_item[0].api_provider]
model_info: ModelInfo = config_item[0]
model_usage_config: ModelUsageArgConfigItem = config_item[1]
remain_try = (
model_usage_config.max_retry or self.req_conf.max_retry
) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1
compressed_messages = None
retry_interval = self.req_conf.retry_interval
while remain_try > 0:
try:
return await client.get_response(
model_info,
message_list=(compressed_messages or messages),
tool_options=tool_options,
max_tokens=model_usage_config.max_tokens
or self.req_conf.default_max_tokens,
temperature=model_usage_config.temperature
or self.req_conf.default_temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
interrupt_flag=interrupt_flag,
)
except Exception as e:
logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
handle_res = default_exception_handler(
e,
self.task_name,
model_info.name,
remain_try,
retry_interval=self.req_conf.retry_interval,
messages=(messages, compressed_messages is not None),
)
if handle_res[0] == -1:
# 等待间隔为-1表示不再请求该模型
remain_try = 0
elif handle_res[0] != 0:
# 等待间隔不为0表示需要等待
await asyncio.sleep(handle_res[0])
retry_interval *= 2
if handle_res[1] is not None:
# 压缩消息
compressed_messages = handle_res[1]
logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用")
raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败
async def get_embedding(
self,
embedding_input: str,
) -> APIResponse:
"""
获取嵌入向量
:param embedding_input: 嵌入输入
:return: APIResponse
"""
for config in self.configs:
client = self.client_map[config[0].api_provider]
model_info: ModelInfo = config[0]
model_usage_config: ModelUsageArgConfigItem = config[1]
remain_try = (
model_usage_config.max_retry or self.req_conf.max_retry
) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1
while remain_try:
try:
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
)
except Exception as e:
logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
handle_res = default_exception_handler(
e,
self.task_name,
model_info.name,
remain_try,
retry_interval=self.req_conf.retry_interval,
)
if handle_res[0] == -1:
# 等待间隔为-1表示不再请求该模型
remain_try = 0
elif handle_res[0] != 0:
# 等待间隔不为0表示需要等待
await asyncio.sleep(handle_res[0])
logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用")
raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败

View File

@@ -81,10 +81,7 @@ class BaseClient:
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
@@ -114,3 +111,37 @@ class BaseClient:
:return: 嵌入响应
"""
raise RuntimeError("This method should be overridden in subclasses")
class ClientRegistry:
def __init__(self) -> None:
self.client_registry: dict[str, type[BaseClient]] = {}
def register_client_class(self, client_type: str):
"""
注册API客户端类
:param client_class: API客户端类
"""
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
return cls
return decorator
def get_client_class(self, client_type: str) -> type[BaseClient]:
"""
获取注册的API客户端类
Args:
client_type: 客户端类型
Returns:
type[BaseClient]: 注册的API客户端类
"""
if client_type not in self.client_registry:
raise KeyError(f"'{client_type}' 类型的 Client 未注册")
return self.client_registry[client_type]
client_registry = ClientRegistry()

View File

@@ -22,7 +22,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDelta
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from .base_client import BaseClient, client_registry
from src.common.logger import get_logger
from ..exceptions import (
@@ -63,9 +63,7 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/{item[0].lower()};base64,{item[1]}"
},
"image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
}
)
elif isinstance(item, str):
@@ -120,13 +118,8 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {
param.name: _convert_tool_param(param)
for param in tool_option.params
},
"required": [
param.name for param in tool_option.params if param.required
],
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
"required": [param.name for param in tool_option.params if param.required],
}
return ret
@@ -190,9 +183,7 @@ def _process_delta(
if tool_call_delta.function.arguments:
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
tool_calls_buffer[tool_call_delta.index][2].write(
tool_call_delta.function.arguments
)
tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments)
return in_rc_flag
@@ -225,14 +216,12 @@ def _build_stream_api_resp(
if not isinstance(arguments, dict):
raise RespParseException(
None,
"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n"
f"{raw_arg_data}",
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}",
)
except json.JSONDecodeError as e:
raise RespParseException(
None,
"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:"
f"{raw_arg_data}",
f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}",
) from e
else:
arguments_buffer.close()
@@ -257,9 +246,7 @@ async def _default_stream_response_handler(
_in_rc_flag = False # 标记是否在推理内容块中
_rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[
tuple[str, str, io.StringIO]
] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
@@ -280,7 +267,7 @@ async def _default_stream_response_handler(
delta = event.choices[0].delta # 获取当前块的delta内容
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 标记:有独立的推理内容块
_has_rc_attr_flag = True
@@ -334,10 +321,10 @@ def _default_normal_response_parser(
raise RespParseException(resp, "响应解析失败缺失choices字段")
message_part = resp.choices[0].message
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content:
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore
# 有有效的推理字段
api_response.content = message_part.content
api_response.reasoning_content = message_part.reasoning_content
api_response.reasoning_content = message_part.reasoning_content # type: ignore
elif message_part.content:
# 提取推理和内容
match = pattern.match(message_part.content)
@@ -358,16 +345,10 @@ def _default_normal_response_parser(
try:
arguments = json.loads(call.function.arguments)
if not isinstance(arguments, dict):
raise RespParseException(
resp, "响应解析失败,工具调用参数无法解析为字典类型"
)
api_response.tool_calls.append(
ToolCall(call.id, call.function.name, arguments)
)
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
except json.JSONDecodeError as e:
raise RespParseException(
resp, "响应解析失败,无法解析工具调用参数"
) from e
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
# 提取Usage信息
if resp.usage:
@@ -385,63 +366,15 @@ def _default_normal_response_parser(
return api_response, _usage_record
@client_registry.register_client_class("openai")
class OpenaiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
# 不再在初始化时创建固定的client而是在请求时动态创建
self._clients_cache = {} # API Key -> AsyncOpenAI client 的缓存
def _get_client(self, api_key: str = None) -> AsyncOpenAI:
"""获取或创建对应API Key的客户端"""
if api_key is None:
api_key = self.api_provider.get_current_api_key()
if not api_key:
raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key")
# 使用缓存避免重复创建客户端
if api_key not in self._clients_cache:
self._clients_cache[api_key] = AsyncOpenAI(
base_url=self.api_provider.base_url,
api_key=api_key,
max_retries=0,
)
return self._clients_cache[api_key]
async def _execute_with_fallback(self, func, *args, **kwargs):
"""执行请求并在失败时切换API Key"""
current_api_key = self.api_provider.get_current_api_key()
max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1
for attempt in range(max_attempts):
try:
client = self._get_client(current_api_key)
result = await func(client, *args, **kwargs)
# 成功时重置失败计数
self.api_provider.reset_key_failures(current_api_key)
return result
except (APIStatusError, APIConnectionError) as e:
# 记录失败并尝试下一个API Key
logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}")
if attempt < max_attempts - 1: # 还有重试机会
next_api_key = self.api_provider.mark_key_failed(current_api_key)
if next_api_key and next_api_key != current_api_key:
current_api_key = next_api_key
logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}")
continue
# 所有API Key都失败了重新抛出异常
if isinstance(e, APIStatusError):
raise RespNotOkException(e.status_code, e.message) from e
elif isinstance(e, APIConnectionError):
raise NetworkConnectionError(str(e)) from e
except Exception as e:
# 其他异常直接抛出
raise e
self.client: AsyncOpenAI = AsyncOpenAI(
base_url=api_provider.base_url,
api_key=api_provider.api_key,
max_retries=0,
)
async def get_response(
self,
@@ -456,10 +389,7 @@ class OpenaiClient(BaseClient):
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
@@ -475,40 +405,6 @@ class OpenaiClient(BaseClient):
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
return await self._execute_with_fallback(
self._get_response_internal,
model_info,
message_list,
tool_options,
max_tokens,
temperature,
response_format,
stream_response_handler,
async_response_parser,
interrupt_flag,
)
async def _get_response_internal(
self,
client: AsyncOpenAI,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""内部方法执行实际的API调用"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
@@ -518,23 +414,19 @@ class OpenaiClient(BaseClient):
# 将messages构造为OpenAI API所需的格式
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
# 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = (
_convert_tool_options(tool_options) if tool_options else NOT_GIVEN
)
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
client.chat.completions.create(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
response_format=response_format.to_dict()
if response_format
else NOT_GIVEN,
response_format=response_format.to_dict() if response_format else NOT_GIVEN,
)
)
while not req_task.done():
@@ -544,22 +436,18 @@ class OpenaiClient(BaseClient):
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(
req_task.result(), interrupt_flag
)
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else:
# 发送请求并获取响应
req_task = asyncio.create_task(
client.chat.completions.create(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format=response_format.to_dict()
if response_format
else NOT_GIVEN,
response_format=response_format.to_dict() if response_format else NOT_GIVEN,
)
)
while not req_task.done():
@@ -599,21 +487,8 @@ class OpenaiClient(BaseClient):
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
return await self._execute_with_fallback(
self._get_embedding_internal,
model_info,
embedding_input,
)
async def _get_embedding_internal(
self,
client: AsyncOpenAI,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""内部方法执行实际的嵌入API调用"""
try:
raw_response = await client.embeddings.create(
raw_response = await self.client.embeddings.create(
model=model_info.model_identifier,
input=embedding_input,
)

View File

@@ -2,7 +2,6 @@ import importlib
from typing import Dict
from src.config.config import model_config
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
from src.common.logger import get_logger
from .model_client import ModelRequestHandler, BaseClient
@@ -10,83 +9,4 @@ from .model_client import ModelRequestHandler, BaseClient
logger = get_logger("模型管理器")
class ModelManager:
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
def __init__(
self,
config: ModuleConfig,
):
self.config: ModuleConfig = config
"""配置信息"""
self.api_client_map: Dict[str, BaseClient] = {}
"""API客户端映射表"""
self._request_handler_cache: Dict[str, ModelRequestHandler] = {}
"""ModelRequestHandler缓存避免重复创建"""
for provider_name, api_provider in self.config.api_providers.items():
# 初始化API客户端
try:
# 根据配置动态加载实现
client_module = importlib.import_module(
f".model_client.{api_provider.client_type}_client", __package__
)
client_class = getattr(
client_module, f"{api_provider.client_type.capitalize()}Client"
)
if not issubclass(client_class, BaseClient):
raise TypeError(
f"'{client_class.__name__}' is not a subclass of 'BaseClient'"
)
self.api_client_map[api_provider.name] = client_class(
api_provider
) # 实例化放入api_client_map
except ImportError as e:
logger.error(f"Failed to import client module: {e}")
raise ImportError(
f"Failed to import client module for '{provider_name}': {e}"
) from e
def __getitem__(self, task_name: str) -> ModelRequestHandler:
"""
获取任务所需的模型客户端(封装)
使用缓存机制避免重复创建ModelRequestHandler
:param task_name: 任务名称
:return: 模型客户端
"""
if task_name not in self.config.task_model_arg_map:
raise KeyError(f"'{task_name}' not registered in ModelManager")
# 检查缓存中是否已存在
if task_name in self._request_handler_cache:
logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}")
return self._request_handler_cache[task_name]
# 创建新的ModelRequestHandler并缓存
logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}")
handler = ModelRequestHandler(
task_name=task_name,
config=self.config,
api_client_map=self.api_client_map,
)
self._request_handler_cache[task_name] = handler
return handler
def __setitem__(self, task_name: str, value: ModelUsageArgConfig):
"""
注册任务的模型使用配置
:param task_name: 任务名称
:param value: 模型使用配置
"""
self.config.task_model_arg_map[task_name] = value
def __contains__(self, task_name: str):
"""
判断任务是否已注册
:param task_name: 任务名称
:return: 是否在模型列表中
"""
return task_name in self.config.task_model_arg_map

View File

@@ -0,0 +1,92 @@
import importlib
from typing import Dict
from src.config.config import model_config
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
from src.common.logger import get_logger
from .model_client import ModelRequestHandler, BaseClient
logger = get_logger("模型管理器")
class ModelManager:
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
def __init__(
self,
config: ModuleConfig,
):
self.config: ModuleConfig = config
"""配置信息"""
self.api_client_map: Dict[str, BaseClient] = {}
"""API客户端映射表"""
self._request_handler_cache: Dict[str, ModelRequestHandler] = {}
"""ModelRequestHandler缓存避免重复创建"""
for provider_name, api_provider in self.config.api_providers.items():
# 初始化API客户端
try:
# 根据配置动态加载实现
client_module = importlib.import_module(
f".model_client.{api_provider.client_type}_client", __package__
)
client_class = getattr(
client_module, f"{api_provider.client_type.capitalize()}Client"
)
if not issubclass(client_class, BaseClient):
raise TypeError(
f"'{client_class.__name__}' is not a subclass of 'BaseClient'"
)
self.api_client_map[api_provider.name] = client_class(
api_provider
) # 实例化放入api_client_map
except ImportError as e:
logger.error(f"Failed to import client module: {e}")
raise ImportError(
f"Failed to import client module for '{provider_name}': {e}"
) from e
def __getitem__(self, task_name: str) -> ModelRequestHandler:
"""
获取任务所需的模型客户端(封装)
使用缓存机制避免重复创建ModelRequestHandler
:param task_name: 任务名称
:return: 模型客户端
"""
if task_name not in self.config.task_model_arg_map:
raise KeyError(f"'{task_name}' not registered in ModelManager")
# 检查缓存中是否已存在
if task_name in self._request_handler_cache:
logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}")
return self._request_handler_cache[task_name]
# 创建新的ModelRequestHandler并缓存
logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}")
handler = ModelRequestHandler(
task_name=task_name,
config=self.config,
api_client_map=self.api_client_map,
)
self._request_handler_cache[task_name] = handler
return handler
def __setitem__(self, task_name: str, value: ModelUsageArgConfig):
"""
注册任务的模型使用配置
:param task_name: 任务名称
:param value: 模型使用配置
"""
self.config.task_model_arg_map[task_name] = value
def __contains__(self, task_name: str):
"""
判断任务是否已注册
:param task_name: 任务名称
:return: 是否在模型列表中
"""
return task_name in self.config.task_model_arg_map

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,778 @@
import re
from datetime import datetime
from typing import Tuple, Union
from src.common.logger import get_logger
import base64
from PIL import Image
import io
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from src.config.config import global_config
from rich.traceback import install
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException, PayLoadTooLargeError, RequestAbortException, PermissionDeniedException
install(extra_lines=3)
logger = get_logger("model_utils")
# 导入具体的异常类型用于精确的异常处理
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
SPECIFIC_EXCEPTIONS_AVAILABLE = True
# 新架构导入 - 使用延迟导入以支持fallback模式
from .model_manager_bak import ModelManager
from .model_client import ModelRequestHandler
from .payload_content.message import MessageBuilder
# 不在模块级别初始化ModelManager延迟到实际使用时
ModelManager_class = ModelManager
model_manager = None # 延迟初始化
# 添加请求处理器缓存,避免重复创建
_request_handler_cache = {} # 格式: {(model_name, task_name): ModelRequestHandler}
NEW_ARCHITECTURE_AVAILABLE = True
logger.info("新架构模块导入成功")
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class LLMRequest:
"""
重构后的LLM请求类基于新的model_manager和model_client架构
保持向后兼容的API接口
"""
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o1",
"o1-2024-12-17",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"o1-pro",
"o1-pro-2025-03-19",
"o3",
"o3-2025-04-16",
"o3-mini",
"o3-mini-2025-01-31",
"o4-mini",
"o4-mini-2025-04-16",
]
def __init__(self, model: dict, **kwargs):
"""
初始化LLM请求实例
Args:
model: 模型配置字典,兼容旧格式和新格式
**kwargs: 额外参数
"""
logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('model_name', model.get('name', 'Unknown'))}")
logger.debug(f"🔍 [模型初始化] 输入的模型配置: {model}")
logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}")
# 兼容新旧模型配置格式
# 新格式使用 model_name旧格式使用 name
self.model_name: str = model.get("model_name", model.get("name", ""))
# 如果传入的配置不完整,自动从全局配置中获取完整配置
if not all(key in model for key in ["task_type", "capabilities"]):
logger.debug("🔍 [模型初始化] 检测到不完整的模型配置,尝试获取完整配置")
if (full_model_config := self._get_full_model_config(self.model_name)):
logger.debug("🔍 [模型初始化] 成功获取完整模型配置,合并配置信息")
# 合并配置:运行时参数优先,但添加缺失的配置字段
model = {**full_model_config, **model}
logger.debug(f"🔍 [模型初始化] 合并后的模型配置: {model}")
else:
logger.warning(f"⚠️ [模型初始化] 无法获取模型 {self.model_name} 的完整配置,使用原始配置")
# 在新架构中provider信息从model_config.toml自动获取不需要在这里设置
self.provider = model.get("provider", "") # 保留兼容性,但在新架构中不使用
# 从全局配置中获取任务配置
self.request_type = kwargs.pop("request_type", "default")
# 确定使用哪个任务配置
task_name = self._determine_task_name(model)
# 初始化 request_handler
self.request_handler = None
# 尝试初始化新架构
if NEW_ARCHITECTURE_AVAILABLE and ModelManager_class is not None:
try:
# 延迟初始化ModelManager
global model_manager, _request_handler_cache
if model_manager is None:
from src.config.config import model_config
model_manager = ModelManager_class(model_config)
logger.debug("🔍 [模型初始化] ModelManager延迟初始化成功")
# 构建缓存键
cache_key = (self.model_name, task_name)
# 检查是否已有缓存的请求处理器
if cache_key in _request_handler_cache:
self.request_handler = _request_handler_cache[cache_key]
logger.debug(f"🚀 [性能优化] 从LLMRequest缓存获取请求处理器: {cache_key}")
else:
# 使用新架构获取模型请求处理器
self.request_handler = model_manager[task_name]
_request_handler_cache[cache_key] = self.request_handler
logger.debug(f"🔧 [性能优化] 创建并缓存LLMRequest请求处理器: {cache_key}")
logger.debug(f"🔍 [模型初始化] 成功获取模型请求处理器,任务: {task_name}")
self.use_new_architecture = True
except Exception as e:
logger.warning(f"无法使用新架构,任务 {task_name} 初始化失败: {e}")
logger.warning("回退到兼容模式,某些功能可能受限")
self.request_handler = None
self.use_new_architecture = False
else:
logger.warning("新架构不可用,使用兼容模式")
logger.warning("回退到兼容模式,某些功能可能受限")
self.request_handler = None
self.use_new_architecture = False
# 保存原始参数用于向后兼容
self.params = kwargs
# 兼容性属性,从模型配置中提取
# 新格式和旧格式都支持
self.enable_thinking = model.get("enable_thinking", False)
self.temp = model.get("temperature", model.get("temp", 0.7)) # 新格式用temperature旧格式用temp
self.thinking_budget = model.get("thinking_budget", 4096)
self.stream = model.get("stream", False)
self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0)
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
# 记录配置文件中声明了哪些参数(不管值是什么)
self.has_enable_thinking = "enable_thinking" in model
self.has_thinking_budget = "thinking_budget" in model
self.pri_out = model.get("pri_out", 0)
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
# 记录配置文件中声明了哪些参数(不管值是什么)
self.has_enable_thinking = "enable_thinking" in model
self.has_thinking_budget = "thinking_budget" in model
logger.debug("🔍 [模型初始化] 模型参数设置完成:")
logger.debug(f" - model_name: {self.model_name}")
logger.debug(f" - provider: {self.provider}")
logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}")
logger.debug(f" - enable_thinking: {self.enable_thinking}")
logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}")
logger.debug(f" - thinking_budget: {self.thinking_budget}")
logger.debug(f" - temp: {self.temp}")
logger.debug(f" - stream: {self.stream}")
logger.debug(f" - max_tokens: {self.max_tokens}")
logger.debug(f" - use_new_architecture: {self.use_new_architecture}")
# 获取数据库实例
self._init_database()
logger.debug(f"🔍 [模型初始化] 初始化完成request_type: {self.request_type}")
def _determine_task_name(self, model: dict) -> str:
"""
根据模型配置确定任务名称
优先使用配置文件中明确定义的任务类型,避免基于模型名称的脆弱推断
Args:
model: 模型配置字典
Returns:
任务名称
"""
# 调试信息:打印模型配置字典的所有键
logger.debug(f"🔍 [任务确定] 模型配置字典的所有键: {list(model.keys())}")
logger.debug(f"🔍 [任务确定] 模型配置字典内容: {model}")
# 获取模型名称
model_name = model.get("model_name", model.get("name", ""))
# 方法1: 优先使用配置文件中明确定义的 task_type 字段
if "task_type" in model:
task_type = model["task_type"]
logger.debug(f"🎯 [任务确定] 使用配置中的 task_type: {task_type}")
return task_type
# 方法2: 使用 capabilities 字段来推断主要任务类型
if "capabilities" in model:
capabilities = model["capabilities"]
if isinstance(capabilities, list):
# 按优先级顺序检查能力
if "vision" in capabilities:
logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: vision")
return "vision"
elif "embedding" in capabilities:
logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: embedding")
return "embedding"
elif "speech" in capabilities:
logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: speech")
return "speech"
elif "text" in capabilities:
# 如果只有文本能力则根据request_type细分
task = "llm_reasoning" if self.request_type == "reasoning" else "llm_normal"
logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 和 request_type {self.request_type} 推断为: {task}")
return task
# 方法3: 向后兼容 - 基于模型名称的关键字推断(不推荐但保留兼容性)
logger.warning(f"⚠️ [任务确定] 配置中未找到 task_type 或 capabilities回退到基于模型名称的推断: {model_name}")
logger.warning("⚠️ [建议] 请在 model_config.toml 中为模型添加明确的 task_type 或 capabilities 字段")
# 保留原有的关键字匹配逻辑作为fallback
if any(keyword in model_name.lower() for keyword in ["vlm", "vision", "gpt-4o", "claude", "vl-"]):
logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: vision")
return "vision"
elif any(keyword in model_name.lower() for keyword in ["embed", "text-embedding", "bge-"]):
logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: embedding")
return "embedding"
elif any(keyword in model_name.lower() for keyword in ["whisper", "speech", "voice"]):
logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: speech")
return "speech"
else:
# 根据request_type确定映射到配置文件中定义的任务
task = "llm_reasoning" if self.request_type == "reasoning" else "llm_normal"
logger.debug(f"🎯 [任务确定] 从 request_type {self.request_type} 推断为: {task}")
return task
def _get_full_model_config(self, model_name: str) -> dict | None:
"""
根据模型名称从全局配置中获取完整的模型配置
现在直接使用已解析的ModelInfo对象不再读取TOML文件
Args:
model_name: 模型名称
Returns:
完整的模型配置字典如果找不到则返回None
"""
try:
from src.config.config import model_config
return self._get_model_config_from_parsed(model_name, model_config)
except Exception as e:
logger.warning(f"⚠️ [配置查找] 获取模型配置时出错: {str(e)}")
return None
def _get_model_config_from_parsed(self, model_name: str, model_config) -> dict | None:
"""
从已解析的配置对象中获取模型配置
使用扩展后的ModelInfo类包含task_type和capabilities字段
"""
try:
# 直接通过模型名称查找
if model_name in model_config.models:
model_info = model_config.models[model_name]
logger.debug(f"🔍 [配置查找] 找到模型 {model_name} 的配置对象: {model_info}")
# 将ModelInfo对象转换为字典
model_dict = {
"model_identifier": model_info.model_identifier,
"name": model_info.name,
"api_provider": model_info.api_provider,
"price_in": model_info.price_in,
"price_out": model_info.price_out,
"force_stream_mode": model_info.force_stream_mode,
"task_type": model_info.task_type,
"capabilities": model_info.capabilities,
}
logger.debug(f"🔍 [配置查找] 转换后的模型配置字典: {model_dict}")
return model_dict
# 如果直接查找失败尝试通过model_identifier查找
for name, model_info in model_config.models.items():
if (model_info.model_identifier == model_name or
hasattr(model_info, 'model_name') and model_info.model_name == model_name):
logger.debug(f"🔍 [配置查找] 通过标识符找到模型 {model_name} (配置名称: {name})")
# 同样转换为字典
model_dict = {
"model_identifier": model_info.model_identifier,
"name": model_info.name,
"api_provider": model_info.api_provider,
"price_in": model_info.price_in,
"price_out": model_info.price_out,
"force_stream_mode": model_info.force_stream_mode,
"task_type": model_info.task_type,
"capabilities": model_info.capabilities,
}
return model_dict
return None
except Exception as e:
logger.warning(f"⚠️ [配置查找] 从已解析配置获取模型配置时出错: {str(e)}")
return None
@staticmethod
def _init_database():
"""初始化数据库集合"""
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def _record_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
user_id: str = "system",
request_type: str | None = None,
endpoint: str = "/chat/completions",
):
"""记录模型使用情况到数据库
Args:
prompt_tokens: 输入token数
completion_tokens: 输出token数
total_tokens: 总token数
user_id: 用户ID默认为system
request_type: 请求类型
endpoint: API端点
"""
# 如果 request_type 为 None则使用实例变量中的值
if request_type is None:
request_type = self.request_type
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=self.model_name,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(prompt_tokens, completion_tokens),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.debug(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
使用模型的pri_in和pri_out价格计算输入和输出的成本
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
Returns:
float: 总成本(元)
"""
# 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取"""
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
reasoning = match[1].strip() if match else ""
return content, reasoning
def _handle_model_exception(self, e: Exception, operation: str) -> None:
"""
统一的模型异常处理方法
根据异常类型提供更精确的错误信息和处理策略
Args:
e: 捕获的异常
operation: 操作类型(用于日志记录)
"""
operation_desc = {
"image": "图片响应生成",
"voice": "语音识别",
"text": "文本响应生成",
"embedding": "向量嵌入获取"
}
op_name = operation_desc.get(operation, operation)
if SPECIFIC_EXCEPTIONS_AVAILABLE:
# 使用具体异常类型进行精确处理
if isinstance(e, NetworkConnectionError):
logger.error(f"模型 {self.model_name} {op_name}失败: 网络连接错误")
raise RuntimeError("网络连接异常请检查网络连接状态或API服务器地址是否正确") from e
elif isinstance(e, ReqAbortException):
logger.error(f"模型 {self.model_name} {op_name}失败: 请求被中断")
raise RuntimeError("请求被中断或取消,请稍后重试") from e
elif isinstance(e, RespNotOkException):
logger.error(f"模型 {self.model_name} {op_name}失败: HTTP响应错误 {e.status_code}")
# 重新抛出原始异常,保留详细的状态码信息
raise e
elif isinstance(e, RespParseException):
logger.error(f"模型 {self.model_name} {op_name}失败: 响应解析错误")
raise RuntimeError("API响应格式异常请检查模型配置或联系管理员") from e
else:
# 未知异常,使用通用处理
logger.error(f"模型 {self.model_name} {op_name}失败: 未知错误 {type(e).__name__}: {str(e)}")
self._handle_generic_exception(e, op_name)
else:
# 如果无法导入具体异常,使用通用处理
logger.error(f"模型 {self.model_name} {op_name}失败: {str(e)}")
self._handle_generic_exception(e, op_name)
def _handle_generic_exception(self, e: Exception, operation: str) -> None:
"""
通用异常处理(向后兼容的错误字符串匹配)
Args:
e: 捕获的异常
operation: 操作描述
"""
error_str = str(e)
# 基于错误消息内容的分类处理
if "401" in error_str or "API key" in error_str or "认证" in error_str:
raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e
elif "429" in error_str or "频繁" in error_str or "rate limit" in error_str:
raise RuntimeError("请求过于频繁,请稍后再试") from e
elif "500" in error_str or "503" in error_str or "服务器" in error_str:
raise RuntimeError("服务器负载过高模型回复失败QAQ") from e
elif "413" in error_str or "payload" in error_str.lower() or "过大" in error_str:
raise RuntimeError("请求体过大,请尝试压缩图片或减少输入内容") from e
elif "timeout" in error_str.lower() or "超时" in error_str:
raise RuntimeError("请求超时,请检查网络连接或稍后重试") from e
else:
raise RuntimeError(f"模型 {self.model_name} {operation}失败: {str(e)}") from e
# === 主要API方法 ===
# 这些方法提供与新架构的桥接
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple:
"""
根据输入的提示和图片生成模型的异步响应
使用新架构的模型请求处理器
"""
if not self.use_new_architecture:
raise RuntimeError(
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
)
if self.request_handler is None:
raise RuntimeError(
f"模型 {self.model_name} 请求处理器未初始化,无法处理图片请求"
)
if MessageBuilder is None:
raise RuntimeError("MessageBuilder不可用请检查新架构配置")
try:
# 构建包含图片的消息
message_builder = MessageBuilder()
message_builder.add_text_content(prompt).add_image_content(
image_format=image_format,
image_base64=image_base64
)
messages = [message_builder.build()]
# 使用新架构发送请求(只传递支持的参数)
response = await self.request_handler.get_response( # type: ignore
messages=messages,
tool_options=None,
response_format=None
)
# 新架构返回的是 APIResponse 对象,直接提取内容
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
# 记录token使用情况
if response.usage:
self._record_usage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
total_tokens=response.usage.total_tokens or 0,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions"
)
# 返回格式兼容旧版本
if tool_calls:
return content, reasoning_content, tool_calls
else:
return content, reasoning_content
except Exception as e:
self._handle_model_exception(e, "image")
# 这行代码永远不会执行因为_handle_model_exception总是抛出异常
# 但是为了满足类型检查的要求,我们添加一个不可达的返回语句
return "", "" # pragma: no cover
async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
"""
根据输入的语音文件生成模型的异步响应
使用新架构的模型请求处理器
"""
if not self.use_new_architecture:
raise RuntimeError(
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
)
if self.request_handler is None:
raise RuntimeError(
f"模型 {self.model_name} 请求处理器未初始化,无法处理语音请求"
)
try:
# 构建语音识别请求参数
# 注意:新架构中的语音识别可能使用不同的方法
# 这里先使用get_response方法可能需要根据实际API调整
response = await self.request_handler.get_response( # type: ignore
messages=[], # 语音识别可能不需要消息
tool_options=None
)
# 新架构返回的是 APIResponse 对象,直接提取文本内容
return (response.content,) if response.content else ("",)
except Exception as e:
self._handle_model_exception(e, "voice")
# 不可达的返回语句,仅用于满足类型检查
return ("",) # pragma: no cover
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
"""
异步方式根据输入的提示生成模型的响应
使用新架构的模型请求处理器,如无法使用则抛出错误
"""
if not self.use_new_architecture:
raise RuntimeError(
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
)
if self.request_handler is None:
raise RuntimeError(
f"模型 {self.model_name} 请求处理器未初始化,无法生成响应"
)
if MessageBuilder is None:
raise RuntimeError("MessageBuilder不可用请检查新架构配置")
try:
# 构建消息
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
messages = [message_builder.build()]
# 使用新架构发送请求(只传递支持的参数)
response = await self.request_handler.get_response( # type: ignore
messages=messages,
tool_options=None,
response_format=None
)
# 新架构返回的是 APIResponse 对象,直接提取内容
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
# 记录token使用情况
if response.usage:
self._record_usage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
total_tokens=response.usage.total_tokens or 0,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions"
)
# 返回格式兼容旧版本
if tool_calls:
return content, (reasoning_content, self.model_name, tool_calls)
else:
return content, (reasoning_content, self.model_name)
except Exception as e:
self._handle_model_exception(e, "text")
# 不可达的返回语句,仅用于满足类型检查
return "", ("", self.model_name) # pragma: no cover
async def get_embedding(self, text: str) -> Union[list, None]:
"""
异步方法获取文本的embedding向量
使用新架构的模型请求处理器
Args:
text: 需要获取embedding的文本
Returns:
list: embedding向量如果失败则返回None
"""
if not text:
logger.debug("该消息没有长度不再发送获取embedding向量的请求")
return None
if not self.use_new_architecture:
logger.warning(f"模型 {self.model_name} 无法使用新架构embedding请求将被跳过")
return None
if self.request_handler is None:
logger.warning(f"模型 {self.model_name} 请求处理器未初始化embedding请求将被跳过")
return None
try:
# 构建embedding请求参数
# 使用新架构的get_embedding方法
response = await self.request_handler.get_embedding(text) # type: ignore
# 新架构返回的是 APIResponse 对象直接提取embedding
if response.embedding:
embedding = response.embedding
# 记录token使用情况
if response.usage:
self._record_usage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
total_tokens=response.usage.total_tokens or 0,
user_id="system",
request_type=self.request_type,
endpoint="/embeddings"
)
return embedding
else:
logger.warning(f"模型 {self.model_name} 返回的embedding响应为空")
return None
except Exception as e:
# 对于embedding请求我们记录错误但不抛出异常而是返回None
# 这是为了保持与原有行为的兼容性
try:
self._handle_model_exception(e, "embedding")
except RuntimeError:
# 捕获_handle_model_exception抛出的RuntimeError转换为警告日志
logger.warning(f"模型 {self.model_name} embedding请求失败返回None: {str(e)}")
return None
def compress_base64_image_by_scale(base64_data: str, target_size: int = int(0.8 * 1024 * 1024)) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
# 确保base64字符串只包含ASCII字符
if isinstance(base64_data, str):
base64_data = base64_data.encode("ascii", errors="ignore").decode("ascii")
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2 * 1024 * 1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
n_frames = getattr(img, 'n_frames', 1)
for frame_idx in range(n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format="GIF",
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get("duration", 100),
loop=img.info.get("loop", 0),
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == "PNG" and img.mode in ("RGBA", "LA"):
resized_img.save(output_buffer, format="PNG", optimize=True)
else:
resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.info(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB")
return base64.b64encode(compressed_data).decode("utf-8")
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data

View File

@@ -1,5 +1,5 @@
[inner]
version = "5.0.0"
version = "6.0.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请在修改后将version的值进行变更
@@ -213,98 +213,10 @@ file_log_level = "DEBUG" # 文件日志级别,可选: DEBUG, INFO, WARNING, ER
suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"] # 完全屏蔽的库
library_log_levels = { "aiohttp" = "WARNING"} # 设置特定库的日志级别
#下面的模型若使用硅基流动则不需要更改使用ds官方则改成.env自定义的宏使用自定义模型则选择定位相似的模型自己填写
# stream = <true|false> : 用于指定模型是否是使用流式输出
# pri_in = <float> : 用于指定模型输入价格
# pri_out = <float> : 用于指定模型输出价格
# temp = <float> : 用于指定模型温度
# enable_thinking = <true|false> : 用于指定模型是否启用思考
# thinking_budget = <int> : 用于指定模型思考最长长度
[debug]
show_prompt = false # 是否显示prompt
[model]
model_max_output_length = 800 # 模型单次返回的最大token数
#------------模型任务配置------------
# 所有模型名称需要对应 model_config.toml 中配置的模型名称
[model.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2 # 模型温度新V3建议0.1-0.3
max_tokens = 800 # 最大输出token数
[model.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
model_name = "qwen3-8b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考
[model.replyer_1] # 首要回复模型,还用于表达器和表达方式学习
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2 # 模型温度新V3建议0.1-0.3
max_tokens = 800
[model.replyer_2] # 次要回复模型
model_name = "siliconflow-deepseek-r1" # 对应 model_config.toml 中的模型名称
temperature = 0.7 # 模型温度
max_tokens = 800
[model.planner] #决策:负责决定麦麦该做什么的模型
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.3
max_tokens = 800
[model.emotion] #负责麦麦的情绪变化
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.3
max_tokens = 800
[model.memory] # 记忆模型
model_name = "qwen3-30b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考
[model.vlm] # 图像识别模型
model_name = "qwen2.5-vl-72b" # 对应 model_config.toml 中的模型名称
max_tokens = 800
[model.voice] # 语音识别模型
model_name = "sensevoice-small" # 对应 model_config.toml 中的模型名称
[model.tool_use] #工具调用模型,需要使用支持工具调用的模型
model_name = "qwen3-14b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考qwen3 only
#嵌入模型
[model.embedding]
model_name = "bge-m3" # 对应 model_config.toml 中的模型名称
#------------LPMM知识库模型------------
[model.lpmm_entity_extract] # 实体提取模型
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2
max_tokens = 800
[model.lpmm_rdf_build] # RDF构建模型
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2
max_tokens = 800
[model.lpmm_qa] # 问答模型
model_name = "deepseek-r1-distill-qwen-32b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考
[maim_message]
auth_token = [] # 认证令牌用于API验证为空则不启用验证
# 以下项目若要使用需要打开use_custom并单独配置maim_message的服务器
@@ -320,8 +232,4 @@ key_file = "" # SSL密钥文件路径仅在use_wss=true时有效
enable = true
[experimental] #实验性功能
enable_friend_chat = false # 是否启用好友聊天
enable_friend_chat = false # 是否启用好友聊天

View File

@@ -1,5 +1,5 @@
[inner]
version = "0.2.1"
version = "1.0.0"
# 配置文件版本号迭代规则同bot_config.toml
#
@@ -42,53 +42,31 @@ version = "0.2.1"
# - 未配置新字段时会自动回退到基于模型名称的推断
[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释)
#max_retry = 2 # 最大重试次数单个模型API调用失败最多重试的次数
#timeout = 10 # API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位
#retry_interval = 10 # 重试间隔如果API调用失败重试的间隔时间单位
#default_temperature = 0.7 # 默认的温度如果bot_config.toml中没有设置temperature参数默认使用这个值
#default_max_tokens = 1024 # 默认的最大输出token数如果bot_config.toml中没有设置max_tokens参数默认使用这个值
max_retry = 2 # 最大重试次数单个模型API调用失败最多重试的次数
timeout = 10 # API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位
retry_interval = 10 # 重试间隔如果API调用失败重试的间隔时间单位
default_temperature = 0.7 # 默认的温度如果bot_config.toml中没有设置temperature参数默认使用这个值
default_max_tokens = 1024 # 默认的最大输出token数如果bot_config.toml中没有设置max_tokens参数默认使用这个值
[[api_providers]] # API服务提供商可以配置多个
name = "DeepSeek" # API服务商名称可随意命名在models的api-provider中需使用这个命名
base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL
# 支持多个API Key实现自动切换和负载均衡
api_keys = [ # API Key列表多个key支持错误自动切换和负载均衡
"sk-your-first-key-here",
"sk-your-second-key-here",
"sk-your-third-key-here"
]
# 向后兼容如果只有一个key也可以使用单个key字段
#key = "******" # API Key 可选默认为None
api_key = "sk-your-first-key-here"
client_type = "openai" # 请求客户端(可选,默认值为"openai"使用gimini等Google系模型时请配置为"gemini"
[[api_providers]] # 特殊Google的Gimini使用特殊API与OpenAI格式不兼容需要配置client为"gemini"
name = "Google"
base_url = "https://api.google.com/v1"
# Google API同样支持多key配置
api_keys = [
"your-google-api-key-1",
"your-google-api-key-2"
]
api_key = "your-google-api-key-1"
client_type = "gemini"
[[api_providers]]
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
# 单个key的示例向后兼容
key = "******"
#
#[[api_providers]]
#name = "LocalHost"
#base_url = "https://localhost:8888"
#key = "lm-studio"
[[models]] # 模型(可以配置多个)
# 模型标识符API服务商提供的模型标识符
model_identifier = "deepseek-chat"
# 模型名称可随意命名在bot_config.toml中需使用这个命名
#可选若无该字段则将自动使用model_identifier填充
name = "deepseek-v3"
# API服务商名称对应在api_providers中配置的服务商名称
api_provider = "DeepSeek"
@@ -111,20 +89,15 @@ price_out = 8.0
model_identifier = "deepseek-reasoner"
name = "deepseek-r1"
api_provider = "DeepSeek"
# 推理模型的配置示例
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "text", "tool_calling", "reasoning",]
price_in = 4.0
price_out = 16.0
has_thinking = true # 有无思考参数
enable_thinking = true # 是否启用思考
[[models]]
model_identifier = "Pro/deepseek-ai/DeepSeek-V3"
name = "siliconflow-deepseek-v3"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 2.0
price_out = 8.0
@@ -132,8 +105,6 @@ price_out = 8.0
model_identifier = "Pro/deepseek-ai/DeepSeek-R1"
name = "siliconflow-deepseek-r1"
api_provider = "SiliconFlow"
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
price_in = 4.0
price_out = 16.0
@@ -141,8 +112,6 @@ price_out = 16.0
model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
name = "deepseek-r1-distill-qwen-32b"
api_provider = "SiliconFlow"
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
price_in = 4.0
price_out = 16.0
@@ -150,8 +119,6 @@ price_out = 16.0
model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text"]
price_in = 0
price_out = 0
@@ -159,8 +126,6 @@ price_out = 0
model_identifier = "Qwen/Qwen3-14B"
name = "qwen3-14b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 0.5
price_out = 2.0
@@ -168,8 +133,6 @@ price_out = 2.0
model_identifier = "Qwen/Qwen3-30B-A3B"
name = "qwen3-30b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 0.7
price_out = 2.8
@@ -177,11 +140,6 @@ price_out = 2.8
model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct"
name = "qwen2.5-vl-72b"
api_provider = "SiliconFlow"
# 视觉模型的配置示例
task_type = "vision"
capabilities = ["vision", "text"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "vision", "text",]
price_in = 4.13
price_out = 4.13
@@ -189,11 +147,6 @@ price_out = 4.13
model_identifier = "FunAudioLLM/SenseVoiceSmall"
name = "sensevoice-small"
api_provider = "SiliconFlow"
# 语音模型的配置示例
task_type = "speech"
capabilities = ["speech"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "audio",]
price_in = 0
price_out = 0
@@ -210,11 +163,73 @@ price_in = 0
price_out = 0
[task_model_usage]
llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0}
llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0}
embedding = "siliconflow-bge-m3"
#schedule = [
# "deepseek-v3",
# "deepseek-r1",
#]
[model.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型
model_list = ["siliconflow-deepseek-v3","qwen3-8b"]
temperature = 0.2 # 模型温度新V3建议0.1-0.3
max_tokens = 800 # 最大输出token数
[model.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
model_name = "qwen3-8b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
[model.replyer_1] # 首要回复模型,还用于表达器和表达方式学习
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2 # 模型温度新V3建议0.1-0.3
max_tokens = 800
[model.replyer_2] # 次要回复模型
model_name = "siliconflow-deepseek-r1" # 对应 model_config.toml 中的模型名称
temperature = 0.7 # 模型温度
max_tokens = 800
[model.planner] #决策:负责决定麦麦该做什么的模型
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.3
max_tokens = 800
[model.emotion] #负责麦麦的情绪变化
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.3
max_tokens = 800
[model.memory] # 记忆模型
model_name = "qwen3-30b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考
[model.vlm] # 图像识别模型
model_name = "qwen2.5-vl-72b" # 对应 model_config.toml 中的模型名称
max_tokens = 800
[model.voice] # 语音识别模型
model_name = "sensevoice-small" # 对应 model_config.toml 中的模型名称
[model.tool_use] #工具调用模型,需要使用支持工具调用的模型
model_name = "qwen3-14b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考qwen3 only
#嵌入模型
[model.embedding]
model_name = "bge-m3" # 对应 model_config.toml 中的模型名称
#------------LPMM知识库模型------------
[model.lpmm_entity_extract] # 实体提取模型
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2
max_tokens = 800
[model.lpmm_rdf_build] # RDF构建模型
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2
max_tokens = 800
[model.lpmm_qa] # 问答模型
model_name = "deepseek-r1-distill-qwen-32b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考