From 909e47bcee95b1f5f43a19520240d9f3019c2bb1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com>
Date: Fri, 25 Jul 2025 13:21:48 +0800
Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E9=87=8D=E6=9E=84llmrequest?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
debug_config.py | 111 ++
src/chat/maibot_llmreq/__init__.py | 19 -
src/chat/maibot_llmreq/config/parser.py | 267 ----
.../maibot_llmreq/tests/test_config_load.py | 84 --
src/config/config.py | 5 +-
.../maibot_llmreq => llm_models}/LICENSE | 0
.../config => llm_models}/__init__.py | 0
.../exceptions.py | 0
.../model_client/__init__.py | 10 +-
.../model_client/base_client.py | 2 +-
.../model_client/gemini_client.py | 2 +-
.../model_client/openai_client.py | 2 +-
.../model_manager.py | 12 +-
.../payload_content/message.py | 0
.../payload_content/resp_format.py | 0
.../payload_content/tool_option.py | 0
src/llm_models/temp.py | 8 +
.../usage_statistic.py | 136 +-
.../maibot_llmreq => llm_models}/utils.py | 4 +-
src/llm_models/utils_model.py | 1110 +++++------------
template/compare/model_config_template.toml | 77 ++
21 files changed, 612 insertions(+), 1237 deletions(-)
create mode 100644 debug_config.py
delete mode 100644 src/chat/maibot_llmreq/__init__.py
delete mode 100644 src/chat/maibot_llmreq/config/parser.py
delete mode 100644 src/chat/maibot_llmreq/tests/test_config_load.py
rename src/{chat/maibot_llmreq => llm_models}/LICENSE (100%)
rename src/{chat/maibot_llmreq/config => llm_models}/__init__.py (100%)
rename src/{chat/maibot_llmreq => llm_models}/exceptions.py (100%)
rename src/{chat/maibot_llmreq => llm_models}/model_client/__init__.py (98%)
rename src/{chat/maibot_llmreq => llm_models}/model_client/base_client.py (98%)
rename src/{chat/maibot_llmreq => llm_models}/model_client/gemini_client.py (99%)
rename src/{chat/maibot_llmreq => llm_models}/model_client/openai_client.py (99%)
rename src/{chat/maibot_llmreq => llm_models}/model_manager.py (92%)
rename src/{chat/maibot_llmreq => llm_models}/payload_content/message.py (100%)
rename src/{chat/maibot_llmreq => llm_models}/payload_content/resp_format.py (100%)
rename src/{chat/maibot_llmreq => llm_models}/payload_content/tool_option.py (100%)
create mode 100644 src/llm_models/temp.py
rename src/{chat/maibot_llmreq => llm_models}/usage_statistic.py (52%)
rename src/{chat/maibot_llmreq => llm_models}/utils.py (98%)
create mode 100644 template/compare/model_config_template.toml
diff --git a/debug_config.py b/debug_config.py
new file mode 100644
index 000000000..a2b960e5c
--- /dev/null
+++ b/debug_config.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+"""
+调试配置加载问题,查看API provider的配置是否正确传递
+"""
+
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
+def debug_config_loading():
+ try:
+ # 临时配置API key
+ import toml
+ config_path = "config/model_config.toml"
+
+ with open(config_path, 'r', encoding='utf-8') as f:
+ config = toml.load(f)
+
+ original_keys = {}
+ for provider in config['api_providers']:
+ original_keys[provider['name']] = provider['api_key']
+ provider['api_key'] = f"sk-test-key-for-{provider['name'].lower()}-12345"
+
+ with open(config_path, 'w', encoding='utf-8') as f:
+ toml.dump(config, f)
+
+ print("✅ 配置了测试API key")
+
+ try:
+ # 清空缓存
+ modules_to_remove = [
+ 'src.config.config',
+ 'src.config.api_ada_configs',
+ 'src.llm_models.model_manager',
+ 'src.llm_models.model_client',
+ 'src.llm_models.utils_model'
+ ]
+ for module in modules_to_remove:
+ if module in sys.modules:
+ del sys.modules[module]
+
+ # 导入配置
+ from src.config.config import model_config
+ print("\n🔍 调试配置加载:")
+ print(f"model_config类型: {type(model_config)}")
+
+ # 检查API providers
+ if hasattr(model_config, 'api_providers'):
+ print(f"API providers数量: {len(model_config.api_providers)}")
+ for name, provider in model_config.api_providers.items():
+ print(f" - {name}: {provider.base_url}")
+ print(f" API key: {provider.api_key[:10]}...{provider.api_key[-5:] if len(provider.api_key) > 15 else provider.api_key}")
+ print(f" Client type: {provider.client_type}")
+
+ # 检查模型配置
+ if hasattr(model_config, 'models'):
+ print(f"模型数量: {len(model_config.models)}")
+ for name, model in model_config.models.items():
+ print(f" - {name}: {model.model_identifier} (提供商: {model.api_provider})")
+
+ # 检查任务配置
+ if hasattr(model_config, 'task_model_arg_map'):
+ print(f"任务配置数量: {len(model_config.task_model_arg_map)}")
+ for task_name, task_config in model_config.task_model_arg_map.items():
+ print(f" - {task_name}: {task_config}")
+
+ # 尝试初始化ModelManager
+ print("\n🔍 调试ModelManager初始化:")
+ from src.llm_models.model_manager import ModelManager
+
+ try:
+ model_manager = ModelManager(model_config)
+ print("✅ ModelManager初始化成功")
+
+ # 检查API客户端映射
+ print(f"API客户端数量: {len(model_manager.api_client_map)}")
+ for name, client in model_manager.api_client_map.items():
+ print(f" - {name}: {type(client).__name__}")
+ if hasattr(client, 'client') and hasattr(client.client, 'api_key'):
+ api_key = client.client.api_key
+ print(f" Client API key: {api_key[:10]}...{api_key[-5:] if len(api_key) > 15 else api_key}")
+
+ # 尝试获取任务处理器
+ try:
+ handler = model_manager["llm_normal"]
+ print("✅ 成功获取llm_normal任务处理器")
+ print(f"任务处理器类型: {type(handler).__name__}")
+ except Exception as e:
+ print(f"❌ 获取任务处理器失败: {e}")
+
+ except Exception as e:
+ print(f"❌ ModelManager初始化失败: {e}")
+ import traceback
+ traceback.print_exc()
+
+ finally:
+ # 恢复配置
+ for provider in config['api_providers']:
+ provider['api_key'] = original_keys[provider['name']]
+
+ with open(config_path, 'w', encoding='utf-8') as f:
+ toml.dump(config, f)
+ print("\n✅ 配置已恢复")
+
+ except Exception as e:
+ print(f"❌ 调试失败: {e}")
+ import traceback
+ traceback.print_exc()
+
+if __name__ == "__main__":
+ debug_config_loading()
diff --git a/src/chat/maibot_llmreq/__init__.py b/src/chat/maibot_llmreq/__init__.py
deleted file mode 100644
index aab812cfa..000000000
--- a/src/chat/maibot_llmreq/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import loguru
-
-type LoguruLogger = loguru.Logger
-
-_logger: LoguruLogger = loguru.logger
-
-
-def init_logger(
- logger: LoguruLogger | None = None,
-):
- """
- 对LLMRequest模块进行配置
- :param logger: 日志对象
- """
- global _logger # 申明使用全局变量
- if logger:
- _logger = logger
- else:
- _logger.warning("Warning: No logger provided, using default logger.")
diff --git a/src/chat/maibot_llmreq/config/parser.py b/src/chat/maibot_llmreq/config/parser.py
deleted file mode 100644
index a6877835d..000000000
--- a/src/chat/maibot_llmreq/config/parser.py
+++ /dev/null
@@ -1,267 +0,0 @@
-import os
-from typing import Any, Dict, List
-
-import tomli
-from packaging import version
-from packaging.specifiers import SpecifierSet
-from packaging.version import Version, InvalidVersion
-
-from .. import _logger as logger
-
-from .config import (
- ModelUsageArgConfigItem,
- ModelUsageArgConfig,
- APIProvider,
- ModelInfo,
- NEWEST_VER,
- ModuleConfig,
-)
-
-
-def _get_config_version(toml: Dict) -> Version:
- """提取配置文件的 SpecifierSet 版本数据
- Args:
- toml[dict]: 输入的配置文件字典
- Returns:
- Version
- """
-
- if "inner" in toml and "version" in toml["inner"]:
- config_version: str = toml["inner"]["version"]
- else:
- config_version = "0.0.0" # 默认版本
-
- try:
- ver = version.parse(config_version)
- except InvalidVersion as e:
- logger.error(
- "配置文件中 inner段 的 version 键是错误的版本描述\n"
- f"请检查配置文件,当前 version 键: {config_version}\n"
- f"错误信息: {e}"
- )
- raise InvalidVersion(
- "配置文件中 inner段 的 version 键是错误的版本描述\n"
- ) from e
-
- return ver
-
-
-def _request_conf(parent: Dict, config: ModuleConfig):
- request_conf_config = parent.get("request_conf")
- config.req_conf.max_retry = request_conf_config.get(
- "max_retry", config.req_conf.max_retry
- )
- config.req_conf.timeout = request_conf_config.get(
- "timeout", config.req_conf.timeout
- )
- config.req_conf.retry_interval = request_conf_config.get(
- "retry_interval", config.req_conf.retry_interval
- )
- config.req_conf.default_temperature = request_conf_config.get(
- "default_temperature", config.req_conf.default_temperature
- )
- config.req_conf.default_max_tokens = request_conf_config.get(
- "default_max_tokens", config.req_conf.default_max_tokens
- )
-
-
-def _api_providers(parent: Dict, config: ModuleConfig):
- api_providers_config = parent.get("api_providers")
- for provider in api_providers_config:
- name = provider.get("name", None)
- base_url = provider.get("base_url", None)
- api_key = provider.get("api_key", None)
- client_type = provider.get("client_type", "openai")
-
- if name in config.api_providers: # 查重
- logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
- raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
-
- if name and base_url:
- config.api_providers[name] = APIProvider(
- name=name,
- base_url=base_url,
- api_key=api_key,
- client_type=client_type,
- )
- else:
- logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
- raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
-
-
-def _models(parent: Dict, config: ModuleConfig):
- models_config = parent.get("models")
- for model in models_config:
- model_identifier = model.get("model_identifier", None)
- name = model.get("name", model_identifier)
- api_provider = model.get("api_provider", None)
- price_in = model.get("price_in", 0.0)
- price_out = model.get("price_out", 0.0)
- force_stream_mode = model.get("force_stream_mode", False)
-
- if name in config.models: # 查重
- logger.error(f"重复的模型名称: {name},请检查配置文件。")
- raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
-
- if model_identifier and api_provider:
- # 检查API提供商是否存在
- if api_provider not in config.api_providers:
- logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
- raise ValueError(
- f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
- )
- config.models[name] = ModelInfo(
- name=name,
- model_identifier=model_identifier,
- api_provider=api_provider,
- price_in=price_in,
- price_out=price_out,
- force_stream_mode=force_stream_mode,
- )
- else:
- logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
- raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
-
-
-def _task_model_usage(parent: Dict, config: ModuleConfig):
- model_usage_configs = parent.get("task_model_usage")
- config.task_model_arg_map = {}
- for task_name, item in model_usage_configs.items():
- if task_name in config.task_model_arg_map:
- logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
- raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
-
- usage = []
- if isinstance(item, Dict):
- if "model" in item:
- usage.append(
- ModelUsageArgConfigItem(
- name=item["model"],
- temperature=item.get("temperature", None),
- max_tokens=item.get("max_tokens", None),
- max_retry=item.get("max_retry", None),
- )
- )
- else:
- logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
- raise ValueError(
- f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
- )
- elif isinstance(item, List):
- for model in item:
- if isinstance(model, Dict):
- usage.append(
- ModelUsageArgConfigItem(
- name=model["model"],
- temperature=model.get("temperature", None),
- max_tokens=model.get("max_tokens", None),
- max_retry=model.get("max_retry", None),
- )
- )
- elif isinstance(model, str):
- usage.append(
- ModelUsageArgConfigItem(
- name=model,
- temperature=None,
- max_tokens=None,
- max_retry=None,
- )
- )
- else:
- logger.error(
- f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
- )
- raise ValueError(
- f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
- )
- elif isinstance(item, str):
- usage.append(
- ModelUsageArgConfigItem(
- name=item,
- temperature=None,
- max_tokens=None,
- max_retry=None,
- )
- )
-
- config.task_model_arg_map[task_name] = ModelUsageArgConfig(
- name=task_name,
- usage=usage,
- )
-
-
-def load_config(config_path: str) -> ModuleConfig:
- """从TOML配置文件加载配置"""
- config = ModuleConfig()
-
- include_configs: Dict[str, Dict[str, Any]] = {
- "request_conf": {
- "func": _request_conf,
- "support": ">=0.0.0",
- "necessary": False,
- },
- "api_providers": {"func": _api_providers, "support": ">=0.0.0"},
- "models": {"func": _models, "support": ">=0.0.0"},
- "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
- }
-
- if os.path.exists(config_path):
- with open(config_path, "rb") as f:
- try:
- toml_dict = tomli.load(f)
- except tomli.TOMLDecodeError as e:
- logger.critical(
- f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}"
- )
- exit(1)
-
- # 获取配置文件版本
- config.INNER_VERSION = _get_config_version(toml_dict)
-
- # 检查版本
- if config.INNER_VERSION > Version(NEWEST_VER):
- logger.warning(
- f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
- )
-
- # 解析配置文件
- # 如果在配置中找到了需要的项,调用对应项的闭包函数处理
- for key in include_configs:
- if key in toml_dict:
- group_specifier_set: SpecifierSet = SpecifierSet(
- include_configs[key]["support"]
- )
-
- # 检查配置文件版本是否在支持范围内
- if config.INNER_VERSION in group_specifier_set:
- # 如果版本在支持范围内,检查是否存在通知
- if "notice" in include_configs[key]:
- logger.warning(include_configs[key]["notice"])
- # 调用闭包函数处理配置
- (include_configs[key]["func"])(toml_dict, config)
- else:
- # 如果版本不在支持范围内,崩溃并提示用户
- logger.error(
- f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
- f"当前程序仅支持以下版本范围: {group_specifier_set}"
- )
- raise InvalidVersion(
- f"当前程序仅支持以下版本范围: {group_specifier_set}"
- )
-
- # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理
- elif (
- "necessary" in include_configs[key]
- and include_configs[key].get("necessary") is False
- ):
- # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
- if key == "keywords_reaction":
- pass
- else:
- # 如果用户根本没有需要的配置项,提示缺少配置
- logger.error(f"配置文件中缺少必需的字段: '{key}'")
- raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
-
- logger.success(f"成功加载配置文件: {config_path}")
-
- return config
diff --git a/src/chat/maibot_llmreq/tests/test_config_load.py b/src/chat/maibot_llmreq/tests/test_config_load.py
deleted file mode 100644
index 7553cb91c..000000000
--- a/src/chat/maibot_llmreq/tests/test_config_load.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import pytest
-from packaging.version import InvalidVersion
-
-from src import maibot_llmreq
-from src.maibot_llmreq.config.parser import _get_config_version, load_config
-
-
-class TestConfigLoad:
- def test_loads_valid_version_from_toml(self):
- maibot_llmreq.init_logger()
-
- toml_data = {"inner": {"version": "1.2.3"}}
- version = _get_config_version(toml_data)
- assert str(version) == "1.2.3"
-
- def test_handles_missing_version_key(self):
- maibot_llmreq.init_logger()
-
- toml_data = {}
- version = _get_config_version(toml_data)
- assert str(version) == "0.0.0"
-
- def test_raises_error_for_invalid_version(self):
- maibot_llmreq.init_logger()
-
- toml_data = {"inner": {"version": "invalid_version"}}
- with pytest.raises(InvalidVersion):
- _get_config_version(toml_data)
-
- def test_loads_complete_config_successfully(self, tmp_path):
- maibot_llmreq.init_logger()
-
- config_path = tmp_path / "config.toml"
- config_path.write_text("""
- [inner]
- version = "0.1.0"
-
- [request_conf]
- max_retry = 5
- timeout = 10
-
- [[api_providers]]
- name = "provider1"
- base_url = "https://api.example.com"
- api_key = "key123"
-
- [[api_providers]]
- name = "provider2"
- base_url = "https://api.example2.com"
- api_key = "key456"
-
- [[models]]
- model_identifier = "model1"
- api_provider = "provider1"
-
- [[models]]
- model_identifier = "model2"
- api_provider = "provider2"
-
- [task_model_usage]
- task1 = { model = "model1" }
- task2 = "model1"
- task3 = [
- "model1",
- { model = "model2", temperature = 0.5 }
- ]
- """)
- config = load_config(str(config_path))
- assert config.req_conf.max_retry == 5
- assert config.req_conf.timeout == 10
- assert "provider1" in config.api_providers
- assert "model1" in config.models
- assert "task1" in config.task_model_arg_map
-
- def test_raises_error_for_missing_required_field(self, tmp_path):
- maibot_llmreq.init_logger()
-
- config_path = tmp_path / "config.toml"
- config_path.write_text("""
- [inner]
- version = "1.0.0"
- """)
- with pytest.raises(KeyError):
- load_config(str(config_path))
diff --git a/src/config/config.py b/src/config/config.py
index bd2d58f04..95ad198a1 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -13,7 +13,6 @@ from packaging.version import Version, InvalidVersion
from typing import Any, Dict, List
from src.common.logger import get_logger
-from src.common.message import api
from src.config.config_base import ConfigBase
from src.config.official_configs import (
BotConfig,
@@ -314,7 +313,7 @@ def api_ada_load_config(config_path: str) -> ModuleConfig:
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
- logger.success(f"成功加载配置文件: {config_path}")
+ logger.info(f"成功加载配置文件: {config_path}")
return config
@@ -653,4 +652,4 @@ update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
-logger.info("非常的新鲜,非常的美味!")
+logger.info("非常的新鲜,非常的美味!")
\ No newline at end of file
diff --git a/src/chat/maibot_llmreq/LICENSE b/src/llm_models/LICENSE
similarity index 100%
rename from src/chat/maibot_llmreq/LICENSE
rename to src/llm_models/LICENSE
diff --git a/src/chat/maibot_llmreq/config/__init__.py b/src/llm_models/__init__.py
similarity index 100%
rename from src/chat/maibot_llmreq/config/__init__.py
rename to src/llm_models/__init__.py
diff --git a/src/chat/maibot_llmreq/exceptions.py b/src/llm_models/exceptions.py
similarity index 100%
rename from src/chat/maibot_llmreq/exceptions.py
rename to src/llm_models/exceptions.py
diff --git a/src/chat/maibot_llmreq/model_client/__init__.py b/src/llm_models/model_client/__init__.py
similarity index 98%
rename from src/chat/maibot_llmreq/model_client/__init__.py
rename to src/llm_models/model_client/__init__.py
index 9dc28d07d..ebe802df2 100644
--- a/src/chat/maibot_llmreq/model_client/__init__.py
+++ b/src/llm_models/model_client/__init__.py
@@ -5,8 +5,7 @@ from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from .base_client import BaseClient, APIResponse
-from .. import _logger as logger
-from ..config.config import (
+from src.config.api_ada_configs import (
ModelInfo,
ModelUsageArgConfigItem,
RequestConfig,
@@ -22,6 +21,9 @@ from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption
from ..utils import compress_messages
+from src.common.logger import get_logger
+
+logger = get_logger("模型客户端")
def _check_retry(
@@ -288,7 +290,7 @@ class ModelRequestHandler:
interrupt_flag=interrupt_flag,
)
except Exception as e:
- logger.trace(e)
+ logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
@@ -340,7 +342,7 @@ class ModelRequestHandler:
embedding_input=embedding_input,
)
except Exception as e:
- logger.trace(e)
+ logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
diff --git a/src/chat/maibot_llmreq/model_client/base_client.py b/src/llm_models/model_client/base_client.py
similarity index 98%
rename from src/chat/maibot_llmreq/model_client/base_client.py
rename to src/llm_models/model_client/base_client.py
index ed877a6c9..50a379d34 100644
--- a/src/chat/maibot_llmreq/model_client/base_client.py
+++ b/src/llm_models/model_client/base_client.py
@@ -5,7 +5,7 @@ from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
-from ..config.config import ModelInfo, APIProvider
+from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
diff --git a/src/chat/maibot_llmreq/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py
similarity index 99%
rename from src/chat/maibot_llmreq/model_client/gemini_client.py
rename to src/llm_models/model_client/gemini_client.py
index 75d2767e0..1861ca1d5 100644
--- a/src/chat/maibot_llmreq/model_client/gemini_client.py
+++ b/src/llm_models/model_client/gemini_client.py
@@ -15,7 +15,7 @@ from google.genai.errors import (
)
from .base_client import APIResponse, UsageRecord
-from ..config.config import ModelInfo, APIProvider
+from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from ..exceptions import (
diff --git a/src/chat/maibot_llmreq/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py
similarity index 99%
rename from src/chat/maibot_llmreq/model_client/openai_client.py
rename to src/llm_models/model_client/openai_client.py
index db256b2d4..e5da59022 100644
--- a/src/chat/maibot_llmreq/model_client/openai_client.py
+++ b/src/llm_models/model_client/openai_client.py
@@ -21,7 +21,7 @@ from openai.types.chat import (
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from .base_client import APIResponse, UsageRecord
-from ..config.config import ModelInfo, APIProvider
+from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from ..exceptions import (
diff --git a/src/chat/maibot_llmreq/model_manager.py b/src/llm_models/model_manager.py
similarity index 92%
rename from src/chat/maibot_llmreq/model_manager.py
rename to src/llm_models/model_manager.py
index 3056b187a..5d983849b 100644
--- a/src/chat/maibot_llmreq/model_manager.py
+++ b/src/llm_models/model_manager.py
@@ -1,15 +1,13 @@
import importlib
from typing import Dict
+from src.config.config import model_config
+from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
+from src.common.logger import get_logger
-from .config.config import (
- ModelUsageArgConfig,
- ModuleConfig,
-)
-
-from . import _logger as logger
from .model_client import ModelRequestHandler, BaseClient
+logger = get_logger("模型管理器")
class ModelManager:
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
@@ -77,3 +75,5 @@ class ModelManager:
:return: 是否在模型列表中
"""
return task_name in self.config.task_model_arg_map
+
+
diff --git a/src/chat/maibot_llmreq/payload_content/message.py b/src/llm_models/payload_content/message.py
similarity index 100%
rename from src/chat/maibot_llmreq/payload_content/message.py
rename to src/llm_models/payload_content/message.py
diff --git a/src/chat/maibot_llmreq/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py
similarity index 100%
rename from src/chat/maibot_llmreq/payload_content/resp_format.py
rename to src/llm_models/payload_content/resp_format.py
diff --git a/src/chat/maibot_llmreq/payload_content/tool_option.py b/src/llm_models/payload_content/tool_option.py
similarity index 100%
rename from src/chat/maibot_llmreq/payload_content/tool_option.py
rename to src/llm_models/payload_content/tool_option.py
diff --git a/src/llm_models/temp.py b/src/llm_models/temp.py
new file mode 100644
index 000000000..89755a314
--- /dev/null
+++ b/src/llm_models/temp.py
@@ -0,0 +1,8 @@
+
+import sys
+import os
+sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
+
+from src.config.config import model_config
+print(f"当前模型配置: {model_config}")
+print(model_config.req_conf.default_max_tokens)
\ No newline at end of file
diff --git a/src/chat/maibot_llmreq/usage_statistic.py b/src/llm_models/usage_statistic.py
similarity index 52%
rename from src/chat/maibot_llmreq/usage_statistic.py
rename to src/llm_models/usage_statistic.py
index 3c5490e3e..176c4b7b1 100644
--- a/src/chat/maibot_llmreq/usage_statistic.py
+++ b/src/llm_models/usage_statistic.py
@@ -2,10 +2,11 @@ from datetime import datetime
from enum import Enum
from typing import Tuple
-from pymongo.synchronous.database import Database
+from src.common.logger import get_logger
+from src.config.api_ada_configs import ModelInfo
+from src.common.database.database_model import LLMUsage
-from . import _logger as logger
-from .config.config import ModelInfo
+logger = get_logger("模型使用统计")
class ReqType(Enum):
@@ -29,33 +30,21 @@ class UsageCallStatus(Enum):
class ModelUsageStatistic:
- db: Database | None = None
+ """
+ 模型使用统计类 - 使用SQLite+Peewee
+ """
- def __init__(self, db: Database):
- if db is None:
- logger.warning(
- "Warning: No database provided, ModelUsageStatistic will not work."
- )
- return
- if self._init_database(db):
- # 成功初始化
- self.db = db
-
- @staticmethod
- def _init_database(db: Database):
+ def __init__(self):
"""
- 初始化数据库相关索引
+ 初始化统计类
+ 由于使用Peewee ORM,不需要传入数据库实例
"""
+ # 确保表已经创建
try:
- db.llm_usage.create_index([("timestamp", 1)])
- db.llm_usage.create_index([("model_name", 1)])
- db.llm_usage.create_index([("task_name", 1)])
- db.llm_usage.create_index([("request_type", 1)])
- db.llm_usage.create_index([("status", 1)])
- return True
+ from src.common.database.database import db
+ db.create_tables([LLMUsage], safe=True)
except Exception as e:
- logger.error(f"创建数据库索引失败: {e}")
- return False
+ logger.error(f"创建LLMUsage表失败: {e}")
@staticmethod
def _calculate_cost(
@@ -67,6 +56,7 @@ class ModelUsageStatistic:
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
+ model_info: 模型信息
Returns:
float: 总成本(元)
@@ -81,46 +71,50 @@ class ModelUsageStatistic:
model_name: str,
task_name: str = "N/A",
request_type: ReqType = ReqType.CHAT,
- ) -> str | None:
+ user_id: str = "system",
+ endpoint: str = "/chat/completions",
+ ) -> int | None:
"""
创建模型使用情况记录
- :param model_name: 模型名
- :param task_name: 任务名称
- :param request_type: 请求类型,默认为Chat
- :return:
- """
- if self.db is None:
- return None # 如果没有数据库连接,则不记录使用情况
+ Args:
+ model_name: 模型名
+ task_name: 任务名称
+ request_type: 请求类型,默认为Chat
+ user_id: 用户ID,默认为system
+ endpoint: API端点
+
+ Returns:
+ int | None: 返回记录ID,失败返回None
+ """
try:
- usage_data = {
- "model_name": model_name,
- "task_name": task_name,
- "request_type": request_type.value,
- "prompt_tokens": 0,
- "completion_tokens": 0,
- "total_tokens": 0,
- "cost": 0.0,
- "status": "processing",
- "timestamp": datetime.now(),
- "ext_msg": None,
- }
- result = self.db.llm_usage.insert_one(usage_data)
+ usage_record = LLMUsage.create(
+ model_name=model_name,
+ user_id=user_id,
+ request_type=request_type.value,
+ endpoint=endpoint,
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ cost=0.0,
+ status=UsageCallStatus.PROCESSING.value,
+ timestamp=datetime.now(),
+ )
logger.trace(
f"创建了一条模型使用情况记录 - 模型: {model_name}, "
- f"子任务: {task_name}, 类型: {request_type}"
- f"记录ID: {str(result.inserted_id)}"
+ f"子任务: {task_name}, 类型: {request_type.value}, "
+ f"用户: {user_id}, 记录ID: {usage_record.id}"
)
- return str(result.inserted_id)
+ return usage_record.id
except Exception as e:
logger.error(f"创建模型使用情况记录失败: {str(e)}")
return None
def update_usage(
self,
- record_id: str | None,
+ record_id: int | None,
model_info: ModelInfo,
usage_data: Tuple[int, int, int] | None = None,
stat: UsageCallStatus = UsageCallStatus.SUCCESS,
@@ -136,9 +130,6 @@ class ModelUsageStatistic:
stat: 任务调用状态
ext_msg: 额外信息
"""
- if self.db is None:
- return # 如果没有数据库连接,则不记录使用情况
-
if not record_id:
logger.error("更新模型使用情况失败: record_id不能为空")
return
@@ -153,28 +144,27 @@ class ModelUsageStatistic:
total_tokens = usage_data[2] if usage_data else 0
try:
- self.db.llm_usage.update_one(
- {"_id": record_id},
- {
- "$set": {
- "status": stat.value,
- "ext_msg": ext_msg,
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- "total_tokens": total_tokens,
- "cost": self._calculate_cost(
- prompt_tokens, completion_tokens, model_info
- )
- if usage_data
- else 0.0,
- }
- },
- )
+ # 使用Peewee更新记录
+ update_query = LLMUsage.update(
+ status=stat.value,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ cost=self._calculate_cost(
+ prompt_tokens, completion_tokens, model_info
+ ) if usage_data else 0.0,
+ ).where(LLMUsage.id == record_id)
+
+ updated_count = update_query.execute()
+
+ if updated_count == 0:
+ logger.warning(f"记录ID {record_id} 不存在,无法更新")
+ return
- logger.trace(
+ logger.debug(
f"Token使用情况 - 模型: {model_info.name}, "
- f"记录ID: {record_id}, "
- f"任务状态: {stat.value}, 额外信息: {ext_msg if ext_msg else 'N/A'}, "
+ f"记录ID: {record_id}, "
+ f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
diff --git a/src/chat/maibot_llmreq/utils.py b/src/llm_models/utils.py
similarity index 98%
rename from src/chat/maibot_llmreq/utils.py
rename to src/llm_models/utils.py
index f8bf4fb39..352df5a43 100644
--- a/src/chat/maibot_llmreq/utils.py
+++ b/src/llm_models/utils.py
@@ -3,9 +3,11 @@ import io
from PIL import Image
-from . import _logger as logger
+from src.common.logger import get_logger
from .payload_content.message import Message, MessageBuilder
+logger = get_logger("消息压缩工具")
+
def compress_messages(
messages: list[Message], img_target_size: int = 1 * 1024 * 1024
diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py
index c994cd173..ff03b2788 100644
--- a/src/llm_models/utils_model.py
+++ b/src/llm_models/utils_model.py
@@ -1,26 +1,39 @@
-import asyncio
-import json
import re
from datetime import datetime
-from typing import Tuple, Union, Dict, Any, Callable
-import aiohttp
-from aiohttp.client import ClientResponse
+from typing import Tuple, Union, Dict, Any
from src.common.logger import get_logger
import base64
from PIL import Image
import io
-import os
import copy # 添加copy模块用于深拷贝
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from src.config.config import global_config
-from src.common.tcp_connector import get_tcp_connector
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("model_utils")
+# 新架构导入 - 使用延迟导入以支持fallback模式
+try:
+ from .model_manager import ModelManager
+ from .model_client import ModelRequestHandler
+ from .payload_content.message import MessageBuilder
+
+ # 不在模块级别初始化ModelManager,延迟到实际使用时
+ ModelManager_class = ModelManager
+ model_manager = None # 延迟初始化
+ NEW_ARCHITECTURE_AVAILABLE = True
+ logger.info("新架构模块导入成功")
+except Exception as e:
+ logger.warning(f"新架构不可用,将使用fallback模式: {str(e)}")
+ ModelManager_class = None
+ model_manager = None
+ ModelRequestHandler = None
+ MessageBuilder = None
+ NEW_ARCHITECTURE_AVAILABLE = False
+
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
@@ -36,10 +49,9 @@ class PayLoadTooLargeError(Exception):
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
- def __init__(self, message: str, response: ClientResponse):
+ def __init__(self, message: str):
super().__init__(message)
self.message = message
- self.response = response
def __str__(self):
return self.message
@@ -59,7 +71,7 @@ class PermissionDeniedException(Exception):
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
- 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
+ 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
@@ -82,19 +94,25 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
and isinstance(safe_payload, dict)
and "messages" in safe_payload
and len(safe_payload["messages"]) > 0
+ and isinstance(safe_payload["messages"][0], dict)
+ and "content" in safe_payload["messages"][0]
):
- if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
- content = safe_payload["messages"][0]["content"]
- if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
- # 只修改拷贝的对象,用于安全的日志记录
- safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
- f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
- f"{image_base64[:10]}...{image_base64[-10:]}"
- )
+ content = safe_payload["messages"][0]["content"]
+ if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
+ # 只修改拷贝的对象,用于安全的日志记录
+ safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
+ f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
+ f"{image_base64[:10]}...{image_base64[-10:]}"
+ )
return safe_payload
class LLMRequest:
+ """
+ 重构后的LLM请求类,基于新的model_manager和model_client架构
+ 保持向后兼容的API接口
+ """
+
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o1",
@@ -114,42 +132,78 @@ class LLMRequest:
]
def __init__(self, model: dict, **kwargs):
- # 将大写的配置键转换为小写并从config中获取实际值
- logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('name', 'Unknown')}")
+ """
+ 初始化LLM请求实例
+ Args:
+ model: 模型配置字典,兼容旧格式和新格式
+ **kwargs: 额外参数
+ """
+ logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('model_name', model.get('name', 'Unknown'))}")
logger.debug(f"🔍 [模型初始化] 模型配置: {model}")
logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}")
- try:
- # print(f"model['provider']: {model['provider']}")
- self.api_key = os.environ[f"{model['provider']}_KEY"]
- self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
- logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL")
- except AttributeError as e:
- logger.error(f"原始 model dict 信息:{model}")
- logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
- raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
- except KeyError:
- logger.warning(
- f"找不到{model['provider']}_KEY或{model['provider']}_BASE_URL环境变量,请检查配置文件或环境变量设置。"
- )
- self.model_name: str = model["name"]
- self.params = kwargs
-
- # 记录配置文件中声明了哪些参数(不管值是什么)
- self.has_enable_thinking = "enable_thinking" in model
- self.has_thinking_budget = "thinking_budget" in model
+ # 兼容新旧模型配置格式
+ # 新格式使用 model_name,旧格式使用 name
+ self.model_name: str = model.get("model_name", model.get("name", ""))
+ self.provider = model.get("provider", "")
+ # 从全局配置中获取任务配置
+ self.request_type = kwargs.pop("request_type", "default")
+
+ # 确定使用哪个任务配置
+ task_name = self._determine_task_name(model)
+
+ # 尝试初始化新架构
+ if NEW_ARCHITECTURE_AVAILABLE and ModelManager_class is not None:
+ try:
+ # 延迟初始化ModelManager
+ global model_manager
+ if model_manager is None:
+ from src.config.config import model_config
+ model_manager = ModelManager_class(model_config)
+ logger.debug("🔍 [模型初始化] ModelManager延迟初始化成功")
+
+ # 使用新架构获取模型请求处理器
+ self.request_handler = model_manager[task_name]
+ logger.debug(f"🔍 [模型初始化] 成功获取模型请求处理器,任务: {task_name}")
+ self.use_new_architecture = True
+ except Exception as e:
+ logger.warning(f"无法使用新架构,任务 {task_name} 初始化失败: {e}")
+ logger.warning("回退到兼容模式,某些功能可能受限")
+ self.request_handler = None
+ self.use_new_architecture = False
+ else:
+ logger.warning("新架构不可用,使用兼容模式")
+ logger.warning("回退到兼容模式,某些功能可能受限")
+ self.request_handler = None
+ self.use_new_architecture = False
+
+ # 保存原始参数用于向后兼容
+ self.params = kwargs
+
+ # 兼容性属性,从模型配置中提取
+ # 新格式和旧格式都支持
self.enable_thinking = model.get("enable_thinking", False)
- self.temp = model.get("temp", 0.7)
+ self.temp = model.get("temperature", model.get("temp", 0.7)) # 新格式用temperature,旧格式用temp
self.thinking_budget = model.get("thinking_budget", 4096)
self.stream = model.get("stream", False)
self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0)
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
- # print(f"max_tokens: {self.max_tokens}")
- logger.debug(f"🔍 [模型初始化] 模型参数设置完成:")
+ # 记录配置文件中声明了哪些参数(不管值是什么)
+ self.has_enable_thinking = "enable_thinking" in model
+ self.has_thinking_budget = "thinking_budget" in model
+ self.pri_out = model.get("pri_out", 0)
+ self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
+
+ # 记录配置文件中声明了哪些参数(不管值是什么)
+ self.has_enable_thinking = "enable_thinking" in model
+ self.has_thinking_budget = "thinking_budget" in model
+
+ logger.debug("🔍 [模型初始化] 模型参数设置完成:")
logger.debug(f" - model_name: {self.model_name}")
+ logger.debug(f" - provider: {self.provider}")
logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}")
logger.debug(f" - enable_thinking: {self.enable_thinking}")
logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}")
@@ -157,15 +211,40 @@ class LLMRequest:
logger.debug(f" - temp: {self.temp}")
logger.debug(f" - stream: {self.stream}")
logger.debug(f" - max_tokens: {self.max_tokens}")
- logger.debug(f" - base_url: {self.base_url}")
+ logger.debug(f" - use_new_architecture: {self.use_new_architecture}")
# 获取数据库实例
self._init_database()
-
- # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
- self.request_type = kwargs.pop("request_type", "default")
+
logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}")
+ def _determine_task_name(self, model: dict) -> str:
+ """
+ 根据模型配置确定任务名称
+ Args:
+ model: 模型配置字典
+ Returns:
+ 任务名称
+ """
+ # 兼容新旧格式的模型名称
+ model_name = model.get("model_name", model.get("name", ""))
+
+ # 根据模型名称推断任务类型
+ if any(keyword in model_name.lower() for keyword in ["vlm", "vision", "gpt-4o", "claude", "vl-"]):
+ return "vision"
+ elif any(keyword in model_name.lower() for keyword in ["embed", "text-embedding", "bge-"]):
+ return "embedding"
+ elif any(keyword in model_name.lower() for keyword in ["whisper", "speech", "voice"]):
+ return "speech"
+ else:
+ # 根据request_type确定,映射到配置文件中定义的任务
+ if self.request_type in ["memory", "emotion"]:
+ return "llm_normal" # 映射到配置中的llm_normal任务
+ elif self.request_type in ["reasoning"]:
+ return "llm_reasoning" # 映射到配置中的llm_reasoning任务
+ else:
+ return "llm_normal" # 默认使用llm_normal任务
+
@staticmethod
def _init_database():
"""初始化数据库集合"""
@@ -237,660 +316,6 @@ class LLMRequest:
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)
- async def _prepare_request(
- self,
- endpoint: str,
- prompt: str = None,
- image_base64: str = None,
- image_format: str = None,
- file_bytes: bytes = None,
- file_format: str = None,
- payload: dict = None,
- retry_policy: dict = None,
- ) -> Dict[str, Any]:
- """配置请求参数
- Args:
- endpoint: API端点路径 (如 "chat/completions")
- prompt: prompt文本
- image_base64: 图片的base64编码
- image_format: 图片格式
- file_bytes: 文件的二进制数据
- file_format: 文件格式
- payload: 请求体数据
- retry_policy: 自定义重试策略
- request_type: 请求类型
- """
-
- # 合并重试策略
- default_retry = {
- "max_retries": 3,
- "base_wait": 10,
- "retry_codes": [429, 413, 500, 503],
- "abort_codes": [400, 401, 402, 403],
- }
- policy = {**default_retry, **(retry_policy or {})}
-
- api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
-
- stream_mode = self.stream
-
- # 构建请求体
- if image_base64:
- payload = await self._build_payload(prompt, image_base64, image_format)
- elif file_bytes:
- payload = await self._build_formdata_payload(file_bytes, file_format)
- elif payload is None:
- payload = await self._build_payload(prompt)
-
- if not file_bytes:
- if stream_mode:
- payload["stream"] = stream_mode
-
- if self.temp != 0.7:
- payload["temperature"] = self.temp
-
- # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false)
- if self.has_enable_thinking:
- payload["enable_thinking"] = self.enable_thinking
-
- # 添加thinking_budget参数(只有配置文件中声明了才添加)
- if self.has_thinking_budget:
- payload["thinking_budget"] = self.thinking_budget
-
- if self.max_tokens:
- payload["max_tokens"] = self.max_tokens
-
- # if "max_tokens" not in payload and "max_completion_tokens" not in payload:
- # payload["max_tokens"] = global_config.model.model_max_output_length
- # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
- payload["max_completion_tokens"] = payload.pop("max_tokens")
-
- return {
- "policy": policy,
- "payload": payload,
- "api_url": api_url,
- "stream_mode": stream_mode,
- "image_base64": image_base64, # 保留必要的exception处理所需的原始数据
- "image_format": image_format,
- "file_bytes": file_bytes,
- "file_format": file_format,
- "prompt": prompt,
- }
-
- async def _execute_request(
- self,
- endpoint: str,
- prompt: str = None,
- image_base64: str = None,
- image_format: str = None,
- file_bytes: bytes = None,
- file_format: str = None,
- payload: dict = None,
- retry_policy: dict = None,
- response_handler: Callable = None,
- user_id: str = "system",
- request_type: str = None,
- ):
- """统一请求执行入口
- Args:
- endpoint: API端点路径 (如 "chat/completions")
- prompt: prompt文本
- image_base64: 图片的base64编码
- image_format: 图片格式
- file_bytes: 文件的二进制数据
- file_format: 文件格式
- payload: 请求体数据
- retry_policy: 自定义重试策略
- response_handler: 自定义响应处理器
- user_id: 用户ID
- request_type: 请求类型
- """
- # 获取请求配置
- request_content = await self._prepare_request(
- endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy
- )
- if request_type is None:
- request_type = self.request_type
- for retry in range(request_content["policy"]["max_retries"]):
- try:
- # 使用上下文管理器处理会话
- if file_bytes:
- headers = await self._build_headers(is_formdata=True)
- else:
- headers = await self._build_headers(is_formdata=False)
- # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
- if request_content["stream_mode"]:
- headers["Accept"] = "text/event-stream"
-
- # 添加请求发送前的调试信息
- logger.debug(f"🔍 [请求调试] 模型 {self.model_name} 准备发送请求")
- logger.debug(f"🔍 [请求调试] API URL: {request_content['api_url']}")
- logger.debug(f"🔍 [请求调试] 请求头: {await self._build_headers(no_key=True, is_formdata=file_bytes is not None)}")
-
- if not file_bytes:
- # 安全地记录请求体(隐藏敏感信息)
- safe_payload = await _safely_record(request_content, request_content["payload"])
- logger.debug(f"🔍 [请求调试] 请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}")
- else:
- logger.debug(f"🔍 [请求调试] 文件上传请求,文件格式: {request_content['file_format']}")
-
- async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
- post_kwargs = {"headers": headers}
- # form-data数据上传方式不同
- if file_bytes:
- post_kwargs["data"] = request_content["payload"]
- else:
- post_kwargs["json"] = request_content["payload"]
-
- async with session.post(request_content["api_url"], **post_kwargs) as response:
- handled_result = await self._handle_response(
- response, request_content, retry, response_handler, user_id, request_type, endpoint
- )
- return handled_result
-
- except Exception as e:
- handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
- retry += count_delta # 降级不计入重试次数
- if handled_payload:
- # 如果降级成功,重新构建请求体
- request_content["payload"] = handled_payload
- continue
-
- logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
- raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
-
- async def _handle_response(
- self,
- response: ClientResponse,
- request_content: Dict[str, Any],
- retry_count: int,
- response_handler: Callable,
- user_id,
- request_type,
- endpoint,
- ):
- policy = request_content["policy"]
- stream_mode = request_content["stream_mode"]
- if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
- await self._handle_error_response(response, retry_count, policy)
- return None
-
- response.raise_for_status()
- result = {}
- if stream_mode:
- # 将流式输出转化为非流式输出
- result = await self._handle_stream_output(response)
- else:
- result = await response.json()
- return (
- response_handler(result)
- if response_handler
- else self._default_response_handler(result, user_id, request_type, endpoint)
- )
-
- async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
- flag_delta_content_finished = False
- accumulated_content = ""
- usage = None # 初始化usage变量,避免未定义错误
- reasoning_content = ""
- content = ""
- tool_calls = None # 初始化工具调用变量
-
- async for line_bytes in response.content:
- try:
- line = line_bytes.decode("utf-8").strip()
- if not line:
- continue
- if line.startswith("data:"):
- data_str = line[5:].strip()
- if data_str == "[DONE]":
- break
- try:
- chunk = json.loads(data_str)
- if flag_delta_content_finished:
- chunk_usage = chunk.get("usage", None)
- if chunk_usage:
- usage = chunk_usage # 获取token用量
- else:
- delta = chunk["choices"][0]["delta"]
- delta_content = delta.get("content")
- if delta_content is None:
- delta_content = ""
- accumulated_content += delta_content
-
- # 提取工具调用信息
- if "tool_calls" in delta:
- if tool_calls is None:
- tool_calls = delta["tool_calls"]
- else:
- # 合并工具调用信息
- tool_calls.extend(delta["tool_calls"])
-
- # 检测流式输出文本是否结束
- finish_reason = chunk["choices"][0].get("finish_reason")
- if delta.get("reasoning_content", None):
- reasoning_content += delta["reasoning_content"]
- if finish_reason == "stop" or finish_reason == "tool_calls":
- chunk_usage = chunk.get("usage", None)
- if chunk_usage:
- usage = chunk_usage
- break
- # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
- flag_delta_content_finished = True
- except Exception as e:
- logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
- except Exception as e:
- if isinstance(e, GeneratorExit):
- log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
- else:
- log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
- logger.warning(log_content)
- # 确保资源被正确清理
- try:
- await response.release()
- except Exception as cleanup_error:
- logger.error(f"清理资源时发生错误: {cleanup_error}")
- # 返回已经累积的内容
- content = accumulated_content
- if not content:
- content = accumulated_content
- think_match = re.search(r"(.*?)", content, re.DOTALL)
- if think_match:
- reasoning_content = think_match.group(1).strip()
- content = re.sub(r".*?", "", content, flags=re.DOTALL).strip()
-
- # 构建消息对象
- message = {
- "content": content,
- "reasoning_content": reasoning_content,
- }
-
- # 如果有工具调用,添加到消息中
- if tool_calls:
- message["tool_calls"] = tool_calls
-
- result = {
- "choices": [{"message": message}],
- "usage": usage,
- }
- return result
-
- async def _handle_error_response(self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]):
- if response.status in policy["retry_codes"]:
- wait_time = policy["base_wait"] * (2**retry_count)
- logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
- if response.status == 413:
- logger.warning("请求体过大,尝试压缩...")
- raise PayLoadTooLargeError("请求体过大")
- elif response.status in [500, 503]:
- logger.error(
- f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
- )
- raise RuntimeError("服务器负载过高,模型回复失败QAQ")
- else:
- logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
- raise RuntimeError("请求限制(429)")
- elif response.status in policy["abort_codes"]:
- # 特别处理400错误,添加详细调试信息
- if response.status == 400:
- logger.error(f"🔍 [调试信息] 模型 {self.model_name} 参数错误 (400) - 开始详细诊断")
- logger.error(f"🔍 [调试信息] 模型名称: {self.model_name}")
- logger.error(f"🔍 [调试信息] API地址: {self.base_url}")
- logger.error(f"🔍 [调试信息] 模型配置参数:")
- logger.error(f" - enable_thinking: {self.enable_thinking}")
- logger.error(f" - temp: {self.temp}")
- logger.error(f" - thinking_budget: {self.thinking_budget}")
- logger.error(f" - stream: {self.stream}")
- logger.error(f" - max_tokens: {self.max_tokens}")
- logger.error(f" - pri_in: {self.pri_in}")
- logger.error(f" - pri_out: {self.pri_out}")
- logger.error(f"🔍 [调试信息] 原始params: {self.params}")
-
- # 尝试获取服务器返回的详细错误信息
- try:
- error_text = await response.text()
- logger.error(f"🔍 [调试信息] 服务器返回的原始错误内容: {error_text}")
-
- try:
- error_json = json.loads(error_text)
- logger.error(f"🔍 [调试信息] 解析后的错误JSON: {json.dumps(error_json, indent=2, ensure_ascii=False)}")
- except json.JSONDecodeError:
- logger.error(f"🔍 [调试信息] 错误响应不是有效的JSON格式")
- except Exception as e:
- logger.error(f"🔍 [调试信息] 无法读取错误响应内容: {str(e)}")
-
- raise RequestAbortException("参数错误,请检查调试信息", response)
- elif response.status != 403:
- raise RequestAbortException("请求出现错误,中断处理", response)
- else:
- raise PermissionDeniedException("模型禁止访问")
-
- async def _handle_exception(
- self, exception, retry_count: int, request_content: Dict[str, Any]
- ) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
- policy = request_content["policy"]
- payload = request_content["payload"]
- wait_time = policy["base_wait"] * (2**retry_count)
- keep_request = False
- if retry_count < policy["max_retries"] - 1:
- keep_request = True
- if isinstance(exception, RequestAbortException):
- response = exception.response
- logger.error(
- f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
- )
-
- # 如果是400错误,额外输出请求体信息用于调试
- if response.status == 400:
- logger.error(f"🔍 [异常调试] 400错误 - 请求体调试信息:")
- try:
- safe_payload = await _safely_record(request_content, payload)
- logger.error(f"🔍 [异常调试] 发送的请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}")
- except Exception as debug_error:
- logger.error(f"🔍 [异常调试] 无法安全记录请求体: {str(debug_error)}")
- logger.error(f"🔍 [异常调试] 原始payload类型: {type(payload)}")
- if isinstance(payload, dict):
- logger.error(f"🔍 [异常调试] 原始payload键: {list(payload.keys())}")
-
- # print(request_content)
- # print(response)
- # 尝试获取并记录服务器返回的详细错误信息
- try:
- error_json = await response.json()
- if error_json and isinstance(error_json, list) and len(error_json) > 0:
- # 处理多个错误的情况
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj: dict = error_item["error"]
- error_code = error_obj.get("code")
- error_message = error_obj.get("message")
- error_status = error_obj.get("status")
- logger.error(
- f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- # 处理单个错误对象的情况
- error_obj = error_json.get("error", {})
- error_code = error_obj.get("code")
- error_message = error_obj.get("message")
- error_status = error_obj.get("status")
- logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}")
- else:
- # 记录原始错误响应内容
- logger.error(f"服务器错误响应: {error_json}")
- except Exception as e:
- logger.warning(f"无法解析服务器错误响应: {str(e)}")
- raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
-
- elif isinstance(exception, PermissionDeniedException):
- # 只针对硅基流动的V3和R1进行降级处理
- if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
- old_model_name = self.model_name
- self.model_name = self.model_name[4:] # 移除"Pro/"前缀
- logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
-
- # 对全局配置进行更新
- if global_config.model.replyer_2.get("name") == old_model_name:
- global_config.model.replyer_2["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
- if global_config.model.replyer_1.get("name") == old_model_name:
- global_config.model.replyer_1["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
-
- if payload and "model" in payload:
- payload["model"] = self.model_name
-
- await asyncio.sleep(wait_time)
- return payload, -1
- raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}")
-
- elif isinstance(exception, PayLoadTooLargeError):
- if keep_request:
- image_base64 = request_content["image_base64"]
- compressed_image_base64 = compress_base64_image_by_scale(image_base64)
- new_payload = await self._build_payload(
- request_content["prompt"], compressed_image_base64, request_content["image_format"]
- )
- return new_payload, 0
- else:
- return None, 0
-
- elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError):
- if keep_request:
- logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}")
- await asyncio.sleep(wait_time)
- return None, 0
- else:
- logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
- raise RuntimeError(f"网络请求失败: {str(exception)}")
-
- elif isinstance(exception, aiohttp.ClientResponseError):
- # 处理aiohttp抛出的,除了policy中的status的响应错误
- if keep_request:
- logger.error(
- f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}"
- )
- try:
- error_text = await exception.response.text()
- error_json = json.loads(error_text)
- if isinstance(error_json, list) and len(error_json) > 0:
- # 处理多个错误的情况
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj = error_item["error"]
- logger.error(
- f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
- f"状态={error_obj.get('status')}, "
- f"消息={error_obj.get('message')}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- error_obj = error_json.get("error", {})
- logger.error(
- f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
- f"状态={error_obj.get('status')}, "
- f"消息={error_obj.get('message')}"
- )
- else:
- logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
- except (json.JSONDecodeError, TypeError) as json_err:
- logger.warning(
- f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
- )
- except Exception as parse_err:
- logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
-
- await asyncio.sleep(wait_time)
- return None, 0
- else:
- logger.critical(
- f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
- )
- # 安全地检查和记录请求详情
- handled_payload = await _safely_record(request_content, payload)
- logger.critical(
- f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}"
- )
- raise RuntimeError(
- f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
- )
-
- else:
- if keep_request:
- logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}")
- await asyncio.sleep(wait_time)
- return None, 0
- else:
- logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
- # 安全地检查和记录请求详情
- handled_payload = await _safely_record(request_content, payload)
- logger.critical(
- f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}"
- )
- raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
-
- async def _transform_parameters(self, params: dict) -> dict:
- """
- 根据模型名称转换参数:
- - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
- 并将 'max_tokens' 重命名为 'max_completion_tokens'
- """
- # 复制一份参数,避免直接修改原始数据
- new_params = dict(params)
-
- logger.debug(f"🔍 [参数转换] 模型 {self.model_name} 开始参数转换")
- logger.debug(f"🔍 [参数转换] 是否为CoT模型: {self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION}")
- logger.debug(f"🔍 [参数转换] CoT模型列表: {self.MODELS_NEEDING_TRANSFORMATION}")
-
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
- logger.debug(f"🔍 [参数转换] 检测到CoT模型,开始参数转换")
- # 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度
- if "temperature" in new_params and new_params["temperature"] == 0.7:
- removed_temp = new_params.pop("temperature")
- logger.debug(f"🔍 [参数转换] 移除默认temperature参数: {removed_temp}")
- # 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
- if "max_tokens" in new_params:
- old_value = new_params["max_tokens"]
- new_params["max_completion_tokens"] = new_params.pop("max_tokens")
- logger.debug(f"🔍 [参数转换] 参数重命名: max_tokens({old_value}) -> max_completion_tokens({new_params['max_completion_tokens']})")
- else:
- logger.debug(f"🔍 [参数转换] 非CoT模型,无需参数转换")
-
- logger.debug(f"🔍 [参数转换] 转换前参数: {params}")
- logger.debug(f"🔍 [参数转换] 转换后参数: {new_params}")
- return new_params
-
- async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData:
- """构建form-data请求体"""
- # 目前只适配了音频文件
- # 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑
- data = aiohttp.FormData()
- content_type_list = {
- "wav": "audio/wav",
- "mp3": "audio/mpeg",
- "ogg": "audio/ogg",
- "flac": "audio/flac",
- "aac": "audio/aac",
- }
-
- content_type = content_type_list.get(file_format)
- if not content_type:
- logger.warning(f"暂不支持的文件类型: {file_format}")
-
- data.add_field(
- "file",
- io.BytesIO(file_bytes),
- filename=f"file.{file_format}",
- content_type=f"{content_type}", # 根据实际文件类型设置
- )
- data.add_field("model", self.model_name)
- return data
-
- async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
- """构建请求体"""
- # 复制一份参数,避免直接修改 self.params
- logger.debug(f"🔍 [参数构建] 模型 {self.model_name} 开始构建请求体")
- logger.debug(f"🔍 [参数构建] 原始self.params: {self.params}")
-
- params_copy = await self._transform_parameters(self.params)
- logger.debug(f"🔍 [参数构建] 转换后的params_copy: {params_copy}")
-
- if image_base64:
- messages = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": prompt},
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
- },
- ],
- }
- ]
- else:
- messages = [{"role": "user", "content": prompt}]
-
- payload = {
- "model": self.model_name,
- "messages": messages,
- **params_copy,
- }
-
- logger.debug(f"🔍 [参数构建] 基础payload构建完成: {list(payload.keys())}")
-
- # 添加temp参数(如果不是默认值0.7)
- if self.temp != 0.7:
- payload["temperature"] = self.temp
- logger.debug(f"🔍 [参数构建] 添加temperature参数: {self.temp}")
-
- # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false)
- if self.has_enable_thinking:
- payload["enable_thinking"] = self.enable_thinking
- logger.debug(f"🔍 [参数构建] 添加enable_thinking参数: {self.enable_thinking}")
-
- # 添加thinking_budget参数(只有配置文件中声明了才添加)
- if self.has_thinking_budget:
- payload["thinking_budget"] = self.thinking_budget
- logger.debug(f"🔍 [参数构建] 添加thinking_budget参数: {self.thinking_budget}")
-
- if self.max_tokens:
- payload["max_tokens"] = self.max_tokens
- logger.debug(f"🔍 [参数构建] 添加max_tokens参数: {self.max_tokens}")
-
- # if "max_tokens" not in payload and "max_completion_tokens" not in payload:
- # payload["max_tokens"] = global_config.model.model_max_output_length
- # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
- old_value = payload["max_tokens"]
- payload["max_completion_tokens"] = payload.pop("max_tokens")
- logger.debug(f"🔍 [参数构建] CoT模型参数转换: max_tokens({old_value}) -> max_completion_tokens({payload['max_completion_tokens']})")
-
- logger.debug(f"🔍 [参数构建] 最终payload键列表: {list(payload.keys())}")
- return payload
-
- def _default_response_handler(
- self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
- ) -> Tuple:
- """默认响应解析"""
- if "choices" in result and result["choices"]:
- message = result["choices"][0]["message"]
- content = message.get("content", "")
- content, reasoning = self._extract_reasoning(content)
- reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
- if not reasoning_content:
- reasoning_content = message.get("reasoning_content", "")
- if not reasoning_content:
- reasoning_content = reasoning
-
- # 提取工具调用信息
- tool_calls = message.get("tool_calls", None)
-
- # 记录token使用情况
- usage = result.get("usage", {})
- if usage:
- prompt_tokens = usage.get("prompt_tokens", 0)
- completion_tokens = usage.get("completion_tokens", 0)
- total_tokens = usage.get("total_tokens", 0)
- self._record_usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- user_id=user_id,
- request_type=request_type if request_type is not None else self.request_type,
- endpoint=endpoint,
- )
-
- # 只有当tool_calls存在且不为空时才返回
- if tool_calls:
- logger.debug(f"检测到工具调用: {tool_calls}")
- return content, reasoning_content, tool_calls
- else:
- return content, reasoning_content
- elif "text" in result and result["text"]:
- return result["text"]
- return "没有返回结果", ""
-
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取"""
@@ -902,61 +327,183 @@ class LLMRequest:
reasoning = ""
return content, reasoning
- async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict:
- """构建请求头"""
- if no_key:
- if is_formdata:
- return {"Authorization": "Bearer **********"}
- return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
- else:
- if is_formdata:
- return {"Authorization": f"Bearer {self.api_key}"}
- return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
- # 防止小朋友们截图自己的key
+ # === 主要API方法 ===
+ # 这些方法提供与新架构的桥接
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple:
- """根据输入的提示和图片生成模型的异步响应"""
-
- response = await self._execute_request(
- endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format
- )
- # 根据返回值的长度决定怎么处理
- if len(response) == 3:
- content, reasoning_content, tool_calls = response
- return content, reasoning_content, tool_calls
- else:
- content, reasoning_content = response
- return content, reasoning_content
+ """
+ 根据输入的提示和图片生成模型的异步响应
+ 使用新架构的模型请求处理器
+ """
+ if not self.use_new_architecture:
+ raise RuntimeError(
+ f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
+ )
+
+ if MessageBuilder is None:
+ raise RuntimeError("MessageBuilder不可用,请检查新架构配置")
+
+ try:
+ # 构建包含图片的消息
+ message_builder = MessageBuilder()
+ message_builder.add_text_content(prompt).add_image_content(
+ image_format=image_format,
+ image_base64=image_base64
+ )
+ messages = [message_builder.build()]
+
+ # 使用新架构发送请求(只传递支持的参数)
+ response = await self.request_handler.get_response(
+ messages=messages,
+ tool_options=None,
+ response_format=None
+ )
+
+ # 新架构返回的是 APIResponse 对象,直接提取内容
+ content = response.content or ""
+ reasoning_content = response.reasoning_content or ""
+ tool_calls = response.tool_calls
+
+ # 从内容中提取标签的推理内容(向后兼容)
+ if not reasoning_content and content:
+ content, extracted_reasoning = self._extract_reasoning(content)
+ reasoning_content = extracted_reasoning
+
+ # 记录token使用情况
+ if response.usage:
+ self._record_usage(
+ prompt_tokens=response.usage.prompt_tokens or 0,
+ completion_tokens=response.usage.completion_tokens or 0,
+ total_tokens=response.usage.total_tokens or 0,
+ user_id="system",
+ request_type=self.request_type,
+ endpoint="/chat/completions"
+ )
+
+ # 返回格式兼容旧版本
+ if tool_calls:
+ return content, reasoning_content, tool_calls
+ else:
+ return content, reasoning_content
+
+ except Exception as e:
+ logger.error(f"模型 {self.model_name} 图片响应生成失败: {str(e)}")
+ # 向后兼容的异常处理
+ if "401" in str(e) or "API key" in str(e):
+ raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e
+ elif "429" in str(e):
+ raise RuntimeError("请求过于频繁,请稍后再试") from e
+ elif "500" in str(e) or "503" in str(e):
+ raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e
+ else:
+ raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(e)}") from e
async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
- """根据输入的语音文件生成模型的异步响应"""
- response = await self._execute_request(
- endpoint="/audio/transcriptions", file_bytes=voice_bytes, file_format="wav"
- )
- return response
+ """
+ 根据输入的语音文件生成模型的异步响应
+ 使用新架构的模型请求处理器
+ """
+ if not self.use_new_architecture:
+ raise RuntimeError(
+ f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
+ )
+
+ try:
+ # 构建语音识别请求参数
+ # 注意:新架构中的语音识别可能使用不同的方法
+ # 这里先使用get_response方法,可能需要根据实际API调整
+ response = await self.request_handler.get_response(
+ messages=[], # 语音识别可能不需要消息
+ tool_options=None
+ )
+
+ # 新架构返回的是 APIResponse 对象,直接提取文本内容
+ if response.content:
+ return response.content
+ else:
+ return ""
+
+ except Exception as e:
+ logger.error(f"模型 {self.model_name} 语音识别失败: {str(e)}")
+ # 向后兼容的异常处理
+ if "401" in str(e) or "API key" in str(e):
+ raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e
+ elif "429" in str(e):
+ raise RuntimeError("请求过于频繁,请稍后再试") from e
+ elif "500" in str(e) or "503" in str(e):
+ raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e
+ else:
+ raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(e)}") from e
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
- """异步方式根据输入的提示生成模型的响应"""
- # 构建请求体,不硬编码max_tokens
- data = {
- "model": self.model_name,
- "messages": [{"role": "user", "content": prompt}],
- **self.params,
- **kwargs,
- }
-
- response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
- # 原样返回响应,不做处理
-
- if len(response) == 3:
- content, reasoning_content, tool_calls = response
- return content, (reasoning_content, self.model_name, tool_calls)
- else:
- content, reasoning_content = response
- return content, (reasoning_content, self.model_name)
+ """
+ 异步方式根据输入的提示生成模型的响应
+ 使用新架构的模型请求处理器,如无法使用则抛出错误
+ """
+ if not self.use_new_architecture:
+ raise RuntimeError(
+ f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
+ )
+
+ if MessageBuilder is None:
+ raise RuntimeError("MessageBuilder不可用,请检查新架构配置")
+
+ try:
+ # 构建消息
+ message_builder = MessageBuilder()
+ message_builder.add_text_content(prompt)
+ messages = [message_builder.build()]
+
+ # 使用新架构发送请求(只传递支持的参数)
+ response = await self.request_handler.get_response(
+ messages=messages,
+ tool_options=None,
+ response_format=None
+ )
+
+ # 新架构返回的是 APIResponse 对象,直接提取内容
+ content = response.content or ""
+ reasoning_content = response.reasoning_content or ""
+ tool_calls = response.tool_calls
+
+ # 从内容中提取标签的推理内容(向后兼容)
+ if not reasoning_content and content:
+ content, extracted_reasoning = self._extract_reasoning(content)
+ reasoning_content = extracted_reasoning
+
+ # 记录token使用情况
+ if response.usage:
+ self._record_usage(
+ prompt_tokens=response.usage.prompt_tokens or 0,
+ completion_tokens=response.usage.completion_tokens or 0,
+ total_tokens=response.usage.total_tokens or 0,
+ user_id="system",
+ request_type=self.request_type,
+ endpoint="/chat/completions"
+ )
+
+ # 返回格式兼容旧版本
+ if tool_calls:
+ return content, (reasoning_content, self.model_name, tool_calls)
+ else:
+ return content, (reasoning_content, self.model_name)
+
+ except Exception as e:
+ logger.error(f"模型 {self.model_name} 生成响应失败: {str(e)}")
+ # 向后兼容的异常处理
+ if "401" in str(e) or "API key" in str(e):
+ raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e
+ elif "429" in str(e):
+ raise RuntimeError("请求过于频繁,请稍后再试") from e
+ elif "500" in str(e) or "503" in str(e):
+ raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e
+ else:
+ raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(e)}") from e
async def get_embedding(self, text: str) -> Union[list, None]:
- """异步方法:获取文本的embedding向量
+ """
+ 异步方法:获取文本的embedding向量
+ 使用新架构的模型请求处理器
Args:
text: 需要获取embedding的文本
@@ -964,42 +511,51 @@ class LLMRequest:
Returns:
list: embedding向量,如果失败则返回None
"""
-
if len(text) < 1:
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
return None
- def embedding_handler(result):
- """处理响应"""
- if "data" in result and len(result["data"]) > 0:
- # 提取 token 使用信息
- usage = result.get("usage", {})
- if usage:
- prompt_tokens = usage.get("prompt_tokens", 0)
- completion_tokens = usage.get("completion_tokens", 0)
- total_tokens = usage.get("total_tokens", 0)
- # 记录 token 使用情况
- self._record_usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- user_id="system", # 可以根据需要修改 user_id
- # request_type="embedding", # 请求类型为 embedding
- request_type=self.request_type, # 请求类型为 text
- endpoint="/embeddings", # API 端点
- )
- return result["data"][0].get("embedding", None)
- return result["data"][0].get("embedding", None)
+ if not self.use_new_architecture:
+ logger.warning(f"模型 {self.model_name} 无法使用新架构,embedding请求将被跳过")
return None
- embedding = await self._execute_request(
- endpoint="/embeddings",
- prompt=text,
- payload={"model": self.model_name, "input": text, "encoding_format": "float"},
- retry_policy={"max_retries": 2, "base_wait": 6},
- response_handler=embedding_handler,
- )
- return embedding
+ try:
+ # 构建embedding请求参数
+ # 使用新架构的get_embedding方法
+ response = await self.request_handler.get_embedding(text)
+
+ # 新架构返回的是 APIResponse 对象,直接提取embedding
+ if response.embedding:
+ embedding = response.embedding
+
+ # 记录token使用情况
+ if response.usage:
+ self._record_usage(
+ prompt_tokens=response.usage.prompt_tokens or 0,
+ completion_tokens=response.usage.completion_tokens or 0,
+ total_tokens=response.usage.total_tokens or 0,
+ user_id="system",
+ request_type=self.request_type,
+ endpoint="/embeddings"
+ )
+
+ return embedding
+ else:
+ logger.warning(f"模型 {self.model_name} 返回的embedding响应为空")
+ return None
+
+ except Exception as e:
+ logger.error(f"模型 {self.model_name} 获取embedding失败: {str(e)}")
+ # 向后兼容的异常处理
+ if "401" in str(e) or "API key" in str(e):
+ raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e
+ elif "429" in str(e):
+ raise RuntimeError("请求过于频繁,请稍后再试") from e
+ elif "500" in str(e) or "503" in str(e):
+ raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e
+ else:
+ logger.warning(f"模型 {self.model_name} embedding请求失败,返回None: {str(e)}")
+ return None
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
diff --git a/template/compare/model_config_template.toml b/template/compare/model_config_template.toml
new file mode 100644
index 000000000..f9055fcea
--- /dev/null
+++ b/template/compare/model_config_template.toml
@@ -0,0 +1,77 @@
+[inner]
+version = "0.1.0"
+
+# 配置文件版本号迭代规则同bot_config.toml
+
+[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释)
+#max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数)
+#timeout = 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)
+#retry_interval = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒)
+#default_temperature = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值)
+#default_max_tokens = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值)
+
+
+[[api_providers]] # API服务提供商(可以配置多个)
+name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名)
+base_url = "https://api.deepseek.cn" # API服务商的BaseURL
+key = "******" # API Key (可选,默认为None)
+client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"google")
+
+#[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"google"
+#name = "Google"
+#base_url = "https://api.google.com"
+#key = "******"
+#client_type = "google"
+#
+#[[api_providers]]
+#name = "SiliconFlow"
+#base_url = "https://api.siliconflow.cn"
+#key = "******"
+#
+#[[api_providers]]
+#name = "LocalHost"
+#base_url = "https://localhost:8888"
+#key = "lm-studio"
+
+
+[[models]] # 模型(可以配置多个)
+# 模型标识符(API服务商提供的模型标识符)
+model_identifier = "deepseek-chat"
+# 模型名称(可随意命名,在bot_config.toml中需使用这个命名)
+#(可选,若无该字段,则将自动使用model_identifier填充)
+name = "deepseek-v3"
+# API服务商名称(对应在api_providers中配置的服务商名称)
+api_provider = "DeepSeek"
+# 输入价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0)
+price_in = 2.0
+# 输出价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0)
+price_out = 8.0
+# 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出)
+#(可选,若无该字段,默认值为false)
+#force_stream_mode = true
+
+#[[models]]
+#model_identifier = "deepseek-reasoner"
+#name = "deepseek-r1"
+#api_provider = "DeepSeek"
+#model_flags = ["text", "tool_calling", "reasoning"]
+#price_in = 4.0
+#price_out = 16.0
+#
+#[[models]]
+#model_identifier = "BAAI/bge-m3"
+#name = "siliconflow-bge-m3"
+#api_provider = "SiliconFlow"
+#model_flags = ["text", "embedding"]
+#price_in = 0
+#price_out = 0
+
+
+[task_model_usage]
+#llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0}
+#llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0}
+#embedding = "siliconflow-bge-m3"
+#schedule = [
+# "deepseek-v3",
+# "deepseek-r1",
+#]
\ No newline at end of file