From c4e76b45dcce5618b5218cbde9f12822b42a3915 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Tue, 22 Jul 2025 23:54:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8A=8A=20API=20ada=E5=85=88=E6=8F=92?= =?UTF-8?q?=E8=BF=9B=E6=9D=A5=EF=BC=8C=E5=88=AB=E7=AE=A1=E8=83=BD=E4=B8=8D?= =?UTF-8?q?=E8=83=BD=E7=94=A8=EF=BC=8C=E5=85=88=E6=8F=92=E8=BF=9B=E6=9D=A5?= =?UTF-8?q?=E5=86=8D=E8=AF=B4=EF=BC=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/maibot_llmreq/LICENSE | 21 + src/chat/maibot_llmreq/__init__.py | 19 + src/chat/maibot_llmreq/config/__init__.py | 0 src/chat/maibot_llmreq/config/config.py | 76 +++ src/chat/maibot_llmreq/config/parser.py | 267 +++++++++ src/chat/maibot_llmreq/exceptions.py | 69 +++ .../maibot_llmreq/model_client/__init__.py | 363 ++++++++++++ .../maibot_llmreq/model_client/base_client.py | 116 ++++ .../model_client/gemini_client.py | 481 +++++++++++++++ .../model_client/openai_client.py | 548 ++++++++++++++++++ src/chat/maibot_llmreq/model_manager.py | 79 +++ .../maibot_llmreq/payload_content/message.py | 104 ++++ .../payload_content/resp_format.py | 223 +++++++ .../payload_content/tool_option.py | 155 +++++ .../maibot_llmreq/tests/test_config_load.py | 84 +++ src/chat/maibot_llmreq/usage_statistic.py | 182 ++++++ src/chat/maibot_llmreq/utils.py | 150 +++++ template/model_config_template.toml | 77 +++ 18 files changed, 3014 insertions(+) create mode 100644 src/chat/maibot_llmreq/LICENSE create mode 100644 src/chat/maibot_llmreq/__init__.py create mode 100644 src/chat/maibot_llmreq/config/__init__.py create mode 100644 src/chat/maibot_llmreq/config/config.py create mode 100644 src/chat/maibot_llmreq/config/parser.py create mode 100644 src/chat/maibot_llmreq/exceptions.py create mode 100644 src/chat/maibot_llmreq/model_client/__init__.py create mode 100644 src/chat/maibot_llmreq/model_client/base_client.py create mode 100644 src/chat/maibot_llmreq/model_client/gemini_client.py create mode 100644 src/chat/maibot_llmreq/model_client/openai_client.py create mode 100644 src/chat/maibot_llmreq/model_manager.py create mode 100644 src/chat/maibot_llmreq/payload_content/message.py create mode 100644 src/chat/maibot_llmreq/payload_content/resp_format.py create mode 100644 src/chat/maibot_llmreq/payload_content/tool_option.py create mode 100644 src/chat/maibot_llmreq/tests/test_config_load.py create mode 100644 src/chat/maibot_llmreq/usage_statistic.py create mode 100644 src/chat/maibot_llmreq/utils.py create mode 100644 template/model_config_template.toml diff --git a/src/chat/maibot_llmreq/LICENSE b/src/chat/maibot_llmreq/LICENSE new file mode 100644 index 000000000..8b3236ed5 --- /dev/null +++ b/src/chat/maibot_llmreq/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Mai.To.The.Gate + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/chat/maibot_llmreq/__init__.py b/src/chat/maibot_llmreq/__init__.py new file mode 100644 index 000000000..aab812cfa --- /dev/null +++ b/src/chat/maibot_llmreq/__init__.py @@ -0,0 +1,19 @@ +import loguru + +type LoguruLogger = loguru.Logger + +_logger: LoguruLogger = loguru.logger + + +def init_logger( + logger: LoguruLogger | None = None, +): + """ + 对LLMRequest模块进行配置 + :param logger: 日志对象 + """ + global _logger # 申明使用全局变量 + if logger: + _logger = logger + else: + _logger.warning("Warning: No logger provided, using default logger.") diff --git a/src/chat/maibot_llmreq/config/__init__.py b/src/chat/maibot_llmreq/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/chat/maibot_llmreq/config/config.py b/src/chat/maibot_llmreq/config/config.py new file mode 100644 index 000000000..59b3d2b67 --- /dev/null +++ b/src/chat/maibot_llmreq/config/config.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass, field +from typing import List, Dict + +from packaging.version import Version + +NEWEST_VER = "0.1.0" # 当前支持的最新版本 + + +@dataclass +class APIProvider: + name: str = "" # API提供商名称 + base_url: str = "" # API基础URL + api_key: str = field(repr=False, default="") # API密钥 + client_type: str = "openai" # 客户端类型(如openai/google等,默认为openai) + + +@dataclass +class ModelInfo: + model_identifier: str = "" # 模型标识符(用于URL调用) + name: str = "" # 模型名称(用于模块调用) + api_provider: str = "" # API提供商(如OpenAI、Azure等) + + # 以下用于模型计费 + price_in: float = 0.0 # 每M token输入价格 + price_out: float = 0.0 # 每M token输出价格 + + force_stream_mode: bool = False # 是否强制使用流式输出模式 + + +@dataclass +class RequestConfig: + max_retry: int = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) + timeout: int = ( + 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) + ) + retry_interval: int = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) + default_temperature: float = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) + default_max_tokens: int = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) + + +@dataclass +class ModelUsageArgConfigItem: + """模型使用的配置类 + 该类用于加载和存储子任务模型使用的配置 + """ + + name: str = "" # 模型名称 + temperature: float | None = None # 温度 + max_tokens: int | None = None # 最大token数 + max_retry: int | None = None # 调用失败时的最大重试次数 + + +@dataclass +class ModelUsageArgConfig: + """子任务使用模型的配置类 + 该类用于加载和存储子任务使用的模型配置 + """ + + name: str = "" # 任务名称 + usage: List[ModelUsageArgConfigItem] = field( + default_factory=lambda: [] + ) # 任务使用的模型列表 + + +@dataclass +class ModuleConfig: + INNER_VERSION: Version | None = None # 配置文件版本 + + req_conf: RequestConfig = field(default_factory=lambda: RequestConfig()) # 请求配置 + api_providers: Dict[str, APIProvider] = field( + default_factory=lambda: {} + ) # API提供商列表 + models: Dict[str, ModelInfo] = field(default_factory=lambda: {}) # 模型列表 + task_model_arg_map: Dict[str, ModelUsageArgConfig] = field( + default_factory=lambda: {} + ) diff --git a/src/chat/maibot_llmreq/config/parser.py b/src/chat/maibot_llmreq/config/parser.py new file mode 100644 index 000000000..a6877835d --- /dev/null +++ b/src/chat/maibot_llmreq/config/parser.py @@ -0,0 +1,267 @@ +import os +from typing import Any, Dict, List + +import tomli +from packaging import version +from packaging.specifiers import SpecifierSet +from packaging.version import Version, InvalidVersion + +from .. import _logger as logger + +from .config import ( + ModelUsageArgConfigItem, + ModelUsageArgConfig, + APIProvider, + ModelInfo, + NEWEST_VER, + ModuleConfig, +) + + +def _get_config_version(toml: Dict) -> Version: + """提取配置文件的 SpecifierSet 版本数据 + Args: + toml[dict]: 输入的配置文件字典 + Returns: + Version + """ + + if "inner" in toml and "version" in toml["inner"]: + config_version: str = toml["inner"]["version"] + else: + config_version = "0.0.0" # 默认版本 + + try: + ver = version.parse(config_version) + except InvalidVersion as e: + logger.error( + "配置文件中 inner段 的 version 键是错误的版本描述\n" + f"请检查配置文件,当前 version 键: {config_version}\n" + f"错误信息: {e}" + ) + raise InvalidVersion( + "配置文件中 inner段 的 version 键是错误的版本描述\n" + ) from e + + return ver + + +def _request_conf(parent: Dict, config: ModuleConfig): + request_conf_config = parent.get("request_conf") + config.req_conf.max_retry = request_conf_config.get( + "max_retry", config.req_conf.max_retry + ) + config.req_conf.timeout = request_conf_config.get( + "timeout", config.req_conf.timeout + ) + config.req_conf.retry_interval = request_conf_config.get( + "retry_interval", config.req_conf.retry_interval + ) + config.req_conf.default_temperature = request_conf_config.get( + "default_temperature", config.req_conf.default_temperature + ) + config.req_conf.default_max_tokens = request_conf_config.get( + "default_max_tokens", config.req_conf.default_max_tokens + ) + + +def _api_providers(parent: Dict, config: ModuleConfig): + api_providers_config = parent.get("api_providers") + for provider in api_providers_config: + name = provider.get("name", None) + base_url = provider.get("base_url", None) + api_key = provider.get("api_key", None) + client_type = provider.get("client_type", "openai") + + if name in config.api_providers: # 查重 + logger.error(f"重复的API提供商名称: {name},请检查配置文件。") + raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。") + + if name and base_url: + config.api_providers[name] = APIProvider( + name=name, + base_url=base_url, + api_key=api_key, + client_type=client_type, + ) + else: + logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") + raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") + + +def _models(parent: Dict, config: ModuleConfig): + models_config = parent.get("models") + for model in models_config: + model_identifier = model.get("model_identifier", None) + name = model.get("name", model_identifier) + api_provider = model.get("api_provider", None) + price_in = model.get("price_in", 0.0) + price_out = model.get("price_out", 0.0) + force_stream_mode = model.get("force_stream_mode", False) + + if name in config.models: # 查重 + logger.error(f"重复的模型名称: {name},请检查配置文件。") + raise KeyError(f"重复的模型名称: {name},请检查配置文件。") + + if model_identifier and api_provider: + # 检查API提供商是否存在 + if api_provider not in config.api_providers: + logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。") + raise ValueError( + f"未声明的API提供商 '{api_provider}' ,请检查配置文件。" + ) + config.models[name] = ModelInfo( + name=name, + model_identifier=model_identifier, + api_provider=api_provider, + price_in=price_in, + price_out=price_out, + force_stream_mode=force_stream_mode, + ) + else: + logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。") + raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。") + + +def _task_model_usage(parent: Dict, config: ModuleConfig): + model_usage_configs = parent.get("task_model_usage") + config.task_model_arg_map = {} + for task_name, item in model_usage_configs.items(): + if task_name in config.task_model_arg_map: + logger.error(f"子任务 {task_name} 已存在,请检查配置文件。") + raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。") + + usage = [] + if isinstance(item, Dict): + if "model" in item: + usage.append( + ModelUsageArgConfigItem( + name=item["model"], + temperature=item.get("temperature", None), + max_tokens=item.get("max_tokens", None), + max_retry=item.get("max_retry", None), + ) + ) + else: + logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。") + raise ValueError( + f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" + ) + elif isinstance(item, List): + for model in item: + if isinstance(model, Dict): + usage.append( + ModelUsageArgConfigItem( + name=model["model"], + temperature=model.get("temperature", None), + max_tokens=model.get("max_tokens", None), + max_retry=model.get("max_retry", None), + ) + ) + elif isinstance(model, str): + usage.append( + ModelUsageArgConfigItem( + name=model, + temperature=None, + max_tokens=None, + max_retry=None, + ) + ) + else: + logger.error( + f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" + ) + raise ValueError( + f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" + ) + elif isinstance(item, str): + usage.append( + ModelUsageArgConfigItem( + name=item, + temperature=None, + max_tokens=None, + max_retry=None, + ) + ) + + config.task_model_arg_map[task_name] = ModelUsageArgConfig( + name=task_name, + usage=usage, + ) + + +def load_config(config_path: str) -> ModuleConfig: + """从TOML配置文件加载配置""" + config = ModuleConfig() + + include_configs: Dict[str, Dict[str, Any]] = { + "request_conf": { + "func": _request_conf, + "support": ">=0.0.0", + "necessary": False, + }, + "api_providers": {"func": _api_providers, "support": ">=0.0.0"}, + "models": {"func": _models, "support": ">=0.0.0"}, + "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"}, + } + + if os.path.exists(config_path): + with open(config_path, "rb") as f: + try: + toml_dict = tomli.load(f) + except tomli.TOMLDecodeError as e: + logger.critical( + f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}" + ) + exit(1) + + # 获取配置文件版本 + config.INNER_VERSION = _get_config_version(toml_dict) + + # 检查版本 + if config.INNER_VERSION > Version(NEWEST_VER): + logger.warning( + f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。" + ) + + # 解析配置文件 + # 如果在配置中找到了需要的项,调用对应项的闭包函数处理 + for key in include_configs: + if key in toml_dict: + group_specifier_set: SpecifierSet = SpecifierSet( + include_configs[key]["support"] + ) + + # 检查配置文件版本是否在支持范围内 + if config.INNER_VERSION in group_specifier_set: + # 如果版本在支持范围内,检查是否存在通知 + if "notice" in include_configs[key]: + logger.warning(include_configs[key]["notice"]) + # 调用闭包函数处理配置 + (include_configs[key]["func"])(toml_dict, config) + else: + # 如果版本不在支持范围内,崩溃并提示用户 + logger.error( + f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n" + f"当前程序仅支持以下版本范围: {group_specifier_set}" + ) + raise InvalidVersion( + f"当前程序仅支持以下版本范围: {group_specifier_set}" + ) + + # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理 + elif ( + "necessary" in include_configs[key] + and include_configs[key].get("necessary") is False + ): + # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 + if key == "keywords_reaction": + pass + else: + # 如果用户根本没有需要的配置项,提示缺少配置 + logger.error(f"配置文件中缺少必需的字段: '{key}'") + raise KeyError(f"配置文件中缺少必需的字段: '{key}'") + + logger.success(f"成功加载配置文件: {config_path}") + + return config diff --git a/src/chat/maibot_llmreq/exceptions.py b/src/chat/maibot_llmreq/exceptions.py new file mode 100644 index 000000000..0ced8dd14 --- /dev/null +++ b/src/chat/maibot_llmreq/exceptions.py @@ -0,0 +1,69 @@ +from typing import Any + + +# 常见Error Code Mapping (以OpenAI API为例) +error_code_mapping = { + 400: "参数不正确", + 401: "API-Key错误,认证失败,请检查/config/model_list.toml中的配置是否正确", + 402: "账号余额不足", + 403: "模型拒绝访问,可能需要实名或余额不足", + 404: "Not Found", + 413: "请求体过大,请尝试压缩图片或减少输入内容", + 429: "请求过于频繁,请稍后再试", + 500: "服务器内部故障", + 503: "服务器负载过高", +} + + +class NetworkConnectionError(Exception): + """连接异常,常见于网络问题或服务器不可用""" + + def __init__(self): + super().__init__() + + def __str__(self): + return "连接异常,请检查网络连接状态或URL是否正确" + + +class ReqAbortException(Exception): + """请求异常退出,常见于请求被中断或取消""" + + def __init__(self, message: str | None = None): + super().__init__(message) + self.message = message + + def __str__(self): + return self.message or "请求因未知原因异常终止" + + +class RespNotOkException(Exception): + """请求响应异常,见于请求未能成功响应(非 '200 OK')""" + + def __init__(self, status_code: int, message: str | None = None): + super().__init__(message) + self.status_code = status_code + self.message = message + + def __str__(self): + if self.status_code in error_code_mapping: + return error_code_mapping[self.status_code] + elif self.message: + return self.message + else: + return f"未知的异常响应代码:{self.status_code}" + + +class RespParseException(Exception): + """响应解析错误,常见于响应格式不正确或解析方法不匹配""" + + def __init__(self, ext_info: Any, message: str | None = None): + super().__init__(message) + self.ext_info = ext_info + self.message = message + + def __str__(self): + return ( + self.message + if self.message + else "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法" + ) diff --git a/src/chat/maibot_llmreq/model_client/__init__.py b/src/chat/maibot_llmreq/model_client/__init__.py new file mode 100644 index 000000000..9dc28d07d --- /dev/null +++ b/src/chat/maibot_llmreq/model_client/__init__.py @@ -0,0 +1,363 @@ +import asyncio +from typing import Callable, Any + +from openai import AsyncStream +from openai.types.chat import ChatCompletionChunk, ChatCompletion + +from .base_client import BaseClient, APIResponse +from .. import _logger as logger +from ..config.config import ( + ModelInfo, + ModelUsageArgConfigItem, + RequestConfig, + ModuleConfig, +) +from ..exceptions import ( + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from ..payload_content.message import Message +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption +from ..utils import compress_messages + + +def _check_retry( + remain_try: int, + retry_interval: int, + can_retry_msg: str, + cannot_retry_msg: str, + can_retry_callable: Callable | None = None, + **kwargs, +) -> tuple[int, Any | None]: + """ + 辅助函数:检查是否可以重试 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param can_retry_msg: 可以重试时的提示信息 + :param cannot_retry_msg: 不可以重试时的提示信息 + :return: (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + if remain_try > 0: + # 还有重试机会 + logger.warning(f"{can_retry_msg}") + if can_retry_callable is not None: + return retry_interval, can_retry_callable(**kwargs) + else: + return retry_interval, None + else: + # 达到最大重试次数 + logger.warning(f"{cannot_retry_msg}") + return -1, None # 不再重试请求该模型 + + +def _handle_resp_not_ok( + e: RespNotOkException, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, +): + """ + 处理响应错误异常 + :param e: 异常对象 + :param task_name: 任务名称 + :param model_name: 模型名称 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param messages: (消息列表, 是否已压缩过) + :return: (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + # 响应错误 + if e.status_code in [400, 401, 402, 403, 404]: + # 客户端错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"请求失败,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None # 不再重试请求该模型 + elif e.status_code == 413: + if messages and not messages[1]: + # 消息列表不为空且未压缩,尝试压缩消息 + return _check_retry( + remain_try, + 0, + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,尝试压缩消息后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,压缩消息后仍然过大,放弃请求" + ), + can_retry_callable=compress_messages, + messages=messages[0], + ) + # 没有消息可压缩 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,无法压缩消息,放弃请求。" + ) + return -1, None + elif e.status_code == 429: + # 请求过于频繁 + return _check_retry( + remain_try, + retry_interval, + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"请求过于频繁,将于{retry_interval}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求过于频繁,超过最大重试次数,放弃请求" + ), + ) + elif e.status_code >= 500: + # 服务器错误 + return _check_retry( + remain_try, + retry_interval, + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"服务器错误,将于{retry_interval}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "服务器错误,超过最大重试次数,请稍后再试" + ), + ) + else: + # 未知错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"未知错误,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None + + +def default_exception_handler( + e: Exception, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, +) -> tuple[int, list[Message] | None]: + """ + 默认异常处理函数 + :param e: 异常对象 + :param task_name: 任务名称 + :param model_name: 模型名称 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param messages: (消息列表, 是否已压缩过) + :return (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + + if isinstance(e, NetworkConnectionError): # 网络连接错误 + return _check_retry( + remain_try, + retry_interval, + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"连接异常,将于{retry_interval}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确" + ), + ) + elif isinstance(e, ReqAbortException): + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n请求被中断,详细信息-{str(e.message)}" + ) + return -1, None # 不再重试请求该模型 + elif isinstance(e, RespNotOkException): + return _handle_resp_not_ok( + e, + task_name, + model_name, + remain_try, + retry_interval, + messages, + ) + elif isinstance(e, RespParseException): + # 响应解析错误 + logger.error( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"响应解析错误,错误信息-{e.message}\n" + ) + logger.debug(f"附加内容:\n{str(e.ext_info)}") + return -1, None # 不再重试请求该模型 + else: + logger.error( + f"任务-'{task_name}' 模型-'{model_name}'\n未知异常,错误信息-{str(e)}" + ) + return -1, None # 不再重试请求该模型 + + +class ModelRequestHandler: + """ + 模型请求处理器 + """ + + def __init__( + self, + task_name: str, + config: ModuleConfig, + api_client_map: dict[str, BaseClient], + ): + self.task_name: str = task_name + """任务名称""" + + self.client_map: dict[str, BaseClient] = {} + """API客户端列表""" + + self.configs: list[tuple[ModelInfo, ModelUsageArgConfigItem]] = [] + """模型参数配置""" + + self.req_conf: RequestConfig = config.req_conf + """请求配置""" + + # 获取模型与使用配置 + for model_usage in config.task_model_arg_map[task_name].usage: + if model_usage.name not in config.models: + logger.error(f"Model '{model_usage.name}' not found in ModelManager") + raise KeyError(f"Model '{model_usage.name}' not found in ModelManager") + model_info = config.models[model_usage.name] + + if model_info.api_provider not in self.client_map: + # 缓存API客户端 + self.client_map[model_info.api_provider] = api_client_map[ + model_info.api_provider + ] + + self.configs.append((model_info, model_usage)) # 添加模型与使用配置 + + async def get_response( + self, + messages: list[Message], + tool_options: list[ToolOption] | None = None, + response_format: RespFormat | None = None, # 暂不启用 + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], APIResponse + ] + | None = None, + async_response_parser: Callable[[ChatCompletion], APIResponse] | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param messages: 消息列表 + :param tool_options: 工具选项列表 + :param response_format: 响应格式 + :param stream_response_handler: 流式响应处理函数(可选) + :param async_response_parser: 响应解析函数(可选) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: APIResponse + """ + # 遍历可用模型,若获取响应失败,则使用下一个模型继续请求 + for config_item in self.configs: + client = self.client_map[config_item[0].api_provider] + model_info: ModelInfo = config_item[0] + model_usage_config: ModelUsageArgConfigItem = config_item[1] + + remain_try = ( + model_usage_config.max_retry or self.req_conf.max_retry + ) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1 + + compressed_messages = None + retry_interval = self.req_conf.retry_interval + while remain_try > 0: + try: + return await client.get_response( + model_info, + message_list=(compressed_messages or messages), + tool_options=tool_options, + max_tokens=model_usage_config.max_tokens + or self.req_conf.default_max_tokens, + temperature=model_usage_config.temperature + or self.req_conf.default_temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, + ) + except Exception as e: + logger.trace(e) + remain_try -= 1 # 剩余尝试次数减1 + + # 处理异常 + handle_res = default_exception_handler( + e, + self.task_name, + model_info.name, + remain_try, + retry_interval=self.req_conf.retry_interval, + messages=(messages, compressed_messages is not None), + ) + + if handle_res[0] == -1: + # 等待间隔为-1,表示不再请求该模型 + remain_try = 0 + elif handle_res[0] != 0: + # 等待间隔不为0,表示需要等待 + await asyncio.sleep(handle_res[0]) + retry_interval *= 2 + + if handle_res[1] is not None: + # 压缩消息 + compressed_messages = handle_res[1] + + logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用") + raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败 + + async def get_embedding( + self, + embedding_input: str, + ) -> APIResponse: + """ + 获取嵌入向量 + :param embedding_input: 嵌入输入 + :return: APIResponse + """ + for config in self.configs: + client = self.client_map[config[0].api_provider] + model_info: ModelInfo = config[0] + model_usage_config: ModelUsageArgConfigItem = config[1] + remain_try = ( + model_usage_config.max_retry or self.req_conf.max_retry + ) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1 + + while remain_try: + try: + return await client.get_embedding( + model_info=model_info, + embedding_input=embedding_input, + ) + except Exception as e: + logger.trace(e) + remain_try -= 1 # 剩余尝试次数减1 + + # 处理异常 + handle_res = default_exception_handler( + e, + self.task_name, + model_info.name, + remain_try, + retry_interval=self.req_conf.retry_interval, + ) + + if handle_res[0] == -1: + # 等待间隔为-1,表示不再请求该模型 + remain_try = 0 + elif handle_res[0] != 0: + # 等待间隔不为0,表示需要等待 + await asyncio.sleep(handle_res[0]) + + logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用") + raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败 diff --git a/src/chat/maibot_llmreq/model_client/base_client.py b/src/chat/maibot_llmreq/model_client/base_client.py new file mode 100644 index 000000000..ed877a6c9 --- /dev/null +++ b/src/chat/maibot_llmreq/model_client/base_client.py @@ -0,0 +1,116 @@ +import asyncio +from dataclasses import dataclass +from typing import Callable, Any + +from openai import AsyncStream +from openai.types.chat import ChatCompletionChunk, ChatCompletion + +from ..config.config import ModelInfo, APIProvider +from ..payload_content.message import Message +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption, ToolCall + + +@dataclass +class UsageRecord: + """ + 使用记录类 + """ + + model_name: str + """模型名称""" + + provider_name: str + """提供商名称""" + + prompt_tokens: int + """提示token数""" + + completion_tokens: int + """完成token数""" + + total_tokens: int + """总token数""" + + +@dataclass +class APIResponse: + """ + API响应类 + """ + + content: str | None = None + """响应内容""" + + reasoning_content: str | None = None + """推理内容""" + + tool_calls: list[ToolCall] | None = None + """工具调用 [(工具名称, 工具参数), ...]""" + + embedding: list[float] | None = None + """嵌入向量""" + + usage: UsageRecord | None = None + """使用情况 (prompt_tokens, completion_tokens, total_tokens)""" + + raw_data: Any = None + """响应原始数据""" + + +class BaseClient: + """ + 基础客户端 + """ + + api_provider: APIProvider + + def __init__(self, api_provider: APIProvider): + self.api_provider = api_provider + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + tuple[APIResponse, tuple[int, int, int]], + ] + | None = None, + async_response_parser: Callable[ + [ChatCompletion], tuple[APIResponse, tuple[int, int, int]] + ] + | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param model_info: 模型信息 + :param message_list: 对话体 + :param tool_options: 工具选项(可选,默认为None) + :param max_tokens: 最大token数(可选,默认为1024) + :param temperature: 温度(可选,默认为0.7) + :param response_format: 响应格式(可选,默认为 NotGiven ) + :param stream_response_handler: 流式响应处理函数(可选) + :param async_response_parser: 响应解析函数(可选) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: (响应文本, 推理文本, 工具调用, 其他数据) + """ + raise RuntimeError("This method should be overridden in subclasses") + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + raise RuntimeError("This method should be overridden in subclasses") diff --git a/src/chat/maibot_llmreq/model_client/gemini_client.py b/src/chat/maibot_llmreq/model_client/gemini_client.py new file mode 100644 index 000000000..75d2767e0 --- /dev/null +++ b/src/chat/maibot_llmreq/model_client/gemini_client.py @@ -0,0 +1,481 @@ +import asyncio +import io +from collections.abc import Iterable +from typing import Callable, Iterator, TypeVar, AsyncIterator + +from google import genai +from google.genai import types +from google.genai.types import FunctionDeclaration, GenerateContentResponse +from google.genai.errors import ( + ClientError, + ServerError, + UnknownFunctionCallArgumentError, + UnsupportedFunctionError, + FunctionInvocationError, +) + +from .base_client import APIResponse, UsageRecord +from ..config.config import ModelInfo, APIProvider +from . import BaseClient + +from ..exceptions import ( + RespParseException, + NetworkConnectionError, + RespNotOkException, + ReqAbortException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat, RespFormatType +from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall + +T = TypeVar("T") + + +def _convert_messages( + messages: list[Message], +) -> tuple[list[types.Content], list[str] | None]: + """ + 转换消息格式 - 将消息转换为Gemini API所需的格式 + :param messages: 消息列表 + :return: 转换后的消息列表(和可能存在的system消息) + """ + + def _convert_message_item(message: Message) -> types.Content: + """ + 转换单个消息格式,除了system和tool类型的消息 + :param message: 消息对象 + :return: 转换后的消息字典 + """ + + # 将openai格式的角色重命名为gemini格式的角色 + if message.role == RoleType.Assistant: + role = "model" + elif message.role == RoleType.User: + role = "user" + + # 添加Content + content: types.Part | list + if isinstance(message.content, str): + content = types.Part.from_text(message.content) + elif isinstance(message.content, list): + content = [] + for item in message.content: + if isinstance(item, tuple): + content.append( + types.Part.from_bytes( + data=item[1], mime_type=f"image/{item[0].lower()}" + ) + ) + elif isinstance(item, str): + content.append(types.Part.from_text(item)) + else: + raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + + return types.Content(role=role, content=content) + + temp_list: list[types.Content] = [] + system_instructions: list[str] = [] + for message in messages: + if message.role == RoleType.System: + if isinstance(message.content, str): + system_instructions.append(message.content) + else: + raise RuntimeError("你tm怎么往system里面塞图片base64?") + elif message.role == RoleType.Tool: + if not message.tool_call_id: + raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") + else: + temp_list.append(_convert_message_item(message)) + if system_instructions: + # 如果有system消息,就把它加上去 + ret: tuple = (temp_list, system_instructions) + else: + # 如果没有system消息,就直接返回 + ret: tuple = (temp_list, None) + + return ret + + +def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]: + """ + 转换工具选项格式 - 将工具选项转换为Gemini API所需的格式 + :param tool_options: 工具选项列表 + :return: 转换后的工具对象列表 + """ + + def _convert_tool_param(tool_option_param: ToolParam) -> dict: + """ + 转换单个工具参数格式 + :param tool_option_param: 工具参数对象 + :return: 转换后的工具参数字典 + """ + return { + "type": tool_option_param.param_type.value, + "description": tool_option_param.description, + } + + def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration: + """ + 转换单个工具项格式 + :param tool_option: 工具选项对象 + :return: 转换后的Gemini工具选项对象 + """ + ret = { + "name": tool_option.name, + "description": tool_option.description, + } + if tool_option.params: + ret["parameters"] = { + "type": "object", + "properties": { + param.name: _convert_tool_param(param) + for param in tool_option.params + }, + "required": [ + param.name for param in tool_option.params if param.required + ], + } + ret1 = types.FunctionDeclaration(**ret) + return ret1 + + return [_convert_tool_option_item(tool_option) for tool_option in tool_options] + + +def _process_delta( + delta: GenerateContentResponse, + fc_delta_buffer: io.StringIO, + tool_calls_buffer: list[tuple[str, str, dict]], +): + if not hasattr(delta, "candidates") or len(delta.candidates) == 0: + raise RespParseException(delta, "响应解析失败,缺失candidates字段") + + if delta.text: + fc_delta_buffer.write(delta.text) + + if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的 + for call in delta.function_calls: + try: + if not isinstance( + call.args, dict + ): # gemini返回的function call参数就是dict格式的了 + raise RespParseException( + delta, "响应解析失败,工具调用参数无法解析为字典类型" + ) + tool_calls_buffer.append( + ( + call.id, + call.name, + call.args, + ) + ) + except Exception as e: + raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e + + +def _build_stream_api_resp( + _fc_delta_buffer: io.StringIO, + _tool_calls_buffer: list[tuple[str, str, dict]], +) -> APIResponse: + resp = APIResponse() + + if _fc_delta_buffer.tell() > 0: + # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 + resp.content = _fc_delta_buffer.getvalue() + _fc_delta_buffer.close() + if len(_tool_calls_buffer) > 0: + # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 + resp.tool_calls = [] + for call_id, function_name, arguments_buffer in _tool_calls_buffer: + if arguments_buffer is not None: + arguments = arguments_buffer + if not isinstance(arguments, dict): + raise RespParseException( + None, + "响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n" + f"{arguments_buffer}", + ) + else: + arguments = None + + resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + + return resp + + +async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]: + """ + 将迭代器转换为异步迭代器 + :param iterable: 迭代器对象 + :return: 异步迭代器对象 + """ + for item in iterable: + await asyncio.sleep(0) + yield item + + +async def _default_stream_response_handler( + resp_stream: Iterator[GenerateContentResponse], + interrupt_flag: asyncio.Event | None, +) -> tuple[APIResponse, tuple[int, int, int]]: + """ + 流式响应处理函数 - 处理Gemini API的流式响应 + :param resp_stream: 流式响应对象,是一个神秘的iterator,我完全不知道这个玩意能不能跑,不过遍历一遍之后它就空了,如果跑不了一点的话可以考虑改成别的东西 + :return: APIResponse对象 + """ + _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 + _tool_calls_buffer: list[ + tuple[str, str, dict] + ] = [] # 工具调用缓冲区,用于存储接收到的工具调用 + _usage_record = None # 使用情况记录 + + def _insure_buffer_closed(): + if _fc_delta_buffer and not _fc_delta_buffer.closed: + _fc_delta_buffer.close() + + async for chunk in _to_async_iterable(resp_stream): + # 检查是否有中断量 + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量被设置,则抛出ReqAbortException + raise ReqAbortException("请求被外部信号中断") + + _process_delta( + chunk, + _fc_delta_buffer, + _tool_calls_buffer, + ) + + if chunk.usage_metadata: + # 如果有使用情况,则将其存储在APIResponse对象中 + _usage_record = ( + chunk.usage_metadata.prompt_token_count, + chunk.usage_metadata.candidates_token_count + + chunk.usage_metadata.thoughts_token_count, + chunk.usage_metadata.total_token_count, + ) + try: + return _build_stream_api_resp( + _fc_delta_buffer, + _tool_calls_buffer, + ), _usage_record + except Exception: + # 确保缓冲区被关闭 + _insure_buffer_closed() + raise + + +def _default_normal_response_parser( + resp: GenerateContentResponse, +) -> tuple[APIResponse, tuple[int, int, int]]: + """ + 解析对话补全响应 - 将Gemini API响应解析为APIResponse对象 + :param resp: 响应对象 + :return: APIResponse对象 + """ + api_response = APIResponse() + + if not hasattr(resp, "candidates") or len(resp.candidates) == 0: + raise RespParseException(resp, "响应解析失败,缺失candidates字段") + + if resp.text: + api_response.content = resp.text + + if resp.function_calls: + api_response.tool_calls = [] + for call in resp.function_calls: + try: + if not isinstance(call.args, dict): + raise RespParseException( + resp, "响应解析失败,工具调用参数无法解析为字典类型" + ) + api_response.tool_calls.append(ToolCall(call.id, call.name, call.args)) + except Exception as e: + raise RespParseException( + resp, "响应解析失败,无法解析工具调用参数" + ) from e + + if resp.usage_metadata: + _usage_record = ( + resp.usage_metadata.prompt_token_count, + resp.usage_metadata.candidates_token_count + + resp.usage_metadata.thoughts_token_count, + resp.usage_metadata.total_token_count, + ) + else: + _usage_record = None + + api_response.raw_data = resp + + return api_response, _usage_record + + +class GeminiClient(BaseClient): + client: genai.Client + + def __init__(self, api_provider: APIProvider): + super().__init__(api_provider) + self.client = genai.Client( + api_key=api_provider.api_key, + ) # 这里和openai不一样,gemini会自己决定自己是否需要retry + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + thinking_budget: int = 0, + response_format: RespFormat | None = None, + stream_response_handler: Callable[ + [Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse + ] + | None = None, + async_response_parser: Callable[[GenerateContentResponse], APIResponse] + | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param model_info: 模型信息 + :param message_list: 对话体 + :param tool_options: 工具选项(可选,默认为None) + :param max_tokens: 最大token数(可选,默认为1024) + :param temperature: 温度(可选,默认为0.7) + :param thinking_budget: 思考预算(可选,默认为0) + :param response_format: 响应格式(默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的,暂不支持其它相应格式输入) + :param stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) + :param async_response_parser: 响应解析函数(可选,默认为default_response_parser) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: (响应文本, 推理文本, 工具调用, 其他数据) + """ + if stream_response_handler is None: + stream_response_handler = _default_stream_response_handler + + if async_response_parser is None: + async_response_parser = _default_normal_response_parser + + # 将messages构造为Gemini API所需的格式 + messages = _convert_messages(message_list) + # 将tool_options转换为Gemini API所需的格式 + tools = _convert_tool_options(tool_options) if tool_options else None + # 将response_format转换为Gemini API所需的格式 + generation_config_dict = { + "max_output_tokens": max_tokens, + "temperature": temperature, + "response_modalities": ["TEXT"], # 暂时只支持文本输出 + } + if "2.5" in model_info.model_identifier.lower(): + # 我偷个懒,在这里识别一下2.5然后开摆,反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容 + generation_config_dict["thinking_config"] = types.ThinkingConfig( + thinking_budget=thinking_budget, include_thoughts=False + ) + if tools: + generation_config_dict["tools"] = types.Tool(tools) + if messages[1]: + # 如果有system消息,则将其添加到配置中 + generation_config_dict["system_instructions"] = messages[1] + if response_format and response_format.format_type == RespFormatType.TEXT: + generation_config_dict["response_mime_type"] = "text/plain" + elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA): + generation_config_dict["response_mime_type"] = "application/json" + generation_config_dict["response_schema"] = response_format.to_dict() + + generation_config = types.GenerateContentConfig(**generation_config_dict) + + try: + if model_info.force_stream_mode: + req_task = asyncio.create_task( + self.client.aio.models.generate_content_stream( + model=model_info.model_identifier, + contents=messages[0], + config=generation_config, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 + resp, usage_record = await stream_response_handler( + req_task.result(), interrupt_flag + ) + else: + req_task = asyncio.create_task( + self.client.aio.models.generate_content( + model=model_info.model_identifier, + contents=messages[0], + config=generation_config, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 + + resp, usage_record = async_response_parser(req_task.result()) + except (ClientError, ServerError) as e: + # 重封装ClientError和ServerError为RespNotOkException + raise RespNotOkException(e.status_code, e.message) + except ( + UnknownFunctionCallArgumentError, + UnsupportedFunctionError, + FunctionInvocationError, + ) as e: + raise ValueError("工具类型错误:请检查工具选项和参数:" + str(e)) + except Exception as e: + raise NetworkConnectionError() from e + + if usage_record: + resp.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + try: + raw_response: types.EmbedContentResponse = ( + await self.client.aio.models.embed_content( + model=model_info.model_identifier, + contents=embedding_input, + config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), + ) + ) + except (ClientError, ServerError) as e: + # 重封装ClientError和ServerError为RespNotOkException + raise RespNotOkException(e.status_code) + except Exception as e: + raise NetworkConnectionError() from e + + response = APIResponse() + + # 解析嵌入响应和使用情况 + if hasattr(raw_response, "embeddings"): + response.embedding = raw_response.embeddings[0].values + else: + raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段") + + response.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=len(embedding_input), + completion_tokens=0, + total_tokens=len(embedding_input), + ) + + return response diff --git a/src/chat/maibot_llmreq/model_client/openai_client.py b/src/chat/maibot_llmreq/model_client/openai_client.py new file mode 100644 index 000000000..db256b2d4 --- /dev/null +++ b/src/chat/maibot_llmreq/model_client/openai_client.py @@ -0,0 +1,548 @@ +import asyncio +import io +import json +import re +from collections.abc import Iterable +from typing import Callable, Any + +from openai import ( + AsyncOpenAI, + APIConnectionError, + APIStatusError, + NOT_GIVEN, + AsyncStream, +) +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionToolParam, +) +from openai.types.chat.chat_completion_chunk import ChoiceDelta + +from .base_client import APIResponse, UsageRecord +from ..config.config import ModelInfo, APIProvider +from . import BaseClient + +from ..exceptions import ( + RespParseException, + NetworkConnectionError, + RespNotOkException, + ReqAbortException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall + + +def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]: + """ + 转换消息格式 - 将消息转换为OpenAI API所需的格式 + :param messages: 消息列表 + :return: 转换后的消息列表 + """ + + def _convert_message_item(message: Message) -> ChatCompletionMessageParam: + """ + 转换单个消息格式 + :param message: 消息对象 + :return: 转换后的消息字典 + """ + + # 添加Content + content: str | list[dict[str, Any]] + if isinstance(message.content, str): + content = message.content + elif isinstance(message.content, list): + content = [] + for item in message.content: + if isinstance(item, tuple): + content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/{item[0].lower()};base64,{item[1]}" + }, + } + ) + elif isinstance(item, str): + content.append({"type": "text", "text": item}) + else: + raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + + ret = { + "role": message.role.value, + "content": content, + } + + # 添加工具调用ID + if message.role == RoleType.Tool: + if not message.tool_call_id: + raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") + ret["tool_call_id"] = message.tool_call_id + + return ret + + return [_convert_message_item(message) for message in messages] + + +def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]: + """ + 转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式 + :param tool_options: 工具选项列表 + :return: 转换后的工具选项列表 + """ + + def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, str]: + """ + 转换单个工具参数格式 + :param tool_option_param: 工具参数对象 + :return: 转换后的工具参数字典 + """ + return { + "type": tool_option_param.param_type.value, + "description": tool_option_param.description, + } + + def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]: + """ + 转换单个工具项格式 + :param tool_option: 工具选项对象 + :return: 转换后的工具选项字典 + """ + ret: dict[str, Any] = { + "name": tool_option.name, + "description": tool_option.description, + } + if tool_option.params: + ret["parameters"] = { + "type": "object", + "properties": { + param.name: _convert_tool_param(param) + for param in tool_option.params + }, + "required": [ + param.name for param in tool_option.params if param.required + ], + } + return ret + + return [ + { + "type": "function", + "function": _convert_tool_option_item(tool_option), + } + for tool_option in tool_options + ] + + +def _process_delta( + delta: ChoiceDelta, + has_rc_attr_flag: bool, + in_rc_flag: bool, + rc_delta_buffer: io.StringIO, + fc_delta_buffer: io.StringIO, + tool_calls_buffer: list[tuple[str, str, io.StringIO]], +) -> bool: + # 接收content + if has_rc_attr_flag: + # 有独立的推理内容块,则无需考虑content内容的判读 + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + # 如果有推理内容,则将其写入推理内容缓冲区 + assert isinstance(delta.reasoning_content, str) + rc_delta_buffer.write(delta.reasoning_content) + elif delta.content: + # 如果有正式内容,则将其写入正式内容缓冲区 + fc_delta_buffer.write(delta.content) + elif hasattr(delta, "content") and delta.content is not None: + # 没有独立的推理内容块,但有正式内容 + if in_rc_flag: + # 当前在推理内容块中 + if delta.content == "": + # 如果当前内容是,则将其视为推理内容的结束标记,退出推理内容块 + in_rc_flag = False + else: + # 其他情况视为推理内容,加入推理内容缓冲区 + rc_delta_buffer.write(delta.content) + elif delta.content == "" and not fc_delta_buffer.getvalue(): + # 如果当前内容是,且正式内容缓冲区为空,说明为输出的首个token + # 则将其视为推理内容的开始标记,进入推理内容块 + in_rc_flag = True + else: + # 其他情况视为正式内容,加入正式内容缓冲区 + fc_delta_buffer.write(delta.content) + # 接收tool_calls + if hasattr(delta, "tool_calls") and delta.tool_calls: + tool_call_delta = delta.tool_calls[0] + + if tool_call_delta.index >= len(tool_calls_buffer): + # 调用索引号大于等于缓冲区长度,说明是新的工具调用 + tool_calls_buffer.append( + ( + tool_call_delta.id, + tool_call_delta.function.name, + io.StringIO(), + ) + ) + + if tool_call_delta.function.arguments: + # 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中 + tool_calls_buffer[tool_call_delta.index][2].write( + tool_call_delta.function.arguments + ) + + return in_rc_flag + + +def _build_stream_api_resp( + _fc_delta_buffer: io.StringIO, + _rc_delta_buffer: io.StringIO, + _tool_calls_buffer: list[tuple[str, str, io.StringIO]], +) -> APIResponse: + resp = APIResponse() + + if _rc_delta_buffer.tell() > 0: + # 如果推理内容缓冲区不为空,则将其写入APIResponse对象 + resp.reasoning_content = _rc_delta_buffer.getvalue() + _rc_delta_buffer.close() + if _fc_delta_buffer.tell() > 0: + # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 + resp.content = _fc_delta_buffer.getvalue() + _fc_delta_buffer.close() + if _tool_calls_buffer: + # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 + resp.tool_calls = [] + for call_id, function_name, arguments_buffer in _tool_calls_buffer: + if arguments_buffer.tell() > 0: + # 如果参数串缓冲区不为空,则解析为JSON对象 + raw_arg_data = arguments_buffer.getvalue() + arguments_buffer.close() + try: + arguments = json.loads(raw_arg_data) + if not isinstance(arguments, dict): + raise RespParseException( + None, + "响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n" + f"{raw_arg_data}", + ) + except json.JSONDecodeError as e: + raise RespParseException( + None, + "响应解析失败,无法解析工具调用参数。工具调用参数原始响应:" + f"{raw_arg_data}", + ) from e + else: + arguments_buffer.close() + arguments = None + + resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + + return resp + + +async def _default_stream_response_handler( + resp_stream: AsyncStream[ChatCompletionChunk], + interrupt_flag: asyncio.Event | None, +) -> tuple[APIResponse, tuple[int, int, int]]: + """ + 流式响应处理函数 - 处理OpenAI API的流式响应 + :param resp_stream: 流式响应对象 + :return: APIResponse对象 + """ + + _has_rc_attr_flag = False # 标记是否有独立的推理内容块 + _in_rc_flag = False # 标记是否在推理内容块中 + _rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容 + _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 + _tool_calls_buffer: list[ + tuple[str, str, io.StringIO] + ] = [] # 工具调用缓冲区,用于存储接收到的工具调用 + _usage_record = None # 使用情况记录 + + def _insure_buffer_closed(): + # 确保缓冲区被关闭 + if _rc_delta_buffer and not _rc_delta_buffer.closed: + _rc_delta_buffer.close() + if _fc_delta_buffer and not _fc_delta_buffer.closed: + _fc_delta_buffer.close() + for _, _, buffer in _tool_calls_buffer: + if buffer and not buffer.closed: + buffer.close() + + async for event in resp_stream: + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量被设置,则抛出ReqAbortException + _insure_buffer_closed() + raise ReqAbortException("请求被外部信号中断") + + delta = event.choices[0].delta # 获取当前块的delta内容 + + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + # 标记:有独立的推理内容块 + _has_rc_attr_flag = True + + _in_rc_flag = _process_delta( + delta, + _has_rc_attr_flag, + _in_rc_flag, + _rc_delta_buffer, + _fc_delta_buffer, + _tool_calls_buffer, + ) + + if event.usage: + # 如果有使用情况,则将其存储在APIResponse对象中 + _usage_record = ( + event.usage.prompt_tokens, + event.usage.completion_tokens, + event.usage.total_tokens, + ) + + try: + return _build_stream_api_resp( + _fc_delta_buffer, + _rc_delta_buffer, + _tool_calls_buffer, + ), _usage_record + except Exception: + # 确保缓冲区被关闭 + _insure_buffer_closed() + raise + + +pattern = re.compile( + r"(?P.*?)(?P.*)|(?P.*)|(?P.+)", + re.DOTALL, +) +"""用于解析推理内容的正则表达式""" + + +def _default_normal_response_parser( + resp: ChatCompletion, +) -> tuple[APIResponse, tuple[int, int, int]]: + """ + 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象 + :param resp: 响应对象 + :return: APIResponse对象 + """ + api_response = APIResponse() + + if not hasattr(resp, "choices") or len(resp.choices) == 0: + raise RespParseException(resp, "响应解析失败,缺失choices字段") + message_part = resp.choices[0].message + + if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: + # 有有效的推理字段 + api_response.content = message_part.content + api_response.reasoning_content = message_part.reasoning_content + elif message_part.content: + # 提取推理和内容 + match = pattern.match(message_part.content) + if not match: + raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容") + if match.group("think") is not None: + result = match.group("think").strip(), match.group("content").strip() + elif match.group("think_unclosed") is not None: + result = match.group("think_unclosed").strip(), None + else: + result = None, match.group("content_only").strip() + api_response.reasoning_content, api_response.content = result + + # 提取工具调用 + if message_part.tool_calls: + api_response.tool_calls = [] + for call in message_part.tool_calls: + try: + arguments = json.loads(call.function.arguments) + if not isinstance(arguments, dict): + raise RespParseException( + resp, "响应解析失败,工具调用参数无法解析为字典类型" + ) + api_response.tool_calls.append( + ToolCall(call.id, call.function.name, arguments) + ) + except json.JSONDecodeError as e: + raise RespParseException( + resp, "响应解析失败,无法解析工具调用参数" + ) from e + + # 提取Usage信息 + if resp.usage: + _usage_record = ( + resp.usage.prompt_tokens, + resp.usage.completion_tokens, + resp.usage.total_tokens, + ) + else: + _usage_record = None + + # 将原始响应存储在原始数据中 + api_response.raw_data = resp + + return api_response, _usage_record + + +class OpenaiClient(BaseClient): + def __init__(self, api_provider: APIProvider): + super().__init__(api_provider) + self.client: AsyncOpenAI = AsyncOpenAI( + base_url=api_provider.base_url, + api_key=api_provider.api_key, + max_retries=0, + ) + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + tuple[APIResponse, tuple[int, int, int]], + ] + | None = None, + async_response_parser: Callable[ + [ChatCompletion], tuple[APIResponse, tuple[int, int, int]] + ] + | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param model_info: 模型信息 + :param message_list: 对话体 + :param tool_options: 工具选项(可选,默认为None) + :param max_tokens: 最大token数(可选,默认为1024) + :param temperature: 温度(可选,默认为0.7) + :param response_format: 响应格式(可选,默认为 NotGiven ) + :param stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) + :param async_response_parser: 响应解析函数(可选,默认为default_response_parser) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: (响应文本, 推理文本, 工具调用, 其他数据) + """ + if stream_response_handler is None: + stream_response_handler = _default_stream_response_handler + + if async_response_parser is None: + async_response_parser = _default_normal_response_parser + + # 将messages构造为OpenAI API所需的格式 + messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) + # 将tool_options转换为OpenAI API所需的格式 + tools: Iterable[ChatCompletionToolParam] = ( + _convert_tool_options(tool_options) if tool_options else NOT_GIVEN + ) + + try: + if model_info.force_stream_mode: + req_task = asyncio.create_task( + self.client.chat.completions.create( + model=model_info.model_identifier, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=True, + response_format=response_format.to_dict() + if response_format + else NOT_GIVEN, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 + + resp, usage_record = await stream_response_handler( + req_task.result(), interrupt_flag + ) + else: + # 发送请求并获取响应 + req_task = asyncio.create_task( + self.client.chat.completions.create( + model=model_info.model_identifier, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + response_format=response_format.to_dict() + if response_format + else NOT_GIVEN, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 + + resp, usage_record = async_response_parser(req_task.result()) + except APIConnectionError as e: + # 重封装APIConnectionError为NetworkConnectionError + raise NetworkConnectionError() from e + except APIStatusError as e: + # 重封装APIError为RespNotOkException + raise RespNotOkException(e.status_code, e.message) from e + + if usage_record: + resp.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + return resp + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + try: + raw_response = await self.client.embeddings.create( + model=model_info.model_identifier, + input=embedding_input, + ) + except APIConnectionError as e: + raise NetworkConnectionError() from e + except APIStatusError as e: + # 重封装APIError为RespNotOkException + raise RespNotOkException(e.status_code) from e + + response = APIResponse() + + # 解析嵌入响应 + if len(raw_response.data) > 0: + response.embedding = raw_response.data[0].embedding + else: + raise RespParseException( + raw_response, + "响应解析失败,缺失嵌入数据。", + ) + + # 解析使用情况 + if hasattr(raw_response, "usage"): + response.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=raw_response.usage.prompt_tokens, + completion_tokens=raw_response.usage.completion_tokens, + total_tokens=raw_response.usage.total_tokens, + ) + + return response diff --git a/src/chat/maibot_llmreq/model_manager.py b/src/chat/maibot_llmreq/model_manager.py new file mode 100644 index 000000000..3056b187a --- /dev/null +++ b/src/chat/maibot_llmreq/model_manager.py @@ -0,0 +1,79 @@ +import importlib +from typing import Dict + + +from .config.config import ( + ModelUsageArgConfig, + ModuleConfig, +) + +from . import _logger as logger +from .model_client import ModelRequestHandler, BaseClient + + +class ModelManager: + # TODO: 添加读写锁,防止异步刷新配置时发生数据竞争 + + def __init__( + self, + config: ModuleConfig, + ): + self.config: ModuleConfig = config + """配置信息""" + + self.api_client_map: Dict[str, BaseClient] = {} + """API客户端映射表""" + + for provider_name, api_provider in self.config.api_providers.items(): + # 初始化API客户端 + try: + # 根据配置动态加载实现 + client_module = importlib.import_module( + f".model_client.{api_provider.client_type}_client", __package__ + ) + client_class = getattr( + client_module, f"{api_provider.client_type.capitalize()}Client" + ) + if not issubclass(client_class, BaseClient): + raise TypeError( + f"'{client_class.__name__}' is not a subclass of 'BaseClient'" + ) + self.api_client_map[api_provider.name] = client_class( + api_provider + ) # 实例化,放入api_client_map + except ImportError as e: + logger.error(f"Failed to import client module: {e}") + raise ImportError( + f"Failed to import client module for '{provider_name}': {e}" + ) from e + + def __getitem__(self, task_name: str) -> ModelRequestHandler: + """ + 获取任务所需的模型客户端(封装) + :param task_name: 任务名称 + :return: 模型客户端 + """ + if task_name not in self.config.task_model_arg_map: + raise KeyError(f"'{task_name}' not registered in ModelManager") + + return ModelRequestHandler( + task_name=task_name, + config=self.config, + api_client_map=self.api_client_map, + ) + + def __setitem__(self, task_name: str, value: ModelUsageArgConfig): + """ + 注册任务的模型使用配置 + :param task_name: 任务名称 + :param value: 模型使用配置 + """ + self.config.task_model_arg_map[task_name] = value + + def __contains__(self, task_name: str): + """ + 判断任务是否已注册 + :param task_name: 任务名称 + :return: 是否在模型列表中 + """ + return task_name in self.config.task_model_arg_map diff --git a/src/chat/maibot_llmreq/payload_content/message.py b/src/chat/maibot_llmreq/payload_content/message.py new file mode 100644 index 000000000..26202ca11 --- /dev/null +++ b/src/chat/maibot_llmreq/payload_content/message.py @@ -0,0 +1,104 @@ +from enum import Enum + + +# 设计这系列类的目的是为未来可能的扩展做准备 + + +class RoleType(Enum): + System = "system" + User = "user" + Assistant = "assistant" + Tool = "tool" + + +SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] + + +class Message: + def __init__( + self, + role: RoleType, + content: str | list[tuple[str, str] | str], + tool_call_id: str | None = None, + ): + """ + 初始化消息对象 + (不应直接修改Message类,而应使用MessageBuilder类来构建对象) + """ + self.role: RoleType = role + self.content: str | list[tuple[str, str] | str] = content + self.tool_call_id: str | None = tool_call_id + + +class MessageBuilder: + def __init__(self): + self.__role: RoleType = RoleType.User + self.__content: list[tuple[str, str] | str] = [] + self.__tool_call_id: str | None = None + + def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder": + """ + 设置角色(默认为User) + :param role: 角色 + :return: MessageBuilder对象 + """ + self.__role = role + return self + + def add_text_content(self, text: str) -> "MessageBuilder": + """ + 添加文本内容 + :param text: 文本内容 + :return: MessageBuilder对象 + """ + self.__content.append(text) + return self + + def add_image_content( + self, image_format: str, image_base64: str + ) -> "MessageBuilder": + """ + 添加图片内容 + :param image_format: 图片格式 + :param image_base64: 图片的base64编码 + :return: MessageBuilder对象 + """ + if image_format.lower() not in SUPPORTED_IMAGE_FORMATS: + raise ValueError("不受支持的图片格式") + if not image_base64: + raise ValueError("图片的base64编码不能为空") + self.__content.append((image_format, image_base64)) + return self + + def add_tool_call(self, tool_call_id: str) -> "MessageBuilder": + """ + 添加工具调用指令(调用时请确保已设置为Tool角色) + :param tool_call_id: 工具调用指令的id + :return: MessageBuilder对象 + """ + if self.__role != RoleType.Tool: + raise ValueError("仅当角色为Tool时才能添加工具调用ID") + if not tool_call_id: + raise ValueError("工具调用ID不能为空") + self.__tool_call_id = tool_call_id + return self + + def build(self) -> Message: + """ + 构建消息对象 + :return: Message对象 + """ + if len(self.__content) == 0: + raise ValueError("内容不能为空") + if self.__role == RoleType.Tool and self.__tool_call_id is None: + raise ValueError("Tool角色的工具调用ID不能为空") + + return Message( + role=self.__role, + content=( + self.__content[0] + if (len(self.__content) == 1 and isinstance(self.__content[0], str)) + else self.__content + ), + tool_call_id=self.__tool_call_id, + ) diff --git a/src/chat/maibot_llmreq/payload_content/resp_format.py b/src/chat/maibot_llmreq/payload_content/resp_format.py new file mode 100644 index 000000000..ab2e2edf4 --- /dev/null +++ b/src/chat/maibot_llmreq/payload_content/resp_format.py @@ -0,0 +1,223 @@ +from enum import Enum +from typing import Optional, Any + +from pydantic import BaseModel +from typing_extensions import TypedDict, Required + + +class RespFormatType(Enum): + TEXT = "text" # 文本 + JSON_OBJ = "json_object" # JSON + JSON_SCHEMA = "json_schema" # JSON Schema + + +class JsonSchema(TypedDict, total=False): + name: Required[str] + """ + The name of the response format. + + Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length + of 64. + """ + + description: Optional[str] + """ + A description of what the response format is for, used by the model to determine + how to respond in the format. + """ + + schema: dict[str, object] + """ + The schema for the response format, described as a JSON Schema object. Learn how + to build JSON schemas [here](https://json-schema.org/). + """ + + strict: Optional[bool] + """ + Whether to enable strict schema adherence when generating the output. If set to + true, the model will always follow the exact schema defined in the `schema` + field. Only a subset of JSON Schema is supported when `strict` is `true`. To + learn more, read the + [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + """ + + +def _json_schema_type_check(instance) -> str | None: + if "name" not in instance: + return "schema必须包含'name'字段" + elif not isinstance(instance["name"], str) or instance["name"].strip() == "": + return "schema的'name'字段必须是非空字符串" + if "description" in instance and ( + not isinstance(instance["description"], str) + or instance["description"].strip() == "" + ): + return "schema的'description'字段只能填入非空字符串" + if "schema" not in instance: + return "schema必须包含'schema'字段" + elif not isinstance(instance["schema"], dict): + return "schema的'schema'字段必须是字典,详见https://json-schema.org/" + if "strict" in instance and not isinstance(instance["strict"], bool): + return "schema的'strict'字段只能填入布尔值" + + return None + + +def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]: + """ + 递归移除JSON Schema中的title字段 + """ + if isinstance(schema, list): + # 如果当前Schema是列表,则对所有dict/list子元素递归调用 + for idx, item in enumerate(schema): + if isinstance(item, (dict, list)): + schema[idx] = _remove_title(item) + elif isinstance(schema, dict): + # 是字典,移除title字段,并对所有dict/list子元素递归调用 + if "title" in schema: + del schema["title"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_title(value) + + return schema + + +def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: + """ + 链接JSON Schema中的definitions字段 + """ + + def link_definitions_recursive( + path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any] + ) -> dict[str, Any]: + """ + 递归链接JSON Schema中的definitions字段 + :param path: 当前路径 + :param sub_schema: 子Schema + :param defs: Schema定义集 + :return: + """ + if isinstance(sub_schema, list): + # 如果当前Schema是列表,则遍历每个元素 + for i in range(len(sub_schema)): + if isinstance(sub_schema[i], dict): + sub_schema[i] = link_definitions_recursive( + f"{path}/{str(i)}", sub_schema[i], defs + ) + else: + # 否则为字典 + if "$defs" in sub_schema: + # 如果当前Schema有$def字段,则将其添加到defs中 + key_prefix = f"{path}/$defs/" + for key, value in sub_schema["$defs"].items(): + def_key = key_prefix + key + if def_key not in defs: + defs[def_key] = value + del sub_schema["$defs"] + if "$ref" in sub_schema: + # 如果当前Schema有$ref字段,则将其替换为defs中的定义 + def_key = sub_schema["$ref"] + if def_key in defs: + sub_schema = defs[def_key] + else: + raise ValueError(f"Schema中引用的定义'{def_key}'不存在") + # 遍历键值对 + for key, value in sub_schema.items(): + if isinstance(value, (dict, list)): + # 如果当前值是字典或列表,则递归调用 + sub_schema[key] = link_definitions_recursive( + f"{path}/{key}", value, defs + ) + + return sub_schema + + return link_definitions_recursive("#", schema, {}) + + +def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]: + """ + 递归移除JSON Schema中的$defs字段 + """ + if isinstance(schema, list): + # 如果当前Schema是列表,则对所有dict/list子元素递归调用 + for idx, item in enumerate(schema): + if isinstance(item, (dict, list)): + schema[idx] = _remove_title(item) + elif isinstance(schema, dict): + # 是字典,移除title字段,并对所有dict/list子元素递归调用 + if "$defs" in schema: + del schema["$defs"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_title(value) + + return schema + + +class RespFormat: + """ + 响应格式 + """ + + @staticmethod + def _generate_schema_from_model(schema): + json_schema = { + "name": schema.__name__, + "schema": _remove_defs( + _link_definitions(_remove_title(schema.model_json_schema())) + ), + "strict": False, + } + if schema.__doc__: + json_schema["description"] = schema.__doc__ + return json_schema + + def __init__( + self, + format_type: RespFormatType = RespFormatType.TEXT, + schema: type | JsonSchema | None = None, + ): + """ + 响应格式 + :param format_type: 响应格式类型(默认为文本) + :param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效) + """ + self.format_type: RespFormatType = format_type + + if format_type == RespFormatType.JSON_SCHEMA: + if schema is None: + raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空") + if isinstance(schema, dict): + if check_msg := _json_schema_type_check(schema): + raise ValueError(f"schema格式不正确,{check_msg}") + + self.schema = schema + elif issubclass(schema, BaseModel): + try: + json_schema = self._generate_schema_from_model(schema) + + self.schema = json_schema + except Exception as e: + raise ValueError( + f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n" + f"{schema.__name__}:\n" + ) from e + else: + raise ValueError("schema必须是BaseModel的子类或JsonSchema") + else: + self.schema = None + + def to_dict(self): + """ + 将响应格式转换为字典 + :return: 字典 + """ + if self.schema: + return { + "format_type": self.format_type.value, + "schema": self.schema, + } + else: + return { + "format_type": self.format_type.value, + } diff --git a/src/chat/maibot_llmreq/payload_content/tool_option.py b/src/chat/maibot_llmreq/payload_content/tool_option.py new file mode 100644 index 000000000..8a9bbdb31 --- /dev/null +++ b/src/chat/maibot_llmreq/payload_content/tool_option.py @@ -0,0 +1,155 @@ +from enum import Enum + + +class ToolParamType(Enum): + """ + 工具调用参数类型 + """ + + String = "string" # 字符串 + Int = "integer" # 整型 + Float = "float" # 浮点型 + Boolean = "bool" # 布尔型 + + +class ToolParam: + """ + 工具调用参数 + """ + + def __init__( + self, name: str, param_type: ToolParamType, description: str, required: bool + ): + """ + 初始化工具调用参数 + (不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象) + :param name: 参数名称 + :param param_type: 参数类型 + :param description: 参数描述 + :param required: 是否必填 + """ + self.name: str = name + self.param_type: ToolParamType = param_type + self.description: str = description + self.required: bool = required + + +class ToolOption: + """ + 工具调用项 + """ + + def __init__( + self, + name: str, + description: str, + params: list[ToolParam] | None = None, + ): + """ + 初始化工具调用项 + (不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象) + :param name: 工具名称 + :param description: 工具描述 + :param params: 工具参数列表 + """ + self.name: str = name + self.description: str = description + self.params: list[ToolParam] | None = params + + +class ToolOptionBuilder: + """ + 工具调用项构建器 + """ + + def __init__(self): + self.__name: str = "" + self.__description: str = "" + self.__params: list[ToolParam] = [] + + def set_name(self, name: str) -> "ToolOptionBuilder": + """ + 设置工具名称 + :param name: 工具名称 + :return: ToolBuilder实例 + """ + if not name: + raise ValueError("工具名称不能为空") + self.__name = name + return self + + def set_description(self, description: str) -> "ToolOptionBuilder": + """ + 设置工具描述 + :param description: 工具描述 + :return: ToolBuilder实例 + """ + if not description: + raise ValueError("工具描述不能为空") + self.__description = description + return self + + def add_param( + self, + name: str, + param_type: ToolParamType, + description: str, + required: bool = False, + ) -> "ToolOptionBuilder": + """ + 添加工具参数 + :param name: 参数名称 + :param param_type: 参数类型 + :param description: 参数描述 + :param required: 是否必填(默认为False) + :return: ToolBuilder实例 + """ + if not name or not description: + raise ValueError("参数名称/描述不能为空") + + self.__params.append( + ToolParam( + name=name, + param_type=param_type, + description=description, + required=required, + ) + ) + + return self + + def build(self): + """ + 构建工具调用项 + :return: 工具调用项 + """ + if self.__name == "" or self.__description == "": + raise ValueError("工具名称/描述不能为空") + + return ToolOption( + name=self.__name, + description=self.__description, + params=None if len(self.__params) == 0 else self.__params, + ) + + +class ToolCall: + """ + 来自模型反馈的工具调用 + """ + + def __init__( + self, + call_id: str, + func_name: str, + args: dict | None = None, + ): + """ + 初始化工具调用 + :param call_id: 工具调用ID + :param func_name: 要调用的函数名称 + :param args: 工具调用参数 + """ + self.call_id: str = call_id + self.func_name: str = func_name + self.args: dict | None = args diff --git a/src/chat/maibot_llmreq/tests/test_config_load.py b/src/chat/maibot_llmreq/tests/test_config_load.py new file mode 100644 index 000000000..7553cb91c --- /dev/null +++ b/src/chat/maibot_llmreq/tests/test_config_load.py @@ -0,0 +1,84 @@ +import pytest +from packaging.version import InvalidVersion + +from src import maibot_llmreq +from src.maibot_llmreq.config.parser import _get_config_version, load_config + + +class TestConfigLoad: + def test_loads_valid_version_from_toml(self): + maibot_llmreq.init_logger() + + toml_data = {"inner": {"version": "1.2.3"}} + version = _get_config_version(toml_data) + assert str(version) == "1.2.3" + + def test_handles_missing_version_key(self): + maibot_llmreq.init_logger() + + toml_data = {} + version = _get_config_version(toml_data) + assert str(version) == "0.0.0" + + def test_raises_error_for_invalid_version(self): + maibot_llmreq.init_logger() + + toml_data = {"inner": {"version": "invalid_version"}} + with pytest.raises(InvalidVersion): + _get_config_version(toml_data) + + def test_loads_complete_config_successfully(self, tmp_path): + maibot_llmreq.init_logger() + + config_path = tmp_path / "config.toml" + config_path.write_text(""" + [inner] + version = "0.1.0" + + [request_conf] + max_retry = 5 + timeout = 10 + + [[api_providers]] + name = "provider1" + base_url = "https://api.example.com" + api_key = "key123" + + [[api_providers]] + name = "provider2" + base_url = "https://api.example2.com" + api_key = "key456" + + [[models]] + model_identifier = "model1" + api_provider = "provider1" + + [[models]] + model_identifier = "model2" + api_provider = "provider2" + + [task_model_usage] + task1 = { model = "model1" } + task2 = "model1" + task3 = [ + "model1", + { model = "model2", temperature = 0.5 } + ] + """) + config = load_config(str(config_path)) + assert config.req_conf.max_retry == 5 + assert config.req_conf.timeout == 10 + assert "provider1" in config.api_providers + assert "model1" in config.models + assert "task1" in config.task_model_arg_map + + def test_raises_error_for_missing_required_field(self, tmp_path): + maibot_llmreq.init_logger() + + config_path = tmp_path / "config.toml" + config_path.write_text(""" + [inner] + version = "1.0.0" + """) + with pytest.raises(KeyError): + load_config(str(config_path)) diff --git a/src/chat/maibot_llmreq/usage_statistic.py b/src/chat/maibot_llmreq/usage_statistic.py new file mode 100644 index 000000000..3c5490e3e --- /dev/null +++ b/src/chat/maibot_llmreq/usage_statistic.py @@ -0,0 +1,182 @@ +from datetime import datetime +from enum import Enum +from typing import Tuple + +from pymongo.synchronous.database import Database + +from . import _logger as logger +from .config.config import ModelInfo + + +class ReqType(Enum): + """ + 请求类型 + """ + + CHAT = "chat" # 对话请求 + EMBEDDING = "embedding" # 嵌入请求 + + +class UsageCallStatus(Enum): + """ + 任务调用状态 + """ + + PROCESSING = "processing" # 处理中 + SUCCESS = "success" # 成功 + FAILURE = "failure" # 失败 + CANCELED = "canceled" # 取消 + + +class ModelUsageStatistic: + db: Database | None = None + + def __init__(self, db: Database): + if db is None: + logger.warning( + "Warning: No database provided, ModelUsageStatistic will not work." + ) + return + if self._init_database(db): + # 成功初始化 + self.db = db + + @staticmethod + def _init_database(db: Database): + """ + 初始化数据库相关索引 + """ + try: + db.llm_usage.create_index([("timestamp", 1)]) + db.llm_usage.create_index([("model_name", 1)]) + db.llm_usage.create_index([("task_name", 1)]) + db.llm_usage.create_index([("request_type", 1)]) + db.llm_usage.create_index([("status", 1)]) + return True + except Exception as e: + logger.error(f"创建数据库索引失败: {e}") + return False + + @staticmethod + def _calculate_cost( + prompt_tokens: int, completion_tokens: int, model_info: ModelInfo + ) -> float: + """计算API调用成本 + 使用模型的pri_in和pri_out价格计算输入和输出的成本 + + Args: + prompt_tokens: 输入token数量 + completion_tokens: 输出token数量 + + Returns: + float: 总成本(元) + """ + # 使用模型的pri_in和pri_out计算成本 + input_cost = (prompt_tokens / 1000000) * model_info.price_in + output_cost = (completion_tokens / 1000000) * model_info.price_out + return round(input_cost + output_cost, 6) + + def create_usage( + self, + model_name: str, + task_name: str = "N/A", + request_type: ReqType = ReqType.CHAT, + ) -> str | None: + """ + 创建模型使用情况记录 + :param model_name: 模型名 + :param task_name: 任务名称 + :param request_type: 请求类型,默认为Chat + :return: + """ + if self.db is None: + return None # 如果没有数据库连接,则不记录使用情况 + + try: + usage_data = { + "model_name": model_name, + "task_name": task_name, + "request_type": request_type.value, + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "cost": 0.0, + "status": "processing", + "timestamp": datetime.now(), + "ext_msg": None, + } + result = self.db.llm_usage.insert_one(usage_data) + + logger.trace( + f"创建了一条模型使用情况记录 - 模型: {model_name}, " + f"子任务: {task_name}, 类型: {request_type}" + f"记录ID: {str(result.inserted_id)}" + ) + + return str(result.inserted_id) + except Exception as e: + logger.error(f"创建模型使用情况记录失败: {str(e)}") + return None + + def update_usage( + self, + record_id: str | None, + model_info: ModelInfo, + usage_data: Tuple[int, int, int] | None = None, + stat: UsageCallStatus = UsageCallStatus.SUCCESS, + ext_msg: str | None = None, + ): + """ + 更新模型使用情况 + + Args: + record_id: 记录ID + model_info: 模型信息 + usage_data: 使用情况数据(输入token数量, 输出token数量, 总token数量) + stat: 任务调用状态 + ext_msg: 额外信息 + """ + if self.db is None: + return # 如果没有数据库连接,则不记录使用情况 + + if not record_id: + logger.error("更新模型使用情况失败: record_id不能为空") + return + + if usage_data and len(usage_data) != 3: + logger.error("更新模型使用情况失败: usage_data的长度不正确,应该为3个元素") + return + + # 提取使用情况数据 + prompt_tokens = usage_data[0] if usage_data else 0 + completion_tokens = usage_data[1] if usage_data else 0 + total_tokens = usage_data[2] if usage_data else 0 + + try: + self.db.llm_usage.update_one( + {"_id": record_id}, + { + "$set": { + "status": stat.value, + "ext_msg": ext_msg, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "cost": self._calculate_cost( + prompt_tokens, completion_tokens, model_info + ) + if usage_data + else 0.0, + } + }, + ) + + logger.trace( + f"Token使用情况 - 模型: {model_info.name}, " + f"记录ID: {record_id}, " + f"任务状态: {stat.value}, 额外信息: {ext_msg if ext_msg else 'N/A'}, " + f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " + f"总计: {total_tokens}" + ) + except Exception as e: + logger.error(f"记录token使用情况失败: {str(e)}") diff --git a/src/chat/maibot_llmreq/utils.py b/src/chat/maibot_llmreq/utils.py new file mode 100644 index 000000000..f8bf4fb39 --- /dev/null +++ b/src/chat/maibot_llmreq/utils.py @@ -0,0 +1,150 @@ +import base64 +import io + +from PIL import Image + +from . import _logger as logger +from .payload_content.message import Message, MessageBuilder + + +def compress_messages( + messages: list[Message], img_target_size: int = 1 * 1024 * 1024 +) -> list[Message]: + """ + 压缩消息列表中的图片 + :param messages: 消息列表 + :param img_target_size: 图片目标大小,默认1MB + :return: 压缩后的消息列表 + """ + + def reformat_static_image(image_data: bytes) -> bytes: + """ + 将静态图片转换为JPEG格式 + :param image_data: 图片数据 + :return: 转换后的图片数据 + """ + try: + image = Image.open(image_data) + + if image.format and ( + image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"] + ): + # 静态图像,转换为JPEG格式 + reformated_image_data = io.BytesIO() + image.save( + reformated_image_data, format="JPEG", quality=95, optimize=True + ) + image_data = reformated_image_data.getvalue() + + return image_data + except Exception as e: + logger.error(f"图片转换格式失败: {str(e)}") + return image_data + + def rescale_image( + image_data: bytes, scale: float + ) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]: + """ + 缩放图片 + :param image_data: 图片数据 + :param scale: 缩放比例 + :return: 缩放后的图片数据 + """ + try: + image = Image.open(image_data) + + # 原始尺寸 + original_size = (image.width, image.height) + + # 计算新的尺寸 + new_size = (int(original_size[0] * scale), int(original_size[1] * scale)) + + output_buffer = io.BytesIO() + + if getattr(image, "is_animated", False): + # 动态图片,处理所有帧 + frames = [] + new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折 + for frame_idx in range(getattr(image, "n_frames", 1)): + image.seek(frame_idx) + new_frame = image.copy() + new_frame = new_frame.resize(new_size, Image.Resampling.LANCZOS) + frames.append(new_frame) + + # 保存到缓冲区 + frames[0].save( + output_buffer, + format="GIF", + save_all=True, + append_images=frames[1:], + optimize=True, + duration=image.info.get("duration", 100), + loop=image.info.get("loop", 0), + ) + else: + # 静态图片,直接缩放保存 + resized_image = image.resize(new_size, Image.Resampling.LANCZOS) + resized_image.save( + output_buffer, format="JPEG", quality=95, optimize=True + ) + + return output_buffer.getvalue(), original_size, new_size + + except Exception as e: + logger.error(f"图片缩放失败: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return image_data, None, None + + def compress_base64_image( + base64_data: str, target_size: int = 1 * 1024 * 1024 + ) -> str: + original_b64_data_size = len(base64_data) # 计算原始数据大小 + + image_data = base64.b64decode(base64_data) + + # 先尝试转换格式为JPEG + image_data = reformat_static_image(image_data) + base64_data = base64.b64encode(image_data).decode("utf-8") + if len(base64_data) <= target_size: + # 如果转换后小于目标大小,直接返回 + logger.info( + f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB" + ) + return base64_data + + # 如果转换后仍然大于目标大小,进行尺寸压缩 + scale = min(1.0, target_size / len(base64_data)) + image_data, original_size, new_size = rescale_image(image_data, scale) + base64_data = base64.b64encode(image_data).decode("utf-8") + + if original_size and new_size: + logger.info( + f"压缩图片: {original_size[0]}x{original_size[1]} -> {new_size[0]}x{new_size[1]}\n" + f"压缩前大小: {original_b64_data_size / 1024:.1f}KB, 压缩后大小: {len(base64_data) / 1024:.1f}KB" + ) + + return base64_data + + compressed_messages = [] + for message in messages: + if isinstance(message.content, list): + # 检查content,如有图片则压缩 + message_builder = MessageBuilder() + for content_item in message.content: + if isinstance(content_item, tuple): + # 图片,进行压缩 + message_builder.add_image_content( + content_item[0], + compress_base64_image( + content_item[1], target_size=img_target_size + ), + ) + else: + message_builder.add_text_content(content_item) + compressed_messages.append(message_builder.build()) + else: + compressed_messages.append(message) + + return compressed_messages diff --git a/template/model_config_template.toml b/template/model_config_template.toml new file mode 100644 index 000000000..f9055fcea --- /dev/null +++ b/template/model_config_template.toml @@ -0,0 +1,77 @@ +[inner] +version = "0.1.0" + +# 配置文件版本号迭代规则同bot_config.toml + +[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释) +#max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) +#timeout = 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) +#retry_interval = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) +#default_temperature = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) +#default_max_tokens = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) + + +[[api_providers]] # API服务提供商(可以配置多个) +name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) +base_url = "https://api.deepseek.cn" # API服务商的BaseURL +key = "******" # API Key (可选,默认为None) +client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"google") + +#[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"google" +#name = "Google" +#base_url = "https://api.google.com" +#key = "******" +#client_type = "google" +# +#[[api_providers]] +#name = "SiliconFlow" +#base_url = "https://api.siliconflow.cn" +#key = "******" +# +#[[api_providers]] +#name = "LocalHost" +#base_url = "https://localhost:8888" +#key = "lm-studio" + + +[[models]] # 模型(可以配置多个) +# 模型标识符(API服务商提供的模型标识符) +model_identifier = "deepseek-chat" +# 模型名称(可随意命名,在bot_config.toml中需使用这个命名) +#(可选,若无该字段,则将自动使用model_identifier填充) +name = "deepseek-v3" +# API服务商名称(对应在api_providers中配置的服务商名称) +api_provider = "DeepSeek" +# 输入价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0) +price_in = 2.0 +# 输出价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0) +price_out = 8.0 +# 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出) +#(可选,若无该字段,默认值为false) +#force_stream_mode = true + +#[[models]] +#model_identifier = "deepseek-reasoner" +#name = "deepseek-r1" +#api_provider = "DeepSeek" +#model_flags = ["text", "tool_calling", "reasoning"] +#price_in = 4.0 +#price_out = 16.0 +# +#[[models]] +#model_identifier = "BAAI/bge-m3" +#name = "siliconflow-bge-m3" +#api_provider = "SiliconFlow" +#model_flags = ["text", "embedding"] +#price_in = 0 +#price_out = 0 + + +[task_model_usage] +#llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0} +#llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0} +#embedding = "siliconflow-bge-m3" +#schedule = [ +# "deepseek-v3", +# "deepseek-r1", +#] \ No newline at end of file