From 909e47bcee95b1f5f43a19520240d9f3019c2bb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 25 Jul 2025 13:21:48 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E9=87=8D=E6=9E=84llmrequest?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug_config.py | 111 ++ src/chat/maibot_llmreq/__init__.py | 19 - src/chat/maibot_llmreq/config/parser.py | 267 ---- .../maibot_llmreq/tests/test_config_load.py | 84 -- src/config/config.py | 5 +- .../maibot_llmreq => llm_models}/LICENSE | 0 .../config => llm_models}/__init__.py | 0 .../exceptions.py | 0 .../model_client/__init__.py | 10 +- .../model_client/base_client.py | 2 +- .../model_client/gemini_client.py | 2 +- .../model_client/openai_client.py | 2 +- .../model_manager.py | 12 +- .../payload_content/message.py | 0 .../payload_content/resp_format.py | 0 .../payload_content/tool_option.py | 0 src/llm_models/temp.py | 8 + .../usage_statistic.py | 136 +- .../maibot_llmreq => llm_models}/utils.py | 4 +- src/llm_models/utils_model.py | 1110 +++++------------ template/compare/model_config_template.toml | 77 ++ 21 files changed, 612 insertions(+), 1237 deletions(-) create mode 100644 debug_config.py delete mode 100644 src/chat/maibot_llmreq/__init__.py delete mode 100644 src/chat/maibot_llmreq/config/parser.py delete mode 100644 src/chat/maibot_llmreq/tests/test_config_load.py rename src/{chat/maibot_llmreq => llm_models}/LICENSE (100%) rename src/{chat/maibot_llmreq/config => llm_models}/__init__.py (100%) rename src/{chat/maibot_llmreq => llm_models}/exceptions.py (100%) rename src/{chat/maibot_llmreq => llm_models}/model_client/__init__.py (98%) rename src/{chat/maibot_llmreq => llm_models}/model_client/base_client.py (98%) rename src/{chat/maibot_llmreq => llm_models}/model_client/gemini_client.py (99%) rename src/{chat/maibot_llmreq => llm_models}/model_client/openai_client.py (99%) rename src/{chat/maibot_llmreq => llm_models}/model_manager.py (92%) rename src/{chat/maibot_llmreq => llm_models}/payload_content/message.py (100%) rename src/{chat/maibot_llmreq => llm_models}/payload_content/resp_format.py (100%) rename src/{chat/maibot_llmreq => llm_models}/payload_content/tool_option.py (100%) create mode 100644 src/llm_models/temp.py rename src/{chat/maibot_llmreq => llm_models}/usage_statistic.py (52%) rename src/{chat/maibot_llmreq => llm_models}/utils.py (98%) create mode 100644 template/compare/model_config_template.toml diff --git a/debug_config.py b/debug_config.py new file mode 100644 index 000000000..a2b960e5c --- /dev/null +++ b/debug_config.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +""" +调试配置加载问题,查看API provider的配置是否正确传递 +""" + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +def debug_config_loading(): + try: + # 临时配置API key + import toml + config_path = "config/model_config.toml" + + with open(config_path, 'r', encoding='utf-8') as f: + config = toml.load(f) + + original_keys = {} + for provider in config['api_providers']: + original_keys[provider['name']] = provider['api_key'] + provider['api_key'] = f"sk-test-key-for-{provider['name'].lower()}-12345" + + with open(config_path, 'w', encoding='utf-8') as f: + toml.dump(config, f) + + print("✅ 配置了测试API key") + + try: + # 清空缓存 + modules_to_remove = [ + 'src.config.config', + 'src.config.api_ada_configs', + 'src.llm_models.model_manager', + 'src.llm_models.model_client', + 'src.llm_models.utils_model' + ] + for module in modules_to_remove: + if module in sys.modules: + del sys.modules[module] + + # 导入配置 + from src.config.config import model_config + print("\n🔍 调试配置加载:") + print(f"model_config类型: {type(model_config)}") + + # 检查API providers + if hasattr(model_config, 'api_providers'): + print(f"API providers数量: {len(model_config.api_providers)}") + for name, provider in model_config.api_providers.items(): + print(f" - {name}: {provider.base_url}") + print(f" API key: {provider.api_key[:10]}...{provider.api_key[-5:] if len(provider.api_key) > 15 else provider.api_key}") + print(f" Client type: {provider.client_type}") + + # 检查模型配置 + if hasattr(model_config, 'models'): + print(f"模型数量: {len(model_config.models)}") + for name, model in model_config.models.items(): + print(f" - {name}: {model.model_identifier} (提供商: {model.api_provider})") + + # 检查任务配置 + if hasattr(model_config, 'task_model_arg_map'): + print(f"任务配置数量: {len(model_config.task_model_arg_map)}") + for task_name, task_config in model_config.task_model_arg_map.items(): + print(f" - {task_name}: {task_config}") + + # 尝试初始化ModelManager + print("\n🔍 调试ModelManager初始化:") + from src.llm_models.model_manager import ModelManager + + try: + model_manager = ModelManager(model_config) + print("✅ ModelManager初始化成功") + + # 检查API客户端映射 + print(f"API客户端数量: {len(model_manager.api_client_map)}") + for name, client in model_manager.api_client_map.items(): + print(f" - {name}: {type(client).__name__}") + if hasattr(client, 'client') and hasattr(client.client, 'api_key'): + api_key = client.client.api_key + print(f" Client API key: {api_key[:10]}...{api_key[-5:] if len(api_key) > 15 else api_key}") + + # 尝试获取任务处理器 + try: + handler = model_manager["llm_normal"] + print("✅ 成功获取llm_normal任务处理器") + print(f"任务处理器类型: {type(handler).__name__}") + except Exception as e: + print(f"❌ 获取任务处理器失败: {e}") + + except Exception as e: + print(f"❌ ModelManager初始化失败: {e}") + import traceback + traceback.print_exc() + + finally: + # 恢复配置 + for provider in config['api_providers']: + provider['api_key'] = original_keys[provider['name']] + + with open(config_path, 'w', encoding='utf-8') as f: + toml.dump(config, f) + print("\n✅ 配置已恢复") + + except Exception as e: + print(f"❌ 调试失败: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + debug_config_loading() diff --git a/src/chat/maibot_llmreq/__init__.py b/src/chat/maibot_llmreq/__init__.py deleted file mode 100644 index aab812cfa..000000000 --- a/src/chat/maibot_llmreq/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -import loguru - -type LoguruLogger = loguru.Logger - -_logger: LoguruLogger = loguru.logger - - -def init_logger( - logger: LoguruLogger | None = None, -): - """ - 对LLMRequest模块进行配置 - :param logger: 日志对象 - """ - global _logger # 申明使用全局变量 - if logger: - _logger = logger - else: - _logger.warning("Warning: No logger provided, using default logger.") diff --git a/src/chat/maibot_llmreq/config/parser.py b/src/chat/maibot_llmreq/config/parser.py deleted file mode 100644 index a6877835d..000000000 --- a/src/chat/maibot_llmreq/config/parser.py +++ /dev/null @@ -1,267 +0,0 @@ -import os -from typing import Any, Dict, List - -import tomli -from packaging import version -from packaging.specifiers import SpecifierSet -from packaging.version import Version, InvalidVersion - -from .. import _logger as logger - -from .config import ( - ModelUsageArgConfigItem, - ModelUsageArgConfig, - APIProvider, - ModelInfo, - NEWEST_VER, - ModuleConfig, -) - - -def _get_config_version(toml: Dict) -> Version: - """提取配置文件的 SpecifierSet 版本数据 - Args: - toml[dict]: 输入的配置文件字典 - Returns: - Version - """ - - if "inner" in toml and "version" in toml["inner"]: - config_version: str = toml["inner"]["version"] - else: - config_version = "0.0.0" # 默认版本 - - try: - ver = version.parse(config_version) - except InvalidVersion as e: - logger.error( - "配置文件中 inner段 的 version 键是错误的版本描述\n" - f"请检查配置文件,当前 version 键: {config_version}\n" - f"错误信息: {e}" - ) - raise InvalidVersion( - "配置文件中 inner段 的 version 键是错误的版本描述\n" - ) from e - - return ver - - -def _request_conf(parent: Dict, config: ModuleConfig): - request_conf_config = parent.get("request_conf") - config.req_conf.max_retry = request_conf_config.get( - "max_retry", config.req_conf.max_retry - ) - config.req_conf.timeout = request_conf_config.get( - "timeout", config.req_conf.timeout - ) - config.req_conf.retry_interval = request_conf_config.get( - "retry_interval", config.req_conf.retry_interval - ) - config.req_conf.default_temperature = request_conf_config.get( - "default_temperature", config.req_conf.default_temperature - ) - config.req_conf.default_max_tokens = request_conf_config.get( - "default_max_tokens", config.req_conf.default_max_tokens - ) - - -def _api_providers(parent: Dict, config: ModuleConfig): - api_providers_config = parent.get("api_providers") - for provider in api_providers_config: - name = provider.get("name", None) - base_url = provider.get("base_url", None) - api_key = provider.get("api_key", None) - client_type = provider.get("client_type", "openai") - - if name in config.api_providers: # 查重 - logger.error(f"重复的API提供商名称: {name},请检查配置文件。") - raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。") - - if name and base_url: - config.api_providers[name] = APIProvider( - name=name, - base_url=base_url, - api_key=api_key, - client_type=client_type, - ) - else: - logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") - raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") - - -def _models(parent: Dict, config: ModuleConfig): - models_config = parent.get("models") - for model in models_config: - model_identifier = model.get("model_identifier", None) - name = model.get("name", model_identifier) - api_provider = model.get("api_provider", None) - price_in = model.get("price_in", 0.0) - price_out = model.get("price_out", 0.0) - force_stream_mode = model.get("force_stream_mode", False) - - if name in config.models: # 查重 - logger.error(f"重复的模型名称: {name},请检查配置文件。") - raise KeyError(f"重复的模型名称: {name},请检查配置文件。") - - if model_identifier and api_provider: - # 检查API提供商是否存在 - if api_provider not in config.api_providers: - logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。") - raise ValueError( - f"未声明的API提供商 '{api_provider}' ,请检查配置文件。" - ) - config.models[name] = ModelInfo( - name=name, - model_identifier=model_identifier, - api_provider=api_provider, - price_in=price_in, - price_out=price_out, - force_stream_mode=force_stream_mode, - ) - else: - logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。") - raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。") - - -def _task_model_usage(parent: Dict, config: ModuleConfig): - model_usage_configs = parent.get("task_model_usage") - config.task_model_arg_map = {} - for task_name, item in model_usage_configs.items(): - if task_name in config.task_model_arg_map: - logger.error(f"子任务 {task_name} 已存在,请检查配置文件。") - raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。") - - usage = [] - if isinstance(item, Dict): - if "model" in item: - usage.append( - ModelUsageArgConfigItem( - name=item["model"], - temperature=item.get("temperature", None), - max_tokens=item.get("max_tokens", None), - max_retry=item.get("max_retry", None), - ) - ) - else: - logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。") - raise ValueError( - f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" - ) - elif isinstance(item, List): - for model in item: - if isinstance(model, Dict): - usage.append( - ModelUsageArgConfigItem( - name=model["model"], - temperature=model.get("temperature", None), - max_tokens=model.get("max_tokens", None), - max_retry=model.get("max_retry", None), - ) - ) - elif isinstance(model, str): - usage.append( - ModelUsageArgConfigItem( - name=model, - temperature=None, - max_tokens=None, - max_retry=None, - ) - ) - else: - logger.error( - f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" - ) - raise ValueError( - f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" - ) - elif isinstance(item, str): - usage.append( - ModelUsageArgConfigItem( - name=item, - temperature=None, - max_tokens=None, - max_retry=None, - ) - ) - - config.task_model_arg_map[task_name] = ModelUsageArgConfig( - name=task_name, - usage=usage, - ) - - -def load_config(config_path: str) -> ModuleConfig: - """从TOML配置文件加载配置""" - config = ModuleConfig() - - include_configs: Dict[str, Dict[str, Any]] = { - "request_conf": { - "func": _request_conf, - "support": ">=0.0.0", - "necessary": False, - }, - "api_providers": {"func": _api_providers, "support": ">=0.0.0"}, - "models": {"func": _models, "support": ">=0.0.0"}, - "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"}, - } - - if os.path.exists(config_path): - with open(config_path, "rb") as f: - try: - toml_dict = tomli.load(f) - except tomli.TOMLDecodeError as e: - logger.critical( - f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}" - ) - exit(1) - - # 获取配置文件版本 - config.INNER_VERSION = _get_config_version(toml_dict) - - # 检查版本 - if config.INNER_VERSION > Version(NEWEST_VER): - logger.warning( - f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。" - ) - - # 解析配置文件 - # 如果在配置中找到了需要的项,调用对应项的闭包函数处理 - for key in include_configs: - if key in toml_dict: - group_specifier_set: SpecifierSet = SpecifierSet( - include_configs[key]["support"] - ) - - # 检查配置文件版本是否在支持范围内 - if config.INNER_VERSION in group_specifier_set: - # 如果版本在支持范围内,检查是否存在通知 - if "notice" in include_configs[key]: - logger.warning(include_configs[key]["notice"]) - # 调用闭包函数处理配置 - (include_configs[key]["func"])(toml_dict, config) - else: - # 如果版本不在支持范围内,崩溃并提示用户 - logger.error( - f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n" - f"当前程序仅支持以下版本范围: {group_specifier_set}" - ) - raise InvalidVersion( - f"当前程序仅支持以下版本范围: {group_specifier_set}" - ) - - # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理 - elif ( - "necessary" in include_configs[key] - and include_configs[key].get("necessary") is False - ): - # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 - if key == "keywords_reaction": - pass - else: - # 如果用户根本没有需要的配置项,提示缺少配置 - logger.error(f"配置文件中缺少必需的字段: '{key}'") - raise KeyError(f"配置文件中缺少必需的字段: '{key}'") - - logger.success(f"成功加载配置文件: {config_path}") - - return config diff --git a/src/chat/maibot_llmreq/tests/test_config_load.py b/src/chat/maibot_llmreq/tests/test_config_load.py deleted file mode 100644 index 7553cb91c..000000000 --- a/src/chat/maibot_llmreq/tests/test_config_load.py +++ /dev/null @@ -1,84 +0,0 @@ -import pytest -from packaging.version import InvalidVersion - -from src import maibot_llmreq -from src.maibot_llmreq.config.parser import _get_config_version, load_config - - -class TestConfigLoad: - def test_loads_valid_version_from_toml(self): - maibot_llmreq.init_logger() - - toml_data = {"inner": {"version": "1.2.3"}} - version = _get_config_version(toml_data) - assert str(version) == "1.2.3" - - def test_handles_missing_version_key(self): - maibot_llmreq.init_logger() - - toml_data = {} - version = _get_config_version(toml_data) - assert str(version) == "0.0.0" - - def test_raises_error_for_invalid_version(self): - maibot_llmreq.init_logger() - - toml_data = {"inner": {"version": "invalid_version"}} - with pytest.raises(InvalidVersion): - _get_config_version(toml_data) - - def test_loads_complete_config_successfully(self, tmp_path): - maibot_llmreq.init_logger() - - config_path = tmp_path / "config.toml" - config_path.write_text(""" - [inner] - version = "0.1.0" - - [request_conf] - max_retry = 5 - timeout = 10 - - [[api_providers]] - name = "provider1" - base_url = "https://api.example.com" - api_key = "key123" - - [[api_providers]] - name = "provider2" - base_url = "https://api.example2.com" - api_key = "key456" - - [[models]] - model_identifier = "model1" - api_provider = "provider1" - - [[models]] - model_identifier = "model2" - api_provider = "provider2" - - [task_model_usage] - task1 = { model = "model1" } - task2 = "model1" - task3 = [ - "model1", - { model = "model2", temperature = 0.5 } - ] - """) - config = load_config(str(config_path)) - assert config.req_conf.max_retry == 5 - assert config.req_conf.timeout == 10 - assert "provider1" in config.api_providers - assert "model1" in config.models - assert "task1" in config.task_model_arg_map - - def test_raises_error_for_missing_required_field(self, tmp_path): - maibot_llmreq.init_logger() - - config_path = tmp_path / "config.toml" - config_path.write_text(""" - [inner] - version = "1.0.0" - """) - with pytest.raises(KeyError): - load_config(str(config_path)) diff --git a/src/config/config.py b/src/config/config.py index bd2d58f04..95ad198a1 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -13,7 +13,6 @@ from packaging.version import Version, InvalidVersion from typing import Any, Dict, List from src.common.logger import get_logger -from src.common.message import api from src.config.config_base import ConfigBase from src.config.official_configs import ( BotConfig, @@ -314,7 +313,7 @@ def api_ada_load_config(config_path: str) -> ModuleConfig: logger.error(f"配置文件中缺少必需的字段: '{key}'") raise KeyError(f"配置文件中缺少必需的字段: '{key}'") - logger.success(f"成功加载配置文件: {config_path}") + logger.info(f"成功加载配置文件: {config_path}") return config @@ -653,4 +652,4 @@ update_model_config() logger.info("正在品鉴配置文件...") global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml")) model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml")) -logger.info("非常的新鲜,非常的美味!") +logger.info("非常的新鲜,非常的美味!") \ No newline at end of file diff --git a/src/chat/maibot_llmreq/LICENSE b/src/llm_models/LICENSE similarity index 100% rename from src/chat/maibot_llmreq/LICENSE rename to src/llm_models/LICENSE diff --git a/src/chat/maibot_llmreq/config/__init__.py b/src/llm_models/__init__.py similarity index 100% rename from src/chat/maibot_llmreq/config/__init__.py rename to src/llm_models/__init__.py diff --git a/src/chat/maibot_llmreq/exceptions.py b/src/llm_models/exceptions.py similarity index 100% rename from src/chat/maibot_llmreq/exceptions.py rename to src/llm_models/exceptions.py diff --git a/src/chat/maibot_llmreq/model_client/__init__.py b/src/llm_models/model_client/__init__.py similarity index 98% rename from src/chat/maibot_llmreq/model_client/__init__.py rename to src/llm_models/model_client/__init__.py index 9dc28d07d..ebe802df2 100644 --- a/src/chat/maibot_llmreq/model_client/__init__.py +++ b/src/llm_models/model_client/__init__.py @@ -5,8 +5,7 @@ from openai import AsyncStream from openai.types.chat import ChatCompletionChunk, ChatCompletion from .base_client import BaseClient, APIResponse -from .. import _logger as logger -from ..config.config import ( +from src.config.api_ada_configs import ( ModelInfo, ModelUsageArgConfigItem, RequestConfig, @@ -22,6 +21,9 @@ from ..payload_content.message import Message from ..payload_content.resp_format import RespFormat from ..payload_content.tool_option import ToolOption from ..utils import compress_messages +from src.common.logger import get_logger + +logger = get_logger("模型客户端") def _check_retry( @@ -288,7 +290,7 @@ class ModelRequestHandler: interrupt_flag=interrupt_flag, ) except Exception as e: - logger.trace(e) + logger.debug(e) remain_try -= 1 # 剩余尝试次数减1 # 处理异常 @@ -340,7 +342,7 @@ class ModelRequestHandler: embedding_input=embedding_input, ) except Exception as e: - logger.trace(e) + logger.debug(e) remain_try -= 1 # 剩余尝试次数减1 # 处理异常 diff --git a/src/chat/maibot_llmreq/model_client/base_client.py b/src/llm_models/model_client/base_client.py similarity index 98% rename from src/chat/maibot_llmreq/model_client/base_client.py rename to src/llm_models/model_client/base_client.py index ed877a6c9..50a379d34 100644 --- a/src/chat/maibot_llmreq/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -5,7 +5,7 @@ from typing import Callable, Any from openai import AsyncStream from openai.types.chat import ChatCompletionChunk, ChatCompletion -from ..config.config import ModelInfo, APIProvider +from src.config.api_ada_configs import ModelInfo, APIProvider from ..payload_content.message import Message from ..payload_content.resp_format import RespFormat from ..payload_content.tool_option import ToolOption, ToolCall diff --git a/src/chat/maibot_llmreq/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py similarity index 99% rename from src/chat/maibot_llmreq/model_client/gemini_client.py rename to src/llm_models/model_client/gemini_client.py index 75d2767e0..1861ca1d5 100644 --- a/src/chat/maibot_llmreq/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -15,7 +15,7 @@ from google.genai.errors import ( ) from .base_client import APIResponse, UsageRecord -from ..config.config import ModelInfo, APIProvider +from src.config.api_ada_configs import ModelInfo, APIProvider from . import BaseClient from ..exceptions import ( diff --git a/src/chat/maibot_llmreq/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py similarity index 99% rename from src/chat/maibot_llmreq/model_client/openai_client.py rename to src/llm_models/model_client/openai_client.py index db256b2d4..e5da59022 100644 --- a/src/chat/maibot_llmreq/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -21,7 +21,7 @@ from openai.types.chat import ( from openai.types.chat.chat_completion_chunk import ChoiceDelta from .base_client import APIResponse, UsageRecord -from ..config.config import ModelInfo, APIProvider +from src.config.api_ada_configs import ModelInfo, APIProvider from . import BaseClient from ..exceptions import ( diff --git a/src/chat/maibot_llmreq/model_manager.py b/src/llm_models/model_manager.py similarity index 92% rename from src/chat/maibot_llmreq/model_manager.py rename to src/llm_models/model_manager.py index 3056b187a..5d983849b 100644 --- a/src/chat/maibot_llmreq/model_manager.py +++ b/src/llm_models/model_manager.py @@ -1,15 +1,13 @@ import importlib from typing import Dict +from src.config.config import model_config +from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig +from src.common.logger import get_logger -from .config.config import ( - ModelUsageArgConfig, - ModuleConfig, -) - -from . import _logger as logger from .model_client import ModelRequestHandler, BaseClient +logger = get_logger("模型管理器") class ModelManager: # TODO: 添加读写锁,防止异步刷新配置时发生数据竞争 @@ -77,3 +75,5 @@ class ModelManager: :return: 是否在模型列表中 """ return task_name in self.config.task_model_arg_map + + diff --git a/src/chat/maibot_llmreq/payload_content/message.py b/src/llm_models/payload_content/message.py similarity index 100% rename from src/chat/maibot_llmreq/payload_content/message.py rename to src/llm_models/payload_content/message.py diff --git a/src/chat/maibot_llmreq/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py similarity index 100% rename from src/chat/maibot_llmreq/payload_content/resp_format.py rename to src/llm_models/payload_content/resp_format.py diff --git a/src/chat/maibot_llmreq/payload_content/tool_option.py b/src/llm_models/payload_content/tool_option.py similarity index 100% rename from src/chat/maibot_llmreq/payload_content/tool_option.py rename to src/llm_models/payload_content/tool_option.py diff --git a/src/llm_models/temp.py b/src/llm_models/temp.py new file mode 100644 index 000000000..89755a314 --- /dev/null +++ b/src/llm_models/temp.py @@ -0,0 +1,8 @@ + +import sys +import os +sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) + +from src.config.config import model_config +print(f"当前模型配置: {model_config}") +print(model_config.req_conf.default_max_tokens) \ No newline at end of file diff --git a/src/chat/maibot_llmreq/usage_statistic.py b/src/llm_models/usage_statistic.py similarity index 52% rename from src/chat/maibot_llmreq/usage_statistic.py rename to src/llm_models/usage_statistic.py index 3c5490e3e..176c4b7b1 100644 --- a/src/chat/maibot_llmreq/usage_statistic.py +++ b/src/llm_models/usage_statistic.py @@ -2,10 +2,11 @@ from datetime import datetime from enum import Enum from typing import Tuple -from pymongo.synchronous.database import Database +from src.common.logger import get_logger +from src.config.api_ada_configs import ModelInfo +from src.common.database.database_model import LLMUsage -from . import _logger as logger -from .config.config import ModelInfo +logger = get_logger("模型使用统计") class ReqType(Enum): @@ -29,33 +30,21 @@ class UsageCallStatus(Enum): class ModelUsageStatistic: - db: Database | None = None + """ + 模型使用统计类 - 使用SQLite+Peewee + """ - def __init__(self, db: Database): - if db is None: - logger.warning( - "Warning: No database provided, ModelUsageStatistic will not work." - ) - return - if self._init_database(db): - # 成功初始化 - self.db = db - - @staticmethod - def _init_database(db: Database): + def __init__(self): """ - 初始化数据库相关索引 + 初始化统计类 + 由于使用Peewee ORM,不需要传入数据库实例 """ + # 确保表已经创建 try: - db.llm_usage.create_index([("timestamp", 1)]) - db.llm_usage.create_index([("model_name", 1)]) - db.llm_usage.create_index([("task_name", 1)]) - db.llm_usage.create_index([("request_type", 1)]) - db.llm_usage.create_index([("status", 1)]) - return True + from src.common.database.database import db + db.create_tables([LLMUsage], safe=True) except Exception as e: - logger.error(f"创建数据库索引失败: {e}") - return False + logger.error(f"创建LLMUsage表失败: {e}") @staticmethod def _calculate_cost( @@ -67,6 +56,7 @@ class ModelUsageStatistic: Args: prompt_tokens: 输入token数量 completion_tokens: 输出token数量 + model_info: 模型信息 Returns: float: 总成本(元) @@ -81,46 +71,50 @@ class ModelUsageStatistic: model_name: str, task_name: str = "N/A", request_type: ReqType = ReqType.CHAT, - ) -> str | None: + user_id: str = "system", + endpoint: str = "/chat/completions", + ) -> int | None: """ 创建模型使用情况记录 - :param model_name: 模型名 - :param task_name: 任务名称 - :param request_type: 请求类型,默认为Chat - :return: - """ - if self.db is None: - return None # 如果没有数据库连接,则不记录使用情况 + Args: + model_name: 模型名 + task_name: 任务名称 + request_type: 请求类型,默认为Chat + user_id: 用户ID,默认为system + endpoint: API端点 + + Returns: + int | None: 返回记录ID,失败返回None + """ try: - usage_data = { - "model_name": model_name, - "task_name": task_name, - "request_type": request_type.value, - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - "cost": 0.0, - "status": "processing", - "timestamp": datetime.now(), - "ext_msg": None, - } - result = self.db.llm_usage.insert_one(usage_data) + usage_record = LLMUsage.create( + model_name=model_name, + user_id=user_id, + request_type=request_type.value, + endpoint=endpoint, + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost=0.0, + status=UsageCallStatus.PROCESSING.value, + timestamp=datetime.now(), + ) logger.trace( f"创建了一条模型使用情况记录 - 模型: {model_name}, " - f"子任务: {task_name}, 类型: {request_type}" - f"记录ID: {str(result.inserted_id)}" + f"子任务: {task_name}, 类型: {request_type.value}, " + f"用户: {user_id}, 记录ID: {usage_record.id}" ) - return str(result.inserted_id) + return usage_record.id except Exception as e: logger.error(f"创建模型使用情况记录失败: {str(e)}") return None def update_usage( self, - record_id: str | None, + record_id: int | None, model_info: ModelInfo, usage_data: Tuple[int, int, int] | None = None, stat: UsageCallStatus = UsageCallStatus.SUCCESS, @@ -136,9 +130,6 @@ class ModelUsageStatistic: stat: 任务调用状态 ext_msg: 额外信息 """ - if self.db is None: - return # 如果没有数据库连接,则不记录使用情况 - if not record_id: logger.error("更新模型使用情况失败: record_id不能为空") return @@ -153,28 +144,27 @@ class ModelUsageStatistic: total_tokens = usage_data[2] if usage_data else 0 try: - self.db.llm_usage.update_one( - {"_id": record_id}, - { - "$set": { - "status": stat.value, - "ext_msg": ext_msg, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "cost": self._calculate_cost( - prompt_tokens, completion_tokens, model_info - ) - if usage_data - else 0.0, - } - }, - ) + # 使用Peewee更新记录 + update_query = LLMUsage.update( + status=stat.value, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=self._calculate_cost( + prompt_tokens, completion_tokens, model_info + ) if usage_data else 0.0, + ).where(LLMUsage.id == record_id) + + updated_count = update_query.execute() + + if updated_count == 0: + logger.warning(f"记录ID {record_id} 不存在,无法更新") + return - logger.trace( + logger.debug( f"Token使用情况 - 模型: {model_info.name}, " - f"记录ID: {record_id}, " - f"任务状态: {stat.value}, 额外信息: {ext_msg if ext_msg else 'N/A'}, " + f"记录ID: {record_id}, " + f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, " f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " f"总计: {total_tokens}" ) diff --git a/src/chat/maibot_llmreq/utils.py b/src/llm_models/utils.py similarity index 98% rename from src/chat/maibot_llmreq/utils.py rename to src/llm_models/utils.py index f8bf4fb39..352df5a43 100644 --- a/src/chat/maibot_llmreq/utils.py +++ b/src/llm_models/utils.py @@ -3,9 +3,11 @@ import io from PIL import Image -from . import _logger as logger +from src.common.logger import get_logger from .payload_content.message import Message, MessageBuilder +logger = get_logger("消息压缩工具") + def compress_messages( messages: list[Message], img_target_size: int = 1 * 1024 * 1024 diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index c994cd173..ff03b2788 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,26 +1,39 @@ -import asyncio -import json import re from datetime import datetime -from typing import Tuple, Union, Dict, Any, Callable -import aiohttp -from aiohttp.client import ClientResponse +from typing import Tuple, Union, Dict, Any from src.common.logger import get_logger import base64 from PIL import Image import io -import os import copy # 添加copy模块用于深拷贝 from src.common.database.database import db # 确保 db 被导入用于 create_tables from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from src.config.config import global_config -from src.common.tcp_connector import get_tcp_connector from rich.traceback import install install(extra_lines=3) logger = get_logger("model_utils") +# 新架构导入 - 使用延迟导入以支持fallback模式 +try: + from .model_manager import ModelManager + from .model_client import ModelRequestHandler + from .payload_content.message import MessageBuilder + + # 不在模块级别初始化ModelManager,延迟到实际使用时 + ModelManager_class = ModelManager + model_manager = None # 延迟初始化 + NEW_ARCHITECTURE_AVAILABLE = True + logger.info("新架构模块导入成功") +except Exception as e: + logger.warning(f"新架构不可用,将使用fallback模式: {str(e)}") + ModelManager_class = None + model_manager = None + ModelRequestHandler = None + MessageBuilder = None + NEW_ARCHITECTURE_AVAILABLE = False + class PayLoadTooLargeError(Exception): """自定义异常类,用于处理请求体过大错误""" @@ -36,10 +49,9 @@ class PayLoadTooLargeError(Exception): class RequestAbortException(Exception): """自定义异常类,用于处理请求中断异常""" - def __init__(self, message: str, response: ClientResponse): + def __init__(self, message: str): super().__init__(message) self.message = message - self.response = response def __str__(self): return self.message @@ -59,7 +71,7 @@ class PermissionDeniedException(Exception): # 常见Error Code Mapping error_code_mapping = { 400: "参数不正确", - 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", + 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", 402: "账号余额不足", 403: "需要实名,或余额不足", 404: "Not Found", @@ -82,19 +94,25 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any and isinstance(safe_payload, dict) and "messages" in safe_payload and len(safe_payload["messages"]) > 0 + and isinstance(safe_payload["messages"][0], dict) + and "content" in safe_payload["messages"][0] ): - if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]: - content = safe_payload["messages"][0]["content"] - if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]: - # 只修改拷贝的对象,用于安全的日志记录 - safe_payload["messages"][0]["content"][1]["image_url"]["url"] = ( - f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," - f"{image_base64[:10]}...{image_base64[-10:]}" - ) + content = safe_payload["messages"][0]["content"] + if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]: + # 只修改拷贝的对象,用于安全的日志记录 + safe_payload["messages"][0]["content"][1]["image_url"]["url"] = ( + f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," + f"{image_base64[:10]}...{image_base64[-10:]}" + ) return safe_payload class LLMRequest: + """ + 重构后的LLM请求类,基于新的model_manager和model_client架构 + 保持向后兼容的API接口 + """ + # 定义需要转换的模型列表,作为类变量避免重复 MODELS_NEEDING_TRANSFORMATION = [ "o1", @@ -114,42 +132,78 @@ class LLMRequest: ] def __init__(self, model: dict, **kwargs): - # 将大写的配置键转换为小写并从config中获取实际值 - logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('name', 'Unknown')}") + """ + 初始化LLM请求实例 + Args: + model: 模型配置字典,兼容旧格式和新格式 + **kwargs: 额外参数 + """ + logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('model_name', model.get('name', 'Unknown'))}") logger.debug(f"🔍 [模型初始化] 模型配置: {model}") logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}") - try: - # print(f"model['provider']: {model['provider']}") - self.api_key = os.environ[f"{model['provider']}_KEY"] - self.base_url = os.environ[f"{model['provider']}_BASE_URL"] - logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL") - except AttributeError as e: - logger.error(f"原始 model dict 信息:{model}") - logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") - raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e - except KeyError: - logger.warning( - f"找不到{model['provider']}_KEY或{model['provider']}_BASE_URL环境变量,请检查配置文件或环境变量设置。" - ) - self.model_name: str = model["name"] - self.params = kwargs - - # 记录配置文件中声明了哪些参数(不管值是什么) - self.has_enable_thinking = "enable_thinking" in model - self.has_thinking_budget = "thinking_budget" in model + # 兼容新旧模型配置格式 + # 新格式使用 model_name,旧格式使用 name + self.model_name: str = model.get("model_name", model.get("name", "")) + self.provider = model.get("provider", "") + # 从全局配置中获取任务配置 + self.request_type = kwargs.pop("request_type", "default") + + # 确定使用哪个任务配置 + task_name = self._determine_task_name(model) + + # 尝试初始化新架构 + if NEW_ARCHITECTURE_AVAILABLE and ModelManager_class is not None: + try: + # 延迟初始化ModelManager + global model_manager + if model_manager is None: + from src.config.config import model_config + model_manager = ModelManager_class(model_config) + logger.debug("🔍 [模型初始化] ModelManager延迟初始化成功") + + # 使用新架构获取模型请求处理器 + self.request_handler = model_manager[task_name] + logger.debug(f"🔍 [模型初始化] 成功获取模型请求处理器,任务: {task_name}") + self.use_new_architecture = True + except Exception as e: + logger.warning(f"无法使用新架构,任务 {task_name} 初始化失败: {e}") + logger.warning("回退到兼容模式,某些功能可能受限") + self.request_handler = None + self.use_new_architecture = False + else: + logger.warning("新架构不可用,使用兼容模式") + logger.warning("回退到兼容模式,某些功能可能受限") + self.request_handler = None + self.use_new_architecture = False + + # 保存原始参数用于向后兼容 + self.params = kwargs + + # 兼容性属性,从模型配置中提取 + # 新格式和旧格式都支持 self.enable_thinking = model.get("enable_thinking", False) - self.temp = model.get("temp", 0.7) + self.temp = model.get("temperature", model.get("temp", 0.7)) # 新格式用temperature,旧格式用temp self.thinking_budget = model.get("thinking_budget", 4096) self.stream = model.get("stream", False) self.pri_in = model.get("pri_in", 0) self.pri_out = model.get("pri_out", 0) self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) - # print(f"max_tokens: {self.max_tokens}") - logger.debug(f"🔍 [模型初始化] 模型参数设置完成:") + # 记录配置文件中声明了哪些参数(不管值是什么) + self.has_enable_thinking = "enable_thinking" in model + self.has_thinking_budget = "thinking_budget" in model + self.pri_out = model.get("pri_out", 0) + self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) + + # 记录配置文件中声明了哪些参数(不管值是什么) + self.has_enable_thinking = "enable_thinking" in model + self.has_thinking_budget = "thinking_budget" in model + + logger.debug("🔍 [模型初始化] 模型参数设置完成:") logger.debug(f" - model_name: {self.model_name}") + logger.debug(f" - provider: {self.provider}") logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}") logger.debug(f" - enable_thinking: {self.enable_thinking}") logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}") @@ -157,15 +211,40 @@ class LLMRequest: logger.debug(f" - temp: {self.temp}") logger.debug(f" - stream: {self.stream}") logger.debug(f" - max_tokens: {self.max_tokens}") - logger.debug(f" - base_url: {self.base_url}") + logger.debug(f" - use_new_architecture: {self.use_new_architecture}") # 获取数据库实例 self._init_database() - - # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default" - self.request_type = kwargs.pop("request_type", "default") + logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}") + def _determine_task_name(self, model: dict) -> str: + """ + 根据模型配置确定任务名称 + Args: + model: 模型配置字典 + Returns: + 任务名称 + """ + # 兼容新旧格式的模型名称 + model_name = model.get("model_name", model.get("name", "")) + + # 根据模型名称推断任务类型 + if any(keyword in model_name.lower() for keyword in ["vlm", "vision", "gpt-4o", "claude", "vl-"]): + return "vision" + elif any(keyword in model_name.lower() for keyword in ["embed", "text-embedding", "bge-"]): + return "embedding" + elif any(keyword in model_name.lower() for keyword in ["whisper", "speech", "voice"]): + return "speech" + else: + # 根据request_type确定,映射到配置文件中定义的任务 + if self.request_type in ["memory", "emotion"]: + return "llm_normal" # 映射到配置中的llm_normal任务 + elif self.request_type in ["reasoning"]: + return "llm_reasoning" # 映射到配置中的llm_reasoning任务 + else: + return "llm_normal" # 默认使用llm_normal任务 + @staticmethod def _init_database(): """初始化数据库集合""" @@ -237,660 +316,6 @@ class LLMRequest: output_cost = (completion_tokens / 1000000) * self.pri_out return round(input_cost + output_cost, 6) - async def _prepare_request( - self, - endpoint: str, - prompt: str = None, - image_base64: str = None, - image_format: str = None, - file_bytes: bytes = None, - file_format: str = None, - payload: dict = None, - retry_policy: dict = None, - ) -> Dict[str, Any]: - """配置请求参数 - Args: - endpoint: API端点路径 (如 "chat/completions") - prompt: prompt文本 - image_base64: 图片的base64编码 - image_format: 图片格式 - file_bytes: 文件的二进制数据 - file_format: 文件格式 - payload: 请求体数据 - retry_policy: 自定义重试策略 - request_type: 请求类型 - """ - - # 合并重试策略 - default_retry = { - "max_retries": 3, - "base_wait": 10, - "retry_codes": [429, 413, 500, 503], - "abort_codes": [400, 401, 402, 403], - } - policy = {**default_retry, **(retry_policy or {})} - - api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" - - stream_mode = self.stream - - # 构建请求体 - if image_base64: - payload = await self._build_payload(prompt, image_base64, image_format) - elif file_bytes: - payload = await self._build_formdata_payload(file_bytes, file_format) - elif payload is None: - payload = await self._build_payload(prompt) - - if not file_bytes: - if stream_mode: - payload["stream"] = stream_mode - - if self.temp != 0.7: - payload["temperature"] = self.temp - - # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false) - if self.has_enable_thinking: - payload["enable_thinking"] = self.enable_thinking - - # 添加thinking_budget参数(只有配置文件中声明了才添加) - if self.has_thinking_budget: - payload["thinking_budget"] = self.thinking_budget - - if self.max_tokens: - payload["max_tokens"] = self.max_tokens - - # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length - # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - payload["max_completion_tokens"] = payload.pop("max_tokens") - - return { - "policy": policy, - "payload": payload, - "api_url": api_url, - "stream_mode": stream_mode, - "image_base64": image_base64, # 保留必要的exception处理所需的原始数据 - "image_format": image_format, - "file_bytes": file_bytes, - "file_format": file_format, - "prompt": prompt, - } - - async def _execute_request( - self, - endpoint: str, - prompt: str = None, - image_base64: str = None, - image_format: str = None, - file_bytes: bytes = None, - file_format: str = None, - payload: dict = None, - retry_policy: dict = None, - response_handler: Callable = None, - user_id: str = "system", - request_type: str = None, - ): - """统一请求执行入口 - Args: - endpoint: API端点路径 (如 "chat/completions") - prompt: prompt文本 - image_base64: 图片的base64编码 - image_format: 图片格式 - file_bytes: 文件的二进制数据 - file_format: 文件格式 - payload: 请求体数据 - retry_policy: 自定义重试策略 - response_handler: 自定义响应处理器 - user_id: 用户ID - request_type: 请求类型 - """ - # 获取请求配置 - request_content = await self._prepare_request( - endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy - ) - if request_type is None: - request_type = self.request_type - for retry in range(request_content["policy"]["max_retries"]): - try: - # 使用上下文管理器处理会话 - if file_bytes: - headers = await self._build_headers(is_formdata=True) - else: - headers = await self._build_headers(is_formdata=False) - # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 - if request_content["stream_mode"]: - headers["Accept"] = "text/event-stream" - - # 添加请求发送前的调试信息 - logger.debug(f"🔍 [请求调试] 模型 {self.model_name} 准备发送请求") - logger.debug(f"🔍 [请求调试] API URL: {request_content['api_url']}") - logger.debug(f"🔍 [请求调试] 请求头: {await self._build_headers(no_key=True, is_formdata=file_bytes is not None)}") - - if not file_bytes: - # 安全地记录请求体(隐藏敏感信息) - safe_payload = await _safely_record(request_content, request_content["payload"]) - logger.debug(f"🔍 [请求调试] 请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}") - else: - logger.debug(f"🔍 [请求调试] 文件上传请求,文件格式: {request_content['file_format']}") - - async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: - post_kwargs = {"headers": headers} - # form-data数据上传方式不同 - if file_bytes: - post_kwargs["data"] = request_content["payload"] - else: - post_kwargs["json"] = request_content["payload"] - - async with session.post(request_content["api_url"], **post_kwargs) as response: - handled_result = await self._handle_response( - response, request_content, retry, response_handler, user_id, request_type, endpoint - ) - return handled_result - - except Exception as e: - handled_payload, count_delta = await self._handle_exception(e, retry, request_content) - retry += count_delta # 降级不计入重试次数 - if handled_payload: - # 如果降级成功,重新构建请求体 - request_content["payload"] = handled_payload - continue - - logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败") - raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败") - - async def _handle_response( - self, - response: ClientResponse, - request_content: Dict[str, Any], - retry_count: int, - response_handler: Callable, - user_id, - request_type, - endpoint, - ): - policy = request_content["policy"] - stream_mode = request_content["stream_mode"] - if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]: - await self._handle_error_response(response, retry_count, policy) - return None - - response.raise_for_status() - result = {} - if stream_mode: - # 将流式输出转化为非流式输出 - result = await self._handle_stream_output(response) - else: - result = await response.json() - return ( - response_handler(result) - if response_handler - else self._default_response_handler(result, user_id, request_type, endpoint) - ) - - async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]: - flag_delta_content_finished = False - accumulated_content = "" - usage = None # 初始化usage变量,避免未定义错误 - reasoning_content = "" - content = "" - tool_calls = None # 初始化工具调用变量 - - async for line_bytes in response.content: - try: - line = line_bytes.decode("utf-8").strip() - if not line: - continue - if line.startswith("data:"): - data_str = line[5:].strip() - if data_str == "[DONE]": - break - try: - chunk = json.loads(data_str) - if flag_delta_content_finished: - chunk_usage = chunk.get("usage", None) - if chunk_usage: - usage = chunk_usage # 获取token用量 - else: - delta = chunk["choices"][0]["delta"] - delta_content = delta.get("content") - if delta_content is None: - delta_content = "" - accumulated_content += delta_content - - # 提取工具调用信息 - if "tool_calls" in delta: - if tool_calls is None: - tool_calls = delta["tool_calls"] - else: - # 合并工具调用信息 - tool_calls.extend(delta["tool_calls"]) - - # 检测流式输出文本是否结束 - finish_reason = chunk["choices"][0].get("finish_reason") - if delta.get("reasoning_content", None): - reasoning_content += delta["reasoning_content"] - if finish_reason == "stop" or finish_reason == "tool_calls": - chunk_usage = chunk.get("usage", None) - if chunk_usage: - usage = chunk_usage - break - # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk - flag_delta_content_finished = True - except Exception as e: - logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}") - except Exception as e: - if isinstance(e, GeneratorExit): - log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..." - else: - log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}" - logger.warning(log_content) - # 确保资源被正确清理 - try: - await response.release() - except Exception as cleanup_error: - logger.error(f"清理资源时发生错误: {cleanup_error}") - # 返回已经累积的内容 - content = accumulated_content - if not content: - content = accumulated_content - think_match = re.search(r"(.*?)", content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() - - # 构建消息对象 - message = { - "content": content, - "reasoning_content": reasoning_content, - } - - # 如果有工具调用,添加到消息中 - if tool_calls: - message["tool_calls"] = tool_calls - - result = { - "choices": [{"message": message}], - "usage": usage, - } - return result - - async def _handle_error_response(self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]): - if response.status in policy["retry_codes"]: - wait_time = policy["base_wait"] * (2**retry_count) - logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试") - if response.status == 413: - logger.warning("请求体过大,尝试压缩...") - raise PayLoadTooLargeError("请求体过大") - elif response.status in [500, 503]: - logger.error( - f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}" - ) - raise RuntimeError("服务器负载过高,模型回复失败QAQ") - else: - logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...") - raise RuntimeError("请求限制(429)") - elif response.status in policy["abort_codes"]: - # 特别处理400错误,添加详细调试信息 - if response.status == 400: - logger.error(f"🔍 [调试信息] 模型 {self.model_name} 参数错误 (400) - 开始详细诊断") - logger.error(f"🔍 [调试信息] 模型名称: {self.model_name}") - logger.error(f"🔍 [调试信息] API地址: {self.base_url}") - logger.error(f"🔍 [调试信息] 模型配置参数:") - logger.error(f" - enable_thinking: {self.enable_thinking}") - logger.error(f" - temp: {self.temp}") - logger.error(f" - thinking_budget: {self.thinking_budget}") - logger.error(f" - stream: {self.stream}") - logger.error(f" - max_tokens: {self.max_tokens}") - logger.error(f" - pri_in: {self.pri_in}") - logger.error(f" - pri_out: {self.pri_out}") - logger.error(f"🔍 [调试信息] 原始params: {self.params}") - - # 尝试获取服务器返回的详细错误信息 - try: - error_text = await response.text() - logger.error(f"🔍 [调试信息] 服务器返回的原始错误内容: {error_text}") - - try: - error_json = json.loads(error_text) - logger.error(f"🔍 [调试信息] 解析后的错误JSON: {json.dumps(error_json, indent=2, ensure_ascii=False)}") - except json.JSONDecodeError: - logger.error(f"🔍 [调试信息] 错误响应不是有效的JSON格式") - except Exception as e: - logger.error(f"🔍 [调试信息] 无法读取错误响应内容: {str(e)}") - - raise RequestAbortException("参数错误,请检查调试信息", response) - elif response.status != 403: - raise RequestAbortException("请求出现错误,中断处理", response) - else: - raise PermissionDeniedException("模型禁止访问") - - async def _handle_exception( - self, exception, retry_count: int, request_content: Dict[str, Any] - ) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]: - policy = request_content["policy"] - payload = request_content["payload"] - wait_time = policy["base_wait"] * (2**retry_count) - keep_request = False - if retry_count < policy["max_retries"] - 1: - keep_request = True - if isinstance(exception, RequestAbortException): - response = exception.response - logger.error( - f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}" - ) - - # 如果是400错误,额外输出请求体信息用于调试 - if response.status == 400: - logger.error(f"🔍 [异常调试] 400错误 - 请求体调试信息:") - try: - safe_payload = await _safely_record(request_content, payload) - logger.error(f"🔍 [异常调试] 发送的请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}") - except Exception as debug_error: - logger.error(f"🔍 [异常调试] 无法安全记录请求体: {str(debug_error)}") - logger.error(f"🔍 [异常调试] 原始payload类型: {type(payload)}") - if isinstance(payload, dict): - logger.error(f"🔍 [异常调试] 原始payload键: {list(payload.keys())}") - - # print(request_content) - # print(response) - # 尝试获取并记录服务器返回的详细错误信息 - try: - error_json = await response.json() - if error_json and isinstance(error_json, list) and len(error_json) > 0: - # 处理多个错误的情况 - for error_item in error_json: - if "error" in error_item and isinstance(error_item["error"], dict): - error_obj: dict = error_item["error"] - error_code = error_obj.get("code") - error_message = error_obj.get("message") - error_status = error_obj.get("status") - logger.error( - f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}" - ) - elif isinstance(error_json, dict) and "error" in error_json: - # 处理单个错误对象的情况 - error_obj = error_json.get("error", {}) - error_code = error_obj.get("code") - error_message = error_obj.get("message") - error_status = error_obj.get("status") - logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}") - else: - # 记录原始错误响应内容 - logger.error(f"服务器错误响应: {error_json}") - except Exception as e: - logger.warning(f"无法解析服务器错误响应: {str(e)}") - raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") - - elif isinstance(exception, PermissionDeniedException): - # 只针对硅基流动的V3和R1进行降级处理 - if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/": - old_model_name = self.model_name - self.model_name = self.model_name[4:] # 移除"Pro/"前缀 - logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}") - - # 对全局配置进行更新 - if global_config.model.replyer_2.get("name") == old_model_name: - global_config.model.replyer_2["name"] = self.model_name - logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}") - if global_config.model.replyer_1.get("name") == old_model_name: - global_config.model.replyer_1["name"] = self.model_name - logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") - - if payload and "model" in payload: - payload["model"] = self.model_name - - await asyncio.sleep(wait_time) - return payload, -1 - raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}") - - elif isinstance(exception, PayLoadTooLargeError): - if keep_request: - image_base64 = request_content["image_base64"] - compressed_image_base64 = compress_base64_image_by_scale(image_base64) - new_payload = await self._build_payload( - request_content["prompt"], compressed_image_base64, request_content["image_format"] - ) - return new_payload, 0 - else: - return None, 0 - - elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError): - if keep_request: - logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}") - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}") - raise RuntimeError(f"网络请求失败: {str(exception)}") - - elif isinstance(exception, aiohttp.ClientResponseError): - # 处理aiohttp抛出的,除了policy中的status的响应错误 - if keep_request: - logger.error( - f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}" - ) - try: - error_text = await exception.response.text() - error_json = json.loads(error_text) - if isinstance(error_json, list) and len(error_json) > 0: - # 处理多个错误的情况 - for error_item in error_json: - if "error" in error_item and isinstance(error_item["error"], dict): - error_obj = error_item["error"] - logger.error( - f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, " - f"状态={error_obj.get('status')}, " - f"消息={error_obj.get('message')}" - ) - elif isinstance(error_json, dict) and "error" in error_json: - error_obj = error_json.get("error", {}) - logger.error( - f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, " - f"状态={error_obj.get('status')}, " - f"消息={error_obj.get('message')}" - ) - else: - logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}") - except (json.JSONDecodeError, TypeError) as json_err: - logger.warning( - f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}" - ) - except Exception as parse_err: - logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}") - - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical( - f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}" - ) - # 安全地检查和记录请求详情 - handled_payload = await _safely_record(request_content, payload) - logger.critical( - f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}" - ) - raise RuntimeError( - f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}" - ) - - else: - if keep_request: - logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}") - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}") - # 安全地检查和记录请求详情 - handled_payload = await _safely_record(request_content, payload) - logger.critical( - f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}" - ) - raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") - - async def _transform_parameters(self, params: dict) -> dict: - """ - 根据模型名称转换参数: - - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数, - 并将 'max_tokens' 重命名为 'max_completion_tokens' - """ - # 复制一份参数,避免直接修改原始数据 - new_params = dict(params) - - logger.debug(f"🔍 [参数转换] 模型 {self.model_name} 开始参数转换") - logger.debug(f"🔍 [参数转换] 是否为CoT模型: {self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION}") - logger.debug(f"🔍 [参数转换] CoT模型列表: {self.MODELS_NEEDING_TRANSFORMATION}") - - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION: - logger.debug(f"🔍 [参数转换] 检测到CoT模型,开始参数转换") - # 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度 - if "temperature" in new_params and new_params["temperature"] == 0.7: - removed_temp = new_params.pop("temperature") - logger.debug(f"🔍 [参数转换] 移除默认temperature参数: {removed_temp}") - # 如果存在 'max_tokens',则重命名为 'max_completion_tokens' - if "max_tokens" in new_params: - old_value = new_params["max_tokens"] - new_params["max_completion_tokens"] = new_params.pop("max_tokens") - logger.debug(f"🔍 [参数转换] 参数重命名: max_tokens({old_value}) -> max_completion_tokens({new_params['max_completion_tokens']})") - else: - logger.debug(f"🔍 [参数转换] 非CoT模型,无需参数转换") - - logger.debug(f"🔍 [参数转换] 转换前参数: {params}") - logger.debug(f"🔍 [参数转换] 转换后参数: {new_params}") - return new_params - - async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData: - """构建form-data请求体""" - # 目前只适配了音频文件 - # 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑 - data = aiohttp.FormData() - content_type_list = { - "wav": "audio/wav", - "mp3": "audio/mpeg", - "ogg": "audio/ogg", - "flac": "audio/flac", - "aac": "audio/aac", - } - - content_type = content_type_list.get(file_format) - if not content_type: - logger.warning(f"暂不支持的文件类型: {file_format}") - - data.add_field( - "file", - io.BytesIO(file_bytes), - filename=f"file.{file_format}", - content_type=f"{content_type}", # 根据实际文件类型设置 - ) - data.add_field("model", self.model_name) - return data - - async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: - """构建请求体""" - # 复制一份参数,避免直接修改 self.params - logger.debug(f"🔍 [参数构建] 模型 {self.model_name} 开始构建请求体") - logger.debug(f"🔍 [参数构建] 原始self.params: {self.params}") - - params_copy = await self._transform_parameters(self.params) - logger.debug(f"🔍 [参数构建] 转换后的params_copy: {params_copy}") - - if image_base64: - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}, - }, - ], - } - ] - else: - messages = [{"role": "user", "content": prompt}] - - payload = { - "model": self.model_name, - "messages": messages, - **params_copy, - } - - logger.debug(f"🔍 [参数构建] 基础payload构建完成: {list(payload.keys())}") - - # 添加temp参数(如果不是默认值0.7) - if self.temp != 0.7: - payload["temperature"] = self.temp - logger.debug(f"🔍 [参数构建] 添加temperature参数: {self.temp}") - - # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false) - if self.has_enable_thinking: - payload["enable_thinking"] = self.enable_thinking - logger.debug(f"🔍 [参数构建] 添加enable_thinking参数: {self.enable_thinking}") - - # 添加thinking_budget参数(只有配置文件中声明了才添加) - if self.has_thinking_budget: - payload["thinking_budget"] = self.thinking_budget - logger.debug(f"🔍 [参数构建] 添加thinking_budget参数: {self.thinking_budget}") - - if self.max_tokens: - payload["max_tokens"] = self.max_tokens - logger.debug(f"🔍 [参数构建] 添加max_tokens参数: {self.max_tokens}") - - # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length - # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - old_value = payload["max_tokens"] - payload["max_completion_tokens"] = payload.pop("max_tokens") - logger.debug(f"🔍 [参数构建] CoT模型参数转换: max_tokens({old_value}) -> max_completion_tokens({payload['max_completion_tokens']})") - - logger.debug(f"🔍 [参数构建] 最终payload键列表: {list(payload.keys())}") - return payload - - def _default_response_handler( - self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions" - ) -> Tuple: - """默认响应解析""" - if "choices" in result and result["choices"]: - message = result["choices"][0]["message"] - content = message.get("content", "") - content, reasoning = self._extract_reasoning(content) - reasoning_content = message.get("model_extra", {}).get("reasoning_content", "") - if not reasoning_content: - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - reasoning_content = reasoning - - # 提取工具调用信息 - tool_calls = message.get("tool_calls", None) - - # 记录token使用情况 - usage = result.get("usage", {}) - if usage: - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - total_tokens = usage.get("total_tokens", 0) - self._record_usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - user_id=user_id, - request_type=request_type if request_type is not None else self.request_type, - endpoint=endpoint, - ) - - # 只有当tool_calls存在且不为空时才返回 - if tool_calls: - logger.debug(f"检测到工具调用: {tool_calls}") - return content, reasoning_content, tool_calls - else: - return content, reasoning_content - elif "text" in result and result["text"]: - return result["text"] - return "没有返回结果", "" - @staticmethod def _extract_reasoning(content: str) -> Tuple[str, str]: """CoT思维链提取""" @@ -902,61 +327,183 @@ class LLMRequest: reasoning = "" return content, reasoning - async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict: - """构建请求头""" - if no_key: - if is_formdata: - return {"Authorization": "Bearer **********"} - return {"Authorization": "Bearer **********", "Content-Type": "application/json"} - else: - if is_formdata: - return {"Authorization": f"Bearer {self.api_key}"} - return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - # 防止小朋友们截图自己的key + # === 主要API方法 === + # 这些方法提供与新架构的桥接 async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple: - """根据输入的提示和图片生成模型的异步响应""" - - response = await self._execute_request( - endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format - ) - # 根据返回值的长度决定怎么处理 - if len(response) == 3: - content, reasoning_content, tool_calls = response - return content, reasoning_content, tool_calls - else: - content, reasoning_content = response - return content, reasoning_content + """ + 根据输入的提示和图片生成模型的异步响应 + 使用新架构的模型请求处理器 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + if MessageBuilder is None: + raise RuntimeError("MessageBuilder不可用,请检查新架构配置") + + try: + # 构建包含图片的消息 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt).add_image_content( + image_format=image_format, + image_base64=image_base64 + ) + messages = [message_builder.build()] + + # 使用新架构发送请求(只传递支持的参数) + response = await self.request_handler.get_response( + messages=messages, + tool_options=None, + response_format=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取内容 + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions" + ) + + # 返回格式兼容旧版本 + if tool_calls: + return content, reasoning_content, tool_calls + else: + return content, reasoning_content + + except Exception as e: + logger.error(f"模型 {self.model_name} 图片响应生成失败: {str(e)}") + # 向后兼容的异常处理 + if "401" in str(e) or "API key" in str(e): + raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e + elif "429" in str(e): + raise RuntimeError("请求过于频繁,请稍后再试") from e + elif "500" in str(e) or "503" in str(e): + raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e + else: + raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(e)}") from e async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple: - """根据输入的语音文件生成模型的异步响应""" - response = await self._execute_request( - endpoint="/audio/transcriptions", file_bytes=voice_bytes, file_format="wav" - ) - return response + """ + 根据输入的语音文件生成模型的异步响应 + 使用新架构的模型请求处理器 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + try: + # 构建语音识别请求参数 + # 注意:新架构中的语音识别可能使用不同的方法 + # 这里先使用get_response方法,可能需要根据实际API调整 + response = await self.request_handler.get_response( + messages=[], # 语音识别可能不需要消息 + tool_options=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取文本内容 + if response.content: + return response.content + else: + return "" + + except Exception as e: + logger.error(f"模型 {self.model_name} 语音识别失败: {str(e)}") + # 向后兼容的异常处理 + if "401" in str(e) or "API key" in str(e): + raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e + elif "429" in str(e): + raise RuntimeError("请求过于频繁,请稍后再试") from e + elif "500" in str(e) or "503" in str(e): + raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e + else: + raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(e)}") from e async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: - """异步方式根据输入的提示生成模型的响应""" - # 构建请求体,不硬编码max_tokens - data = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - **self.params, - **kwargs, - } - - response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt) - # 原样返回响应,不做处理 - - if len(response) == 3: - content, reasoning_content, tool_calls = response - return content, (reasoning_content, self.model_name, tool_calls) - else: - content, reasoning_content = response - return content, (reasoning_content, self.model_name) + """ + 异步方式根据输入的提示生成模型的响应 + 使用新架构的模型请求处理器,如无法使用则抛出错误 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + if MessageBuilder is None: + raise RuntimeError("MessageBuilder不可用,请检查新架构配置") + + try: + # 构建消息 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + messages = [message_builder.build()] + + # 使用新架构发送请求(只传递支持的参数) + response = await self.request_handler.get_response( + messages=messages, + tool_options=None, + response_format=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取内容 + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions" + ) + + # 返回格式兼容旧版本 + if tool_calls: + return content, (reasoning_content, self.model_name, tool_calls) + else: + return content, (reasoning_content, self.model_name) + + except Exception as e: + logger.error(f"模型 {self.model_name} 生成响应失败: {str(e)}") + # 向后兼容的异常处理 + if "401" in str(e) or "API key" in str(e): + raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e + elif "429" in str(e): + raise RuntimeError("请求过于频繁,请稍后再试") from e + elif "500" in str(e) or "503" in str(e): + raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e + else: + raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(e)}") from e async def get_embedding(self, text: str) -> Union[list, None]: - """异步方法:获取文本的embedding向量 + """ + 异步方法:获取文本的embedding向量 + 使用新架构的模型请求处理器 Args: text: 需要获取embedding的文本 @@ -964,42 +511,51 @@ class LLMRequest: Returns: list: embedding向量,如果失败则返回None """ - if len(text) < 1: logger.debug("该消息没有长度,不再发送获取embedding向量的请求") return None - def embedding_handler(result): - """处理响应""" - if "data" in result and len(result["data"]) > 0: - # 提取 token 使用信息 - usage = result.get("usage", {}) - if usage: - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - total_tokens = usage.get("total_tokens", 0) - # 记录 token 使用情况 - self._record_usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - user_id="system", # 可以根据需要修改 user_id - # request_type="embedding", # 请求类型为 embedding - request_type=self.request_type, # 请求类型为 text - endpoint="/embeddings", # API 端点 - ) - return result["data"][0].get("embedding", None) - return result["data"][0].get("embedding", None) + if not self.use_new_architecture: + logger.warning(f"模型 {self.model_name} 无法使用新架构,embedding请求将被跳过") return None - embedding = await self._execute_request( - endpoint="/embeddings", - prompt=text, - payload={"model": self.model_name, "input": text, "encoding_format": "float"}, - retry_policy={"max_retries": 2, "base_wait": 6}, - response_handler=embedding_handler, - ) - return embedding + try: + # 构建embedding请求参数 + # 使用新架构的get_embedding方法 + response = await self.request_handler.get_embedding(text) + + # 新架构返回的是 APIResponse 对象,直接提取embedding + if response.embedding: + embedding = response.embedding + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings" + ) + + return embedding + else: + logger.warning(f"模型 {self.model_name} 返回的embedding响应为空") + return None + + except Exception as e: + logger.error(f"模型 {self.model_name} 获取embedding失败: {str(e)}") + # 向后兼容的异常处理 + if "401" in str(e) or "API key" in str(e): + raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e + elif "429" in str(e): + raise RuntimeError("请求过于频繁,请稍后再试") from e + elif "500" in str(e) or "503" in str(e): + raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e + else: + logger.warning(f"模型 {self.model_name} embedding请求失败,返回None: {str(e)}") + return None def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: diff --git a/template/compare/model_config_template.toml b/template/compare/model_config_template.toml new file mode 100644 index 000000000..f9055fcea --- /dev/null +++ b/template/compare/model_config_template.toml @@ -0,0 +1,77 @@ +[inner] +version = "0.1.0" + +# 配置文件版本号迭代规则同bot_config.toml + +[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释) +#max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) +#timeout = 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) +#retry_interval = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) +#default_temperature = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) +#default_max_tokens = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) + + +[[api_providers]] # API服务提供商(可以配置多个) +name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) +base_url = "https://api.deepseek.cn" # API服务商的BaseURL +key = "******" # API Key (可选,默认为None) +client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"google") + +#[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"google" +#name = "Google" +#base_url = "https://api.google.com" +#key = "******" +#client_type = "google" +# +#[[api_providers]] +#name = "SiliconFlow" +#base_url = "https://api.siliconflow.cn" +#key = "******" +# +#[[api_providers]] +#name = "LocalHost" +#base_url = "https://localhost:8888" +#key = "lm-studio" + + +[[models]] # 模型(可以配置多个) +# 模型标识符(API服务商提供的模型标识符) +model_identifier = "deepseek-chat" +# 模型名称(可随意命名,在bot_config.toml中需使用这个命名) +#(可选,若无该字段,则将自动使用model_identifier填充) +name = "deepseek-v3" +# API服务商名称(对应在api_providers中配置的服务商名称) +api_provider = "DeepSeek" +# 输入价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0) +price_in = 2.0 +# 输出价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0) +price_out = 8.0 +# 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出) +#(可选,若无该字段,默认值为false) +#force_stream_mode = true + +#[[models]] +#model_identifier = "deepseek-reasoner" +#name = "deepseek-r1" +#api_provider = "DeepSeek" +#model_flags = ["text", "tool_calling", "reasoning"] +#price_in = 4.0 +#price_out = 16.0 +# +#[[models]] +#model_identifier = "BAAI/bge-m3" +#name = "siliconflow-bge-m3" +#api_provider = "SiliconFlow" +#model_flags = ["text", "embedding"] +#price_in = 0 +#price_out = 0 + + +[task_model_usage] +#llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0} +#llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0} +#embedding = "siliconflow-bge-m3" +#schedule = [ +# "deepseek-v3", +# "deepseek-r1", +#] \ No newline at end of file