From 3c40ceda4cf5b27f237512c973465102514e192b Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Wed, 30 Jul 2025 09:45:13 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=A7=E4=BF=AELLMReq?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/api_ada_configs.py | 264 ++-- src/config/config.py | 629 +++++----- src/config/official_configs.py | 50 +- src/llm_models/exceptions.py | 39 +- src/llm_models/model_client/__init__.py | 380 ------ src/llm_models/model_client/__init__bak.py | 380 ++++++ src/llm_models/model_client/base_client.py | 39 +- src/llm_models/model_client/openai_client.py | 181 +-- src/llm_models/model_manager.py | 82 +- src/llm_models/model_manager_bak.py | 92 ++ src/llm_models/payload_content/__init__.py | 0 src/llm_models/utils_model.py | 1130 +++++++----------- src/llm_models/utils_model_bak.py | 778 ++++++++++++ template/bot_config_template.toml | 96 +- template/model_config_template.toml | 145 ++- 15 files changed, 2290 insertions(+), 1995 deletions(-) create mode 100644 src/llm_models/model_client/__init__bak.py create mode 100644 src/llm_models/model_manager_bak.py create mode 100644 src/llm_models/payload_content/__init__.py create mode 100644 src/llm_models/utils_model_bak.py diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index f5f5abe32..819872c1e 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,180 +1,128 @@ from dataclasses import dataclass, field -from typing import List, Dict, Union -import threading -import time -from packaging.version import Version - -NEWEST_VER = "0.2.1" # 当前支持的最新版本 - -@dataclass -class APIProvider: - name: str = "" # API提供商名称 - base_url: str = "" # API基础URL - api_key: str = field(repr=False, default="") # API密钥(向后兼容) - api_keys: List[str] = field(repr=False, default_factory=list) # API密钥列表(新格式) - client_type: str = "openai" # 客户端类型(如openai/google等,默认为openai) - - # 多API Key管理相关属性 - _current_key_index: int = field(default=0, init=False, repr=False) # 当前使用的key索引 - _key_failure_count: Dict[int, int] = field(default_factory=dict, init=False, repr=False) # 每个key的失败次数 - _key_last_failure_time: Dict[int, float] = field(default_factory=dict, init=False, repr=False) # 每个key最后失败时间 - _lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) # 线程锁 - - def __post_init__(self): - """初始化后处理,确保API keys列表正确""" - # 向后兼容:如果只设置了api_key,将其添加到api_keys列表 - if self.api_key and not self.api_keys: - self.api_keys = [self.api_key] - # 如果api_keys不为空但api_key为空,设置api_key为第一个 - elif self.api_keys and not self.api_key: - self.api_key = self.api_keys[0] - - # 初始化失败计数器 - for i in range(len(self.api_keys)): - self._key_failure_count[i] = 0 - self._key_last_failure_time[i] = 0 - - def get_current_api_key(self) -> str: - """获取当前应该使用的API Key""" - with self._lock: - if not self.api_keys: - return "" - - # 确保索引在有效范围内 - if self._current_key_index >= len(self.api_keys): - self._current_key_index = 0 - - return self.api_keys[self._current_key_index] - - def get_next_api_key(self) -> Union[str, None]: - """获取下一个可用的API Key(负载均衡)""" - with self._lock: - if not self.api_keys: - return None - - # 如果只有一个key,直接返回 - if len(self.api_keys) == 1: - return self.api_keys[0] - - # 轮询到下一个key - self._current_key_index = (self._current_key_index + 1) % len(self.api_keys) - return self.api_keys[self._current_key_index] - - def mark_key_failed(self, api_key: str) -> Union[str, None]: - """标记某个API Key失败,返回下一个可用的key""" - with self._lock: - if not self.api_keys or api_key not in self.api_keys: - return None - - key_index = self.api_keys.index(api_key) - self._key_failure_count[key_index] += 1 - self._key_last_failure_time[key_index] = time.time() - - # 寻找下一个可用的key - current_time = time.time() - for _ in range(len(self.api_keys)): - self._current_key_index = (self._current_key_index + 1) % len(self.api_keys) - next_key_index = self._current_key_index - - # 检查该key是否最近失败过(5分钟内失败超过3次则暂时跳过) - if (self._key_failure_count[next_key_index] <= 3 or - current_time - self._key_last_failure_time[next_key_index] > 300): # 5分钟后重试 - return self.api_keys[next_key_index] - - # 如果所有key都不可用,返回当前key(让上层处理) - return api_key - - def reset_key_failures(self, api_key: str | None = None): - """重置失败计数(成功调用后调用)""" - with self._lock: - if api_key and api_key in self.api_keys: - key_index = self.api_keys.index(api_key) - self._key_failure_count[key_index] = 0 - self._key_last_failure_time[key_index] = 0 - else: - # 重置所有key的失败计数 - for i in range(len(self.api_keys)): - self._key_failure_count[i] = 0 - self._key_last_failure_time[i] = 0 - - def get_api_key_stats(self) -> Dict[str, Dict[str, Union[int, float]]]: - """获取API Key使用统计""" - with self._lock: - stats = {} - for i, key in enumerate(self.api_keys): - # 只显示key的前8位和后4位,中间用*代替 - masked_key = f"{key[:8]}***{key[-4:]}" if len(key) > 12 else "***" - stats[masked_key] = { - "failure_count": self._key_failure_count.get(i, 0), - "last_failure_time": self._key_last_failure_time.get(i, 0), - "is_current": i == self._current_key_index - } - return stats +from .config_base import ConfigBase @dataclass -class ModelInfo: - model_identifier: str = "" # 模型标识符(用于URL调用) - name: str = "" # 模型名称(用于模块调用) - api_provider: str = "" # API提供商(如OpenAI、Azure等) +class APIProvider(ConfigBase): + """API提供商配置类""" - # 以下用于模型计费 - price_in: float = 0.0 # 每M token输入价格 - price_out: float = 0.0 # 每M token输出价格 + name: str + """API提供商名称""" - force_stream_mode: bool = False # 是否强制使用流式输出模式 - - # 新增:任务类型和能力字段 - task_type: str = "" # 任务类型:llm_normal, llm_reasoning, vision, embedding, speech - capabilities: List[str] = field(default_factory=list) # 模型能力:text, vision, embedding, speech, tool_calling, reasoning + base_url: str + """API基础URL""" + + api_key: str = field(default_factory=str, repr=False) + """API密钥列表""" + + client_type: str = field(default="openai") + """客户端类型(如openai/google等,默认为openai)""" + + max_retry: int = 2 + """最大重试次数(单个模型API调用失败,最多重试的次数)""" + + timeout: int = 10 + """API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)""" + + retry_interval: int = 10 + """重试间隔(如果API调用失败,重试的间隔时间,单位:秒)""" + + def get_api_key(self) -> str: + return self.api_key @dataclass -class RequestConfig: - max_retry: int = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) - timeout: int = ( - 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) - ) - retry_interval: int = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) - default_temperature: float = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) - default_max_tokens: int = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) +class ModelInfo(ConfigBase): + """单个模型信息配置类""" + + model_identifier: str + """模型标识符(用于URL调用)""" + + name: str + """模型名称(用于模块调用)""" + + api_provider: str + """API提供商(如OpenAI、Azure等)""" + + price_in: float = field(default=0.0) + """每M token输入价格""" + + price_out: float = field(default=0.0) + """每M token输出价格""" + + force_stream_mode: bool = field(default=False) + """是否强制使用流式输出模式""" + + has_thinking: bool = field(default=False) + """是否有思考参数""" + + enable_thinking: bool = field(default=False) + """是否启用思考""" @dataclass -class ModelUsageArgConfigItem: - """模型使用的配置类 - 该类用于加载和存储子任务模型使用的配置 - """ +class TaskConfig(ConfigBase): + """任务配置类""" - name: str = "" # 模型名称 - temperature: float | None = None # 温度 - max_tokens: int | None = None # 最大token数 - max_retry: int | None = None # 调用失败时的最大重试次数 + model_list: list[str] = field(default_factory=list) + """任务使用的模型列表""" + + max_tokens: int = 1024 + """任务最大输出token数""" + + temperature: float = 0.3 + """模型温度""" @dataclass -class ModelUsageArgConfig: - """子任务使用模型的配置类 - 该类用于加载和存储子任务使用的模型配置 - """ +class ModelTaskConfig(ConfigBase): + """模型配置类""" - name: str = "" # 任务名称 - usage: List[ModelUsageArgConfigItem] = field( - default_factory=lambda: [] - ) # 任务使用的模型列表 + utils: TaskConfig + """组件模型配置""" + utils_small: TaskConfig + """组件小模型配置""" + replyer_1: TaskConfig + """normal_chat首要回复模型模型配置""" -@dataclass -class ModuleConfig: - INNER_VERSION: Version | None = None # 配置文件版本 + replyer_2: TaskConfig + """normal_chat次要回复模型配置""" - req_conf: RequestConfig = field(default_factory=lambda: RequestConfig()) # 请求配置 - api_providers: Dict[str, APIProvider] = field( - default_factory=lambda: {} - ) # API提供商列表 - models: Dict[str, ModelInfo] = field(default_factory=lambda: {}) # 模型列表 - task_model_arg_map: Dict[str, ModelUsageArgConfig] = field( - default_factory=lambda: {} - ) \ No newline at end of file + memory: TaskConfig + """记忆模型配置""" + + emotion: TaskConfig + """情绪模型配置""" + + vlm: TaskConfig + """视觉语言模型配置""" + + voice: TaskConfig + """语音识别模型配置""" + + tool_use: TaskConfig + """专注工具使用模型配置""" + + planner: TaskConfig + """规划模型配置""" + + embedding: TaskConfig + """嵌入模型配置""" + + lpmm_entity_extract: TaskConfig + """LPMM实体提取模型配置""" + + lpmm_rdf_build: TaskConfig + """LPMM RDF构建模型配置""" + + lpmm_qa: TaskConfig + """LPMM问答模型配置""" + + def get_task(self, task_name: str) -> TaskConfig: + """获取指定任务的配置""" + if hasattr(self, task_name): + return getattr(self, task_name) + raise ValueError(f"任务 '{task_name}' 未找到对应的配置") diff --git a/src/config/config.py b/src/config/config.py index b8f24c5fa..298163b07 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,16 +1,14 @@ import os import tomlkit import shutil +import sys from datetime import datetime from tomlkit import TOMLDocument from tomlkit.items import Table, KeyType from dataclasses import field, dataclass from rich.traceback import install -from packaging import version -from packaging.specifiers import SpecifierSet -from packaging.version import Version, InvalidVersion -from typing import Any, Dict, List +from typing import List, Optional from src.common.logger import get_logger from src.config.config_base import ConfigBase @@ -29,7 +27,6 @@ from src.config.official_configs import ( ResponseSplitterConfig, TelemetryConfig, ExperimentalConfig, - ModelConfig, MessageReceiveConfig, MaimMessageConfig, LPMMKnowledgeConfig, @@ -41,16 +38,12 @@ from src.config.official_configs import ( ) from .api_ada_configs import ( - ModelUsageArgConfigItem, - ModelUsageArgConfig, - APIProvider, + ModelTaskConfig, ModelInfo, - NEWEST_VER, - ModuleConfig, + APIProvider, ) - install(extra_lines=3) @@ -64,275 +57,270 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") # 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.10.0-snapshot1" +MMC_VERSION = "0.10.0-snapshot.2" +# def _get_config_version(toml: Dict) -> Version: +# """提取配置文件的 SpecifierSet 版本数据 +# Args: +# toml[dict]: 输入的配置文件字典 +# Returns: +# Version +# """ + +# if "inner" in toml and "version" in toml["inner"]: +# config_version: str = toml["inner"]["version"] +# else: +# raise InvalidVersion("配置文件缺少版本信息,请检查配置文件。") + +# try: +# return version.parse(config_version) +# except InvalidVersion as e: +# logger.error( +# "配置文件中 inner段 的 version 键是错误的版本描述\n" +# f"请检查配置文件,当前 version 键: {config_version}\n" +# f"错误信息: {e}" +# ) +# raise e -def _get_config_version(toml: Dict) -> Version: - """提取配置文件的 SpecifierSet 版本数据 - Args: - toml[dict]: 输入的配置文件字典 - Returns: - Version - """ - - if "inner" in toml and "version" in toml["inner"]: - config_version: str = toml["inner"]["version"] - else: - config_version = "0.0.0" # 默认版本 - - try: - ver = version.parse(config_version) - except InvalidVersion as e: - logger.error( - "配置文件中 inner段 的 version 键是错误的版本描述\n" - f"请检查配置文件,当前 version 键: {config_version}\n" - f"错误信息: {e}" - ) - raise InvalidVersion( - "配置文件中 inner段 的 version 键是错误的版本描述\n" - ) from e - - return ver +# def _request_conf(parent: Dict, config: ModuleConfig): +# request_conf_config = parent.get("request_conf") +# config.req_conf.max_retry = request_conf_config.get( +# "max_retry", config.req_conf.max_retry +# ) +# config.req_conf.timeout = request_conf_config.get( +# "timeout", config.req_conf.timeout +# ) +# config.req_conf.retry_interval = request_conf_config.get( +# "retry_interval", config.req_conf.retry_interval +# ) +# config.req_conf.default_temperature = request_conf_config.get( +# "default_temperature", config.req_conf.default_temperature +# ) +# config.req_conf.default_max_tokens = request_conf_config.get( +# "default_max_tokens", config.req_conf.default_max_tokens +# ) -def _request_conf(parent: Dict, config: ModuleConfig): - request_conf_config = parent.get("request_conf") - config.req_conf.max_retry = request_conf_config.get( - "max_retry", config.req_conf.max_retry - ) - config.req_conf.timeout = request_conf_config.get( - "timeout", config.req_conf.timeout - ) - config.req_conf.retry_interval = request_conf_config.get( - "retry_interval", config.req_conf.retry_interval - ) - config.req_conf.default_temperature = request_conf_config.get( - "default_temperature", config.req_conf.default_temperature - ) - config.req_conf.default_max_tokens = request_conf_config.get( - "default_max_tokens", config.req_conf.default_max_tokens - ) +# def _api_providers(parent: Dict, config: ModuleConfig): +# api_providers_config = parent.get("api_providers") +# for provider in api_providers_config: +# name = provider.get("name", None) +# base_url = provider.get("base_url", None) +# api_key = provider.get("api_key", None) +# api_keys = provider.get("api_keys", []) # 新增:支持多个API Key +# client_type = provider.get("client_type", "openai") + +# if name in config.api_providers: # 查重 +# logger.error(f"重复的API提供商名称: {name},请检查配置文件。") +# raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。") + +# if name and base_url: +# # 处理API Key配置:支持单个api_key或多个api_keys +# if api_keys: +# # 使用新格式:api_keys列表 +# logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key") +# elif api_key: +# # 向后兼容:使用单个api_key +# api_keys = [api_key] +# logger.debug(f"API提供商 '{name}' 使用单个API Key(向后兼容模式)") +# else: +# logger.warning(f"API提供商 '{name}' 没有配置API Key,某些功能可能不可用") + +# config.api_providers[name] = APIProvider( +# name=name, +# base_url=base_url, +# api_key=api_key, # 保留向后兼容 +# api_keys=api_keys, # 新格式 +# client_type=client_type, +# ) +# else: +# logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") +# raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") -def _api_providers(parent: Dict, config: ModuleConfig): - api_providers_config = parent.get("api_providers") - for provider in api_providers_config: - name = provider.get("name", None) - base_url = provider.get("base_url", None) - api_key = provider.get("api_key", None) - api_keys = provider.get("api_keys", []) # 新增:支持多个API Key - client_type = provider.get("client_type", "openai") +# def _models(parent: Dict, config: ModuleConfig): +# models_config = parent.get("models") +# for model in models_config: +# model_identifier = model.get("model_identifier", None) +# name = model.get("name", model_identifier) +# api_provider = model.get("api_provider", None) +# price_in = model.get("price_in", 0.0) +# price_out = model.get("price_out", 0.0) +# force_stream_mode = model.get("force_stream_mode", False) +# task_type = model.get("task_type", "") +# capabilities = model.get("capabilities", []) - if name in config.api_providers: # 查重 - logger.error(f"重复的API提供商名称: {name},请检查配置文件。") - raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。") +# if name in config.models: # 查重 +# logger.error(f"重复的模型名称: {name},请检查配置文件。") +# raise KeyError(f"重复的模型名称: {name},请检查配置文件。") - if name and base_url: - # 处理API Key配置:支持单个api_key或多个api_keys - if api_keys: - # 使用新格式:api_keys列表 - logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key") - elif api_key: - # 向后兼容:使用单个api_key - api_keys = [api_key] - logger.debug(f"API提供商 '{name}' 使用单个API Key(向后兼容模式)") - else: - logger.warning(f"API提供商 '{name}' 没有配置API Key,某些功能可能不可用") - - config.api_providers[name] = APIProvider( - name=name, - base_url=base_url, - api_key=api_key, # 保留向后兼容 - api_keys=api_keys, # 新格式 - client_type=client_type, - ) - else: - logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") - raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") +# if model_identifier and api_provider: +# # 检查API提供商是否存在 +# if api_provider not in config.api_providers: +# logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。") +# raise ValueError( +# f"未声明的API提供商 '{api_provider}' ,请检查配置文件。" +# ) +# config.models[name] = ModelInfo( +# name=name, +# model_identifier=model_identifier, +# api_provider=api_provider, +# price_in=price_in, +# price_out=price_out, +# force_stream_mode=force_stream_mode, +# task_type=task_type, +# capabilities=capabilities, +# ) +# else: +# logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。") +# raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。") -def _models(parent: Dict, config: ModuleConfig): - models_config = parent.get("models") - for model in models_config: - model_identifier = model.get("model_identifier", None) - name = model.get("name", model_identifier) - api_provider = model.get("api_provider", None) - price_in = model.get("price_in", 0.0) - price_out = model.get("price_out", 0.0) - force_stream_mode = model.get("force_stream_mode", False) - task_type = model.get("task_type", "") - capabilities = model.get("capabilities", []) +# def _task_model_usage(parent: Dict, config: ModuleConfig): +# model_usage_configs = parent.get("task_model_usage") +# config.task_model_arg_map = {} +# for task_name, item in model_usage_configs.items(): +# if task_name in config.task_model_arg_map: +# logger.error(f"子任务 {task_name} 已存在,请检查配置文件。") +# raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。") - if name in config.models: # 查重 - logger.error(f"重复的模型名称: {name},请检查配置文件。") - raise KeyError(f"重复的模型名称: {name},请检查配置文件。") +# usage = [] +# if isinstance(item, Dict): +# if "model" in item: +# usage.append( +# ModelUsageArgConfigItem( +# name=item["model"], +# temperature=item.get("temperature", None), +# max_tokens=item.get("max_tokens", None), +# max_retry=item.get("max_retry", None), +# ) +# ) +# else: +# logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。") +# raise ValueError( +# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" +# ) +# elif isinstance(item, List): +# for model in item: +# if isinstance(model, Dict): +# usage.append( +# ModelUsageArgConfigItem( +# name=model["model"], +# temperature=model.get("temperature", None), +# max_tokens=model.get("max_tokens", None), +# max_retry=model.get("max_retry", None), +# ) +# ) +# elif isinstance(model, str): +# usage.append( +# ModelUsageArgConfigItem( +# name=model, +# temperature=None, +# max_tokens=None, +# max_retry=None, +# ) +# ) +# else: +# logger.error( +# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" +# ) +# raise ValueError( +# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" +# ) +# elif isinstance(item, str): +# usage.append( +# ModelUsageArgConfigItem( +# name=item, +# temperature=None, +# max_tokens=None, +# max_retry=None, +# ) +# ) - if model_identifier and api_provider: - # 检查API提供商是否存在 - if api_provider not in config.api_providers: - logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。") - raise ValueError( - f"未声明的API提供商 '{api_provider}' ,请检查配置文件。" - ) - config.models[name] = ModelInfo( - name=name, - model_identifier=model_identifier, - api_provider=api_provider, - price_in=price_in, - price_out=price_out, - force_stream_mode=force_stream_mode, - task_type=task_type, - capabilities=capabilities, - ) - else: - logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。") - raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。") +# config.task_model_arg_map[task_name] = ModelUsageArgConfig( +# name=task_name, +# usage=usage, +# ) -def _task_model_usage(parent: Dict, config: ModuleConfig): - model_usage_configs = parent.get("task_model_usage") - config.task_model_arg_map = {} - for task_name, item in model_usage_configs.items(): - if task_name in config.task_model_arg_map: - logger.error(f"子任务 {task_name} 已存在,请检查配置文件。") - raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。") +# def api_ada_load_config(config_path: str) -> ModuleConfig: +# """从TOML配置文件加载配置""" +# config = ModuleConfig() - usage = [] - if isinstance(item, Dict): - if "model" in item: - usage.append( - ModelUsageArgConfigItem( - name=item["model"], - temperature=item.get("temperature", None), - max_tokens=item.get("max_tokens", None), - max_retry=item.get("max_retry", None), - ) - ) - else: - logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。") - raise ValueError( - f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" - ) - elif isinstance(item, List): - for model in item: - if isinstance(model, Dict): - usage.append( - ModelUsageArgConfigItem( - name=model["model"], - temperature=model.get("temperature", None), - max_tokens=model.get("max_tokens", None), - max_retry=model.get("max_retry", None), - ) - ) - elif isinstance(model, str): - usage.append( - ModelUsageArgConfigItem( - name=model, - temperature=None, - max_tokens=None, - max_retry=None, - ) - ) - else: - logger.error( - f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" - ) - raise ValueError( - f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" - ) - elif isinstance(item, str): - usage.append( - ModelUsageArgConfigItem( - name=item, - temperature=None, - max_tokens=None, - max_retry=None, - ) - ) +# include_configs: Dict[str, Dict[str, Any]] = { +# "request_conf": { +# "func": _request_conf, +# "support": ">=0.0.0", +# "necessary": False, +# }, +# "api_providers": {"func": _api_providers, "support": ">=0.0.0"}, +# "models": {"func": _models, "support": ">=0.0.0"}, +# "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"}, +# } - config.task_model_arg_map[task_name] = ModelUsageArgConfig( - name=task_name, - usage=usage, - ) +# if os.path.exists(config_path): +# with open(config_path, "rb") as f: +# try: +# toml_dict = tomlkit.load(f) +# except tomlkit.TOMLDecodeError as e: +# logger.critical( +# f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}" +# ) +# exit(1) +# # 获取配置文件版本 +# config.INNER_VERSION = _get_config_version(toml_dict) -def api_ada_load_config(config_path: str) -> ModuleConfig: - """从TOML配置文件加载配置""" - config = ModuleConfig() +# # 检查版本 +# if config.INNER_VERSION > Version(NEWEST_VER): +# logger.warning( +# f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。" +# ) - include_configs: Dict[str, Dict[str, Any]] = { - "request_conf": { - "func": _request_conf, - "support": ">=0.0.0", - "necessary": False, - }, - "api_providers": {"func": _api_providers, "support": ">=0.0.0"}, - "models": {"func": _models, "support": ">=0.0.0"}, - "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"}, - } +# # 解析配置文件 +# # 如果在配置中找到了需要的项,调用对应项的闭包函数处理 +# for key in include_configs: +# if key in toml_dict: +# group_specifier_set: SpecifierSet = SpecifierSet( +# include_configs[key]["support"] +# ) - if os.path.exists(config_path): - with open(config_path, "rb") as f: - try: - toml_dict = tomlkit.load(f) - except tomlkit.TOMLDecodeError as e: - logger.critical( - f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}" - ) - exit(1) +# # 检查配置文件版本是否在支持范围内 +# if config.INNER_VERSION in group_specifier_set: +# # 如果版本在支持范围内,检查是否存在通知 +# if "notice" in include_configs[key]: +# logger.warning(include_configs[key]["notice"]) +# # 调用闭包函数处理配置 +# (include_configs[key]["func"])(toml_dict, config) +# else: +# # 如果版本不在支持范围内,崩溃并提示用户 +# logger.error( +# f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n" +# f"当前程序仅支持以下版本范围: {group_specifier_set}" +# ) +# raise InvalidVersion( +# f"当前程序仅支持以下版本范围: {group_specifier_set}" +# ) - # 获取配置文件版本 - config.INNER_VERSION = _get_config_version(toml_dict) +# # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理 +# elif ( +# "necessary" in include_configs[key] +# and include_configs[key].get("necessary") is False +# ): +# # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 +# if key == "keywords_reaction": +# pass +# else: +# # 如果用户根本没有需要的配置项,提示缺少配置 +# logger.error(f"配置文件中缺少必需的字段: '{key}'") +# raise KeyError(f"配置文件中缺少必需的字段: '{key}'") - # 检查版本 - if config.INNER_VERSION > Version(NEWEST_VER): - logger.warning( - f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。" - ) +# logger.info(f"成功加载配置文件: {config_path}") - # 解析配置文件 - # 如果在配置中找到了需要的项,调用对应项的闭包函数处理 - for key in include_configs: - if key in toml_dict: - group_specifier_set: SpecifierSet = SpecifierSet( - include_configs[key]["support"] - ) +# return config - # 检查配置文件版本是否在支持范围内 - if config.INNER_VERSION in group_specifier_set: - # 如果版本在支持范围内,检查是否存在通知 - if "notice" in include_configs[key]: - logger.warning(include_configs[key]["notice"]) - # 调用闭包函数处理配置 - (include_configs[key]["func"])(toml_dict, config) - else: - # 如果版本不在支持范围内,崩溃并提示用户 - logger.error( - f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n" - f"当前程序仅支持以下版本范围: {group_specifier_set}" - ) - raise InvalidVersion( - f"当前程序仅支持以下版本范围: {group_specifier_set}" - ) - - # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理 - elif ( - "necessary" in include_configs[key] - and include_configs[key].get("necessary") is False - ): - # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 - if key == "keywords_reaction": - pass - else: - # 如果用户根本没有需要的配置项,提示缺少配置 - logger.error(f"配置文件中缺少必需的字段: '{key}'") - raise KeyError(f"配置文件中缺少必需的字段: '{key}'") - - logger.info(f"成功加载配置文件: {config_path}") - - return config def get_key_comment(toml_table, key): # 获取key的注释(如果有) @@ -361,7 +349,7 @@ def compare_dicts(new, old, path=None, logs=None): continue if key not in old: comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): compare_dicts(new[key], old[key], path + [str(key)], logs) # 删减项 @@ -370,7 +358,7 @@ def compare_dicts(new, old, path=None, logs=None): continue if key not in new: comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") return logs @@ -405,17 +393,13 @@ def compare_default_values(new, old, path=None, logs=None, changes=None): if key in old: if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)): compare_default_values(new[key], old[key], path + [str(key)], logs, changes) - else: - # 只要值发生变化就记录 - if new[key] != old[key]: - logs.append( - f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}" - ) - changes.append((path + [str(key)], old[key], new[key])) + elif new[key] != old[key]: + logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}") + changes.append((path + [str(key)], old[key], new[key])) return logs, changes -def _get_version_from_toml(toml_path): +def _get_version_from_toml(toml_path) -> Optional[str]: """从TOML文件中获取版本号""" if not os.path.exists(toml_path): return None @@ -459,14 +443,13 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic target[key] = value -def _update_config_generic(config_name: str, template_name: str, should_quit_on_new: bool = True): +def _update_config_generic(config_name: str, template_name: str): """ 通用的配置文件更新函数 - + Args: config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config' template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template' - should_quit_on_new: 创建新配置文件后是否退出程序 """ # 获取根目录路径 old_config_dir = os.path.join(CONFIG_DIR, "old") @@ -484,19 +467,30 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_ template_version = _get_version_from_toml(template_path) compare_version = _get_version_from_toml(compare_path) + # 检查配置文件是否存在 + if not os.path.exists(old_config_path): + logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置") + os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 + shutil.copy2(template_path, old_config_path) # 复制模板文件 + logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}") + # 新创建配置文件,退出 + sys.exit(0) + + compare_config = None + new_config = None + old_config = None + # 先读取 compare 下的模板(如果有),用于默认值变动检测 if os.path.exists(compare_path): with open(compare_path, "r", encoding="utf-8") as f: compare_config = tomlkit.load(f) - else: - compare_config = None # 读取当前模板 with open(template_path, "r", encoding="utf-8") as f: new_config = tomlkit.load(f) # 检查默认值变化并处理(只有 compare_config 存在时才做) - if compare_config is not None: + if compare_config: # 读取旧配置 with open(old_config_path, "r", encoding="utf-8") as f: old_config = tomlkit.load(f) @@ -515,32 +509,16 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_ ) else: logger.info(f"未检测到{config_name}模板默认值变动") - # 保存旧配置的变更(后续合并逻辑会用到 old_config) - else: - old_config = None # 检查 compare 下没有模板,或新模板版本更高,则复制 if not os.path.exists(compare_path): shutil.copy2(template_path, compare_path) logger.info(f"已将{config_name}模板文件复制到: {compare_path}") + elif _version_tuple(template_version) > _version_tuple(compare_version): + shutil.copy2(template_path, compare_path) + logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}") else: - if _version_tuple(template_version) > _version_tuple(compare_version): - shutil.copy2(template_path, compare_path) - logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}") - else: - logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}") - - # 检查配置文件是否存在 - if not os.path.exists(old_config_path): - logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置") - os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 - shutil.copy2(template_path, old_config_path) # 复制模板文件 - logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}") - # 如果是新创建的配置文件,根据参数决定是否退出 - if should_quit_on_new: - quit() - else: - return + logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}") # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次) if old_config is None: @@ -578,8 +556,7 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_ # 输出新增和删减项及注释 if old_config: logger.info(f"{config_name}配置项变动如下:\n----------------------------------------") - logs = compare_dicts(new_config, old_config) - if logs: + if logs := compare_dicts(new_config, old_config): for log in logs: logger.info(log) else: @@ -597,12 +574,12 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_ def update_config(): """更新bot_config.toml配置文件""" - _update_config_generic("bot_config", "bot_config_template", should_quit_on_new=True) + _update_config_generic("bot_config", "bot_config_template") def update_model_config(): """更新model_config.toml配置文件""" - _update_config_generic("model_config", "model_config_template", should_quit_on_new=False) + _update_config_generic("model_config", "model_config_template") @dataclass @@ -627,7 +604,6 @@ class Config(ConfigBase): response_splitter: ResponseSplitterConfig telemetry: TelemetryConfig experimental: ExperimentalConfig - model: ModelConfig maim_message: MaimMessageConfig lpmm_knowledge: LPMMKnowledgeConfig tool: ToolConfig @@ -635,11 +611,48 @@ class Config(ConfigBase): custom_prompt: CustomPromptConfig voice: VoiceConfig + +@dataclass +class APIAdapterConfig(ConfigBase): + """API Adapter配置类""" + + models: List[ModelInfo] + """模型列表""" + + model_task_config: ModelTaskConfig + """模型任务配置""" + + api_providers: List[APIProvider] = field(default_factory=list) + """API提供商列表""" + + def __post_init__(self): + self.api_providers_dict = {provider.name: provider for provider in self.api_providers} + self.models_dict = {model.name: model for model in self.models} + + def get_model_info(self, model_name: str) -> ModelInfo: + """根据模型名称获取模型信息""" + if not model_name: + raise ValueError("模型名称不能为空") + if model_name not in self.models_dict: + raise KeyError(f"模型 '{model_name}' 不存在") + return self.models_dict[model_name] + + def get_provider(self, provider_name: str) -> APIProvider: + """根据提供商名称获取API提供商信息""" + if not provider_name: + raise ValueError("API提供商名称不能为空") + if provider_name not in self.api_providers_dict: + raise KeyError(f"API提供商 '{provider_name}' 不存在") + return self.api_providers_dict[provider_name] + + def load_config(config_path: str) -> Config: """ 加载配置文件 - :param config_path: 配置文件路径 - :return: Config对象 + Args: + config_path: 配置文件路径 + Returns: + Config对象 """ # 读取配置文件 with open(config_path, "r", encoding="utf-8") as f: @@ -653,12 +666,24 @@ def load_config(config_path: str) -> Config: raise e -def get_config_dir() -> str: +def api_ada_load_config(config_path: str) -> APIAdapterConfig: """ - 获取配置目录 - :return: 配置目录路径 + 加载API适配器配置文件 + Args: + config_path: 配置文件路径 + Returns: + APIAdapterConfig对象 """ - return CONFIG_DIR + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + config_data = tomlkit.load(f) + + # 创建APIAdapterConfig对象 + try: + return APIAdapterConfig.from_dict(config_data) + except Exception as e: + logger.critical("API适配器配置文件解析失败") + raise e # 获取配置文件路径 @@ -669,4 +694,4 @@ update_model_config() logger.info("正在品鉴配置文件...") global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml")) model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml")) -logger.info("非常的新鲜,非常的美味!") \ No newline at end of file +logger.info("非常的新鲜,非常的美味!") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 08acf97c6..8f34a1843 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,10 +1,9 @@ import re from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from typing import Literal, Optional from src.config.config_base import ConfigBase -from packaging.version import Version """ 须知: @@ -599,50 +598,3 @@ class LPMMKnowledgeConfig(ConfigBase): embedding_dimension: int = 1024 """嵌入向量维度,应该与模型的输出维度一致""" -@dataclass -class ModelConfig(ConfigBase): - """模型配置类""" - - model_max_output_length: int = 800 # 最大回复长度 - - utils: dict[str, Any] = field(default_factory=lambda: {}) - """组件模型配置""" - - utils_small: dict[str, Any] = field(default_factory=lambda: {}) - """组件小模型配置""" - - replyer_1: dict[str, Any] = field(default_factory=lambda: {}) - """normal_chat首要回复模型模型配置""" - - replyer_2: dict[str, Any] = field(default_factory=lambda: {}) - """normal_chat次要回复模型配置""" - - memory: dict[str, Any] = field(default_factory=lambda: {}) - """记忆模型配置""" - - emotion: dict[str, Any] = field(default_factory=lambda: {}) - """情绪模型配置""" - - vlm: dict[str, Any] = field(default_factory=lambda: {}) - """视觉语言模型配置""" - - voice: dict[str, Any] = field(default_factory=lambda: {}) - """语音识别模型配置""" - - tool_use: dict[str, Any] = field(default_factory=lambda: {}) - """专注工具使用模型配置""" - - planner: dict[str, Any] = field(default_factory=lambda: {}) - """规划模型配置""" - - embedding: dict[str, Any] = field(default_factory=lambda: {}) - """嵌入模型配置""" - - lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {}) - """LPMM实体提取模型配置""" - - lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {}) - """LPMM RDF构建模型配置""" - - lpmm_qa: dict[str, Any] = field(default_factory=lambda: {}) - """LPMM问答模型配置""" diff --git a/src/llm_models/exceptions.py b/src/llm_models/exceptions.py index 0ced8dd14..5b04f58c6 100644 --- a/src/llm_models/exceptions.py +++ b/src/llm_models/exceptions.py @@ -62,8 +62,37 @@ class RespParseException(Exception): self.message = message def __str__(self): - return ( - self.message - if self.message - else "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法" - ) + return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法" + + +class PayLoadTooLargeError(Exception): + """自定义异常类,用于处理请求体过大错误""" + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def __str__(self): + return "请求体过大,请尝试压缩图片或减少输入内容。" + + +class RequestAbortException(Exception): + """自定义异常类,用于处理请求中断异常""" + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def __str__(self): + return self.message + + +class PermissionDeniedException(Exception): + """自定义异常类,用于处理访问拒绝的异常""" + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def __str__(self): + return self.message diff --git a/src/llm_models/model_client/__init__.py b/src/llm_models/model_client/__init__.py index 7e57c82d6..e69de29bb 100644 --- a/src/llm_models/model_client/__init__.py +++ b/src/llm_models/model_client/__init__.py @@ -1,380 +0,0 @@ -import asyncio -from typing import Callable, Any - -from openai import AsyncStream -from openai.types.chat import ChatCompletionChunk, ChatCompletion - -from .base_client import BaseClient, APIResponse -from src.config.api_ada_configs import ( - ModelInfo, - ModelUsageArgConfigItem, - RequestConfig, - ModuleConfig, -) -from ..exceptions import ( - NetworkConnectionError, - ReqAbortException, - RespNotOkException, - RespParseException, -) -from ..payload_content.message import Message -from ..payload_content.resp_format import RespFormat -from ..payload_content.tool_option import ToolOption -from ..utils import compress_messages -from src.common.logger import get_logger - -logger = get_logger("模型客户端") - - -def _check_retry( - remain_try: int, - retry_interval: int, - can_retry_msg: str, - cannot_retry_msg: str, - can_retry_callable: Callable | None = None, - **kwargs, -) -> tuple[int, Any | None]: - """ - 辅助函数:检查是否可以重试 - :param remain_try: 剩余尝试次数 - :param retry_interval: 重试间隔 - :param can_retry_msg: 可以重试时的提示信息 - :param cannot_retry_msg: 不可以重试时的提示信息 - :return: (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - if remain_try > 0: - # 还有重试机会 - logger.warning(f"{can_retry_msg}") - if can_retry_callable is not None: - return retry_interval, can_retry_callable(**kwargs) - else: - return retry_interval, None - else: - # 达到最大重试次数 - logger.warning(f"{cannot_retry_msg}") - return -1, None # 不再重试请求该模型 - - -def _handle_resp_not_ok( - e: RespNotOkException, - task_name: str, - model_name: str, - remain_try: int, - retry_interval: int = 10, - messages: tuple[list[Message], bool] | None = None, -): - """ - 处理响应错误异常 - :param e: 异常对象 - :param task_name: 任务名称 - :param model_name: 模型名称 - :param remain_try: 剩余尝试次数 - :param retry_interval: 重试间隔 - :param messages: (消息列表, 是否已压缩过) - :return: (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - # 响应错误 - if e.status_code in [401, 403]: - # API Key认证错误 - 让多API Key机制处理,给一次重试机会 - if remain_try > 0: - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"API Key认证失败(错误代码-{e.status_code}),多API Key机制会自动切换" - ) - return 0, None # 立即重试,让底层客户端切换API Key - else: - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"所有API Key都认证失败,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None # 不再重试请求该模型 - elif e.status_code in [400, 402, 404]: - # 其他客户端错误(不应该重试) - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"请求失败,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None # 不再重试请求该模型 - elif e.status_code == 413: - if messages and not messages[1]: - # 消息列表不为空且未压缩,尝试压缩消息 - return _check_retry( - remain_try, - 0, - can_retry_msg=( - f"任务-'{task_name}' 模型-'{model_name}'\n" - "请求体过大,尝试压缩消息后重试" - ), - cannot_retry_msg=( - f"任务-'{task_name}' 模型-'{model_name}'\n" - "请求体过大,压缩消息后仍然过大,放弃请求" - ), - can_retry_callable=compress_messages, - messages=messages[0], - ) - # 没有消息可压缩 - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}'\n" - "请求体过大,无法压缩消息,放弃请求。" - ) - return -1, None - elif e.status_code == 429: - # 请求过于频繁 - 让多API Key机制处理,适当延迟后重试 - return _check_retry( - remain_try, - min(retry_interval, 5), # 限制最大延迟为5秒,让API Key切换更快生效 - can_retry_msg=( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"请求过于频繁,多API Key机制会自动切换,{min(retry_interval, 5)}秒后重试" - ), - cannot_retry_msg=( - f"任务-'{task_name}' 模型-'{model_name}'\n" - "请求过于频繁,所有API Key都被限制,放弃请求" - ), - ) - elif e.status_code >= 500: - # 服务器错误 - return _check_retry( - remain_try, - retry_interval, - can_retry_msg=( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"服务器错误,将于{retry_interval}秒后重试" - ), - cannot_retry_msg=( - f"任务-'{task_name}' 模型-'{model_name}'\n" - "服务器错误,超过最大重试次数,请稍后再试" - ), - ) - else: - # 未知错误 - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"未知错误,错误代码-{e.status_code},错误信息-{e.message}" - ) - return -1, None - - -def default_exception_handler( - e: Exception, - task_name: str, - model_name: str, - remain_try: int, - retry_interval: int = 10, - messages: tuple[list[Message], bool] | None = None, -) -> tuple[int, list[Message] | None]: - """ - 默认异常处理函数 - :param e: 异常对象 - :param task_name: 任务名称 - :param model_name: 模型名称 - :param remain_try: 剩余尝试次数 - :param retry_interval: 重试间隔 - :param messages: (消息列表, 是否已压缩过) - :return (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) - """ - - if isinstance(e, NetworkConnectionError): # 网络连接错误 - # 网络错误可能是某个API Key的端点问题,给多API Key机制一次快速重试机会 - return _check_retry( - remain_try, - min(retry_interval, 3), # 网络错误时减少等待时间,让API Key切换更快 - can_retry_msg=( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"连接异常,多API Key机制会尝试其他Key,{min(retry_interval, 3)}秒后重试" - ), - cannot_retry_msg=( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确" - ), - ) - elif isinstance(e, ReqAbortException): - logger.warning( - f"任务-'{task_name}' 模型-'{model_name}'\n请求被中断,详细信息-{str(e.message)}" - ) - return -1, None # 不再重试请求该模型 - elif isinstance(e, RespNotOkException): - return _handle_resp_not_ok( - e, - task_name, - model_name, - remain_try, - retry_interval, - messages, - ) - elif isinstance(e, RespParseException): - # 响应解析错误 - logger.error( - f"任务-'{task_name}' 模型-'{model_name}'\n" - f"响应解析错误,错误信息-{e.message}\n" - ) - logger.debug(f"附加内容:\n{str(e.ext_info)}") - return -1, None # 不再重试请求该模型 - else: - logger.error( - f"任务-'{task_name}' 模型-'{model_name}'\n未知异常,错误信息-{str(e)}" - ) - return -1, None # 不再重试请求该模型 - - -class ModelRequestHandler: - """ - 模型请求处理器 - """ - - def __init__( - self, - task_name: str, - config: ModuleConfig, - api_client_map: dict[str, BaseClient], - ): - self.task_name: str = task_name - """任务名称""" - - self.client_map: dict[str, BaseClient] = {} - """API客户端列表""" - - self.configs: list[tuple[ModelInfo, ModelUsageArgConfigItem]] = [] - """模型参数配置""" - - self.req_conf: RequestConfig = config.req_conf - """请求配置""" - - # 获取模型与使用配置 - for model_usage in config.task_model_arg_map[task_name].usage: - if model_usage.name not in config.models: - logger.error(f"Model '{model_usage.name}' not found in ModelManager") - raise KeyError(f"Model '{model_usage.name}' not found in ModelManager") - model_info = config.models[model_usage.name] - - if model_info.api_provider not in self.client_map: - # 缓存API客户端 - self.client_map[model_info.api_provider] = api_client_map[ - model_info.api_provider - ] - - self.configs.append((model_info, model_usage)) # 添加模型与使用配置 - - async def get_response( - self, - messages: list[Message], - tool_options: list[ToolOption] | None = None, - response_format: RespFormat | None = None, # 暂不启用 - stream_response_handler: Callable[ - [AsyncStream[ChatCompletionChunk], asyncio.Event | None], APIResponse - ] - | None = None, - async_response_parser: Callable[[ChatCompletion], APIResponse] | None = None, - interrupt_flag: asyncio.Event | None = None, - ) -> APIResponse: - """ - 获取对话响应 - :param messages: 消息列表 - :param tool_options: 工具选项列表 - :param response_format: 响应格式 - :param stream_response_handler: 流式响应处理函数(可选) - :param async_response_parser: 响应解析函数(可选) - :param interrupt_flag: 中断信号量(可选,默认为None) - :return: APIResponse - """ - # 遍历可用模型,若获取响应失败,则使用下一个模型继续请求 - for config_item in self.configs: - client = self.client_map[config_item[0].api_provider] - model_info: ModelInfo = config_item[0] - model_usage_config: ModelUsageArgConfigItem = config_item[1] - - remain_try = ( - model_usage_config.max_retry or self.req_conf.max_retry - ) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1 - - compressed_messages = None - retry_interval = self.req_conf.retry_interval - while remain_try > 0: - try: - return await client.get_response( - model_info, - message_list=(compressed_messages or messages), - tool_options=tool_options, - max_tokens=model_usage_config.max_tokens - or self.req_conf.default_max_tokens, - temperature=model_usage_config.temperature - or self.req_conf.default_temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - interrupt_flag=interrupt_flag, - ) - except Exception as e: - logger.debug(e) - remain_try -= 1 # 剩余尝试次数减1 - - # 处理异常 - handle_res = default_exception_handler( - e, - self.task_name, - model_info.name, - remain_try, - retry_interval=self.req_conf.retry_interval, - messages=(messages, compressed_messages is not None), - ) - - if handle_res[0] == -1: - # 等待间隔为-1,表示不再请求该模型 - remain_try = 0 - elif handle_res[0] != 0: - # 等待间隔不为0,表示需要等待 - await asyncio.sleep(handle_res[0]) - retry_interval *= 2 - - if handle_res[1] is not None: - # 压缩消息 - compressed_messages = handle_res[1] - - logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用") - raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败 - - async def get_embedding( - self, - embedding_input: str, - ) -> APIResponse: - """ - 获取嵌入向量 - :param embedding_input: 嵌入输入 - :return: APIResponse - """ - for config in self.configs: - client = self.client_map[config[0].api_provider] - model_info: ModelInfo = config[0] - model_usage_config: ModelUsageArgConfigItem = config[1] - remain_try = ( - model_usage_config.max_retry or self.req_conf.max_retry - ) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1 - - while remain_try: - try: - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - ) - except Exception as e: - logger.debug(e) - remain_try -= 1 # 剩余尝试次数减1 - - # 处理异常 - handle_res = default_exception_handler( - e, - self.task_name, - model_info.name, - remain_try, - retry_interval=self.req_conf.retry_interval, - ) - - if handle_res[0] == -1: - # 等待间隔为-1,表示不再请求该模型 - remain_try = 0 - elif handle_res[0] != 0: - # 等待间隔不为0,表示需要等待 - await asyncio.sleep(handle_res[0]) - - logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用") - raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败 diff --git a/src/llm_models/model_client/__init__bak.py b/src/llm_models/model_client/__init__bak.py new file mode 100644 index 000000000..7e57c82d6 --- /dev/null +++ b/src/llm_models/model_client/__init__bak.py @@ -0,0 +1,380 @@ +import asyncio +from typing import Callable, Any + +from openai import AsyncStream +from openai.types.chat import ChatCompletionChunk, ChatCompletion + +from .base_client import BaseClient, APIResponse +from src.config.api_ada_configs import ( + ModelInfo, + ModelUsageArgConfigItem, + RequestConfig, + ModuleConfig, +) +from ..exceptions import ( + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from ..payload_content.message import Message +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption +from ..utils import compress_messages +from src.common.logger import get_logger + +logger = get_logger("模型客户端") + + +def _check_retry( + remain_try: int, + retry_interval: int, + can_retry_msg: str, + cannot_retry_msg: str, + can_retry_callable: Callable | None = None, + **kwargs, +) -> tuple[int, Any | None]: + """ + 辅助函数:检查是否可以重试 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param can_retry_msg: 可以重试时的提示信息 + :param cannot_retry_msg: 不可以重试时的提示信息 + :return: (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + if remain_try > 0: + # 还有重试机会 + logger.warning(f"{can_retry_msg}") + if can_retry_callable is not None: + return retry_interval, can_retry_callable(**kwargs) + else: + return retry_interval, None + else: + # 达到最大重试次数 + logger.warning(f"{cannot_retry_msg}") + return -1, None # 不再重试请求该模型 + + +def _handle_resp_not_ok( + e: RespNotOkException, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, +): + """ + 处理响应错误异常 + :param e: 异常对象 + :param task_name: 任务名称 + :param model_name: 模型名称 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param messages: (消息列表, 是否已压缩过) + :return: (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + # 响应错误 + if e.status_code in [401, 403]: + # API Key认证错误 - 让多API Key机制处理,给一次重试机会 + if remain_try > 0: + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"API Key认证失败(错误代码-{e.status_code}),多API Key机制会自动切换" + ) + return 0, None # 立即重试,让底层客户端切换API Key + else: + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"所有API Key都认证失败,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None # 不再重试请求该模型 + elif e.status_code in [400, 402, 404]: + # 其他客户端错误(不应该重试) + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"请求失败,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None # 不再重试请求该模型 + elif e.status_code == 413: + if messages and not messages[1]: + # 消息列表不为空且未压缩,尝试压缩消息 + return _check_retry( + remain_try, + 0, + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,尝试压缩消息后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,压缩消息后仍然过大,放弃请求" + ), + can_retry_callable=compress_messages, + messages=messages[0], + ) + # 没有消息可压缩 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,无法压缩消息,放弃请求。" + ) + return -1, None + elif e.status_code == 429: + # 请求过于频繁 - 让多API Key机制处理,适当延迟后重试 + return _check_retry( + remain_try, + min(retry_interval, 5), # 限制最大延迟为5秒,让API Key切换更快生效 + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"请求过于频繁,多API Key机制会自动切换,{min(retry_interval, 5)}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求过于频繁,所有API Key都被限制,放弃请求" + ), + ) + elif e.status_code >= 500: + # 服务器错误 + return _check_retry( + remain_try, + retry_interval, + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"服务器错误,将于{retry_interval}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "服务器错误,超过最大重试次数,请稍后再试" + ), + ) + else: + # 未知错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"未知错误,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None + + +def default_exception_handler( + e: Exception, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, +) -> tuple[int, list[Message] | None]: + """ + 默认异常处理函数 + :param e: 异常对象 + :param task_name: 任务名称 + :param model_name: 模型名称 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param messages: (消息列表, 是否已压缩过) + :return (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + + if isinstance(e, NetworkConnectionError): # 网络连接错误 + # 网络错误可能是某个API Key的端点问题,给多API Key机制一次快速重试机会 + return _check_retry( + remain_try, + min(retry_interval, 3), # 网络错误时减少等待时间,让API Key切换更快 + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"连接异常,多API Key机制会尝试其他Key,{min(retry_interval, 3)}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确" + ), + ) + elif isinstance(e, ReqAbortException): + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n请求被中断,详细信息-{str(e.message)}" + ) + return -1, None # 不再重试请求该模型 + elif isinstance(e, RespNotOkException): + return _handle_resp_not_ok( + e, + task_name, + model_name, + remain_try, + retry_interval, + messages, + ) + elif isinstance(e, RespParseException): + # 响应解析错误 + logger.error( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"响应解析错误,错误信息-{e.message}\n" + ) + logger.debug(f"附加内容:\n{str(e.ext_info)}") + return -1, None # 不再重试请求该模型 + else: + logger.error( + f"任务-'{task_name}' 模型-'{model_name}'\n未知异常,错误信息-{str(e)}" + ) + return -1, None # 不再重试请求该模型 + + +class ModelRequestHandler: + """ + 模型请求处理器 + """ + + def __init__( + self, + task_name: str, + config: ModuleConfig, + api_client_map: dict[str, BaseClient], + ): + self.task_name: str = task_name + """任务名称""" + + self.client_map: dict[str, BaseClient] = {} + """API客户端列表""" + + self.configs: list[tuple[ModelInfo, ModelUsageArgConfigItem]] = [] + """模型参数配置""" + + self.req_conf: RequestConfig = config.req_conf + """请求配置""" + + # 获取模型与使用配置 + for model_usage in config.task_model_arg_map[task_name].usage: + if model_usage.name not in config.models: + logger.error(f"Model '{model_usage.name}' not found in ModelManager") + raise KeyError(f"Model '{model_usage.name}' not found in ModelManager") + model_info = config.models[model_usage.name] + + if model_info.api_provider not in self.client_map: + # 缓存API客户端 + self.client_map[model_info.api_provider] = api_client_map[ + model_info.api_provider + ] + + self.configs.append((model_info, model_usage)) # 添加模型与使用配置 + + async def get_response( + self, + messages: list[Message], + tool_options: list[ToolOption] | None = None, + response_format: RespFormat | None = None, # 暂不启用 + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], APIResponse + ] + | None = None, + async_response_parser: Callable[[ChatCompletion], APIResponse] | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param messages: 消息列表 + :param tool_options: 工具选项列表 + :param response_format: 响应格式 + :param stream_response_handler: 流式响应处理函数(可选) + :param async_response_parser: 响应解析函数(可选) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: APIResponse + """ + # 遍历可用模型,若获取响应失败,则使用下一个模型继续请求 + for config_item in self.configs: + client = self.client_map[config_item[0].api_provider] + model_info: ModelInfo = config_item[0] + model_usage_config: ModelUsageArgConfigItem = config_item[1] + + remain_try = ( + model_usage_config.max_retry or self.req_conf.max_retry + ) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1 + + compressed_messages = None + retry_interval = self.req_conf.retry_interval + while remain_try > 0: + try: + return await client.get_response( + model_info, + message_list=(compressed_messages or messages), + tool_options=tool_options, + max_tokens=model_usage_config.max_tokens + or self.req_conf.default_max_tokens, + temperature=model_usage_config.temperature + or self.req_conf.default_temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, + ) + except Exception as e: + logger.debug(e) + remain_try -= 1 # 剩余尝试次数减1 + + # 处理异常 + handle_res = default_exception_handler( + e, + self.task_name, + model_info.name, + remain_try, + retry_interval=self.req_conf.retry_interval, + messages=(messages, compressed_messages is not None), + ) + + if handle_res[0] == -1: + # 等待间隔为-1,表示不再请求该模型 + remain_try = 0 + elif handle_res[0] != 0: + # 等待间隔不为0,表示需要等待 + await asyncio.sleep(handle_res[0]) + retry_interval *= 2 + + if handle_res[1] is not None: + # 压缩消息 + compressed_messages = handle_res[1] + + logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用") + raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败 + + async def get_embedding( + self, + embedding_input: str, + ) -> APIResponse: + """ + 获取嵌入向量 + :param embedding_input: 嵌入输入 + :return: APIResponse + """ + for config in self.configs: + client = self.client_map[config[0].api_provider] + model_info: ModelInfo = config[0] + model_usage_config: ModelUsageArgConfigItem = config[1] + remain_try = ( + model_usage_config.max_retry or self.req_conf.max_retry + ) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1 + + while remain_try: + try: + return await client.get_embedding( + model_info=model_info, + embedding_input=embedding_input, + ) + except Exception as e: + logger.debug(e) + remain_try -= 1 # 剩余尝试次数减1 + + # 处理异常 + handle_res = default_exception_handler( + e, + self.task_name, + model_info.name, + remain_try, + retry_interval=self.req_conf.retry_interval, + ) + + if handle_res[0] == -1: + # 等待间隔为-1,表示不再请求该模型 + remain_try = 0 + elif handle_res[0] != 0: + # 等待间隔不为0,表示需要等待 + await asyncio.sleep(handle_res[0]) + + logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用") + raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败 diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 50a379d34..5089666f1 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -81,10 +81,7 @@ class BaseClient: tuple[APIResponse, tuple[int, int, int]], ] | None = None, - async_response_parser: Callable[ - [ChatCompletion], tuple[APIResponse, tuple[int, int, int]] - ] - | None = None, + async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None, interrupt_flag: asyncio.Event | None = None, ) -> APIResponse: """ @@ -114,3 +111,37 @@ class BaseClient: :return: 嵌入响应 """ raise RuntimeError("This method should be overridden in subclasses") + + +class ClientRegistry: + def __init__(self) -> None: + self.client_registry: dict[str, type[BaseClient]] = {} + + def register_client_class(self, client_type: str): + """ + 注册API客户端类 + :param client_class: API客户端类 + """ + + def decorator(cls: type[BaseClient]) -> type[BaseClient]: + if not issubclass(cls, BaseClient): + raise TypeError(f"{cls.__name__} is not a subclass of BaseClient") + self.client_registry[client_type] = cls + return cls + + return decorator + + def get_client_class(self, client_type: str) -> type[BaseClient]: + """ + 获取注册的API客户端类 + Args: + client_type: 客户端类型 + Returns: + type[BaseClient]: 注册的API客户端类 + """ + if client_type not in self.client_registry: + raise KeyError(f"'{client_type}' 类型的 Client 未注册") + return self.client_registry[client_type] + + +client_registry = ClientRegistry() diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index a70458ffe..109fe7593 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -22,7 +22,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDelta from .base_client import APIResponse, UsageRecord from src.config.api_ada_configs import ModelInfo, APIProvider -from . import BaseClient +from .base_client import BaseClient, client_registry from src.common.logger import get_logger from ..exceptions import ( @@ -63,9 +63,7 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara content.append( { "type": "image_url", - "image_url": { - "url": f"data:image/{item[0].lower()};base64,{item[1]}" - }, + "image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"}, } ) elif isinstance(item, str): @@ -120,13 +118,8 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any] if tool_option.params: ret["parameters"] = { "type": "object", - "properties": { - param.name: _convert_tool_param(param) - for param in tool_option.params - }, - "required": [ - param.name for param in tool_option.params if param.required - ], + "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, + "required": [param.name for param in tool_option.params if param.required], } return ret @@ -190,9 +183,7 @@ def _process_delta( if tool_call_delta.function.arguments: # 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中 - tool_calls_buffer[tool_call_delta.index][2].write( - tool_call_delta.function.arguments - ) + tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments) return in_rc_flag @@ -225,14 +216,12 @@ def _build_stream_api_resp( if not isinstance(arguments, dict): raise RespParseException( None, - "响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n" - f"{raw_arg_data}", + f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}", ) except json.JSONDecodeError as e: raise RespParseException( None, - "响应解析失败,无法解析工具调用参数。工具调用参数原始响应:" - f"{raw_arg_data}", + f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}", ) from e else: arguments_buffer.close() @@ -257,9 +246,7 @@ async def _default_stream_response_handler( _in_rc_flag = False # 标记是否在推理内容块中 _rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容 _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 - _tool_calls_buffer: list[ - tuple[str, str, io.StringIO] - ] = [] # 工具调用缓冲区,用于存储接收到的工具调用 + _tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 _usage_record = None # 使用情况记录 def _insure_buffer_closed(): @@ -280,7 +267,7 @@ async def _default_stream_response_handler( delta = event.choices[0].delta # 获取当前块的delta内容 - if hasattr(delta, "reasoning_content") and delta.reasoning_content: + if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore # 标记:有独立的推理内容块 _has_rc_attr_flag = True @@ -334,10 +321,10 @@ def _default_normal_response_parser( raise RespParseException(resp, "响应解析失败,缺失choices字段") message_part = resp.choices[0].message - if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: + if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore # 有有效的推理字段 api_response.content = message_part.content - api_response.reasoning_content = message_part.reasoning_content + api_response.reasoning_content = message_part.reasoning_content # type: ignore elif message_part.content: # 提取推理和内容 match = pattern.match(message_part.content) @@ -358,16 +345,10 @@ def _default_normal_response_parser( try: arguments = json.loads(call.function.arguments) if not isinstance(arguments, dict): - raise RespParseException( - resp, "响应解析失败,工具调用参数无法解析为字典类型" - ) - api_response.tool_calls.append( - ToolCall(call.id, call.function.name, arguments) - ) + raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型") + api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments)) except json.JSONDecodeError as e: - raise RespParseException( - resp, "响应解析失败,无法解析工具调用参数" - ) from e + raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e # 提取Usage信息 if resp.usage: @@ -385,63 +366,15 @@ def _default_normal_response_parser( return api_response, _usage_record +@client_registry.register_client_class("openai") class OpenaiClient(BaseClient): def __init__(self, api_provider: APIProvider): super().__init__(api_provider) - # 不再在初始化时创建固定的client,而是在请求时动态创建 - self._clients_cache = {} # API Key -> AsyncOpenAI client 的缓存 - - def _get_client(self, api_key: str = None) -> AsyncOpenAI: - """获取或创建对应API Key的客户端""" - if api_key is None: - api_key = self.api_provider.get_current_api_key() - - if not api_key: - raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key") - - # 使用缓存避免重复创建客户端 - if api_key not in self._clients_cache: - self._clients_cache[api_key] = AsyncOpenAI( - base_url=self.api_provider.base_url, - api_key=api_key, - max_retries=0, - ) - - return self._clients_cache[api_key] - - async def _execute_with_fallback(self, func, *args, **kwargs): - """执行请求并在失败时切换API Key""" - current_api_key = self.api_provider.get_current_api_key() - max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1 - - for attempt in range(max_attempts): - try: - client = self._get_client(current_api_key) - result = await func(client, *args, **kwargs) - # 成功时重置失败计数 - self.api_provider.reset_key_failures(current_api_key) - return result - - except (APIStatusError, APIConnectionError) as e: - # 记录失败并尝试下一个API Key - logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}") - - if attempt < max_attempts - 1: # 还有重试机会 - next_api_key = self.api_provider.mark_key_failed(current_api_key) - if next_api_key and next_api_key != current_api_key: - current_api_key = next_api_key - logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}") - continue - - # 所有API Key都失败了,重新抛出异常 - if isinstance(e, APIStatusError): - raise RespNotOkException(e.status_code, e.message) from e - elif isinstance(e, APIConnectionError): - raise NetworkConnectionError(str(e)) from e - - except Exception as e: - # 其他异常直接抛出 - raise e + self.client: AsyncOpenAI = AsyncOpenAI( + base_url=api_provider.base_url, + api_key=api_provider.api_key, + max_retries=0, + ) async def get_response( self, @@ -456,10 +389,7 @@ class OpenaiClient(BaseClient): tuple[APIResponse, tuple[int, int, int]], ] | None = None, - async_response_parser: Callable[ - [ChatCompletion], tuple[APIResponse, tuple[int, int, int]] - ] - | None = None, + async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None, interrupt_flag: asyncio.Event | None = None, ) -> APIResponse: """ @@ -475,40 +405,6 @@ class OpenaiClient(BaseClient): :param interrupt_flag: 中断信号量(可选,默认为None) :return: (响应文本, 推理文本, 工具调用, 其他数据) """ - return await self._execute_with_fallback( - self._get_response_internal, - model_info, - message_list, - tool_options, - max_tokens, - temperature, - response_format, - stream_response_handler, - async_response_parser, - interrupt_flag, - ) - - async def _get_response_internal( - self, - client: AsyncOpenAI, - model_info: ModelInfo, - message_list: list[Message], - tool_options: list[ToolOption] | None = None, - max_tokens: int = 1024, - temperature: float = 0.7, - response_format: RespFormat | None = None, - stream_response_handler: Callable[ - [AsyncStream[ChatCompletionChunk], asyncio.Event | None], - tuple[APIResponse, tuple[int, int, int]], - ] - | None = None, - async_response_parser: Callable[ - [ChatCompletion], tuple[APIResponse, tuple[int, int, int]] - ] - | None = None, - interrupt_flag: asyncio.Event | None = None, - ) -> APIResponse: - """内部方法:执行实际的API调用""" if stream_response_handler is None: stream_response_handler = _default_stream_response_handler @@ -518,23 +414,19 @@ class OpenaiClient(BaseClient): # 将messages构造为OpenAI API所需的格式 messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) # 将tool_options转换为OpenAI API所需的格式 - tools: Iterable[ChatCompletionToolParam] = ( - _convert_tool_options(tool_options) if tool_options else NOT_GIVEN - ) + tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN try: if model_info.force_stream_mode: req_task = asyncio.create_task( - client.chat.completions.create( + self.client.chat.completions.create( model=model_info.model_identifier, messages=messages, tools=tools, temperature=temperature, max_tokens=max_tokens, stream=True, - response_format=response_format.to_dict() - if response_format - else NOT_GIVEN, + response_format=response_format.to_dict() if response_format else NOT_GIVEN, ) ) while not req_task.done(): @@ -544,22 +436,18 @@ class OpenaiClient(BaseClient): raise ReqAbortException("请求被外部信号中断") await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 - resp, usage_record = await stream_response_handler( - req_task.result(), interrupt_flag - ) + resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) else: # 发送请求并获取响应 req_task = asyncio.create_task( - client.chat.completions.create( + self.client.chat.completions.create( model=model_info.model_identifier, messages=messages, tools=tools, temperature=temperature, max_tokens=max_tokens, stream=False, - response_format=response_format.to_dict() - if response_format - else NOT_GIVEN, + response_format=response_format.to_dict() if response_format else NOT_GIVEN, ) ) while not req_task.done(): @@ -599,21 +487,8 @@ class OpenaiClient(BaseClient): :param embedding_input: 嵌入输入文本 :return: 嵌入响应 """ - return await self._execute_with_fallback( - self._get_embedding_internal, - model_info, - embedding_input, - ) - - async def _get_embedding_internal( - self, - client: AsyncOpenAI, - model_info: ModelInfo, - embedding_input: str, - ) -> APIResponse: - """内部方法:执行实际的嵌入API调用""" try: - raw_response = await client.embeddings.create( + raw_response = await self.client.embeddings.create( model=model_info.model_identifier, input=embedding_input, ) diff --git a/src/llm_models/model_manager.py b/src/llm_models/model_manager.py index 36d63c72e..2db3a6d25 100644 --- a/src/llm_models/model_manager.py +++ b/src/llm_models/model_manager.py @@ -2,7 +2,6 @@ import importlib from typing import Dict from src.config.config import model_config -from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig from src.common.logger import get_logger from .model_client import ModelRequestHandler, BaseClient @@ -10,83 +9,4 @@ from .model_client import ModelRequestHandler, BaseClient logger = get_logger("模型管理器") class ModelManager: - # TODO: 添加读写锁,防止异步刷新配置时发生数据竞争 - - def __init__( - self, - config: ModuleConfig, - ): - self.config: ModuleConfig = config - """配置信息""" - - self.api_client_map: Dict[str, BaseClient] = {} - """API客户端映射表""" - - self._request_handler_cache: Dict[str, ModelRequestHandler] = {} - """ModelRequestHandler缓存,避免重复创建""" - - for provider_name, api_provider in self.config.api_providers.items(): - # 初始化API客户端 - try: - # 根据配置动态加载实现 - client_module = importlib.import_module( - f".model_client.{api_provider.client_type}_client", __package__ - ) - client_class = getattr( - client_module, f"{api_provider.client_type.capitalize()}Client" - ) - if not issubclass(client_class, BaseClient): - raise TypeError( - f"'{client_class.__name__}' is not a subclass of 'BaseClient'" - ) - self.api_client_map[api_provider.name] = client_class( - api_provider - ) # 实例化,放入api_client_map - except ImportError as e: - logger.error(f"Failed to import client module: {e}") - raise ImportError( - f"Failed to import client module for '{provider_name}': {e}" - ) from e - - def __getitem__(self, task_name: str) -> ModelRequestHandler: - """ - 获取任务所需的模型客户端(封装) - 使用缓存机制避免重复创建ModelRequestHandler - :param task_name: 任务名称 - :return: 模型客户端 - """ - if task_name not in self.config.task_model_arg_map: - raise KeyError(f"'{task_name}' not registered in ModelManager") - - # 检查缓存中是否已存在 - if task_name in self._request_handler_cache: - logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}") - return self._request_handler_cache[task_name] - - # 创建新的ModelRequestHandler并缓存 - logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}") - handler = ModelRequestHandler( - task_name=task_name, - config=self.config, - api_client_map=self.api_client_map, - ) - self._request_handler_cache[task_name] = handler - return handler - - def __setitem__(self, task_name: str, value: ModelUsageArgConfig): - """ - 注册任务的模型使用配置 - :param task_name: 任务名称 - :param value: 模型使用配置 - """ - self.config.task_model_arg_map[task_name] = value - - def __contains__(self, task_name: str): - """ - 判断任务是否已注册 - :param task_name: 任务名称 - :return: 是否在模型列表中 - """ - return task_name in self.config.task_model_arg_map - - + \ No newline at end of file diff --git a/src/llm_models/model_manager_bak.py b/src/llm_models/model_manager_bak.py new file mode 100644 index 000000000..36d63c72e --- /dev/null +++ b/src/llm_models/model_manager_bak.py @@ -0,0 +1,92 @@ +import importlib +from typing import Dict + +from src.config.config import model_config +from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig +from src.common.logger import get_logger + +from .model_client import ModelRequestHandler, BaseClient + +logger = get_logger("模型管理器") + +class ModelManager: + # TODO: 添加读写锁,防止异步刷新配置时发生数据竞争 + + def __init__( + self, + config: ModuleConfig, + ): + self.config: ModuleConfig = config + """配置信息""" + + self.api_client_map: Dict[str, BaseClient] = {} + """API客户端映射表""" + + self._request_handler_cache: Dict[str, ModelRequestHandler] = {} + """ModelRequestHandler缓存,避免重复创建""" + + for provider_name, api_provider in self.config.api_providers.items(): + # 初始化API客户端 + try: + # 根据配置动态加载实现 + client_module = importlib.import_module( + f".model_client.{api_provider.client_type}_client", __package__ + ) + client_class = getattr( + client_module, f"{api_provider.client_type.capitalize()}Client" + ) + if not issubclass(client_class, BaseClient): + raise TypeError( + f"'{client_class.__name__}' is not a subclass of 'BaseClient'" + ) + self.api_client_map[api_provider.name] = client_class( + api_provider + ) # 实例化,放入api_client_map + except ImportError as e: + logger.error(f"Failed to import client module: {e}") + raise ImportError( + f"Failed to import client module for '{provider_name}': {e}" + ) from e + + def __getitem__(self, task_name: str) -> ModelRequestHandler: + """ + 获取任务所需的模型客户端(封装) + 使用缓存机制避免重复创建ModelRequestHandler + :param task_name: 任务名称 + :return: 模型客户端 + """ + if task_name not in self.config.task_model_arg_map: + raise KeyError(f"'{task_name}' not registered in ModelManager") + + # 检查缓存中是否已存在 + if task_name in self._request_handler_cache: + logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}") + return self._request_handler_cache[task_name] + + # 创建新的ModelRequestHandler并缓存 + logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}") + handler = ModelRequestHandler( + task_name=task_name, + config=self.config, + api_client_map=self.api_client_map, + ) + self._request_handler_cache[task_name] = handler + return handler + + def __setitem__(self, task_name: str, value: ModelUsageArgConfig): + """ + 注册任务的模型使用配置 + :param task_name: 任务名称 + :param value: 模型使用配置 + """ + self.config.task_model_arg_map[task_name] = value + + def __contains__(self, task_name: str): + """ + 判断任务是否已注册 + :param task_name: 任务名称 + :return: 是否在模型列表中 + """ + return task_name in self.config.task_model_arg_map + + diff --git a/src/llm_models/payload_content/__init__.py b/src/llm_models/payload_content/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 805a47343..4602fb751 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,89 +1,39 @@ import re +import copy +import asyncio from datetime import datetime -from typing import Tuple, Union +from typing import Tuple, Union, List, Dict, Optional, Callable, Any from src.common.logger import get_logger import base64 from PIL import Image +from enum import Enum import io from src.common.database.database import db # 确保 db 被导入用于 create_tables from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 -from src.config.config import global_config +from src.config.config import global_config, model_config +from src.config.api_ada_configs import APIProvider, ModelInfo from rich.traceback import install +from .payload_content.message import MessageBuilder, Message +from .payload_content.resp_format import RespFormat +from .payload_content.tool_option import ToolOption, ToolCall +from .model_client.base_client import BaseClient, APIResponse, UsageRecord, client_registry +from .utils import compress_messages + +from .exceptions import ( + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, + PayLoadTooLargeError, + RequestAbortException, + PermissionDeniedException, +) + install(extra_lines=3) logger = get_logger("model_utils") -# 导入具体的异常类型用于精确的异常处理 -try: - from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException - SPECIFIC_EXCEPTIONS_AVAILABLE = True -except ImportError: - logger.warning("无法导入具体异常类型,将使用通用异常处理") - NetworkConnectionError = Exception - ReqAbortException = Exception - RespNotOkException = Exception - RespParseException = Exception - SPECIFIC_EXCEPTIONS_AVAILABLE = False - -# 新架构导入 - 使用延迟导入以支持fallback模式 -try: - from .model_manager import ModelManager - from .model_client import ModelRequestHandler - from .payload_content.message import MessageBuilder - - # 不在模块级别初始化ModelManager,延迟到实际使用时 - ModelManager_class = ModelManager - model_manager = None # 延迟初始化 - - # 添加请求处理器缓存,避免重复创建 - _request_handler_cache = {} # 格式: {(model_name, task_name): ModelRequestHandler} - - NEW_ARCHITECTURE_AVAILABLE = True - logger.info("新架构模块导入成功") -except Exception as e: - logger.warning(f"新架构不可用,将使用fallback模式: {str(e)}") - ModelManager_class = None - model_manager = None - ModelRequestHandler = None - MessageBuilder = None - _request_handler_cache = {} - NEW_ARCHITECTURE_AVAILABLE = False - - -class PayLoadTooLargeError(Exception): - """自定义异常类,用于处理请求体过大错误""" - - def __init__(self, message: str): - super().__init__(message) - self.message = message - - def __str__(self): - return "请求体过大,请尝试压缩图片或减少输入内容。" - - -class RequestAbortException(Exception): - """自定义异常类,用于处理请求中断异常""" - - def __init__(self, message: str): - super().__init__(message) - self.message = message - - def __str__(self): - return self.message - - -class PermissionDeniedException(Exception): - """自定义异常类,用于处理访问拒绝的异常""" - - def __init__(self, message: str): - super().__init__(message) - self.message = message - - def __str__(self): - return self.message - - # 常见Error Code Mapping error_code_mapping = { 400: "参数不正确", @@ -97,14 +47,16 @@ error_code_mapping = { } +class RequestType(Enum): + """请求类型枚举""" + + RESPONSE = "response" + EMBEDDING = "embedding" class LLMRequest: - """ - 重构后的LLM请求类,基于新的model_manager和model_client架构 - 保持向后兼容的API接口 - """ - + """LLM请求类""" + # 定义需要转换的模型列表,作为类变量避免重复 MODELS_NEEDING_TRANSFORMATION = [ "o1", @@ -123,252 +75,17 @@ class LLMRequest: "o4-mini-2025-04-16", ] - def __init__(self, model: dict, **kwargs): - """ - 初始化LLM请求实例 - Args: - model: 模型配置字典,兼容旧格式和新格式 - **kwargs: 额外参数 - """ - logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('model_name', model.get('name', 'Unknown'))}") - logger.debug(f"🔍 [模型初始化] 输入的模型配置: {model}") - logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}") - - # 兼容新旧模型配置格式 - # 新格式使用 model_name,旧格式使用 name - self.model_name: str = model.get("model_name", model.get("name", "")) - - # 如果传入的配置不完整,自动从全局配置中获取完整配置 - if not all(key in model for key in ["task_type", "capabilities"]): - logger.debug("🔍 [模型初始化] 检测到不完整的模型配置,尝试获取完整配置") - if (full_model_config := self._get_full_model_config(self.model_name)): - logger.debug("🔍 [模型初始化] 成功获取完整模型配置,合并配置信息") - # 合并配置:运行时参数优先,但添加缺失的配置字段 - model = {**full_model_config, **model} - logger.debug(f"🔍 [模型初始化] 合并后的模型配置: {model}") - else: - logger.warning(f"⚠️ [模型初始化] 无法获取模型 {self.model_name} 的完整配置,使用原始配置") - - # 在新架构中,provider信息从model_config.toml自动获取,不需要在这里设置 - self.provider = model.get("provider", "") # 保留兼容性,但在新架构中不使用 - - # 从全局配置中获取任务配置 - self.request_type = kwargs.pop("request_type", "default") - - # 确定使用哪个任务配置 - task_name = self._determine_task_name(model) - - # 初始化 request_handler - self.request_handler = None - - # 尝试初始化新架构 - if NEW_ARCHITECTURE_AVAILABLE and ModelManager_class is not None: - try: - # 延迟初始化ModelManager - global model_manager, _request_handler_cache - if model_manager is None: - from src.config.config import model_config - model_manager = ModelManager_class(model_config) - logger.debug("🔍 [模型初始化] ModelManager延迟初始化成功") - - # 构建缓存键 - cache_key = (self.model_name, task_name) - - # 检查是否已有缓存的请求处理器 - if cache_key in _request_handler_cache: - self.request_handler = _request_handler_cache[cache_key] - logger.debug(f"🚀 [性能优化] 从LLMRequest缓存获取请求处理器: {cache_key}") - else: - # 使用新架构获取模型请求处理器 - self.request_handler = model_manager[task_name] - _request_handler_cache[cache_key] = self.request_handler - logger.debug(f"🔧 [性能优化] 创建并缓存LLMRequest请求处理器: {cache_key}") - - logger.debug(f"🔍 [模型初始化] 成功获取模型请求处理器,任务: {task_name}") - self.use_new_architecture = True - except Exception as e: - logger.warning(f"无法使用新架构,任务 {task_name} 初始化失败: {e}") - logger.warning("回退到兼容模式,某些功能可能受限") - self.request_handler = None - self.use_new_architecture = False - else: - logger.warning("新架构不可用,使用兼容模式") - logger.warning("回退到兼容模式,某些功能可能受限") - self.request_handler = None - self.use_new_architecture = False - - # 保存原始参数用于向后兼容 - self.params = kwargs - - # 兼容性属性,从模型配置中提取 - # 新格式和旧格式都支持 - self.enable_thinking = model.get("enable_thinking", False) - self.temp = model.get("temperature", model.get("temp", 0.7)) # 新格式用temperature,旧格式用temp - self.thinking_budget = model.get("thinking_budget", 4096) - self.stream = model.get("stream", False) - self.pri_in = model.get("pri_in", 0) - self.pri_out = model.get("pri_out", 0) - self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) - - # 记录配置文件中声明了哪些参数(不管值是什么) - self.has_enable_thinking = "enable_thinking" in model - self.has_thinking_budget = "thinking_budget" in model - self.pri_out = model.get("pri_out", 0) - self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) - - # 记录配置文件中声明了哪些参数(不管值是什么) - self.has_enable_thinking = "enable_thinking" in model - self.has_thinking_budget = "thinking_budget" in model - - logger.debug("🔍 [模型初始化] 模型参数设置完成:") - logger.debug(f" - model_name: {self.model_name}") - logger.debug(f" - provider: {self.provider}") - logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}") - logger.debug(f" - enable_thinking: {self.enable_thinking}") - logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}") - logger.debug(f" - thinking_budget: {self.thinking_budget}") - logger.debug(f" - temp: {self.temp}") - logger.debug(f" - stream: {self.stream}") - logger.debug(f" - max_tokens: {self.max_tokens}") - logger.debug(f" - use_new_architecture: {self.use_new_architecture}") + def __init__(self, task_name: str, request_type: str = "") -> None: + self.task_name = task_name + self.model_for_task = model_config.model_task_config.get_task(task_name) + self.request_type = request_type + self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list} + """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整""" - # 获取数据库实例 + self.pri_in = 0 + self.pri_out = 0 + self._init_database() - - logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}") - - def _determine_task_name(self, model: dict) -> str: - """ - 根据模型配置确定任务名称 - 优先使用配置文件中明确定义的任务类型,避免基于模型名称的脆弱推断 - - Args: - model: 模型配置字典 - Returns: - 任务名称 - """ - # 调试信息:打印模型配置字典的所有键 - logger.debug(f"🔍 [任务确定] 模型配置字典的所有键: {list(model.keys())}") - logger.debug(f"🔍 [任务确定] 模型配置字典内容: {model}") - - # 获取模型名称 - model_name = model.get("model_name", model.get("name", "")) - - # 方法1: 优先使用配置文件中明确定义的 task_type 字段 - if "task_type" in model: - task_type = model["task_type"] - logger.debug(f"🎯 [任务确定] 使用配置中的 task_type: {task_type}") - return task_type - - # 方法2: 使用 capabilities 字段来推断主要任务类型 - if "capabilities" in model: - capabilities = model["capabilities"] - if isinstance(capabilities, list): - # 按优先级顺序检查能力 - if "vision" in capabilities: - logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: vision") - return "vision" - elif "embedding" in capabilities: - logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: embedding") - return "embedding" - elif "speech" in capabilities: - logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: speech") - return "speech" - elif "text" in capabilities: - # 如果只有文本能力,则根据request_type细分 - task = "llm_reasoning" if self.request_type == "reasoning" else "llm_normal" - logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 和 request_type {self.request_type} 推断为: {task}") - return task - - # 方法3: 向后兼容 - 基于模型名称的关键字推断(不推荐但保留兼容性) - logger.warning(f"⚠️ [任务确定] 配置中未找到 task_type 或 capabilities,回退到基于模型名称的推断: {model_name}") - logger.warning("⚠️ [建议] 请在 model_config.toml 中为模型添加明确的 task_type 或 capabilities 字段") - - # 保留原有的关键字匹配逻辑作为fallback - if any(keyword in model_name.lower() for keyword in ["vlm", "vision", "gpt-4o", "claude", "vl-"]): - logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: vision") - return "vision" - elif any(keyword in model_name.lower() for keyword in ["embed", "text-embedding", "bge-"]): - logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: embedding") - return "embedding" - elif any(keyword in model_name.lower() for keyword in ["whisper", "speech", "voice"]): - logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: speech") - return "speech" - else: - # 根据request_type确定,映射到配置文件中定义的任务 - task = "llm_reasoning" if self.request_type == "reasoning" else "llm_normal" - logger.debug(f"🎯 [任务确定] 从 request_type {self.request_type} 推断为: {task}") - return task - - def _get_full_model_config(self, model_name: str) -> dict | None: - """ - 根据模型名称从全局配置中获取完整的模型配置 - 现在直接使用已解析的ModelInfo对象,不再读取TOML文件 - - Args: - model_name: 模型名称 - Returns: - 完整的模型配置字典,如果找不到则返回None - """ - try: - from src.config.config import model_config - return self._get_model_config_from_parsed(model_name, model_config) - - except Exception as e: - logger.warning(f"⚠️ [配置查找] 获取模型配置时出错: {str(e)}") - return None - - def _get_model_config_from_parsed(self, model_name: str, model_config) -> dict | None: - """ - 从已解析的配置对象中获取模型配置 - 使用扩展后的ModelInfo类,包含task_type和capabilities字段 - """ - try: - # 直接通过模型名称查找 - if model_name in model_config.models: - model_info = model_config.models[model_name] - logger.debug(f"🔍 [配置查找] 找到模型 {model_name} 的配置对象: {model_info}") - - # 将ModelInfo对象转换为字典 - model_dict = { - "model_identifier": model_info.model_identifier, - "name": model_info.name, - "api_provider": model_info.api_provider, - "price_in": model_info.price_in, - "price_out": model_info.price_out, - "force_stream_mode": model_info.force_stream_mode, - "task_type": model_info.task_type, - "capabilities": model_info.capabilities, - } - - logger.debug(f"🔍 [配置查找] 转换后的模型配置字典: {model_dict}") - return model_dict - - # 如果直接查找失败,尝试通过model_identifier查找 - for name, model_info in model_config.models.items(): - if (model_info.model_identifier == model_name or - hasattr(model_info, 'model_name') and model_info.model_name == model_name): - - logger.debug(f"🔍 [配置查找] 通过标识符找到模型 {model_name} (配置名称: {name})") - # 同样转换为字典 - model_dict = { - "model_identifier": model_info.model_identifier, - "name": model_info.name, - "api_provider": model_info.api_provider, - "price_in": model_info.price_in, - "price_out": model_info.price_out, - "force_stream_mode": model_info.force_stream_mode, - "task_type": model_info.task_type, - "capabilities": model_info.capabilities, - } - - return model_dict - - return None - - except Exception as e: - logger.warning(f"⚠️ [配置查找] 从已解析配置获取模型配置时出错: {str(e)}") - return None @staticmethod def _init_database(): @@ -380,8 +97,394 @@ class LLMRequest: except Exception as e: logger.error(f"创建 LLMUsage 表失败: {str(e)}") + async def generate_response_for_image( + self, + prompt: str, + image_base64: str, + image_format: str, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]: + """ + 为图像生成响应 + Args: + prompt (str): 提示词 + image_base64 (str): 图像的Base64编码字符串 + image_format (str): 图像格式(如 'png', 'jpeg' 等) + Returns: + + """ + # 请求体构建 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + message_builder.add_image_content(image_base64=image_base64, image_format=image_format) + messages = [message_builder.build()] + + # 模型选择 + model_info, api_provider, client = self._select_model() + + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + if usage := response.usage: + self.pri_in = model_info.price_in + self.pri_out = model_info.price_out + self._record_usage( + model_name=model_info.name, + prompt_tokens=usage.prompt_tokens or 0, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions", + ) + return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None + + async def generate_response_for_voice(self): + pass + + async def generate_response_async( + self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None + ) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]: + """ + 异步生成响应 + Args: + prompt (str): 提示词 + temperature (float, optional): 温度参数 + max_tokens (int, optional): 最大token数 + Returns: + Tuple[str, str, Optional[List[Dict[str, Any]]]]: 响应内容、推理内容和工具调用列表 + """ + # 请求体构建 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + messages = [message_builder.build()] + + # 模型选择 + model_info, api_provider, client = self._select_model() + + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + content = response.content + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + if usage := response.usage: + self.pri_in = model_info.price_in + self.pri_out = model_info.price_out + self._record_usage( + model_name=model_info.name, + prompt_tokens=usage.prompt_tokens or 0, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions", + ) + if not content: + raise RuntimeError("获取LLM生成内容失败") + + return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None + + async def get_embedding(self, embedding_input: str) -> List[float]: + """获取嵌入向量""" + # 无需构建消息体,直接使用输入文本 + model_info, api_provider, client = self._select_model() + + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.EMBEDDING, + model_info=model_info, + embedding_input=embedding_input, + ) + + embedding = response.embedding + + if response.usage: + self.pri_in = model_info.price_in + self.pri_out = model_info.price_out + self._record_usage( + model_name=model_info.name, + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings", + ) + + if not embedding: + raise RuntimeError("获取embedding失败") + + return embedding + + def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: + """ + 根据总tokens和惩罚值选择的模型 + """ + least_used_model_name = min( + self.model_usage, key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + ) + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider)) + return model_info, api_provider, client + + def _convert_tool_calls(self, tool_calls: List[ToolCall]) -> List[Dict[str, Any]]: + """将ToolCall对象转换为Dict列表""" + pass + + async def _execute_request( + self, + api_provider: APIProvider, + client: BaseClient, + request_type: RequestType, + model_info: ModelInfo, + message_list: List[Message] | None = None, + tool_options: list[ToolOption] | None = None, + response_format: RespFormat | None = None, + stream_response_handler: Optional[Callable] = None, + async_response_parser: Optional[Callable] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + embedding_input: str = "", + ) -> APIResponse: + """ + 实际执行请求的方法 + + 包含了重试和异常处理逻辑 + """ + retry_remain = api_provider.max_retry + compressed_messages: Optional[List[Message]] = None + while retry_remain > 0: + try: + if request_type == RequestType.RESPONSE: + assert message_list is not None, "message_list cannot be None for response requests" + return await client.get_response( + model_info=model_info, + message_list=(compressed_messages or message_list), + tool_options=tool_options, + max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, + temperature=self.model_for_task.temperature if temperature is None else temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + ) + elif request_type == RequestType.EMBEDDING: + assert embedding_input, "embedding_input cannot be empty for embedding requests" + return await client.get_embedding(model_info=model_info, embedding_input=embedding_input) + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + # 处理异常 + total_tokens, penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1) + wait_interval, compressed_messages = self._default_exception_handler( + e, + self.task_name, + model_name=model_info.name, + remain_try=retry_remain, + messages=(message_list, compressed_messages is not None), + ) + + if wait_interval == -1: + retry_remain = 0 # 不再重试 + elif wait_interval > 0: + logger.info(f"等待 {wait_interval} 秒后重试...") + await asyncio.sleep(wait_interval) + finally: + # 放在finally防止死循环 + retry_remain -= 1 + logger.error( + f"任务 '{self.task_name}' 模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次" + ) + raise RuntimeError("请求失败,已达到最大重试次数") + + def _default_exception_handler( + self, + e: Exception, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: Tuple[List[Message], bool] | None = None, + ) -> Tuple[int, List[Message] | None]: + """ + 默认异常处理函数 + Args: + e (Exception): 异常对象 + task_name (str): 任务名称 + model_name (str): 模型名称 + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) + Returns: + (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + + if isinstance(e, NetworkConnectionError): # 网络连接错误 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确", + ) + elif isinstance(e, ReqAbortException): + logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") + return -1, None # 不再重试请求该模型 + elif isinstance(e, RespNotOkException): + return self._handle_resp_not_ok( + e, + task_name, + model_name, + remain_try, + retry_interval, + messages, + ) + elif isinstance(e, RespParseException): + # 响应解析错误 + logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") + logger.debug(f"附加内容: {str(e.ext_info)}") + return -1, None # 不再重试请求该模型 + else: + logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") + return -1, None # 不再重试请求该模型 + + def _check_retry( + self, + remain_try: int, + retry_interval: int, + can_retry_msg: str, + cannot_retry_msg: str, + can_retry_callable: Callable | None = None, + **kwargs, + ) -> Tuple[int, List[Message] | None]: + """辅助函数:检查是否可以重试 + Args: + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + can_retry_msg (str): 可以重试时的提示信息 + cannot_retry_msg (str): 不可以重试时的提示信息 + can_retry_callable (Callable | None): 可以重试时调用的函数(如果有) + **kwargs: 其他参数 + + Returns: + (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + if remain_try > 0: + # 还有重试机会 + logger.warning(f"{can_retry_msg}") + if can_retry_callable is not None: + return retry_interval, can_retry_callable(**kwargs) + else: + return retry_interval, None + else: + # 达到最大重试次数 + logger.warning(f"{cannot_retry_msg}") + return -1, None # 不再重试请求该模型 + + def _handle_resp_not_ok( + self, + e: RespNotOkException, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, + ): + """ + 处理响应错误异常 + Args: + e (RespNotOkException): 响应错误异常对象 + task_name (str): 任务名称 + model_name (str): 模型名称 + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) + Returns: + (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + # 响应错误 + if e.status_code in [400, 401, 402, 403, 404]: + # 客户端错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None # 不再重试请求该模型 + elif e.status_code == 413: + if messages and not messages[1]: + # 消息列表不为空且未压缩,尝试压缩消息 + return self._check_retry( + remain_try, + 0, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求", + can_retry_callable=compress_messages, + messages=messages[0], + ) + # 没有消息可压缩 + logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。") + return -1, None + elif e.status_code == 429: + # 请求过于频繁 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求", + ) + elif e.status_code >= 500: + # 服务器错误 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试", + ) + else: + # 未知错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None + + @staticmethod + def _extract_reasoning(content: str) -> Tuple[str, str]: + """CoT思维链提取,向后兼容""" + match = re.search(r"(?:)?(.*?)", content, re.DOTALL) + content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() + reasoning = match[1].strip() if match else "" + return content, reasoning + def _record_usage( self, + model_name: str, prompt_tokens: int, completion_tokens: int, total_tokens: int, @@ -405,7 +508,7 @@ class LLMRequest: try: # 使用 Peewee 模型创建记录 LLMUsage.create( - model_name=self.model_name, + model_name=model_name, user_id=user_id, request_type=request_type, endpoint=endpoint, @@ -417,7 +520,7 @@ class LLMRequest: timestamp=datetime.now(), # Peewee 会处理 DateTimeField ) logger.debug( - f"Token使用情况 - 模型: {self.model_name}, " + f"Token使用情况 - 模型: {model_name}, " f"用户: {user_id}, 类型: {request_type}, " f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " f"总计: {total_tokens}" @@ -440,384 +543,3 @@ class LLMRequest: input_cost = (prompt_tokens / 1000000) * self.pri_in output_cost = (completion_tokens / 1000000) * self.pri_out return round(input_cost + output_cost, 6) - - @staticmethod - def _extract_reasoning(content: str) -> Tuple[str, str]: - """CoT思维链提取""" - match = re.search(r"(?:)?(.*?)", content, re.DOTALL) - content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() - reasoning = match[1].strip() if match else "" - return content, reasoning - - def _handle_model_exception(self, e: Exception, operation: str) -> None: - """ - 统一的模型异常处理方法 - 根据异常类型提供更精确的错误信息和处理策略 - - Args: - e: 捕获的异常 - operation: 操作类型(用于日志记录) - """ - operation_desc = { - "image": "图片响应生成", - "voice": "语音识别", - "text": "文本响应生成", - "embedding": "向量嵌入获取" - } - - op_name = operation_desc.get(operation, operation) - - if SPECIFIC_EXCEPTIONS_AVAILABLE: - # 使用具体异常类型进行精确处理 - if isinstance(e, NetworkConnectionError): - logger.error(f"模型 {self.model_name} {op_name}失败: 网络连接错误") - raise RuntimeError("网络连接异常,请检查网络连接状态或API服务器地址是否正确") from e - - elif isinstance(e, ReqAbortException): - logger.error(f"模型 {self.model_name} {op_name}失败: 请求被中断") - raise RuntimeError("请求被中断或取消,请稍后重试") from e - - elif isinstance(e, RespNotOkException): - logger.error(f"模型 {self.model_name} {op_name}失败: HTTP响应错误 {e.status_code}") - # 重新抛出原始异常,保留详细的状态码信息 - raise e - - elif isinstance(e, RespParseException): - logger.error(f"模型 {self.model_name} {op_name}失败: 响应解析错误") - raise RuntimeError("API响应格式异常,请检查模型配置或联系管理员") from e - - else: - # 未知异常,使用通用处理 - logger.error(f"模型 {self.model_name} {op_name}失败: 未知错误 {type(e).__name__}: {str(e)}") - self._handle_generic_exception(e, op_name) - else: - # 如果无法导入具体异常,使用通用处理 - logger.error(f"模型 {self.model_name} {op_name}失败: {str(e)}") - self._handle_generic_exception(e, op_name) - - def _handle_generic_exception(self, e: Exception, operation: str) -> None: - """ - 通用异常处理(向后兼容的错误字符串匹配) - - Args: - e: 捕获的异常 - operation: 操作描述 - """ - error_str = str(e) - - # 基于错误消息内容的分类处理 - if "401" in error_str or "API key" in error_str or "认证" in error_str: - raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e - elif "429" in error_str or "频繁" in error_str or "rate limit" in error_str: - raise RuntimeError("请求过于频繁,请稍后再试") from e - elif "500" in error_str or "503" in error_str or "服务器" in error_str: - raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e - elif "413" in error_str or "payload" in error_str.lower() or "过大" in error_str: - raise RuntimeError("请求体过大,请尝试压缩图片或减少输入内容") from e - elif "timeout" in error_str.lower() or "超时" in error_str: - raise RuntimeError("请求超时,请检查网络连接或稍后重试") from e - else: - raise RuntimeError(f"模型 {self.model_name} {operation}失败: {str(e)}") from e - - # === 主要API方法 === - # 这些方法提供与新架构的桥接 - - async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple: - """ - 根据输入的提示和图片生成模型的异步响应 - 使用新架构的模型请求处理器 - """ - if not self.use_new_architecture: - raise RuntimeError( - f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" - ) - - if self.request_handler is None: - raise RuntimeError( - f"模型 {self.model_name} 请求处理器未初始化,无法处理图片请求" - ) - - if MessageBuilder is None: - raise RuntimeError("MessageBuilder不可用,请检查新架构配置") - - try: - # 构建包含图片的消息 - message_builder = MessageBuilder() - message_builder.add_text_content(prompt).add_image_content( - image_format=image_format, - image_base64=image_base64 - ) - messages = [message_builder.build()] - - # 使用新架构发送请求(只传递支持的参数) - response = await self.request_handler.get_response( # type: ignore - messages=messages, - tool_options=None, - response_format=None - ) - - # 新架构返回的是 APIResponse 对象,直接提取内容 - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - # 从内容中提取标签的推理内容(向后兼容) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - # 记录token使用情况 - if response.usage: - self._record_usage( - prompt_tokens=response.usage.prompt_tokens or 0, - completion_tokens=response.usage.completion_tokens or 0, - total_tokens=response.usage.total_tokens or 0, - user_id="system", - request_type=self.request_type, - endpoint="/chat/completions" - ) - - # 返回格式兼容旧版本 - if tool_calls: - return content, reasoning_content, tool_calls - else: - return content, reasoning_content - - except Exception as e: - self._handle_model_exception(e, "image") - # 这行代码永远不会执行,因为_handle_model_exception总是抛出异常 - # 但是为了满足类型检查的要求,我们添加一个不可达的返回语句 - return "", "" # pragma: no cover - - async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple: - """ - 根据输入的语音文件生成模型的异步响应 - 使用新架构的模型请求处理器 - """ - if not self.use_new_architecture: - raise RuntimeError( - f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" - ) - - if self.request_handler is None: - raise RuntimeError( - f"模型 {self.model_name} 请求处理器未初始化,无法处理语音请求" - ) - - try: - # 构建语音识别请求参数 - # 注意:新架构中的语音识别可能使用不同的方法 - # 这里先使用get_response方法,可能需要根据实际API调整 - response = await self.request_handler.get_response( # type: ignore - messages=[], # 语音识别可能不需要消息 - tool_options=None - ) - - # 新架构返回的是 APIResponse 对象,直接提取文本内容 - return (response.content,) if response.content else ("",) - - except Exception as e: - self._handle_model_exception(e, "voice") - # 不可达的返回语句,仅用于满足类型检查 - return ("",) # pragma: no cover - - async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: - """ - 异步方式根据输入的提示生成模型的响应 - 使用新架构的模型请求处理器,如无法使用则抛出错误 - """ - if not self.use_new_architecture: - raise RuntimeError( - f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" - ) - - if self.request_handler is None: - raise RuntimeError( - f"模型 {self.model_name} 请求处理器未初始化,无法生成响应" - ) - - if MessageBuilder is None: - raise RuntimeError("MessageBuilder不可用,请检查新架构配置") - - try: - # 构建消息 - message_builder = MessageBuilder() - message_builder.add_text_content(prompt) - messages = [message_builder.build()] - - # 使用新架构发送请求(只传递支持的参数) - response = await self.request_handler.get_response( # type: ignore - messages=messages, - tool_options=None, - response_format=None - ) - - # 新架构返回的是 APIResponse 对象,直接提取内容 - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - # 从内容中提取标签的推理内容(向后兼容) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - # 记录token使用情况 - if response.usage: - self._record_usage( - prompt_tokens=response.usage.prompt_tokens or 0, - completion_tokens=response.usage.completion_tokens or 0, - total_tokens=response.usage.total_tokens or 0, - user_id="system", - request_type=self.request_type, - endpoint="/chat/completions" - ) - - # 返回格式兼容旧版本 - if tool_calls: - return content, (reasoning_content, self.model_name, tool_calls) - else: - return content, (reasoning_content, self.model_name) - - except Exception as e: - self._handle_model_exception(e, "text") - # 不可达的返回语句,仅用于满足类型检查 - return "", ("", self.model_name) # pragma: no cover - - async def get_embedding(self, text: str) -> Union[list, None]: - """ - 异步方法:获取文本的embedding向量 - 使用新架构的模型请求处理器 - - Args: - text: 需要获取embedding的文本 - - Returns: - list: embedding向量,如果失败则返回None - """ - if not text: - logger.debug("该消息没有长度,不再发送获取embedding向量的请求") - return None - - if not self.use_new_architecture: - logger.warning(f"模型 {self.model_name} 无法使用新架构,embedding请求将被跳过") - return None - - if self.request_handler is None: - logger.warning(f"模型 {self.model_name} 请求处理器未初始化,embedding请求将被跳过") - return None - - try: - # 构建embedding请求参数 - # 使用新架构的get_embedding方法 - response = await self.request_handler.get_embedding(text) # type: ignore - - # 新架构返回的是 APIResponse 对象,直接提取embedding - if response.embedding: - embedding = response.embedding - - # 记录token使用情况 - if response.usage: - self._record_usage( - prompt_tokens=response.usage.prompt_tokens or 0, - completion_tokens=response.usage.completion_tokens or 0, - total_tokens=response.usage.total_tokens or 0, - user_id="system", - request_type=self.request_type, - endpoint="/embeddings" - ) - - return embedding - else: - logger.warning(f"模型 {self.model_name} 返回的embedding响应为空") - return None - - except Exception as e: - # 对于embedding请求,我们记录错误但不抛出异常,而是返回None - # 这是为了保持与原有行为的兼容性 - try: - self._handle_model_exception(e, "embedding") - except RuntimeError: - # 捕获_handle_model_exception抛出的RuntimeError,转换为警告日志 - logger.warning(f"模型 {self.model_name} embedding请求失败,返回None: {str(e)}") - return None - - -def compress_base64_image_by_scale(base64_data: str, target_size: int = int(0.8 * 1024 * 1024)) -> str: - """压缩base64格式的图片到指定大小 - Args: - base64_data: base64编码的图片数据 - target_size: 目标文件大小(字节),默认0.8MB - Returns: - str: 压缩后的base64图片数据 - """ - try: - # 将base64转换为字节数据 - # 确保base64字符串只包含ASCII字符 - if isinstance(base64_data, str): - base64_data = base64_data.encode("ascii", errors="ignore").decode("ascii") - image_data = base64.b64decode(base64_data) - - # 如果已经小于目标大小,直接返回原图 - if len(image_data) <= 2 * 1024 * 1024: - return base64_data - - # 将字节数据转换为图片对象 - img = Image.open(io.BytesIO(image_data)) - - # 获取原始尺寸 - original_width, original_height = img.size - - # 计算缩放比例 - scale = min(1.0, (target_size / len(image_data)) ** 0.5) - - # 计算新的尺寸 - new_width = int(original_width * scale) - new_height = int(original_height * scale) - - # 创建内存缓冲区 - output_buffer = io.BytesIO() - - # 如果是GIF,处理所有帧 - if getattr(img, "is_animated", False): - frames = [] - n_frames = getattr(img, 'n_frames', 1) - for frame_idx in range(n_frames): - img.seek(frame_idx) - new_frame = img.copy() - new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折 - frames.append(new_frame) - - # 保存到缓冲区 - frames[0].save( - output_buffer, - format="GIF", - save_all=True, - append_images=frames[1:], - optimize=True, - duration=img.info.get("duration", 100), - loop=img.info.get("loop", 0), - ) - else: - # 处理静态图片 - resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - - # 保存到缓冲区,保持原始格式 - if img.format == "PNG" and img.mode in ("RGBA", "LA"): - resized_img.save(output_buffer, format="PNG", optimize=True) - else: - resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True) - - # 获取压缩后的数据并转换为base64 - compressed_data = output_buffer.getvalue() - logger.info(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}") - logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB") - - return base64.b64encode(compressed_data).decode("utf-8") - - except Exception as e: - logger.error(f"压缩图片失败: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return base64_data diff --git a/src/llm_models/utils_model_bak.py b/src/llm_models/utils_model_bak.py new file mode 100644 index 000000000..fd78d559b --- /dev/null +++ b/src/llm_models/utils_model_bak.py @@ -0,0 +1,778 @@ +import re +from datetime import datetime +from typing import Tuple, Union +from src.common.logger import get_logger +import base64 +from PIL import Image +import io +from src.common.database.database import db # 确保 db 被导入用于 create_tables +from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 +from src.config.config import global_config +from rich.traceback import install + +from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException, PayLoadTooLargeError, RequestAbortException, PermissionDeniedException +install(extra_lines=3) + +logger = get_logger("model_utils") + +# 导入具体的异常类型用于精确的异常处理 +from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException +SPECIFIC_EXCEPTIONS_AVAILABLE = True + +# 新架构导入 - 使用延迟导入以支持fallback模式 + +from .model_manager_bak import ModelManager +from .model_client import ModelRequestHandler +from .payload_content.message import MessageBuilder + +# 不在模块级别初始化ModelManager,延迟到实际使用时 +ModelManager_class = ModelManager +model_manager = None # 延迟初始化 + +# 添加请求处理器缓存,避免重复创建 +_request_handler_cache = {} # 格式: {(model_name, task_name): ModelRequestHandler} + +NEW_ARCHITECTURE_AVAILABLE = True +logger.info("新架构模块导入成功") + + + + + +# 常见Error Code Mapping +error_code_mapping = { + 400: "参数不正确", + 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", + 402: "账号余额不足", + 403: "需要实名,或余额不足", + 404: "Not Found", + 429: "请求过于频繁,请稍后再试", + 500: "服务器内部故障", + 503: "服务器负载过高", +} + + + + +class LLMRequest: + """ + 重构后的LLM请求类,基于新的model_manager和model_client架构 + 保持向后兼容的API接口 + """ + + # 定义需要转换的模型列表,作为类变量避免重复 + MODELS_NEEDING_TRANSFORMATION = [ + "o1", + "o1-2024-12-17", + "o1-mini", + "o1-mini-2024-09-12", + "o1-preview", + "o1-preview-2024-09-12", + "o1-pro", + "o1-pro-2025-03-19", + "o3", + "o3-2025-04-16", + "o3-mini", + "o3-mini-2025-01-31", + "o4-mini", + "o4-mini-2025-04-16", + ] + + def __init__(self, model: dict, **kwargs): + """ + 初始化LLM请求实例 + Args: + model: 模型配置字典,兼容旧格式和新格式 + **kwargs: 额外参数 + """ + logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('model_name', model.get('name', 'Unknown'))}") + logger.debug(f"🔍 [模型初始化] 输入的模型配置: {model}") + logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}") + + # 兼容新旧模型配置格式 + # 新格式使用 model_name,旧格式使用 name + self.model_name: str = model.get("model_name", model.get("name", "")) + + # 如果传入的配置不完整,自动从全局配置中获取完整配置 + if not all(key in model for key in ["task_type", "capabilities"]): + logger.debug("🔍 [模型初始化] 检测到不完整的模型配置,尝试获取完整配置") + if (full_model_config := self._get_full_model_config(self.model_name)): + logger.debug("🔍 [模型初始化] 成功获取完整模型配置,合并配置信息") + # 合并配置:运行时参数优先,但添加缺失的配置字段 + model = {**full_model_config, **model} + logger.debug(f"🔍 [模型初始化] 合并后的模型配置: {model}") + else: + logger.warning(f"⚠️ [模型初始化] 无法获取模型 {self.model_name} 的完整配置,使用原始配置") + + # 在新架构中,provider信息从model_config.toml自动获取,不需要在这里设置 + self.provider = model.get("provider", "") # 保留兼容性,但在新架构中不使用 + + # 从全局配置中获取任务配置 + self.request_type = kwargs.pop("request_type", "default") + + # 确定使用哪个任务配置 + task_name = self._determine_task_name(model) + + # 初始化 request_handler + self.request_handler = None + + # 尝试初始化新架构 + if NEW_ARCHITECTURE_AVAILABLE and ModelManager_class is not None: + try: + # 延迟初始化ModelManager + global model_manager, _request_handler_cache + if model_manager is None: + from src.config.config import model_config + model_manager = ModelManager_class(model_config) + logger.debug("🔍 [模型初始化] ModelManager延迟初始化成功") + + # 构建缓存键 + cache_key = (self.model_name, task_name) + + # 检查是否已有缓存的请求处理器 + if cache_key in _request_handler_cache: + self.request_handler = _request_handler_cache[cache_key] + logger.debug(f"🚀 [性能优化] 从LLMRequest缓存获取请求处理器: {cache_key}") + else: + # 使用新架构获取模型请求处理器 + self.request_handler = model_manager[task_name] + _request_handler_cache[cache_key] = self.request_handler + logger.debug(f"🔧 [性能优化] 创建并缓存LLMRequest请求处理器: {cache_key}") + + logger.debug(f"🔍 [模型初始化] 成功获取模型请求处理器,任务: {task_name}") + self.use_new_architecture = True + except Exception as e: + logger.warning(f"无法使用新架构,任务 {task_name} 初始化失败: {e}") + logger.warning("回退到兼容模式,某些功能可能受限") + self.request_handler = None + self.use_new_architecture = False + else: + logger.warning("新架构不可用,使用兼容模式") + logger.warning("回退到兼容模式,某些功能可能受限") + self.request_handler = None + self.use_new_architecture = False + + # 保存原始参数用于向后兼容 + self.params = kwargs + + # 兼容性属性,从模型配置中提取 + # 新格式和旧格式都支持 + self.enable_thinking = model.get("enable_thinking", False) + self.temp = model.get("temperature", model.get("temp", 0.7)) # 新格式用temperature,旧格式用temp + self.thinking_budget = model.get("thinking_budget", 4096) + self.stream = model.get("stream", False) + self.pri_in = model.get("pri_in", 0) + self.pri_out = model.get("pri_out", 0) + self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) + + # 记录配置文件中声明了哪些参数(不管值是什么) + self.has_enable_thinking = "enable_thinking" in model + self.has_thinking_budget = "thinking_budget" in model + self.pri_out = model.get("pri_out", 0) + self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) + + # 记录配置文件中声明了哪些参数(不管值是什么) + self.has_enable_thinking = "enable_thinking" in model + self.has_thinking_budget = "thinking_budget" in model + + logger.debug("🔍 [模型初始化] 模型参数设置完成:") + logger.debug(f" - model_name: {self.model_name}") + logger.debug(f" - provider: {self.provider}") + logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}") + logger.debug(f" - enable_thinking: {self.enable_thinking}") + logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}") + logger.debug(f" - thinking_budget: {self.thinking_budget}") + logger.debug(f" - temp: {self.temp}") + logger.debug(f" - stream: {self.stream}") + logger.debug(f" - max_tokens: {self.max_tokens}") + logger.debug(f" - use_new_architecture: {self.use_new_architecture}") + + # 获取数据库实例 + self._init_database() + + logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}") + + def _determine_task_name(self, model: dict) -> str: + """ + 根据模型配置确定任务名称 + 优先使用配置文件中明确定义的任务类型,避免基于模型名称的脆弱推断 + + Args: + model: 模型配置字典 + Returns: + 任务名称 + """ + # 调试信息:打印模型配置字典的所有键 + logger.debug(f"🔍 [任务确定] 模型配置字典的所有键: {list(model.keys())}") + logger.debug(f"🔍 [任务确定] 模型配置字典内容: {model}") + + # 获取模型名称 + model_name = model.get("model_name", model.get("name", "")) + + # 方法1: 优先使用配置文件中明确定义的 task_type 字段 + if "task_type" in model: + task_type = model["task_type"] + logger.debug(f"🎯 [任务确定] 使用配置中的 task_type: {task_type}") + return task_type + + # 方法2: 使用 capabilities 字段来推断主要任务类型 + if "capabilities" in model: + capabilities = model["capabilities"] + if isinstance(capabilities, list): + # 按优先级顺序检查能力 + if "vision" in capabilities: + logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: vision") + return "vision" + elif "embedding" in capabilities: + logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: embedding") + return "embedding" + elif "speech" in capabilities: + logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: speech") + return "speech" + elif "text" in capabilities: + # 如果只有文本能力,则根据request_type细分 + task = "llm_reasoning" if self.request_type == "reasoning" else "llm_normal" + logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 和 request_type {self.request_type} 推断为: {task}") + return task + + # 方法3: 向后兼容 - 基于模型名称的关键字推断(不推荐但保留兼容性) + logger.warning(f"⚠️ [任务确定] 配置中未找到 task_type 或 capabilities,回退到基于模型名称的推断: {model_name}") + logger.warning("⚠️ [建议] 请在 model_config.toml 中为模型添加明确的 task_type 或 capabilities 字段") + + # 保留原有的关键字匹配逻辑作为fallback + if any(keyword in model_name.lower() for keyword in ["vlm", "vision", "gpt-4o", "claude", "vl-"]): + logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: vision") + return "vision" + elif any(keyword in model_name.lower() for keyword in ["embed", "text-embedding", "bge-"]): + logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: embedding") + return "embedding" + elif any(keyword in model_name.lower() for keyword in ["whisper", "speech", "voice"]): + logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: speech") + return "speech" + else: + # 根据request_type确定,映射到配置文件中定义的任务 + task = "llm_reasoning" if self.request_type == "reasoning" else "llm_normal" + logger.debug(f"🎯 [任务确定] 从 request_type {self.request_type} 推断为: {task}") + return task + + def _get_full_model_config(self, model_name: str) -> dict | None: + """ + 根据模型名称从全局配置中获取完整的模型配置 + 现在直接使用已解析的ModelInfo对象,不再读取TOML文件 + + Args: + model_name: 模型名称 + Returns: + 完整的模型配置字典,如果找不到则返回None + """ + try: + from src.config.config import model_config + return self._get_model_config_from_parsed(model_name, model_config) + + except Exception as e: + logger.warning(f"⚠️ [配置查找] 获取模型配置时出错: {str(e)}") + return None + + def _get_model_config_from_parsed(self, model_name: str, model_config) -> dict | None: + """ + 从已解析的配置对象中获取模型配置 + 使用扩展后的ModelInfo类,包含task_type和capabilities字段 + """ + try: + # 直接通过模型名称查找 + if model_name in model_config.models: + model_info = model_config.models[model_name] + logger.debug(f"🔍 [配置查找] 找到模型 {model_name} 的配置对象: {model_info}") + + # 将ModelInfo对象转换为字典 + model_dict = { + "model_identifier": model_info.model_identifier, + "name": model_info.name, + "api_provider": model_info.api_provider, + "price_in": model_info.price_in, + "price_out": model_info.price_out, + "force_stream_mode": model_info.force_stream_mode, + "task_type": model_info.task_type, + "capabilities": model_info.capabilities, + } + + logger.debug(f"🔍 [配置查找] 转换后的模型配置字典: {model_dict}") + return model_dict + + # 如果直接查找失败,尝试通过model_identifier查找 + for name, model_info in model_config.models.items(): + if (model_info.model_identifier == model_name or + hasattr(model_info, 'model_name') and model_info.model_name == model_name): + + logger.debug(f"🔍 [配置查找] 通过标识符找到模型 {model_name} (配置名称: {name})") + # 同样转换为字典 + model_dict = { + "model_identifier": model_info.model_identifier, + "name": model_info.name, + "api_provider": model_info.api_provider, + "price_in": model_info.price_in, + "price_out": model_info.price_out, + "force_stream_mode": model_info.force_stream_mode, + "task_type": model_info.task_type, + "capabilities": model_info.capabilities, + } + + return model_dict + + return None + + except Exception as e: + logger.warning(f"⚠️ [配置查找] 从已解析配置获取模型配置时出错: {str(e)}") + return None + + @staticmethod + def _init_database(): + """初始化数据库集合""" + try: + # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 + db.create_tables([LLMUsage], safe=True) + # logger.debug("LLMUsage 表已初始化/确保存在。") + except Exception as e: + logger.error(f"创建 LLMUsage 表失败: {str(e)}") + + def _record_usage( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + user_id: str = "system", + request_type: str | None = None, + endpoint: str = "/chat/completions", + ): + """记录模型使用情况到数据库 + Args: + prompt_tokens: 输入token数 + completion_tokens: 输出token数 + total_tokens: 总token数 + user_id: 用户ID,默认为system + request_type: 请求类型 + endpoint: API端点 + """ + # 如果 request_type 为 None,则使用实例变量中的值 + if request_type is None: + request_type = self.request_type + + try: + # 使用 Peewee 模型创建记录 + LLMUsage.create( + model_name=self.model_name, + user_id=user_id, + request_type=request_type, + endpoint=endpoint, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=self._calculate_cost(prompt_tokens, completion_tokens), + status="success", + timestamp=datetime.now(), # Peewee 会处理 DateTimeField + ) + logger.debug( + f"Token使用情况 - 模型: {self.model_name}, " + f"用户: {user_id}, 类型: {request_type}, " + f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " + f"总计: {total_tokens}" + ) + except Exception as e: + logger.error(f"记录token使用情况失败: {str(e)}") + + def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: + """计算API调用成本 + 使用模型的pri_in和pri_out价格计算输入和输出的成本 + + Args: + prompt_tokens: 输入token数量 + completion_tokens: 输出token数量 + + Returns: + float: 总成本(元) + """ + # 使用模型的pri_in和pri_out计算成本 + input_cost = (prompt_tokens / 1000000) * self.pri_in + output_cost = (completion_tokens / 1000000) * self.pri_out + return round(input_cost + output_cost, 6) + + @staticmethod + def _extract_reasoning(content: str) -> Tuple[str, str]: + """CoT思维链提取""" + match = re.search(r"(?:)?(.*?)", content, re.DOTALL) + content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() + reasoning = match[1].strip() if match else "" + return content, reasoning + + def _handle_model_exception(self, e: Exception, operation: str) -> None: + """ + 统一的模型异常处理方法 + 根据异常类型提供更精确的错误信息和处理策略 + + Args: + e: 捕获的异常 + operation: 操作类型(用于日志记录) + """ + operation_desc = { + "image": "图片响应生成", + "voice": "语音识别", + "text": "文本响应生成", + "embedding": "向量嵌入获取" + } + + op_name = operation_desc.get(operation, operation) + + if SPECIFIC_EXCEPTIONS_AVAILABLE: + # 使用具体异常类型进行精确处理 + if isinstance(e, NetworkConnectionError): + logger.error(f"模型 {self.model_name} {op_name}失败: 网络连接错误") + raise RuntimeError("网络连接异常,请检查网络连接状态或API服务器地址是否正确") from e + + elif isinstance(e, ReqAbortException): + logger.error(f"模型 {self.model_name} {op_name}失败: 请求被中断") + raise RuntimeError("请求被中断或取消,请稍后重试") from e + + elif isinstance(e, RespNotOkException): + logger.error(f"模型 {self.model_name} {op_name}失败: HTTP响应错误 {e.status_code}") + # 重新抛出原始异常,保留详细的状态码信息 + raise e + + elif isinstance(e, RespParseException): + logger.error(f"模型 {self.model_name} {op_name}失败: 响应解析错误") + raise RuntimeError("API响应格式异常,请检查模型配置或联系管理员") from e + + else: + # 未知异常,使用通用处理 + logger.error(f"模型 {self.model_name} {op_name}失败: 未知错误 {type(e).__name__}: {str(e)}") + self._handle_generic_exception(e, op_name) + else: + # 如果无法导入具体异常,使用通用处理 + logger.error(f"模型 {self.model_name} {op_name}失败: {str(e)}") + self._handle_generic_exception(e, op_name) + + def _handle_generic_exception(self, e: Exception, operation: str) -> None: + """ + 通用异常处理(向后兼容的错误字符串匹配) + + Args: + e: 捕获的异常 + operation: 操作描述 + """ + error_str = str(e) + + # 基于错误消息内容的分类处理 + if "401" in error_str or "API key" in error_str or "认证" in error_str: + raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e + elif "429" in error_str or "频繁" in error_str or "rate limit" in error_str: + raise RuntimeError("请求过于频繁,请稍后再试") from e + elif "500" in error_str or "503" in error_str or "服务器" in error_str: + raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e + elif "413" in error_str or "payload" in error_str.lower() or "过大" in error_str: + raise RuntimeError("请求体过大,请尝试压缩图片或减少输入内容") from e + elif "timeout" in error_str.lower() or "超时" in error_str: + raise RuntimeError("请求超时,请检查网络连接或稍后重试") from e + else: + raise RuntimeError(f"模型 {self.model_name} {operation}失败: {str(e)}") from e + + # === 主要API方法 === + # 这些方法提供与新架构的桥接 + + async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple: + """ + 根据输入的提示和图片生成模型的异步响应 + 使用新架构的模型请求处理器 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + if self.request_handler is None: + raise RuntimeError( + f"模型 {self.model_name} 请求处理器未初始化,无法处理图片请求" + ) + + if MessageBuilder is None: + raise RuntimeError("MessageBuilder不可用,请检查新架构配置") + + try: + # 构建包含图片的消息 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt).add_image_content( + image_format=image_format, + image_base64=image_base64 + ) + messages = [message_builder.build()] + + # 使用新架构发送请求(只传递支持的参数) + response = await self.request_handler.get_response( # type: ignore + messages=messages, + tool_options=None, + response_format=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取内容 + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions" + ) + + # 返回格式兼容旧版本 + if tool_calls: + return content, reasoning_content, tool_calls + else: + return content, reasoning_content + + except Exception as e: + self._handle_model_exception(e, "image") + # 这行代码永远不会执行,因为_handle_model_exception总是抛出异常 + # 但是为了满足类型检查的要求,我们添加一个不可达的返回语句 + return "", "" # pragma: no cover + + async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple: + """ + 根据输入的语音文件生成模型的异步响应 + 使用新架构的模型请求处理器 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + if self.request_handler is None: + raise RuntimeError( + f"模型 {self.model_name} 请求处理器未初始化,无法处理语音请求" + ) + + try: + # 构建语音识别请求参数 + # 注意:新架构中的语音识别可能使用不同的方法 + # 这里先使用get_response方法,可能需要根据实际API调整 + response = await self.request_handler.get_response( # type: ignore + messages=[], # 语音识别可能不需要消息 + tool_options=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取文本内容 + return (response.content,) if response.content else ("",) + + except Exception as e: + self._handle_model_exception(e, "voice") + # 不可达的返回语句,仅用于满足类型检查 + return ("",) # pragma: no cover + + async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: + """ + 异步方式根据输入的提示生成模型的响应 + 使用新架构的模型请求处理器,如无法使用则抛出错误 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + if self.request_handler is None: + raise RuntimeError( + f"模型 {self.model_name} 请求处理器未初始化,无法生成响应" + ) + + if MessageBuilder is None: + raise RuntimeError("MessageBuilder不可用,请检查新架构配置") + + try: + # 构建消息 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + messages = [message_builder.build()] + + # 使用新架构发送请求(只传递支持的参数) + response = await self.request_handler.get_response( # type: ignore + messages=messages, + tool_options=None, + response_format=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取内容 + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions" + ) + + # 返回格式兼容旧版本 + if tool_calls: + return content, (reasoning_content, self.model_name, tool_calls) + else: + return content, (reasoning_content, self.model_name) + + except Exception as e: + self._handle_model_exception(e, "text") + # 不可达的返回语句,仅用于满足类型检查 + return "", ("", self.model_name) # pragma: no cover + + async def get_embedding(self, text: str) -> Union[list, None]: + """ + 异步方法:获取文本的embedding向量 + 使用新架构的模型请求处理器 + + Args: + text: 需要获取embedding的文本 + + Returns: + list: embedding向量,如果失败则返回None + """ + if not text: + logger.debug("该消息没有长度,不再发送获取embedding向量的请求") + return None + + if not self.use_new_architecture: + logger.warning(f"模型 {self.model_name} 无法使用新架构,embedding请求将被跳过") + return None + + if self.request_handler is None: + logger.warning(f"模型 {self.model_name} 请求处理器未初始化,embedding请求将被跳过") + return None + + try: + # 构建embedding请求参数 + # 使用新架构的get_embedding方法 + response = await self.request_handler.get_embedding(text) # type: ignore + + # 新架构返回的是 APIResponse 对象,直接提取embedding + if response.embedding: + embedding = response.embedding + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings" + ) + + return embedding + else: + logger.warning(f"模型 {self.model_name} 返回的embedding响应为空") + return None + + except Exception as e: + # 对于embedding请求,我们记录错误但不抛出异常,而是返回None + # 这是为了保持与原有行为的兼容性 + try: + self._handle_model_exception(e, "embedding") + except RuntimeError: + # 捕获_handle_model_exception抛出的RuntimeError,转换为警告日志 + logger.warning(f"模型 {self.model_name} embedding请求失败,返回None: {str(e)}") + return None + + +def compress_base64_image_by_scale(base64_data: str, target_size: int = int(0.8 * 1024 * 1024)) -> str: + """压缩base64格式的图片到指定大小 + Args: + base64_data: base64编码的图片数据 + target_size: 目标文件大小(字节),默认0.8MB + Returns: + str: 压缩后的base64图片数据 + """ + try: + # 将base64转换为字节数据 + # 确保base64字符串只包含ASCII字符 + if isinstance(base64_data, str): + base64_data = base64_data.encode("ascii", errors="ignore").decode("ascii") + image_data = base64.b64decode(base64_data) + + # 如果已经小于目标大小,直接返回原图 + if len(image_data) <= 2 * 1024 * 1024: + return base64_data + + # 将字节数据转换为图片对象 + img = Image.open(io.BytesIO(image_data)) + + # 获取原始尺寸 + original_width, original_height = img.size + + # 计算缩放比例 + scale = min(1.0, (target_size / len(image_data)) ** 0.5) + + # 计算新的尺寸 + new_width = int(original_width * scale) + new_height = int(original_height * scale) + + # 创建内存缓冲区 + output_buffer = io.BytesIO() + + # 如果是GIF,处理所有帧 + if getattr(img, "is_animated", False): + frames = [] + n_frames = getattr(img, 'n_frames', 1) + for frame_idx in range(n_frames): + img.seek(frame_idx) + new_frame = img.copy() + new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折 + frames.append(new_frame) + + # 保存到缓冲区 + frames[0].save( + output_buffer, + format="GIF", + save_all=True, + append_images=frames[1:], + optimize=True, + duration=img.info.get("duration", 100), + loop=img.info.get("loop", 0), + ) + else: + # 处理静态图片 + resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # 保存到缓冲区,保持原始格式 + if img.format == "PNG" and img.mode in ("RGBA", "LA"): + resized_img.save(output_buffer, format="PNG", optimize=True) + else: + resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True) + + # 获取压缩后的数据并转换为base64 + compressed_data = output_buffer.getvalue() + logger.info(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}") + logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB") + + return base64.b64encode(compressed_data).decode("utf-8") + + except Exception as e: + logger.error(f"压缩图片失败: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return base64_data diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index fa9466c6d..de154491c 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "5.0.0" +version = "6.0.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -213,98 +213,10 @@ file_log_level = "DEBUG" # 文件日志级别,可选: DEBUG, INFO, WARNING, ER suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"] # 完全屏蔽的库 library_log_levels = { "aiohttp" = "WARNING"} # 设置特定库的日志级别 -#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写 - -# stream = : 用于指定模型是否是使用流式输出 -# pri_in = : 用于指定模型输入价格 -# pri_out = : 用于指定模型输出价格 -# temp = : 用于指定模型温度 -# enable_thinking = : 用于指定模型是否启用思考 -# thinking_budget = : 用于指定模型思考最长长度 - [debug] show_prompt = false # 是否显示prompt -[model] -model_max_output_length = 800 # 模型单次返回的最大token数 - -#------------模型任务配置------------ -# 所有模型名称需要对应 model_config.toml 中配置的模型名称 - -[model.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型 -model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 -temperature = 0.2 # 模型温度,新V3建议0.1-0.3 -max_tokens = 800 # 最大输出token数 - -[model.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型 -model_name = "qwen3-8b" # 对应 model_config.toml 中的模型名称 -temperature = 0.7 -max_tokens = 800 -enable_thinking = false # 是否启用思考 - -[model.replyer_1] # 首要回复模型,还用于表达器和表达方式学习 -model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 -temperature = 0.2 # 模型温度,新V3建议0.1-0.3 -max_tokens = 800 - -[model.replyer_2] # 次要回复模型 -model_name = "siliconflow-deepseek-r1" # 对应 model_config.toml 中的模型名称 -temperature = 0.7 # 模型温度 -max_tokens = 800 - -[model.planner] #决策:负责决定麦麦该做什么的模型 -model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 -temperature = 0.3 -max_tokens = 800 - -[model.emotion] #负责麦麦的情绪变化 -model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 -temperature = 0.3 -max_tokens = 800 - -[model.memory] # 记忆模型 -model_name = "qwen3-30b" # 对应 model_config.toml 中的模型名称 -temperature = 0.7 -max_tokens = 800 -enable_thinking = false # 是否启用思考 - -[model.vlm] # 图像识别模型 -model_name = "qwen2.5-vl-72b" # 对应 model_config.toml 中的模型名称 -max_tokens = 800 - -[model.voice] # 语音识别模型 -model_name = "sensevoice-small" # 对应 model_config.toml 中的模型名称 - -[model.tool_use] #工具调用模型,需要使用支持工具调用的模型 -model_name = "qwen3-14b" # 对应 model_config.toml 中的模型名称 -temperature = 0.7 -max_tokens = 800 -enable_thinking = false # 是否启用思考(qwen3 only) - -#嵌入模型 -[model.embedding] -model_name = "bge-m3" # 对应 model_config.toml 中的模型名称 - -#------------LPMM知识库模型------------ - -[model.lpmm_entity_extract] # 实体提取模型 -model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 -temperature = 0.2 -max_tokens = 800 - -[model.lpmm_rdf_build] # RDF构建模型 -model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 -temperature = 0.2 -max_tokens = 800 - -[model.lpmm_qa] # 问答模型 -model_name = "deepseek-r1-distill-qwen-32b" # 对应 model_config.toml 中的模型名称 -temperature = 0.7 -max_tokens = 800 -enable_thinking = false # 是否启用思考 - - [maim_message] auth_token = [] # 认证令牌,用于API验证,为空则不启用验证 # 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器 @@ -320,8 +232,4 @@ key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效 enable = true [experimental] #实验性功能 -enable_friend_chat = false # 是否启用好友聊天 - - - - +enable_friend_chat = false # 是否启用好友聊天 \ No newline at end of file diff --git a/template/model_config_template.toml b/template/model_config_template.toml index 8ab187626..ff392b054 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "0.2.1" +version = "1.0.0" # 配置文件版本号迭代规则同bot_config.toml # @@ -42,53 +42,31 @@ version = "0.2.1" # - 未配置新字段时会自动回退到基于模型名称的推断 [request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释) -#max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) -#timeout = 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) -#retry_interval = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) -#default_temperature = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) -#default_max_tokens = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) +max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) +timeout = 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) +retry_interval = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) +default_temperature = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) +default_max_tokens = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) [[api_providers]] # API服务提供商(可以配置多个) name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL # 支持多个API Key,实现自动切换和负载均衡 -api_keys = [ # API Key列表(多个key支持错误自动切换和负载均衡) - "sk-your-first-key-here", - "sk-your-second-key-here", - "sk-your-third-key-here" -] -# 向后兼容:如果只有一个key,也可以使用单个key字段 -#key = "******" # API Key (可选,默认为None) +api_key = "sk-your-first-key-here" client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini") [[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"gemini" name = "Google" base_url = "https://api.google.com/v1" -# Google API同样支持多key配置 -api_keys = [ - "your-google-api-key-1", - "your-google-api-key-2" -] +api_key = "your-google-api-key-1" client_type = "gemini" -[[api_providers]] -name = "SiliconFlow" -base_url = "https://api.siliconflow.cn/v1" -# 单个key的示例(向后兼容) -key = "******" -# -#[[api_providers]] -#name = "LocalHost" -#base_url = "https://localhost:8888" -#key = "lm-studio" - [[models]] # 模型(可以配置多个) # 模型标识符(API服务商提供的模型标识符) model_identifier = "deepseek-chat" # 模型名称(可随意命名,在bot_config.toml中需使用这个命名) -#(可选,若无该字段,则将自动使用model_identifier填充) name = "deepseek-v3" # API服务商名称(对应在api_providers中配置的服务商名称) api_provider = "DeepSeek" @@ -111,20 +89,15 @@ price_out = 8.0 model_identifier = "deepseek-reasoner" name = "deepseek-r1" api_provider = "DeepSeek" -# 推理模型的配置示例 -task_type = "llm_reasoning" -capabilities = ["text", "tool_calling", "reasoning"] -# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) -model_flags = [ "text", "tool_calling", "reasoning",] price_in = 4.0 price_out = 16.0 +has_thinking = true # 有无思考参数 +enable_thinking = true # 是否启用思考 [[models]] model_identifier = "Pro/deepseek-ai/DeepSeek-V3" name = "siliconflow-deepseek-v3" api_provider = "SiliconFlow" -task_type = "llm_normal" -capabilities = ["text", "tool_calling"] price_in = 2.0 price_out = 8.0 @@ -132,8 +105,6 @@ price_out = 8.0 model_identifier = "Pro/deepseek-ai/DeepSeek-R1" name = "siliconflow-deepseek-r1" api_provider = "SiliconFlow" -task_type = "llm_reasoning" -capabilities = ["text", "tool_calling", "reasoning"] price_in = 4.0 price_out = 16.0 @@ -141,8 +112,6 @@ price_out = 16.0 model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" name = "deepseek-r1-distill-qwen-32b" api_provider = "SiliconFlow" -task_type = "llm_reasoning" -capabilities = ["text", "tool_calling", "reasoning"] price_in = 4.0 price_out = 16.0 @@ -150,8 +119,6 @@ price_out = 16.0 model_identifier = "Qwen/Qwen3-8B" name = "qwen3-8b" api_provider = "SiliconFlow" -task_type = "llm_normal" -capabilities = ["text"] price_in = 0 price_out = 0 @@ -159,8 +126,6 @@ price_out = 0 model_identifier = "Qwen/Qwen3-14B" name = "qwen3-14b" api_provider = "SiliconFlow" -task_type = "llm_normal" -capabilities = ["text", "tool_calling"] price_in = 0.5 price_out = 2.0 @@ -168,8 +133,6 @@ price_out = 2.0 model_identifier = "Qwen/Qwen3-30B-A3B" name = "qwen3-30b" api_provider = "SiliconFlow" -task_type = "llm_normal" -capabilities = ["text", "tool_calling"] price_in = 0.7 price_out = 2.8 @@ -177,11 +140,6 @@ price_out = 2.8 model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct" name = "qwen2.5-vl-72b" api_provider = "SiliconFlow" -# 视觉模型的配置示例 -task_type = "vision" -capabilities = ["vision", "text"] -# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) -model_flags = [ "vision", "text",] price_in = 4.13 price_out = 4.13 @@ -189,11 +147,6 @@ price_out = 4.13 model_identifier = "FunAudioLLM/SenseVoiceSmall" name = "sensevoice-small" api_provider = "SiliconFlow" -# 语音模型的配置示例 -task_type = "speech" -capabilities = ["speech"] -# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) -model_flags = [ "audio",] price_in = 0 price_out = 0 @@ -210,11 +163,73 @@ price_in = 0 price_out = 0 -[task_model_usage] -llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0} -llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0} -embedding = "siliconflow-bge-m3" -#schedule = [ -# "deepseek-v3", -# "deepseek-r1", -#] \ No newline at end of file +[model.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型 +model_list = ["siliconflow-deepseek-v3","qwen3-8b"] +temperature = 0.2 # 模型温度,新V3建议0.1-0.3 +max_tokens = 800 # 最大输出token数 + +[model.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型 +model_name = "qwen3-8b" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 +max_tokens = 800 + +[model.replyer_1] # 首要回复模型,还用于表达器和表达方式学习 +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.2 # 模型温度,新V3建议0.1-0.3 +max_tokens = 800 + +[model.replyer_2] # 次要回复模型 +model_name = "siliconflow-deepseek-r1" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 # 模型温度 +max_tokens = 800 + +[model.planner] #决策:负责决定麦麦该做什么的模型 +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.3 +max_tokens = 800 + +[model.emotion] #负责麦麦的情绪变化 +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.3 +max_tokens = 800 + +[model.memory] # 记忆模型 +model_name = "qwen3-30b" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 +max_tokens = 800 +enable_thinking = false # 是否启用思考 + +[model.vlm] # 图像识别模型 +model_name = "qwen2.5-vl-72b" # 对应 model_config.toml 中的模型名称 +max_tokens = 800 + +[model.voice] # 语音识别模型 +model_name = "sensevoice-small" # 对应 model_config.toml 中的模型名称 + +[model.tool_use] #工具调用模型,需要使用支持工具调用的模型 +model_name = "qwen3-14b" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 +max_tokens = 800 +enable_thinking = false # 是否启用思考(qwen3 only) + +#嵌入模型 +[model.embedding] +model_name = "bge-m3" # 对应 model_config.toml 中的模型名称 + +#------------LPMM知识库模型------------ + +[model.lpmm_entity_extract] # 实体提取模型 +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.2 +max_tokens = 800 + +[model.lpmm_rdf_build] # RDF构建模型 +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.2 +max_tokens = 800 + +[model.lpmm_qa] # 问答模型 +model_name = "deepseek-r1-distill-qwen-32b" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 +max_tokens = 800 +enable_thinking = false # 是否启用思考 \ No newline at end of file