初步重构llmrequest
This commit is contained in:
111
debug_config.py
Normal file
111
debug_config.py
Normal file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
调试配置加载问题,查看API provider的配置是否正确传递
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
def debug_config_loading():
|
||||
try:
|
||||
# 临时配置API key
|
||||
import toml
|
||||
config_path = "config/model_config.toml"
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = toml.load(f)
|
||||
|
||||
original_keys = {}
|
||||
for provider in config['api_providers']:
|
||||
original_keys[provider['name']] = provider['api_key']
|
||||
provider['api_key'] = f"sk-test-key-for-{provider['name'].lower()}-12345"
|
||||
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
toml.dump(config, f)
|
||||
|
||||
print("✅ 配置了测试API key")
|
||||
|
||||
try:
|
||||
# 清空缓存
|
||||
modules_to_remove = [
|
||||
'src.config.config',
|
||||
'src.config.api_ada_configs',
|
||||
'src.llm_models.model_manager',
|
||||
'src.llm_models.model_client',
|
||||
'src.llm_models.utils_model'
|
||||
]
|
||||
for module in modules_to_remove:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
|
||||
# 导入配置
|
||||
from src.config.config import model_config
|
||||
print("\n🔍 调试配置加载:")
|
||||
print(f"model_config类型: {type(model_config)}")
|
||||
|
||||
# 检查API providers
|
||||
if hasattr(model_config, 'api_providers'):
|
||||
print(f"API providers数量: {len(model_config.api_providers)}")
|
||||
for name, provider in model_config.api_providers.items():
|
||||
print(f" - {name}: {provider.base_url}")
|
||||
print(f" API key: {provider.api_key[:10]}...{provider.api_key[-5:] if len(provider.api_key) > 15 else provider.api_key}")
|
||||
print(f" Client type: {provider.client_type}")
|
||||
|
||||
# 检查模型配置
|
||||
if hasattr(model_config, 'models'):
|
||||
print(f"模型数量: {len(model_config.models)}")
|
||||
for name, model in model_config.models.items():
|
||||
print(f" - {name}: {model.model_identifier} (提供商: {model.api_provider})")
|
||||
|
||||
# 检查任务配置
|
||||
if hasattr(model_config, 'task_model_arg_map'):
|
||||
print(f"任务配置数量: {len(model_config.task_model_arg_map)}")
|
||||
for task_name, task_config in model_config.task_model_arg_map.items():
|
||||
print(f" - {task_name}: {task_config}")
|
||||
|
||||
# 尝试初始化ModelManager
|
||||
print("\n🔍 调试ModelManager初始化:")
|
||||
from src.llm_models.model_manager import ModelManager
|
||||
|
||||
try:
|
||||
model_manager = ModelManager(model_config)
|
||||
print("✅ ModelManager初始化成功")
|
||||
|
||||
# 检查API客户端映射
|
||||
print(f"API客户端数量: {len(model_manager.api_client_map)}")
|
||||
for name, client in model_manager.api_client_map.items():
|
||||
print(f" - {name}: {type(client).__name__}")
|
||||
if hasattr(client, 'client') and hasattr(client.client, 'api_key'):
|
||||
api_key = client.client.api_key
|
||||
print(f" Client API key: {api_key[:10]}...{api_key[-5:] if len(api_key) > 15 else api_key}")
|
||||
|
||||
# 尝试获取任务处理器
|
||||
try:
|
||||
handler = model_manager["llm_normal"]
|
||||
print("✅ 成功获取llm_normal任务处理器")
|
||||
print(f"任务处理器类型: {type(handler).__name__}")
|
||||
except Exception as e:
|
||||
print(f"❌ 获取任务处理器失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ModelManager初始化失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
# 恢复配置
|
||||
for provider in config['api_providers']:
|
||||
provider['api_key'] = original_keys[provider['name']]
|
||||
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
toml.dump(config, f)
|
||||
print("\n✅ 配置已恢复")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 调试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
debug_config_loading()
|
||||
@@ -1,19 +0,0 @@
|
||||
import loguru
|
||||
|
||||
type LoguruLogger = loguru.Logger
|
||||
|
||||
_logger: LoguruLogger = loguru.logger
|
||||
|
||||
|
||||
def init_logger(
|
||||
logger: LoguruLogger | None = None,
|
||||
):
|
||||
"""
|
||||
对LLMRequest模块进行配置
|
||||
:param logger: 日志对象
|
||||
"""
|
||||
global _logger # 申明使用全局变量
|
||||
if logger:
|
||||
_logger = logger
|
||||
else:
|
||||
_logger.warning("Warning: No logger provided, using default logger.")
|
||||
@@ -1,267 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tomli
|
||||
from packaging import version
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import Version, InvalidVersion
|
||||
|
||||
from .. import _logger as logger
|
||||
|
||||
from .config import (
|
||||
ModelUsageArgConfigItem,
|
||||
ModelUsageArgConfig,
|
||||
APIProvider,
|
||||
ModelInfo,
|
||||
NEWEST_VER,
|
||||
ModuleConfig,
|
||||
)
|
||||
|
||||
|
||||
def _get_config_version(toml: Dict) -> Version:
|
||||
"""提取配置文件的 SpecifierSet 版本数据
|
||||
Args:
|
||||
toml[dict]: 输入的配置文件字典
|
||||
Returns:
|
||||
Version
|
||||
"""
|
||||
|
||||
if "inner" in toml and "version" in toml["inner"]:
|
||||
config_version: str = toml["inner"]["version"]
|
||||
else:
|
||||
config_version = "0.0.0" # 默认版本
|
||||
|
||||
try:
|
||||
ver = version.parse(config_version)
|
||||
except InvalidVersion as e:
|
||||
logger.error(
|
||||
"配置文件中 inner段 的 version 键是错误的版本描述\n"
|
||||
f"请检查配置文件,当前 version 键: {config_version}\n"
|
||||
f"错误信息: {e}"
|
||||
)
|
||||
raise InvalidVersion(
|
||||
"配置文件中 inner段 的 version 键是错误的版本描述\n"
|
||||
) from e
|
||||
|
||||
return ver
|
||||
|
||||
|
||||
def _request_conf(parent: Dict, config: ModuleConfig):
|
||||
request_conf_config = parent.get("request_conf")
|
||||
config.req_conf.max_retry = request_conf_config.get(
|
||||
"max_retry", config.req_conf.max_retry
|
||||
)
|
||||
config.req_conf.timeout = request_conf_config.get(
|
||||
"timeout", config.req_conf.timeout
|
||||
)
|
||||
config.req_conf.retry_interval = request_conf_config.get(
|
||||
"retry_interval", config.req_conf.retry_interval
|
||||
)
|
||||
config.req_conf.default_temperature = request_conf_config.get(
|
||||
"default_temperature", config.req_conf.default_temperature
|
||||
)
|
||||
config.req_conf.default_max_tokens = request_conf_config.get(
|
||||
"default_max_tokens", config.req_conf.default_max_tokens
|
||||
)
|
||||
|
||||
|
||||
def _api_providers(parent: Dict, config: ModuleConfig):
|
||||
api_providers_config = parent.get("api_providers")
|
||||
for provider in api_providers_config:
|
||||
name = provider.get("name", None)
|
||||
base_url = provider.get("base_url", None)
|
||||
api_key = provider.get("api_key", None)
|
||||
client_type = provider.get("client_type", "openai")
|
||||
|
||||
if name in config.api_providers: # 查重
|
||||
logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
|
||||
raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
|
||||
|
||||
if name and base_url:
|
||||
config.api_providers[name] = APIProvider(
|
||||
name=name,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
client_type=client_type,
|
||||
)
|
||||
else:
|
||||
logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
|
||||
raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
|
||||
|
||||
|
||||
def _models(parent: Dict, config: ModuleConfig):
|
||||
models_config = parent.get("models")
|
||||
for model in models_config:
|
||||
model_identifier = model.get("model_identifier", None)
|
||||
name = model.get("name", model_identifier)
|
||||
api_provider = model.get("api_provider", None)
|
||||
price_in = model.get("price_in", 0.0)
|
||||
price_out = model.get("price_out", 0.0)
|
||||
force_stream_mode = model.get("force_stream_mode", False)
|
||||
|
||||
if name in config.models: # 查重
|
||||
logger.error(f"重复的模型名称: {name},请检查配置文件。")
|
||||
raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
|
||||
|
||||
if model_identifier and api_provider:
|
||||
# 检查API提供商是否存在
|
||||
if api_provider not in config.api_providers:
|
||||
logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
|
||||
raise ValueError(
|
||||
f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
|
||||
)
|
||||
config.models[name] = ModelInfo(
|
||||
name=name,
|
||||
model_identifier=model_identifier,
|
||||
api_provider=api_provider,
|
||||
price_in=price_in,
|
||||
price_out=price_out,
|
||||
force_stream_mode=force_stream_mode,
|
||||
)
|
||||
else:
|
||||
logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
|
||||
raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
|
||||
|
||||
|
||||
def _task_model_usage(parent: Dict, config: ModuleConfig):
|
||||
model_usage_configs = parent.get("task_model_usage")
|
||||
config.task_model_arg_map = {}
|
||||
for task_name, item in model_usage_configs.items():
|
||||
if task_name in config.task_model_arg_map:
|
||||
logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
|
||||
raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
|
||||
|
||||
usage = []
|
||||
if isinstance(item, Dict):
|
||||
if "model" in item:
|
||||
usage.append(
|
||||
ModelUsageArgConfigItem(
|
||||
name=item["model"],
|
||||
temperature=item.get("temperature", None),
|
||||
max_tokens=item.get("max_tokens", None),
|
||||
max_retry=item.get("max_retry", None),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
|
||||
raise ValueError(
|
||||
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
|
||||
)
|
||||
elif isinstance(item, List):
|
||||
for model in item:
|
||||
if isinstance(model, Dict):
|
||||
usage.append(
|
||||
ModelUsageArgConfigItem(
|
||||
name=model["model"],
|
||||
temperature=model.get("temperature", None),
|
||||
max_tokens=model.get("max_tokens", None),
|
||||
max_retry=model.get("max_retry", None),
|
||||
)
|
||||
)
|
||||
elif isinstance(model, str):
|
||||
usage.append(
|
||||
ModelUsageArgConfigItem(
|
||||
name=model,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
max_retry=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
|
||||
)
|
||||
raise ValueError(
|
||||
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
usage.append(
|
||||
ModelUsageArgConfigItem(
|
||||
name=item,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
max_retry=None,
|
||||
)
|
||||
)
|
||||
|
||||
config.task_model_arg_map[task_name] = ModelUsageArgConfig(
|
||||
name=task_name,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> ModuleConfig:
|
||||
"""从TOML配置文件加载配置"""
|
||||
config = ModuleConfig()
|
||||
|
||||
include_configs: Dict[str, Dict[str, Any]] = {
|
||||
"request_conf": {
|
||||
"func": _request_conf,
|
||||
"support": ">=0.0.0",
|
||||
"necessary": False,
|
||||
},
|
||||
"api_providers": {"func": _api_providers, "support": ">=0.0.0"},
|
||||
"models": {"func": _models, "support": ">=0.0.0"},
|
||||
"task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
|
||||
}
|
||||
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "rb") as f:
|
||||
try:
|
||||
toml_dict = tomli.load(f)
|
||||
except tomli.TOMLDecodeError as e:
|
||||
logger.critical(
|
||||
f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# 获取配置文件版本
|
||||
config.INNER_VERSION = _get_config_version(toml_dict)
|
||||
|
||||
# 检查版本
|
||||
if config.INNER_VERSION > Version(NEWEST_VER):
|
||||
logger.warning(
|
||||
f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
|
||||
)
|
||||
|
||||
# 解析配置文件
|
||||
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
|
||||
for key in include_configs:
|
||||
if key in toml_dict:
|
||||
group_specifier_set: SpecifierSet = SpecifierSet(
|
||||
include_configs[key]["support"]
|
||||
)
|
||||
|
||||
# 检查配置文件版本是否在支持范围内
|
||||
if config.INNER_VERSION in group_specifier_set:
|
||||
# 如果版本在支持范围内,检查是否存在通知
|
||||
if "notice" in include_configs[key]:
|
||||
logger.warning(include_configs[key]["notice"])
|
||||
# 调用闭包函数处理配置
|
||||
(include_configs[key]["func"])(toml_dict, config)
|
||||
else:
|
||||
# 如果版本不在支持范围内,崩溃并提示用户
|
||||
logger.error(
|
||||
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
|
||||
f"当前程序仅支持以下版本范围: {group_specifier_set}"
|
||||
)
|
||||
raise InvalidVersion(
|
||||
f"当前程序仅支持以下版本范围: {group_specifier_set}"
|
||||
)
|
||||
|
||||
# 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理
|
||||
elif (
|
||||
"necessary" in include_configs[key]
|
||||
and include_configs[key].get("necessary") is False
|
||||
):
|
||||
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
|
||||
if key == "keywords_reaction":
|
||||
pass
|
||||
else:
|
||||
# 如果用户根本没有需要的配置项,提示缺少配置
|
||||
logger.error(f"配置文件中缺少必需的字段: '{key}'")
|
||||
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
|
||||
|
||||
logger.success(f"成功加载配置文件: {config_path}")
|
||||
|
||||
return config
|
||||
@@ -1,84 +0,0 @@
|
||||
import pytest
|
||||
from packaging.version import InvalidVersion
|
||||
|
||||
from src import maibot_llmreq
|
||||
from src.maibot_llmreq.config.parser import _get_config_version, load_config
|
||||
|
||||
|
||||
class TestConfigLoad:
|
||||
def test_loads_valid_version_from_toml(self):
|
||||
maibot_llmreq.init_logger()
|
||||
|
||||
toml_data = {"inner": {"version": "1.2.3"}}
|
||||
version = _get_config_version(toml_data)
|
||||
assert str(version) == "1.2.3"
|
||||
|
||||
def test_handles_missing_version_key(self):
|
||||
maibot_llmreq.init_logger()
|
||||
|
||||
toml_data = {}
|
||||
version = _get_config_version(toml_data)
|
||||
assert str(version) == "0.0.0"
|
||||
|
||||
def test_raises_error_for_invalid_version(self):
|
||||
maibot_llmreq.init_logger()
|
||||
|
||||
toml_data = {"inner": {"version": "invalid_version"}}
|
||||
with pytest.raises(InvalidVersion):
|
||||
_get_config_version(toml_data)
|
||||
|
||||
def test_loads_complete_config_successfully(self, tmp_path):
|
||||
maibot_llmreq.init_logger()
|
||||
|
||||
config_path = tmp_path / "config.toml"
|
||||
config_path.write_text("""
|
||||
[inner]
|
||||
version = "0.1.0"
|
||||
|
||||
[request_conf]
|
||||
max_retry = 5
|
||||
timeout = 10
|
||||
|
||||
[[api_providers]]
|
||||
name = "provider1"
|
||||
base_url = "https://api.example.com"
|
||||
api_key = "key123"
|
||||
|
||||
[[api_providers]]
|
||||
name = "provider2"
|
||||
base_url = "https://api.example2.com"
|
||||
api_key = "key456"
|
||||
|
||||
[[models]]
|
||||
model_identifier = "model1"
|
||||
api_provider = "provider1"
|
||||
|
||||
[[models]]
|
||||
model_identifier = "model2"
|
||||
api_provider = "provider2"
|
||||
|
||||
[task_model_usage]
|
||||
task1 = { model = "model1" }
|
||||
task2 = "model1"
|
||||
task3 = [
|
||||
"model1",
|
||||
{ model = "model2", temperature = 0.5 }
|
||||
]
|
||||
""")
|
||||
config = load_config(str(config_path))
|
||||
assert config.req_conf.max_retry == 5
|
||||
assert config.req_conf.timeout == 10
|
||||
assert "provider1" in config.api_providers
|
||||
assert "model1" in config.models
|
||||
assert "task1" in config.task_model_arg_map
|
||||
|
||||
def test_raises_error_for_missing_required_field(self, tmp_path):
|
||||
maibot_llmreq.init_logger()
|
||||
|
||||
config_path = tmp_path / "config.toml"
|
||||
config_path.write_text("""
|
||||
[inner]
|
||||
version = "1.0.0"
|
||||
""")
|
||||
with pytest.raises(KeyError):
|
||||
load_config(str(config_path))
|
||||
@@ -13,7 +13,6 @@ from packaging.version import Version, InvalidVersion
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message import api
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
@@ -314,7 +313,7 @@ def api_ada_load_config(config_path: str) -> ModuleConfig:
|
||||
logger.error(f"配置文件中缺少必需的字段: '{key}'")
|
||||
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
|
||||
|
||||
logger.success(f"成功加载配置文件: {config_path}")
|
||||
logger.info(f"成功加载配置文件: {config_path}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletion
|
||||
|
||||
from .base_client import BaseClient, APIResponse
|
||||
from .. import _logger as logger
|
||||
from ..config.config import (
|
||||
from src.config.api_ada_configs import (
|
||||
ModelInfo,
|
||||
ModelUsageArgConfigItem,
|
||||
RequestConfig,
|
||||
@@ -22,6 +21,9 @@ from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption
|
||||
from ..utils import compress_messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("模型客户端")
|
||||
|
||||
|
||||
def _check_retry(
|
||||
@@ -288,7 +290,7 @@ class ModelRequestHandler:
|
||||
interrupt_flag=interrupt_flag,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.trace(e)
|
||||
logger.debug(e)
|
||||
remain_try -= 1 # 剩余尝试次数减1
|
||||
|
||||
# 处理异常
|
||||
@@ -340,7 +342,7 @@ class ModelRequestHandler:
|
||||
embedding_input=embedding_input,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.trace(e)
|
||||
logger.debug(e)
|
||||
remain_try -= 1 # 剩余尝试次数减1
|
||||
|
||||
# 处理异常
|
||||
@@ -5,7 +5,7 @@ from typing import Callable, Any
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletion
|
||||
|
||||
from ..config.config import ModelInfo, APIProvider
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||
@@ -15,7 +15,7 @@ from google.genai.errors import (
|
||||
)
|
||||
|
||||
from .base_client import APIResponse, UsageRecord
|
||||
from ..config.config import ModelInfo, APIProvider
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from . import BaseClient
|
||||
|
||||
from ..exceptions import (
|
||||
@@ -21,7 +21,7 @@ from openai.types.chat import (
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
|
||||
from .base_client import APIResponse, UsageRecord
|
||||
from ..config.config import ModelInfo, APIProvider
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from . import BaseClient
|
||||
|
||||
from ..exceptions import (
|
||||
@@ -1,15 +1,13 @@
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .config.config import (
|
||||
ModelUsageArgConfig,
|
||||
ModuleConfig,
|
||||
)
|
||||
|
||||
from . import _logger as logger
|
||||
from .model_client import ModelRequestHandler, BaseClient
|
||||
|
||||
logger = get_logger("模型管理器")
|
||||
|
||||
class ModelManager:
|
||||
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
|
||||
@@ -77,3 +75,5 @@ class ModelManager:
|
||||
:return: 是否在模型列表中
|
||||
"""
|
||||
return task_name in self.config.task_model_arg_map
|
||||
|
||||
|
||||
8
src/llm_models/temp.py
Normal file
8
src/llm_models/temp.py
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
from src.config.config import model_config
|
||||
print(f"当前模型配置: {model_config}")
|
||||
print(model_config.req_conf.default_max_tokens)
|
||||
@@ -2,10 +2,11 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from pymongo.synchronous.database import Database
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from src.common.database.database_model import LLMUsage
|
||||
|
||||
from . import _logger as logger
|
||||
from .config.config import ModelInfo
|
||||
logger = get_logger("模型使用统计")
|
||||
|
||||
|
||||
class ReqType(Enum):
|
||||
@@ -29,33 +30,21 @@ class UsageCallStatus(Enum):
|
||||
|
||||
|
||||
class ModelUsageStatistic:
|
||||
db: Database | None = None
|
||||
|
||||
def __init__(self, db: Database):
|
||||
if db is None:
|
||||
logger.warning(
|
||||
"Warning: No database provided, ModelUsageStatistic will not work."
|
||||
)
|
||||
return
|
||||
if self._init_database(db):
|
||||
# 成功初始化
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def _init_database(db: Database):
|
||||
"""
|
||||
初始化数据库相关索引
|
||||
模型使用统计类 - 使用SQLite+Peewee
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化统计类
|
||||
由于使用Peewee ORM,不需要传入数据库实例
|
||||
"""
|
||||
# 确保表已经创建
|
||||
try:
|
||||
db.llm_usage.create_index([("timestamp", 1)])
|
||||
db.llm_usage.create_index([("model_name", 1)])
|
||||
db.llm_usage.create_index([("task_name", 1)])
|
||||
db.llm_usage.create_index([("request_type", 1)])
|
||||
db.llm_usage.create_index([("status", 1)])
|
||||
return True
|
||||
from src.common.database.database import db
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库索引失败: {e}")
|
||||
return False
|
||||
logger.error(f"创建LLMUsage表失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _calculate_cost(
|
||||
@@ -67,6 +56,7 @@ class ModelUsageStatistic:
|
||||
Args:
|
||||
prompt_tokens: 输入token数量
|
||||
completion_tokens: 输出token数量
|
||||
model_info: 模型信息
|
||||
|
||||
Returns:
|
||||
float: 总成本(元)
|
||||
@@ -81,46 +71,50 @@ class ModelUsageStatistic:
|
||||
model_name: str,
|
||||
task_name: str = "N/A",
|
||||
request_type: ReqType = ReqType.CHAT,
|
||||
) -> str | None:
|
||||
user_id: str = "system",
|
||||
endpoint: str = "/chat/completions",
|
||||
) -> int | None:
|
||||
"""
|
||||
创建模型使用情况记录
|
||||
:param model_name: 模型名
|
||||
:param task_name: 任务名称
|
||||
:param request_type: 请求类型,默认为Chat
|
||||
:return:
|
||||
"""
|
||||
if self.db is None:
|
||||
return None # 如果没有数据库连接,则不记录使用情况
|
||||
|
||||
Args:
|
||||
model_name: 模型名
|
||||
task_name: 任务名称
|
||||
request_type: 请求类型,默认为Chat
|
||||
user_id: 用户ID,默认为system
|
||||
endpoint: API端点
|
||||
|
||||
Returns:
|
||||
int | None: 返回记录ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
usage_data = {
|
||||
"model_name": model_name,
|
||||
"task_name": task_name,
|
||||
"request_type": request_type.value,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"status": "processing",
|
||||
"timestamp": datetime.now(),
|
||||
"ext_msg": None,
|
||||
}
|
||||
result = self.db.llm_usage.insert_one(usage_data)
|
||||
usage_record = LLMUsage.create(
|
||||
model_name=model_name,
|
||||
user_id=user_id,
|
||||
request_type=request_type.value,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost=0.0,
|
||||
status=UsageCallStatus.PROCESSING.value,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
|
||||
logger.trace(
|
||||
f"创建了一条模型使用情况记录 - 模型: {model_name}, "
|
||||
f"子任务: {task_name}, 类型: {request_type}"
|
||||
f"记录ID: {str(result.inserted_id)}"
|
||||
f"子任务: {task_name}, 类型: {request_type.value}, "
|
||||
f"用户: {user_id}, 记录ID: {usage_record.id}"
|
||||
)
|
||||
|
||||
return str(result.inserted_id)
|
||||
return usage_record.id
|
||||
except Exception as e:
|
||||
logger.error(f"创建模型使用情况记录失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
record_id: str | None,
|
||||
record_id: int | None,
|
||||
model_info: ModelInfo,
|
||||
usage_data: Tuple[int, int, int] | None = None,
|
||||
stat: UsageCallStatus = UsageCallStatus.SUCCESS,
|
||||
@@ -136,9 +130,6 @@ class ModelUsageStatistic:
|
||||
stat: 任务调用状态
|
||||
ext_msg: 额外信息
|
||||
"""
|
||||
if self.db is None:
|
||||
return # 如果没有数据库连接,则不记录使用情况
|
||||
|
||||
if not record_id:
|
||||
logger.error("更新模型使用情况失败: record_id不能为空")
|
||||
return
|
||||
@@ -153,28 +144,27 @@ class ModelUsageStatistic:
|
||||
total_tokens = usage_data[2] if usage_data else 0
|
||||
|
||||
try:
|
||||
self.db.llm_usage.update_one(
|
||||
{"_id": record_id},
|
||||
{
|
||||
"$set": {
|
||||
"status": stat.value,
|
||||
"ext_msg": ext_msg,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"cost": self._calculate_cost(
|
||||
# 使用Peewee更新记录
|
||||
update_query = LLMUsage.update(
|
||||
status=stat.value,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=self._calculate_cost(
|
||||
prompt_tokens, completion_tokens, model_info
|
||||
)
|
||||
if usage_data
|
||||
else 0.0,
|
||||
}
|
||||
},
|
||||
)
|
||||
) if usage_data else 0.0,
|
||||
).where(LLMUsage.id == record_id)
|
||||
|
||||
logger.trace(
|
||||
updated_count = update_query.execute()
|
||||
|
||||
if updated_count == 0:
|
||||
logger.warning(f"记录ID {record_id} 不存在,无法更新")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_info.name}, "
|
||||
f"记录ID: {record_id}, "
|
||||
f"任务状态: {stat.value}, 额外信息: {ext_msg if ext_msg else 'N/A'}, "
|
||||
f"记录ID: {record_id}, "
|
||||
f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, "
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
@@ -3,9 +3,11 @@ import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from . import _logger as logger
|
||||
from src.common.logger import get_logger
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
|
||||
logger = get_logger("消息压缩工具")
|
||||
|
||||
|
||||
def compress_messages(
|
||||
messages: list[Message], img_target_size: int = 1 * 1024 * 1024
|
||||
File diff suppressed because it is too large
Load Diff
77
template/compare/model_config_template.toml
Normal file
77
template/compare/model_config_template.toml
Normal file
@@ -0,0 +1,77 @@
|
||||
[inner]
|
||||
version = "0.1.0"
|
||||
|
||||
# 配置文件版本号迭代规则同bot_config.toml
|
||||
|
||||
[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释)
|
||||
#max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数)
|
||||
#timeout = 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)
|
||||
#retry_interval = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒)
|
||||
#default_temperature = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值)
|
||||
#default_max_tokens = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值)
|
||||
|
||||
|
||||
[[api_providers]] # API服务提供商(可以配置多个)
|
||||
name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名)
|
||||
base_url = "https://api.deepseek.cn" # API服务商的BaseURL
|
||||
key = "******" # API Key (可选,默认为None)
|
||||
client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"google")
|
||||
|
||||
#[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"google"
|
||||
#name = "Google"
|
||||
#base_url = "https://api.google.com"
|
||||
#key = "******"
|
||||
#client_type = "google"
|
||||
#
|
||||
#[[api_providers]]
|
||||
#name = "SiliconFlow"
|
||||
#base_url = "https://api.siliconflow.cn"
|
||||
#key = "******"
|
||||
#
|
||||
#[[api_providers]]
|
||||
#name = "LocalHost"
|
||||
#base_url = "https://localhost:8888"
|
||||
#key = "lm-studio"
|
||||
|
||||
|
||||
[[models]] # 模型(可以配置多个)
|
||||
# 模型标识符(API服务商提供的模型标识符)
|
||||
model_identifier = "deepseek-chat"
|
||||
# 模型名称(可随意命名,在bot_config.toml中需使用这个命名)
|
||||
#(可选,若无该字段,则将自动使用model_identifier填充)
|
||||
name = "deepseek-v3"
|
||||
# API服务商名称(对应在api_providers中配置的服务商名称)
|
||||
api_provider = "DeepSeek"
|
||||
# 输入价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0)
|
||||
price_in = 2.0
|
||||
# 输出价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0)
|
||||
price_out = 8.0
|
||||
# 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出)
|
||||
#(可选,若无该字段,默认值为false)
|
||||
#force_stream_mode = true
|
||||
|
||||
#[[models]]
|
||||
#model_identifier = "deepseek-reasoner"
|
||||
#name = "deepseek-r1"
|
||||
#api_provider = "DeepSeek"
|
||||
#model_flags = ["text", "tool_calling", "reasoning"]
|
||||
#price_in = 4.0
|
||||
#price_out = 16.0
|
||||
#
|
||||
#[[models]]
|
||||
#model_identifier = "BAAI/bge-m3"
|
||||
#name = "siliconflow-bge-m3"
|
||||
#api_provider = "SiliconFlow"
|
||||
#model_flags = ["text", "embedding"]
|
||||
#price_in = 0
|
||||
#price_out = 0
|
||||
|
||||
|
||||
[task_model_usage]
|
||||
#llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0}
|
||||
#llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0}
|
||||
#embedding = "siliconflow-bge-m3"
|
||||
#schedule = [
|
||||
# "deepseek-v3",
|
||||
# "deepseek-r1",
|
||||
#]
|
||||
Reference in New Issue
Block a user