Merge pull request #1148 from MaiM-with-u/dev-api-ada

添加API Adapter
This commit is contained in:
墨梓柒
2025-07-29 10:24:03 +08:00
committed by GitHub
24 changed files with 4290 additions and 1169 deletions

33
bot.py
View File

@@ -74,36 +74,6 @@ def easter_egg():
print(rainbow_text)
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
async def graceful_shutdown():
try:
@@ -229,9 +199,6 @@ def raw_main():
easter_egg()
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
# 返回MainSystem实例
return MainSystem()

View File

@@ -981,8 +981,9 @@ class DefaultReplyer:
with Timer("LLM生成", {}): # 内部计时器,可选保留
# 加权随机选择一个模型配置
selected_model_config = self._select_weighted_model_config()
model_display_name = selected_model_config.get('model_name') or selected_model_config.get('name', 'N/A')
logger.info(
f"使用模型生成回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
)
express_model = LLMRequest(

View File

@@ -0,0 +1,180 @@
from dataclasses import dataclass, field
from typing import List, Dict, Union
import threading
import time
from packaging.version import Version
NEWEST_VER = "0.1.1" # 当前支持的最新版本
@dataclass
class APIProvider:
name: str = "" # API提供商名称
base_url: str = "" # API基础URL
api_key: str = field(repr=False, default="") # API密钥向后兼容
api_keys: List[str] = field(repr=False, default_factory=list) # API密钥列表新格式
client_type: str = "openai" # 客户端类型如openai/google等默认为openai
# 多API Key管理相关属性
_current_key_index: int = field(default=0, init=False, repr=False) # 当前使用的key索引
_key_failure_count: Dict[int, int] = field(default_factory=dict, init=False, repr=False) # 每个key的失败次数
_key_last_failure_time: Dict[int, float] = field(default_factory=dict, init=False, repr=False) # 每个key最后失败时间
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) # 线程锁
def __post_init__(self):
"""初始化后处理确保API keys列表正确"""
# 向后兼容如果只设置了api_key将其添加到api_keys列表
if self.api_key and not self.api_keys:
self.api_keys = [self.api_key]
# 如果api_keys不为空但api_key为空设置api_key为第一个
elif self.api_keys and not self.api_key:
self.api_key = self.api_keys[0]
# 初始化失败计数器
for i in range(len(self.api_keys)):
self._key_failure_count[i] = 0
self._key_last_failure_time[i] = 0
def get_current_api_key(self) -> str:
"""获取当前应该使用的API Key"""
with self._lock:
if not self.api_keys:
return ""
# 确保索引在有效范围内
if self._current_key_index >= len(self.api_keys):
self._current_key_index = 0
return self.api_keys[self._current_key_index]
def get_next_api_key(self) -> Union[str, None]:
"""获取下一个可用的API Key负载均衡"""
with self._lock:
if not self.api_keys:
return None
# 如果只有一个key直接返回
if len(self.api_keys) == 1:
return self.api_keys[0]
# 轮询到下一个key
self._current_key_index = (self._current_key_index + 1) % len(self.api_keys)
return self.api_keys[self._current_key_index]
def mark_key_failed(self, api_key: str) -> Union[str, None]:
"""标记某个API Key失败返回下一个可用的key"""
with self._lock:
if not self.api_keys or api_key not in self.api_keys:
return None
key_index = self.api_keys.index(api_key)
self._key_failure_count[key_index] += 1
self._key_last_failure_time[key_index] = time.time()
# 寻找下一个可用的key
current_time = time.time()
for _ in range(len(self.api_keys)):
self._current_key_index = (self._current_key_index + 1) % len(self.api_keys)
next_key_index = self._current_key_index
# 检查该key是否最近失败过5分钟内失败超过3次则暂时跳过
if (self._key_failure_count[next_key_index] <= 3 or
current_time - self._key_last_failure_time[next_key_index] > 300): # 5分钟后重试
return self.api_keys[next_key_index]
# 如果所有key都不可用返回当前key让上层处理
return api_key
def reset_key_failures(self, api_key: str | None = None):
"""重置失败计数(成功调用后调用)"""
with self._lock:
if api_key and api_key in self.api_keys:
key_index = self.api_keys.index(api_key)
self._key_failure_count[key_index] = 0
self._key_last_failure_time[key_index] = 0
else:
# 重置所有key的失败计数
for i in range(len(self.api_keys)):
self._key_failure_count[i] = 0
self._key_last_failure_time[i] = 0
def get_api_key_stats(self) -> Dict[str, Dict[str, Union[int, float]]]:
"""获取API Key使用统计"""
with self._lock:
stats = {}
for i, key in enumerate(self.api_keys):
# 只显示key的前8位和后4位中间用*代替
masked_key = f"{key[:8]}***{key[-4:]}" if len(key) > 12 else "***"
stats[masked_key] = {
"failure_count": self._key_failure_count.get(i, 0),
"last_failure_time": self._key_last_failure_time.get(i, 0),
"is_current": i == self._current_key_index
}
return stats
@dataclass
class ModelInfo:
model_identifier: str = "" # 模型标识符用于URL调用
name: str = "" # 模型名称(用于模块调用)
api_provider: str = "" # API提供商如OpenAI、Azure等
# 以下用于模型计费
price_in: float = 0.0 # 每M token输入价格
price_out: float = 0.0 # 每M token输出价格
force_stream_mode: bool = False # 是否强制使用流式输出模式
# 新增:任务类型和能力字段
task_type: str = "" # 任务类型llm_normal, llm_reasoning, vision, embedding, speech
capabilities: List[str] = field(default_factory=list) # 模型能力text, vision, embedding, speech, tool_calling, reasoning
@dataclass
class RequestConfig:
max_retry: int = 2 # 最大重试次数单个模型API调用失败最多重试的次数
timeout: int = (
10 # API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位
)
retry_interval: int = 10 # 重试间隔如果API调用失败重试的间隔时间单位
default_temperature: float = 0.7 # 默认的温度如果bot_config.toml中没有设置temperature参数默认使用这个值
default_max_tokens: int = 1024 # 默认的最大输出token数如果bot_config.toml中没有设置max_tokens参数默认使用这个值
@dataclass
class ModelUsageArgConfigItem:
"""模型使用的配置类
该类用于加载和存储子任务模型使用的配置
"""
name: str = "" # 模型名称
temperature: float | None = None # 温度
max_tokens: int | None = None # 最大token数
max_retry: int | None = None # 调用失败时的最大重试次数
@dataclass
class ModelUsageArgConfig:
"""子任务使用模型的配置类
该类用于加载和存储子任务使用的模型配置
"""
name: str = "" # 任务名称
usage: List[ModelUsageArgConfigItem] = field(
default_factory=lambda: []
) # 任务使用的模型列表
@dataclass
class ModuleConfig:
INNER_VERSION: Version | None = None # 配置文件版本
req_conf: RequestConfig = field(default_factory=lambda: RequestConfig()) # 请求配置
api_providers: Dict[str, APIProvider] = field(
default_factory=lambda: {}
) # API提供商列表
models: Dict[str, ModelInfo] = field(default_factory=lambda: {}) # 模型列表
task_model_arg_map: Dict[str, ModelUsageArgConfig] = field(
default_factory=lambda: {}
)

View File

@@ -1,162 +0,0 @@
import shutil
import tomlkit
from tomlkit.items import Table, KeyType
from pathlib import Path
from datetime import datetime
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment
return None
def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None):
# 递归比较两个dict找出新增和删减项收集注释
if path is None:
path = []
if logs is None:
logs = []
if new_comments is None:
new_comments = {}
if old_comments is None:
old_comments = {}
# 新增项
for key in new:
if key == "version":
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
# 删减项
for key in old:
if key == "version":
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs
def update_config():
print("开始更新配置文件...")
# 获取根目录路径
root_dir = Path(__file__).parent.parent.parent.parent
template_dir = root_dir / "template"
config_dir = root_dir / "config"
old_config_dir = config_dir / "old"
# 创建old目录如果不存在
old_config_dir.mkdir(exist_ok=True)
# 定义文件路径
template_path = template_dir / "bot_config_template.toml"
old_config_path = config_dir / "bot_config.toml"
new_config_path = config_dir / "bot_config.toml"
# 读取旧配置文件
old_config = {}
if old_config_path.exists():
print(f"发现旧配置文件: {old_config_path}")
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
print(f"已备份旧配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
print(f"从模板文件创建新配置: {template_path}")
shutil.copy2(template_path, new_config_path)
# 读取新配置文件
with open(new_config_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同恢复旧配置文件并返回
shutil.move(old_backup_path, old_config_path) # type: ignore
return
else:
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
# 输出新增和删减项及注释
if old_config:
print("配置项变动如下:")
logs = compare_dicts(new_config, old_config)
if logs:
for log in logs:
print(log)
else:
print("无新增或删减项")
# 递归更新配置
def update_dict(target, source):
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
if not value:
target[key] = tomlkit.array()
else:
# 特殊处理正则表达式数组和包含正则表达式的结构
if key == "ban_msgs_regex":
# 直接使用原始值,不进行额外处理
target[key] = value
elif key == "regex_rules":
# 对于regex_rules需要特殊处理其中的regex字段
target[key] = value
else:
# 检查是否包含正则表达式相关的字典项
contains_regex = False
if value and isinstance(value[0], dict) and "regex" in value[0]:
contains_regex = True
target[key] = value if contains_regex else tomlkit.array(str(value))
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
print("开始合并新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
print("配置文件更新完成")
if __name__ == "__main__":
update_config()

View File

@@ -7,6 +7,10 @@ from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass
from rich.traceback import install
from packaging import version
from packaging.specifiers import SpecifierSet
from packaging.version import Version, InvalidVersion
from typing import Any, Dict, List
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
@@ -36,6 +40,17 @@ from src.config.official_configs import (
CustomPromptConfig,
)
from .api_ada_configs import (
ModelUsageArgConfigItem,
ModelUsageArgConfig,
APIProvider,
ModelInfo,
NEWEST_VER,
ModuleConfig,
)
install(extra_lines=3)
@@ -52,6 +67,273 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
MMC_VERSION = "0.9.1"
def _get_config_version(toml: Dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
"""
if "inner" in toml and "version" in toml["inner"]:
config_version: str = toml["inner"]["version"]
else:
config_version = "0.0.0" # 默认版本
try:
ver = version.parse(config_version)
except InvalidVersion as e:
logger.error(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
f"请检查配置文件,当前 version 键: {config_version}\n"
f"错误信息: {e}"
)
raise InvalidVersion(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
) from e
return ver
def _request_conf(parent: Dict, config: ModuleConfig):
request_conf_config = parent.get("request_conf")
config.req_conf.max_retry = request_conf_config.get(
"max_retry", config.req_conf.max_retry
)
config.req_conf.timeout = request_conf_config.get(
"timeout", config.req_conf.timeout
)
config.req_conf.retry_interval = request_conf_config.get(
"retry_interval", config.req_conf.retry_interval
)
config.req_conf.default_temperature = request_conf_config.get(
"default_temperature", config.req_conf.default_temperature
)
config.req_conf.default_max_tokens = request_conf_config.get(
"default_max_tokens", config.req_conf.default_max_tokens
)
def _api_providers(parent: Dict, config: ModuleConfig):
api_providers_config = parent.get("api_providers")
for provider in api_providers_config:
name = provider.get("name", None)
base_url = provider.get("base_url", None)
api_key = provider.get("api_key", None)
api_keys = provider.get("api_keys", []) # 新增支持多个API Key
client_type = provider.get("client_type", "openai")
if name in config.api_providers: # 查重
logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
if name and base_url:
# 处理API Key配置支持单个api_key或多个api_keys
if api_keys:
# 使用新格式api_keys列表
logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key")
elif api_key:
# 向后兼容使用单个api_key
api_keys = [api_key]
logger.debug(f"API提供商 '{name}' 使用单个API Key向后兼容模式")
else:
logger.warning(f"API提供商 '{name}' 没有配置API Key某些功能可能不可用")
config.api_providers[name] = APIProvider(
name=name,
base_url=base_url,
api_key=api_key, # 保留向后兼容
api_keys=api_keys, # 新格式
client_type=client_type,
)
else:
logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
def _models(parent: Dict, config: ModuleConfig):
models_config = parent.get("models")
for model in models_config:
model_identifier = model.get("model_identifier", None)
name = model.get("name", model_identifier)
api_provider = model.get("api_provider", None)
price_in = model.get("price_in", 0.0)
price_out = model.get("price_out", 0.0)
force_stream_mode = model.get("force_stream_mode", False)
task_type = model.get("task_type", "")
capabilities = model.get("capabilities", [])
if name in config.models: # 查重
logger.error(f"重复的模型名称: {name},请检查配置文件。")
raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
if model_identifier and api_provider:
# 检查API提供商是否存在
if api_provider not in config.api_providers:
logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
raise ValueError(
f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
)
config.models[name] = ModelInfo(
name=name,
model_identifier=model_identifier,
api_provider=api_provider,
price_in=price_in,
price_out=price_out,
force_stream_mode=force_stream_mode,
task_type=task_type,
capabilities=capabilities,
)
else:
logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
def _task_model_usage(parent: Dict, config: ModuleConfig):
model_usage_configs = parent.get("task_model_usage")
config.task_model_arg_map = {}
for task_name, item in model_usage_configs.items():
if task_name in config.task_model_arg_map:
logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
usage = []
if isinstance(item, Dict):
if "model" in item:
usage.append(
ModelUsageArgConfigItem(
name=item["model"],
temperature=item.get("temperature", None),
max_tokens=item.get("max_tokens", None),
max_retry=item.get("max_retry", None),
)
)
else:
logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
raise ValueError(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
elif isinstance(item, List):
for model in item:
if isinstance(model, Dict):
usage.append(
ModelUsageArgConfigItem(
name=model["model"],
temperature=model.get("temperature", None),
max_tokens=model.get("max_tokens", None),
max_retry=model.get("max_retry", None),
)
)
elif isinstance(model, str):
usage.append(
ModelUsageArgConfigItem(
name=model,
temperature=None,
max_tokens=None,
max_retry=None,
)
)
else:
logger.error(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
raise ValueError(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
elif isinstance(item, str):
usage.append(
ModelUsageArgConfigItem(
name=item,
temperature=None,
max_tokens=None,
max_retry=None,
)
)
config.task_model_arg_map[task_name] = ModelUsageArgConfig(
name=task_name,
usage=usage,
)
def api_ada_load_config(config_path: str) -> ModuleConfig:
"""从TOML配置文件加载配置"""
config = ModuleConfig()
include_configs: Dict[str, Dict[str, Any]] = {
"request_conf": {
"func": _request_conf,
"support": ">=0.0.0",
"necessary": False,
},
"api_providers": {"func": _api_providers, "support": ">=0.0.0"},
"models": {"func": _models, "support": ">=0.0.0"},
"task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
}
if os.path.exists(config_path):
with open(config_path, "rb") as f:
try:
toml_dict = tomlkit.load(f)
except tomlkit.TOMLDecodeError as e:
logger.critical(
f"配置文件model_list.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}"
)
exit(1)
# 获取配置文件版本
config.INNER_VERSION = _get_config_version(toml_dict)
# 检查版本
if config.INNER_VERSION > Version(NEWEST_VER):
logger.warning(
f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
)
# 解析配置文件
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
for key in include_configs:
if key in toml_dict:
group_specifier_set: SpecifierSet = SpecifierSet(
include_configs[key]["support"]
)
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifier_set:
# 如果版本在支持范围内,检查是否存在通知
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
# 调用闭包函数处理配置
(include_configs[key]["func"])(toml_dict, config)
else:
# 如果版本不在支持范围内,崩溃并提示用户
logger.error(
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
f"当前程序仅支持以下版本范围: {group_specifier_set}"
)
raise InvalidVersion(
f"当前程序仅支持以下版本范围: {group_specifier_set}"
)
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif (
"necessary" in include_configs[key]
and include_configs[key].get("necessary") is False
):
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
logger.info(f"成功加载配置文件: {config_path}")
return config
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
@@ -133,37 +415,74 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
return logs, changes
def update_config():
def _get_version_from_toml(toml_path):
"""从TOML文件中获取版本号"""
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
def _version_tuple(v):
"""将版本字符串转换为元组以便比较"""
if v is None:
return (0,)
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
_update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
def _update_config_generic(config_name: str, template_name: str, should_quit_on_new: bool = True):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
should_quit_on_new: 创建新配置文件后是否退出程序
"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
# 定义文件路径
template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml")
old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
compare_path = os.path.join(compare_dir, "bot_config_template.toml")
template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
compare_path = os.path.join(compare_dir, f"{template_name}.toml")
# 创建compare目录如果不存在
os.makedirs(compare_dir, exist_ok=True)
# 处理compare下的模板文件
def get_version_from_toml(toml_path):
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
template_version = get_version_from_toml(template_path)
compare_version = get_version_from_toml(compare_path)
def version_tuple(v):
if v is None:
return (0,)
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
template_version = _get_version_from_toml(template_path)
compare_version = _get_version_from_toml(compare_path)
# 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path):
@@ -183,7 +502,7 @@ def update_config():
old_config = tomlkit.load(f)
logs, changes = compare_default_values(new_config, compare_config)
if logs:
logger.info("检测到模板默认值变动如下:")
logger.info(f"检测到{config_name}模板默认值变动如下:")
for log in logs:
logger.info(log)
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
@@ -192,10 +511,10 @@ def update_config():
if old_value == old_default:
set_value_by_path(old_config, path, new_default)
logger.info(
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
else:
logger.info("未检测到模板默认值变动")
logger.info(f"未检测到{config_name}模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config
else:
old_config = None
@@ -203,22 +522,25 @@ def update_config():
# 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path)
logger.info(f"已将模板文件复制到: {compare_path}")
logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
else:
if version_tuple(template_version) > version_tuple(compare_version):
if _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"模板版本较新已替换compare下的模板: {compare_path}")
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else:
logger.debug(f"compare下的模板版本不低于当前模板无需替换: {compare_path}")
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info("配置文件不存在,从模板创建新配置")
logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回
quit()
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,根据参数决定是否退出
if should_quit_on_new:
quit()
else:
return
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None:
@@ -226,38 +548,36 @@ def update_config():
old_config = tomlkit.load(f)
# new_config 已经读取
# 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
)
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录如果不存在
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"bot_config_{timestamp}.toml")
old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧配置文件到: {old_backup_path}")
logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}")
logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
# 输出新增和删减项及注释
if old_config:
logger.info("配置项变动如下:\n----------------------------------------")
logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
logs = compare_dicts(new_config, old_config)
if logs:
for log in logs:
@@ -265,40 +585,24 @@ def update_config():
else:
logger.info("无新增或删减项")
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并新旧配置...")
update_dict(new_config, old_config)
logger.info(f"开始合并{config_name}新旧配置...")
_update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
quit()
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template", should_quit_on_new=True)
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template", should_quit_on_new=False)
@dataclass
@@ -360,7 +664,9 @@ def get_config_dir() -> str:
# 获取配置文件路径
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config()
update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
logger.info("非常的新鲜,非常的美味!")

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from src.config.config_base import ConfigBase
from packaging.version import Version
"""
须知:
@@ -598,7 +599,6 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致"""
@dataclass
class ModelConfig(ConfigBase):
"""模型配置类"""

21
src/llm_models/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Mai.To.The.Gate
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

View File

@@ -0,0 +1,69 @@
from typing import Any
# 常见Error Code Mapping (以OpenAI API为例)
error_code_mapping = {
400: "参数不正确",
401: "API-Key错误认证失败请检查/config/model_list.toml中的配置是否正确",
402: "账号余额不足",
403: "模型拒绝访问,可能需要实名或余额不足",
404: "Not Found",
413: "请求体过大,请尝试压缩图片或减少输入内容",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class NetworkConnectionError(Exception):
"""连接异常,常见于网络问题或服务器不可用"""
def __init__(self):
super().__init__()
def __str__(self):
return "连接异常请检查网络连接状态或URL是否正确"
class ReqAbortException(Exception):
"""请求异常退出,常见于请求被中断或取消"""
def __init__(self, message: str | None = None):
super().__init__(message)
self.message = message
def __str__(self):
return self.message or "请求因未知原因异常终止"
class RespNotOkException(Exception):
"""请求响应异常,见于请求未能成功响应(非 '200 OK'"""
def __init__(self, status_code: int, message: str | None = None):
super().__init__(message)
self.status_code = status_code
self.message = message
def __str__(self):
if self.status_code in error_code_mapping:
return error_code_mapping[self.status_code]
elif self.message:
return self.message
else:
return f"未知的异常响应代码:{self.status_code}"
class RespParseException(Exception):
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
def __init__(self, ext_info: Any, message: str | None = None):
super().__init__(message)
self.ext_info = ext_info
self.message = message
def __str__(self):
return (
self.message
if self.message
else "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
)

View File

@@ -0,0 +1,380 @@
import asyncio
from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from .base_client import BaseClient, APIResponse
from src.config.api_ada_configs import (
ModelInfo,
ModelUsageArgConfigItem,
RequestConfig,
ModuleConfig,
)
from ..exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption
from ..utils import compress_messages
from src.common.logger import get_logger
logger = get_logger("模型客户端")
def _check_retry(
remain_try: int,
retry_interval: int,
can_retry_msg: str,
cannot_retry_msg: str,
can_retry_callable: Callable | None = None,
**kwargs,
) -> tuple[int, Any | None]:
"""
辅助函数:检查是否可以重试
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param can_retry_msg: 可以重试时的提示信息
:param cannot_retry_msg: 不可以重试时的提示信息
:return: (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if remain_try > 0:
# 还有重试机会
logger.warning(f"{can_retry_msg}")
if can_retry_callable is not None:
return retry_interval, can_retry_callable(**kwargs)
else:
return retry_interval, None
else:
# 达到最大重试次数
logger.warning(f"{cannot_retry_msg}")
return -1, None # 不再重试请求该模型
def _handle_resp_not_ok(
e: RespNotOkException,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
):
"""
处理响应错误异常
:param e: 异常对象
:param task_name: 任务名称
:param model_name: 模型名称
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param messages: (消息列表, 是否已压缩过)
:return: (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
# 响应错误
if e.status_code in [401, 403]:
# API Key认证错误 - 让多API Key机制处理给一次重试机会
if remain_try > 0:
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"API Key认证失败错误代码-{e.status_code}多API Key机制会自动切换"
)
return 0, None # 立即重试让底层客户端切换API Key
else:
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"所有API Key都认证失败错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code in [400, 402, 404]:
# 其他客户端错误(不应该重试)
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"请求失败,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code == 413:
if messages and not messages[1]:
# 消息列表不为空且未压缩,尝试压缩消息
return _check_retry(
remain_try,
0,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,尝试压缩消息后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,压缩消息后仍然过大,放弃请求"
),
can_retry_callable=compress_messages,
messages=messages[0],
)
# 没有消息可压缩
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,无法压缩消息,放弃请求。"
)
return -1, None
elif e.status_code == 429:
# 请求过于频繁 - 让多API Key机制处理适当延迟后重试
return _check_retry(
remain_try,
min(retry_interval, 5), # 限制最大延迟为5秒让API Key切换更快生效
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"请求过于频繁多API Key机制会自动切换{min(retry_interval, 5)}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求过于频繁所有API Key都被限制放弃请求"
),
)
elif e.status_code >= 500:
# 服务器错误
return _check_retry(
remain_try,
retry_interval,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"服务器错误,将于{retry_interval}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"服务器错误,超过最大重试次数,请稍后再试"
),
)
else:
# 未知错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"未知错误,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None
def default_exception_handler(
e: Exception,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
) -> tuple[int, list[Message] | None]:
"""
默认异常处理函数
:param e: 异常对象
:param task_name: 任务名称
:param model_name: 模型名称
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param messages: (消息列表, 是否已压缩过)
:return (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if isinstance(e, NetworkConnectionError): # 网络连接错误
# 网络错误可能是某个API Key的端点问题给多API Key机制一次快速重试机会
return _check_retry(
remain_try,
min(retry_interval, 3), # 网络错误时减少等待时间让API Key切换更快
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"连接异常多API Key机制会尝试其他Key{min(retry_interval, 3)}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"连接异常超过最大重试次数请检查网络连接状态或URL是否正确"
),
)
elif isinstance(e, ReqAbortException):
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n请求被中断,详细信息-{str(e.message)}"
)
return -1, None # 不再重试请求该模型
elif isinstance(e, RespNotOkException):
return _handle_resp_not_ok(
e,
task_name,
model_name,
remain_try,
retry_interval,
messages,
)
elif isinstance(e, RespParseException):
# 响应解析错误
logger.error(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"响应解析错误,错误信息-{e.message}\n"
)
logger.debug(f"附加内容:\n{str(e.ext_info)}")
return -1, None # 不再重试请求该模型
else:
logger.error(
f"任务-'{task_name}' 模型-'{model_name}'\n未知异常,错误信息-{str(e)}"
)
return -1, None # 不再重试请求该模型
class ModelRequestHandler:
"""
模型请求处理器
"""
def __init__(
self,
task_name: str,
config: ModuleConfig,
api_client_map: dict[str, BaseClient],
):
self.task_name: str = task_name
"""任务名称"""
self.client_map: dict[str, BaseClient] = {}
"""API客户端列表"""
self.configs: list[tuple[ModelInfo, ModelUsageArgConfigItem]] = []
"""模型参数配置"""
self.req_conf: RequestConfig = config.req_conf
"""请求配置"""
# 获取模型与使用配置
for model_usage in config.task_model_arg_map[task_name].usage:
if model_usage.name not in config.models:
logger.error(f"Model '{model_usage.name}' not found in ModelManager")
raise KeyError(f"Model '{model_usage.name}' not found in ModelManager")
model_info = config.models[model_usage.name]
if model_info.api_provider not in self.client_map:
# 缓存API客户端
self.client_map[model_info.api_provider] = api_client_map[
model_info.api_provider
]
self.configs.append((model_info, model_usage)) # 添加模型与使用配置
async def get_response(
self,
messages: list[Message],
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None, # 暂不启用
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[ChatCompletion], APIResponse] | None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param messages: 消息列表
:param tool_options: 工具选项列表
:param response_format: 响应格式
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: APIResponse
"""
# 遍历可用模型,若获取响应失败,则使用下一个模型继续请求
for config_item in self.configs:
client = self.client_map[config_item[0].api_provider]
model_info: ModelInfo = config_item[0]
model_usage_config: ModelUsageArgConfigItem = config_item[1]
remain_try = (
model_usage_config.max_retry or self.req_conf.max_retry
) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1
compressed_messages = None
retry_interval = self.req_conf.retry_interval
while remain_try > 0:
try:
return await client.get_response(
model_info,
message_list=(compressed_messages or messages),
tool_options=tool_options,
max_tokens=model_usage_config.max_tokens
or self.req_conf.default_max_tokens,
temperature=model_usage_config.temperature
or self.req_conf.default_temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
interrupt_flag=interrupt_flag,
)
except Exception as e:
logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
handle_res = default_exception_handler(
e,
self.task_name,
model_info.name,
remain_try,
retry_interval=self.req_conf.retry_interval,
messages=(messages, compressed_messages is not None),
)
if handle_res[0] == -1:
# 等待间隔为-1表示不再请求该模型
remain_try = 0
elif handle_res[0] != 0:
# 等待间隔不为0表示需要等待
await asyncio.sleep(handle_res[0])
retry_interval *= 2
if handle_res[1] is not None:
# 压缩消息
compressed_messages = handle_res[1]
logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用")
raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败
async def get_embedding(
self,
embedding_input: str,
) -> APIResponse:
"""
获取嵌入向量
:param embedding_input: 嵌入输入
:return: APIResponse
"""
for config in self.configs:
client = self.client_map[config[0].api_provider]
model_info: ModelInfo = config[0]
model_usage_config: ModelUsageArgConfigItem = config[1]
remain_try = (
model_usage_config.max_retry or self.req_conf.max_retry
) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1
while remain_try:
try:
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
)
except Exception as e:
logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
handle_res = default_exception_handler(
e,
self.task_name,
model_info.name,
remain_try,
retry_interval=self.req_conf.retry_interval,
)
if handle_res[0] == -1:
# 等待间隔为-1表示不再请求该模型
remain_try = 0
elif handle_res[0] != 0:
# 等待间隔不为0表示需要等待
await asyncio.sleep(handle_res[0])
logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用")
raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败

View File

@@ -0,0 +1,116 @@
import asyncio
from dataclasses import dataclass
from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
@dataclass
class UsageRecord:
"""
使用记录类
"""
model_name: str
"""模型名称"""
provider_name: str
"""提供商名称"""
prompt_tokens: int
"""提示token数"""
completion_tokens: int
"""完成token数"""
total_tokens: int
"""总token数"""
@dataclass
class APIResponse:
"""
API响应类
"""
content: str | None = None
"""响应内容"""
reasoning_content: str | None = None
"""推理内容"""
tool_calls: list[ToolCall] | None = None
"""工具调用 [(工具名称, 工具参数), ...]"""
embedding: list[float] | None = None
"""嵌入向量"""
usage: UsageRecord | None = None
"""使用情况 (prompt_tokens, completion_tokens, total_tokens)"""
raw_data: Any = None
"""响应原始数据"""
class BaseClient:
"""
基础客户端
"""
api_provider: APIProvider
def __init__(self, api_provider: APIProvider):
self.api_provider = api_provider
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
raise RuntimeError("This method should be overridden in subclasses")
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
raise RuntimeError("This method should be overridden in subclasses")

View File

@@ -0,0 +1,573 @@
import asyncio
import io
from collections.abc import Iterable
from typing import Callable, Iterator, TypeVar, AsyncIterator
from google import genai
from google.genai import types
from google.genai.types import FunctionDeclaration, GenerateContentResponse
from google.genai.errors import (
ClientError,
ServerError,
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
)
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from src.common.logger import get_logger
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端")
T = TypeVar("T")
def _convert_messages(
messages: list[Message],
) -> tuple[list[types.Content], list[str] | None]:
"""
转换消息格式 - 将消息转换为Gemini API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表(和可能存在的system消息)
"""
def _convert_message_item(message: Message) -> types.Content:
"""
转换单个消息格式除了system和tool类型的消息
:param message: 消息对象
:return: 转换后的消息字典
"""
# 将openai格式的角色重命名为gemini格式的角色
if message.role == RoleType.Assistant:
role = "model"
elif message.role == RoleType.User:
role = "user"
# 添加Content
content: types.Part | list
if isinstance(message.content, str):
content = types.Part.from_text(message.content)
elif isinstance(message.content, list):
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
types.Part.from_bytes(
data=item[1], mime_type=f"image/{item[0].lower()}"
)
)
elif isinstance(item, str):
content.append(types.Part.from_text(item))
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return types.Content(role=role, content=content)
temp_list: list[types.Content] = []
system_instructions: list[str] = []
for message in messages:
if message.role == RoleType.System:
if isinstance(message.content, str):
system_instructions.append(message.content)
else:
raise RuntimeError("你tm怎么往system里面塞图片base64")
elif message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
else:
temp_list.append(_convert_message_item(message))
if system_instructions:
# 如果有system消息就把它加上去
ret: tuple = (temp_list, system_instructions)
else:
# 如果没有system消息就直接返回
ret: tuple = (temp_list, None)
return ret
def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]:
"""
转换工具选项格式 - 将工具选项转换为Gemini API所需的格式
:param tool_options: 工具选项列表
:return: 转换后的工具对象列表
"""
def _convert_tool_param(tool_option_param: ToolParam) -> dict:
"""
转换单个工具参数格式
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
return {
"type": tool_option_param.param_type.value,
"description": tool_option_param.description,
}
def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration:
"""
转换单个工具项格式
:param tool_option: 工具选项对象
:return: 转换后的Gemini工具选项对象
"""
ret = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {
param.name: _convert_tool_param(param)
for param in tool_option.params
},
"required": [
param.name for param in tool_option.params if param.required
],
}
ret1 = types.FunctionDeclaration(**ret)
return ret1
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
def _process_delta(
delta: GenerateContentResponse,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, dict]],
):
if not hasattr(delta, "candidates") or len(delta.candidates) == 0:
raise RespParseException(delta, "响应解析失败缺失candidates字段")
if delta.text:
fc_delta_buffer.write(delta.text)
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
for call in delta.function_calls:
try:
if not isinstance(
call.args, dict
): # gemini返回的function call参数就是dict格式的了
raise RespParseException(
delta, "响应解析失败,工具调用参数无法解析为字典类型"
)
tool_calls_buffer.append(
(
call.id,
call.name,
call.args,
)
)
except Exception as e:
raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
) -> APIResponse:
resp = APIResponse()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
resp.content = _fc_delta_buffer.getvalue()
_fc_delta_buffer.close()
if len(_tool_calls_buffer) > 0:
# 如果工具调用缓冲区不为空则将其解析为ToolCall对象列表
resp.tool_calls = []
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
if arguments_buffer is not None:
arguments = arguments_buffer
if not isinstance(arguments, dict):
raise RespParseException(
None,
"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n"
f"{arguments_buffer}",
)
else:
arguments = None
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
return resp
async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
"""
将迭代器转换为异步迭代器
:param iterable: 迭代器对象
:return: 异步迭代器对象
"""
for item in iterable:
await asyncio.sleep(0)
yield item
async def _default_stream_response_handler(
resp_stream: Iterator[GenerateContentResponse],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, tuple[int, int, int]]:
"""
流式响应处理函数 - 处理Gemini API的流式响应
:param resp_stream: 流式响应对象,是一个神秘的iterator我完全不知道这个玩意能不能跑不过遍历一遍之后它就空了如果跑不了一点的话可以考虑改成别的东西
:return: APIResponse对象
"""
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[
tuple[str, str, dict]
] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
async for chunk in _to_async_iterable(resp_stream):
# 检查是否有中断量
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
raise ReqAbortException("请求被外部信号中断")
_process_delta(
chunk,
_fc_delta_buffer,
_tool_calls_buffer,
)
if chunk.usage_metadata:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
chunk.usage_metadata.prompt_token_count,
chunk.usage_metadata.candidates_token_count
+ chunk.usage_metadata.thoughts_token_count,
chunk.usage_metadata.total_token_count,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_tool_calls_buffer,
), _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise
def _default_normal_response_parser(
resp: GenerateContentResponse,
) -> tuple[APIResponse, tuple[int, int, int]]:
"""
解析对话补全响应 - 将Gemini API响应解析为APIResponse对象
:param resp: 响应对象
:return: APIResponse对象
"""
api_response = APIResponse()
if not hasattr(resp, "candidates") or len(resp.candidates) == 0:
raise RespParseException(resp, "响应解析失败缺失candidates字段")
if resp.text:
api_response.content = resp.text
if resp.function_calls:
api_response.tool_calls = []
for call in resp.function_calls:
try:
if not isinstance(call.args, dict):
raise RespParseException(
resp, "响应解析失败,工具调用参数无法解析为字典类型"
)
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args))
except Exception as e:
raise RespParseException(
resp, "响应解析失败,无法解析工具调用参数"
) from e
if resp.usage_metadata:
_usage_record = (
resp.usage_metadata.prompt_token_count,
resp.usage_metadata.candidates_token_count
+ resp.usage_metadata.thoughts_token_count,
resp.usage_metadata.total_token_count,
)
else:
_usage_record = None
api_response.raw_data = resp
return api_response, _usage_record
class GeminiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
# 不再在初始化时创建固定的client而是在请求时动态创建
self._clients_cache = {} # API Key -> genai.Client 的缓存
def _get_client(self, api_key: str = None) -> genai.Client:
"""获取或创建对应API Key的客户端"""
if api_key is None:
api_key = self.api_provider.get_current_api_key()
if not api_key:
raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key")
# 使用缓存避免重复创建客户端
if api_key not in self._clients_cache:
self._clients_cache[api_key] = genai.Client(api_key=api_key)
return self._clients_cache[api_key]
async def _execute_with_fallback(self, func, *args, **kwargs):
"""执行请求并在失败时切换API Key"""
current_api_key = self.api_provider.get_current_api_key()
max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1
for attempt in range(max_attempts):
try:
client = self._get_client(current_api_key)
result = await func(client, *args, **kwargs)
# 成功时重置失败计数
self.api_provider.reset_key_failures(current_api_key)
return result
except (ClientError, ServerError) as e:
# 记录失败并尝试下一个API Key
logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}")
if attempt < max_attempts - 1: # 还有重试机会
next_api_key = self.api_provider.mark_key_failed(current_api_key)
if next_api_key and next_api_key != current_api_key:
current_api_key = next_api_key
logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}")
continue
# 所有API Key都失败了重新抛出异常
raise RespNotOkException(e.status_code, e.message) from e
except Exception as e:
# 其他异常直接抛出
raise e
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
thinking_budget: int = 0,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[GenerateContentResponse], APIResponse]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param thinking_budget: 思考预算可选默认为0
:param response_format: 响应格式默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的暂不支持其它相应格式输入
:param stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
:param async_response_parser: 响应解析函数可选默认为default_response_parser
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
return await self._execute_with_fallback(
self._get_response_internal,
model_info,
message_list,
tool_options,
max_tokens,
temperature,
thinking_budget,
response_format,
stream_response_handler,
async_response_parser,
interrupt_flag,
)
async def _get_response_internal(
self,
client: genai.Client,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
thinking_budget: int = 0,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[GenerateContentResponse], APIResponse]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""内部方法执行实际的API调用"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
if async_response_parser is None:
async_response_parser = _default_normal_response_parser
# 将messages构造为Gemini API所需的格式
messages = _convert_messages(message_list)
# 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
"max_output_tokens": max_tokens,
"temperature": temperature,
"response_modalities": ["TEXT"], # 暂时只支持文本输出
}
if "2.5" in model_info.model_identifier.lower():
# 我偷个懒在这里识别一下2.5然后开摆反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容
generation_config_dict["thinking_config"] = types.ThinkingConfig(
thinking_budget=thinking_budget, include_thoughts=False
)
if tools:
generation_config_dict["tools"] = types.Tool(tools)
if messages[1]:
# 如果有system消息则将其添加到配置中
generation_config_dict["system_instructions"] = messages[1]
if response_format and response_format.format_type == RespFormatType.TEXT:
generation_config_dict["response_mime_type"] = "text/plain"
elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict()
generation_config = types.GenerateContentConfig(**generation_config_dict)
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
client.aio.models.generate_content_stream(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(
req_task.result(), interrupt_flag
)
else:
req_task = asyncio.create_task(
client.aio.models.generate_content(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from e
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from e
except Exception as e:
raise NetworkConnectionError() from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
return await self._execute_with_fallback(
self._get_embedding_internal,
model_info,
embedding_input,
)
async def _get_embedding_internal(
self,
client: genai.Client,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""内部方法执行实际的嵌入API调用"""
try:
raw_response: types.EmbedContentResponse = (
await client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code) from e
except Exception as e:
raise NetworkConnectionError() from e
response = APIResponse()
# 解析嵌入响应和使用情况
if hasattr(raw_response, "embeddings"):
response.embedding = raw_response.embeddings[0].values
else:
raise RespParseException(raw_response, "响应解析失败缺失embeddings字段")
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=len(embedding_input),
completion_tokens=0,
total_tokens=len(embedding_input),
)
return response

View File

@@ -0,0 +1,647 @@
import asyncio
import io
import json
import re
from collections.abc import Iterable
from typing import Callable, Any
from openai import (
AsyncOpenAI,
APIConnectionError,
APIStatusError,
NOT_GIVEN,
AsyncStream,
)
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from src.common.logger import get_logger
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("OpenAI客户端")
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
"""
转换消息格式 - 将消息转换为OpenAI API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表
"""
def _convert_message_item(message: Message) -> ChatCompletionMessageParam:
"""
转换单个消息格式
:param message: 消息对象
:return: 转换后的消息字典
"""
# 添加Content
content: str | list[dict[str, Any]]
if isinstance(message.content, str):
content = message.content
elif isinstance(message.content, list):
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/{item[0].lower()};base64,{item[1]}"
},
}
)
elif isinstance(item, str):
content.append({"type": "text", "text": item})
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret = {
"role": message.role.value,
"content": content,
}
# 添加工具调用ID
if message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret["tool_call_id"] = message.tool_call_id
return ret
return [_convert_message_item(message) for message in messages]
def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]:
"""
转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式
:param tool_options: 工具选项列表
:return: 转换后的工具选项列表
"""
def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, str]:
"""
转换单个工具参数格式
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
return {
"type": tool_option_param.param_type.value,
"description": tool_option_param.description,
}
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
"""
转换单个工具项格式
:param tool_option: 工具选项对象
:return: 转换后的工具选项字典
"""
ret: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {
param.name: _convert_tool_param(param)
for param in tool_option.params
},
"required": [
param.name for param in tool_option.params if param.required
],
}
return ret
return [
{
"type": "function",
"function": _convert_tool_option_item(tool_option),
}
for tool_option in tool_options
]
def _process_delta(
delta: ChoiceDelta,
has_rc_attr_flag: bool,
in_rc_flag: bool,
rc_delta_buffer: io.StringIO,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, io.StringIO]],
) -> bool:
# 接收content
if has_rc_attr_flag:
# 有独立的推理内容块则无需考虑content内容的判读
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
# 如果有推理内容,则将其写入推理内容缓冲区
assert isinstance(delta.reasoning_content, str)
rc_delta_buffer.write(delta.reasoning_content)
elif delta.content:
# 如果有正式内容,则将其写入正式内容缓冲区
fc_delta_buffer.write(delta.content)
elif hasattr(delta, "content") and delta.content is not None:
# 没有独立的推理内容块,但有正式内容
if in_rc_flag:
# 当前在推理内容块中
if delta.content == "</think>":
# 如果当前内容是</think>,则将其视为推理内容的结束标记,退出推理内容块
in_rc_flag = False
else:
# 其他情况视为推理内容,加入推理内容缓冲区
rc_delta_buffer.write(delta.content)
elif delta.content == "<think>" and not fc_delta_buffer.getvalue():
# 如果当前内容是<think>,且正式内容缓冲区为空,说明<think>为输出的首个token
# 则将其视为推理内容的开始标记,进入推理内容块
in_rc_flag = True
else:
# 其他情况视为正式内容,加入正式内容缓冲区
fc_delta_buffer.write(delta.content)
# 接收tool_calls
if hasattr(delta, "tool_calls") and delta.tool_calls:
tool_call_delta = delta.tool_calls[0]
if tool_call_delta.index >= len(tool_calls_buffer):
# 调用索引号大于等于缓冲区长度,说明是新的工具调用
tool_calls_buffer.append(
(
tool_call_delta.id,
tool_call_delta.function.name,
io.StringIO(),
)
)
if tool_call_delta.function.arguments:
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
tool_calls_buffer[tool_call_delta.index][2].write(
tool_call_delta.function.arguments
)
return in_rc_flag
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_rc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, io.StringIO]],
) -> APIResponse:
resp = APIResponse()
if _rc_delta_buffer.tell() > 0:
# 如果推理内容缓冲区不为空则将其写入APIResponse对象
resp.reasoning_content = _rc_delta_buffer.getvalue()
_rc_delta_buffer.close()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
resp.content = _fc_delta_buffer.getvalue()
_fc_delta_buffer.close()
if _tool_calls_buffer:
# 如果工具调用缓冲区不为空则将其解析为ToolCall对象列表
resp.tool_calls = []
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
if arguments_buffer.tell() > 0:
# 如果参数串缓冲区不为空则解析为JSON对象
raw_arg_data = arguments_buffer.getvalue()
arguments_buffer.close()
try:
arguments = json.loads(raw_arg_data)
if not isinstance(arguments, dict):
raise RespParseException(
None,
"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n"
f"{raw_arg_data}",
)
except json.JSONDecodeError as e:
raise RespParseException(
None,
"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:"
f"{raw_arg_data}",
) from e
else:
arguments_buffer.close()
arguments = None
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
return resp
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, tuple[int, int, int]]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
:return: APIResponse对象
"""
_has_rc_attr_flag = False # 标记是否有独立的推理内容块
_in_rc_flag = False # 标记是否在推理内容块中
_rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[
tuple[str, str, io.StringIO]
] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
# 确保缓冲区被关闭
if _rc_delta_buffer and not _rc_delta_buffer.closed:
_rc_delta_buffer.close()
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
for _, _, buffer in _tool_calls_buffer:
if buffer and not buffer.closed:
buffer.close()
async for event in resp_stream:
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
_insure_buffer_closed()
raise ReqAbortException("请求被外部信号中断")
delta = event.choices[0].delta # 获取当前块的delta内容
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
# 标记:有独立的推理内容块
_has_rc_attr_flag = True
_in_rc_flag = _process_delta(
delta,
_has_rc_attr_flag,
_in_rc_flag,
_rc_delta_buffer,
_fc_delta_buffer,
_tool_calls_buffer,
)
if event.usage:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
event.usage.prompt_tokens,
event.usage.completion_tokens,
event.usage.total_tokens,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_rc_delta_buffer,
_tool_calls_buffer,
), _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise
pattern = re.compile(
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
re.DOTALL,
)
"""用于解析推理内容的正则表达式"""
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, tuple[int, int, int]]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
:return: APIResponse对象
"""
api_response = APIResponse()
if not hasattr(resp, "choices") or len(resp.choices) == 0:
raise RespParseException(resp, "响应解析失败缺失choices字段")
message_part = resp.choices[0].message
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content:
# 有有效的推理字段
api_response.content = message_part.content
api_response.reasoning_content = message_part.reasoning_content
elif message_part.content:
# 提取推理和内容
match = pattern.match(message_part.content)
if not match:
raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容")
if match.group("think") is not None:
result = match.group("think").strip(), match.group("content").strip()
elif match.group("think_unclosed") is not None:
result = match.group("think_unclosed").strip(), None
else:
result = None, match.group("content_only").strip()
api_response.reasoning_content, api_response.content = result
# 提取工具调用
if message_part.tool_calls:
api_response.tool_calls = []
for call in message_part.tool_calls:
try:
arguments = json.loads(call.function.arguments)
if not isinstance(arguments, dict):
raise RespParseException(
resp, "响应解析失败,工具调用参数无法解析为字典类型"
)
api_response.tool_calls.append(
ToolCall(call.id, call.function.name, arguments)
)
except json.JSONDecodeError as e:
raise RespParseException(
resp, "响应解析失败,无法解析工具调用参数"
) from e
# 提取Usage信息
if resp.usage:
_usage_record = (
resp.usage.prompt_tokens,
resp.usage.completion_tokens,
resp.usage.total_tokens,
)
else:
_usage_record = None
# 将原始响应存储在原始数据中
api_response.raw_data = resp
return api_response, _usage_record
class OpenaiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
# 不再在初始化时创建固定的client而是在请求时动态创建
self._clients_cache = {} # API Key -> AsyncOpenAI client 的缓存
def _get_client(self, api_key: str = None) -> AsyncOpenAI:
"""获取或创建对应API Key的客户端"""
if api_key is None:
api_key = self.api_provider.get_current_api_key()
if not api_key:
raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key")
# 使用缓存避免重复创建客户端
if api_key not in self._clients_cache:
self._clients_cache[api_key] = AsyncOpenAI(
base_url=self.api_provider.base_url,
api_key=api_key,
max_retries=0,
)
return self._clients_cache[api_key]
async def _execute_with_fallback(self, func, *args, **kwargs):
"""执行请求并在失败时切换API Key"""
current_api_key = self.api_provider.get_current_api_key()
max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1
for attempt in range(max_attempts):
try:
client = self._get_client(current_api_key)
result = await func(client, *args, **kwargs)
# 成功时重置失败计数
self.api_provider.reset_key_failures(current_api_key)
return result
except (APIStatusError, APIConnectionError) as e:
# 记录失败并尝试下一个API Key
logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}")
if attempt < max_attempts - 1: # 还有重试机会
next_api_key = self.api_provider.mark_key_failed(current_api_key)
if next_api_key and next_api_key != current_api_key:
current_api_key = next_api_key
logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}")
continue
# 所有API Key都失败了重新抛出异常
if isinstance(e, APIStatusError):
raise RespNotOkException(e.status_code, e.message) from e
elif isinstance(e, APIConnectionError):
raise NetworkConnectionError(str(e)) from e
except Exception as e:
# 其他异常直接抛出
raise e
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
:param async_response_parser: 响应解析函数可选默认为default_response_parser
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
return await self._execute_with_fallback(
self._get_response_internal,
model_info,
message_list,
tool_options,
max_tokens,
temperature,
response_format,
stream_response_handler,
async_response_parser,
interrupt_flag,
)
async def _get_response_internal(
self,
client: AsyncOpenAI,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""内部方法执行实际的API调用"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
if async_response_parser is None:
async_response_parser = _default_normal_response_parser
# 将messages构造为OpenAI API所需的格式
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
# 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = (
_convert_tool_options(tool_options) if tool_options else NOT_GIVEN
)
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
response_format=response_format.to_dict()
if response_format
else NOT_GIVEN,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(
req_task.result(), interrupt_flag
)
else:
# 发送请求并获取响应
req_task = asyncio.create_task(
client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format=response_format.to_dict()
if response_format
else NOT_GIVEN,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result())
except APIConnectionError as e:
# 重封装APIConnectionError为NetworkConnectionError
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
return await self._execute_with_fallback(
self._get_embedding_internal,
model_info,
embedding_input,
)
async def _get_embedding_internal(
self,
client: AsyncOpenAI,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""内部方法执行实际的嵌入API调用"""
try:
raw_response = await client.embeddings.create(
model=model_info.model_identifier,
input=embedding_input,
)
except APIConnectionError as e:
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
response = APIResponse()
# 解析嵌入响应
if len(raw_response.data) > 0:
response.embedding = raw_response.data[0].embedding
else:
raise RespParseException(
raw_response,
"响应解析失败,缺失嵌入数据。",
)
# 解析使用情况
if hasattr(raw_response, "usage"):
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=raw_response.usage.prompt_tokens,
completion_tokens=raw_response.usage.completion_tokens,
total_tokens=raw_response.usage.total_tokens,
)
return response

View File

@@ -0,0 +1,92 @@
import importlib
from typing import Dict
from src.config.config import model_config
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
from src.common.logger import get_logger
from .model_client import ModelRequestHandler, BaseClient
logger = get_logger("模型管理器")
class ModelManager:
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
def __init__(
self,
config: ModuleConfig,
):
self.config: ModuleConfig = config
"""配置信息"""
self.api_client_map: Dict[str, BaseClient] = {}
"""API客户端映射表"""
self._request_handler_cache: Dict[str, ModelRequestHandler] = {}
"""ModelRequestHandler缓存避免重复创建"""
for provider_name, api_provider in self.config.api_providers.items():
# 初始化API客户端
try:
# 根据配置动态加载实现
client_module = importlib.import_module(
f".model_client.{api_provider.client_type}_client", __package__
)
client_class = getattr(
client_module, f"{api_provider.client_type.capitalize()}Client"
)
if not issubclass(client_class, BaseClient):
raise TypeError(
f"'{client_class.__name__}' is not a subclass of 'BaseClient'"
)
self.api_client_map[api_provider.name] = client_class(
api_provider
) # 实例化放入api_client_map
except ImportError as e:
logger.error(f"Failed to import client module: {e}")
raise ImportError(
f"Failed to import client module for '{provider_name}': {e}"
) from e
def __getitem__(self, task_name: str) -> ModelRequestHandler:
"""
获取任务所需的模型客户端(封装)
使用缓存机制避免重复创建ModelRequestHandler
:param task_name: 任务名称
:return: 模型客户端
"""
if task_name not in self.config.task_model_arg_map:
raise KeyError(f"'{task_name}' not registered in ModelManager")
# 检查缓存中是否已存在
if task_name in self._request_handler_cache:
logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}")
return self._request_handler_cache[task_name]
# 创建新的ModelRequestHandler并缓存
logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}")
handler = ModelRequestHandler(
task_name=task_name,
config=self.config,
api_client_map=self.api_client_map,
)
self._request_handler_cache[task_name] = handler
return handler
def __setitem__(self, task_name: str, value: ModelUsageArgConfig):
"""
注册任务的模型使用配置
:param task_name: 任务名称
:param value: 模型使用配置
"""
self.config.task_model_arg_map[task_name] = value
def __contains__(self, task_name: str):
"""
判断任务是否已注册
:param task_name: 任务名称
:return: 是否在模型列表中
"""
return task_name in self.config.task_model_arg_map

View File

@@ -0,0 +1,104 @@
from enum import Enum
# 设计这系列类的目的是为未来可能的扩展做准备
class RoleType(Enum):
System = "system"
User = "user"
Assistant = "assistant"
Tool = "tool"
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"]
class Message:
def __init__(
self,
role: RoleType,
content: str | list[tuple[str, str] | str],
tool_call_id: str | None = None,
):
"""
初始化消息对象
不应直接修改Message类而应使用MessageBuilder类来构建对象
"""
self.role: RoleType = role
self.content: str | list[tuple[str, str] | str] = content
self.tool_call_id: str | None = tool_call_id
class MessageBuilder:
def __init__(self):
self.__role: RoleType = RoleType.User
self.__content: list[tuple[str, str] | str] = []
self.__tool_call_id: str | None = None
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
"""
设置角色默认为User
:param role: 角色
:return: MessageBuilder对象
"""
self.__role = role
return self
def add_text_content(self, text: str) -> "MessageBuilder":
"""
添加文本内容
:param text: 文本内容
:return: MessageBuilder对象
"""
self.__content.append(text)
return self
def add_image_content(
self, image_format: str, image_base64: str
) -> "MessageBuilder":
"""
添加图片内容
:param image_format: 图片格式
:param image_base64: 图片的base64编码
:return: MessageBuilder对象
"""
if image_format.lower() not in SUPPORTED_IMAGE_FORMATS:
raise ValueError("不受支持的图片格式")
if not image_base64:
raise ValueError("图片的base64编码不能为空")
self.__content.append((image_format, image_base64))
return self
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
"""
添加工具调用指令调用时请确保已设置为Tool角色
:param tool_call_id: 工具调用指令的id
:return: MessageBuilder对象
"""
if self.__role != RoleType.Tool:
raise ValueError("仅当角色为Tool时才能添加工具调用ID")
if not tool_call_id:
raise ValueError("工具调用ID不能为空")
self.__tool_call_id = tool_call_id
return self
def build(self) -> Message:
"""
构建消息对象
:return: Message对象
"""
if len(self.__content) == 0:
raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空")
return Message(
role=self.__role,
content=(
self.__content[0]
if (len(self.__content) == 1 and isinstance(self.__content[0], str))
else self.__content
),
tool_call_id=self.__tool_call_id,
)

View File

@@ -0,0 +1,223 @@
from enum import Enum
from typing import Optional, Any
from pydantic import BaseModel
from typing_extensions import TypedDict, Required
class RespFormatType(Enum):
TEXT = "text" # 文本
JSON_OBJ = "json_object" # JSON
JSON_SCHEMA = "json_schema" # JSON Schema
class JsonSchema(TypedDict, total=False):
name: Required[str]
"""
The name of the response format.
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
of 64.
"""
description: Optional[str]
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
"""
schema: dict[str, object]
"""
The schema for the response format, described as a JSON Schema object. Learn how
to build JSON schemas [here](https://json-schema.org/).
"""
strict: Optional[bool]
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
field. Only a subset of JSON Schema is supported when `strict` is `true`. To
learn more, read the
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
"""
def _json_schema_type_check(instance) -> str | None:
if "name" not in instance:
return "schema必须包含'name'字段"
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串"
if "description" in instance and (
not isinstance(instance["description"], str)
or instance["description"].strip() == ""
):
return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance:
return "schema必须包含'schema'字段"
elif not isinstance(instance["schema"], dict):
return "schema的'schema'字段必须是字典详见https://json-schema.org/"
if "strict" in instance and not isinstance(instance["strict"], bool):
return "schema的'strict'字段只能填入布尔值"
return None
def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
"""
递归移除JSON Schema中的title字段
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "title" in schema:
del schema["title"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
return schema
def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
"""
链接JSON Schema中的definitions字段
"""
def link_definitions_recursive(
path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
) -> dict[str, Any]:
"""
递归链接JSON Schema中的definitions字段
:param path: 当前路径
:param sub_schema: 子Schema
:param defs: Schema定义集
:return:
"""
if isinstance(sub_schema, list):
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(
f"{path}/{str(i)}", sub_schema[i], defs
)
else:
# 否则为字典
if "$defs" in sub_schema:
# 如果当前Schema有$def字段则将其添加到defs中
key_prefix = f"{path}/$defs/"
for key, value in sub_schema["$defs"].items():
def_key = key_prefix + key
if def_key not in defs:
defs[def_key] = value
del sub_schema["$defs"]
if "$ref" in sub_schema:
# 如果当前Schema有$ref字段则将其替换为defs中的定义
def_key = sub_schema["$ref"]
if def_key in defs:
sub_schema = defs[def_key]
else:
raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
# 遍历键值对
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive(
f"{path}/{key}", value, defs
)
return sub_schema
return link_definitions_recursive("#", schema, {})
def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
"""
递归移除JSON Schema中的$defs字段
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "$defs" in schema:
del schema["$defs"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
return schema
class RespFormat:
"""
响应格式
"""
@staticmethod
def _generate_schema_from_model(schema):
json_schema = {
"name": schema.__name__,
"schema": _remove_defs(
_link_definitions(_remove_title(schema.model_json_schema()))
),
"strict": False,
}
if schema.__doc__:
json_schema["description"] = schema.__doc__
return json_schema
def __init__(
self,
format_type: RespFormatType = RespFormatType.TEXT,
schema: type | JsonSchema | None = None,
):
"""
响应格式
:param format_type: 响应格式类型(默认为文本)
:param schema: 模板类或JsonSchema仅当format_type为JSON Schema时有效
"""
self.format_type: RespFormatType = format_type
if format_type == RespFormatType.JSON_SCHEMA:
if schema is None:
raise ValueError("当format_type为'JSON_SCHEMA'schema不能为空")
if isinstance(schema, dict):
if check_msg := _json_schema_type_check(schema):
raise ValueError(f"schema格式不正确{check_msg}")
self.schema = schema
elif issubclass(schema, BaseModel):
try:
json_schema = self._generate_schema_from_model(schema)
self.schema = json_schema
except Exception as e:
raise ValueError(
f"自动生成JSON Schema时发生异常请检查模型类{schema.__name__}的定义,详细信息:\n"
f"{schema.__name__}:\n"
) from e
else:
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
else:
self.schema = None
def to_dict(self):
"""
将响应格式转换为字典
:return: 字典
"""
if self.schema:
return {
"format_type": self.format_type.value,
"schema": self.schema,
}
else:
return {
"format_type": self.format_type.value,
}

View File

@@ -0,0 +1,155 @@
from enum import Enum
class ToolParamType(Enum):
"""
工具调用参数类型
"""
String = "string" # 字符串
Int = "integer" # 整型
Float = "float" # 浮点型
Boolean = "bool" # 布尔型
class ToolParam:
"""
工具调用参数
"""
def __init__(
self, name: str, param_type: ToolParamType, description: str, required: bool
):
"""
初始化工具调用参数
不应直接修改ToolParam类而应使用ToolOptionBuilder类来构建对象
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填
"""
self.name: str = name
self.param_type: ToolParamType = param_type
self.description: str = description
self.required: bool = required
class ToolOption:
"""
工具调用项
"""
def __init__(
self,
name: str,
description: str,
params: list[ToolParam] | None = None,
):
"""
初始化工具调用项
不应直接修改ToolOption类而应使用ToolOptionBuilder类来构建对象
:param name: 工具名称
:param description: 工具描述
:param params: 工具参数列表
"""
self.name: str = name
self.description: str = description
self.params: list[ToolParam] | None = params
class ToolOptionBuilder:
"""
工具调用项构建器
"""
def __init__(self):
self.__name: str = ""
self.__description: str = ""
self.__params: list[ToolParam] = []
def set_name(self, name: str) -> "ToolOptionBuilder":
"""
设置工具名称
:param name: 工具名称
:return: ToolBuilder实例
"""
if not name:
raise ValueError("工具名称不能为空")
self.__name = name
return self
def set_description(self, description: str) -> "ToolOptionBuilder":
"""
设置工具描述
:param description: 工具描述
:return: ToolBuilder实例
"""
if not description:
raise ValueError("工具描述不能为空")
self.__description = description
return self
def add_param(
self,
name: str,
param_type: ToolParamType,
description: str,
required: bool = False,
) -> "ToolOptionBuilder":
"""
添加工具参数
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填默认为False
:return: ToolBuilder实例
"""
if not name or not description:
raise ValueError("参数名称/描述不能为空")
self.__params.append(
ToolParam(
name=name,
param_type=param_type,
description=description,
required=required,
)
)
return self
def build(self):
"""
构建工具调用项
:return: 工具调用项
"""
if self.__name == "" or self.__description == "":
raise ValueError("工具名称/描述不能为空")
return ToolOption(
name=self.__name,
description=self.__description,
params=None if len(self.__params) == 0 else self.__params,
)
class ToolCall:
"""
来自模型反馈的工具调用
"""
def __init__(
self,
call_id: str,
func_name: str,
args: dict | None = None,
):
"""
初始化工具调用
:param call_id: 工具调用ID
:param func_name: 要调用的函数名称
:param args: 工具调用参数
"""
self.call_id: str = call_id
self.func_name: str = func_name
self.args: dict | None = args

View File

@@ -0,0 +1,172 @@
from datetime import datetime
from enum import Enum
from typing import Tuple
from src.common.logger import get_logger
from src.config.api_ada_configs import ModelInfo
from src.common.database.database_model import LLMUsage
logger = get_logger("模型使用统计")
class ReqType(Enum):
"""
请求类型
"""
CHAT = "chat" # 对话请求
EMBEDDING = "embedding" # 嵌入请求
class UsageCallStatus(Enum):
"""
任务调用状态
"""
PROCESSING = "processing" # 处理中
SUCCESS = "success" # 成功
FAILURE = "failure" # 失败
CANCELED = "canceled" # 取消
class ModelUsageStatistic:
"""
模型使用统计类 - 使用SQLite+Peewee
"""
def __init__(self):
"""
初始化统计类
由于使用Peewee ORM不需要传入数据库实例
"""
# 确保表已经创建
try:
from src.common.database.database import db
db.create_tables([LLMUsage], safe=True)
except Exception as e:
logger.error(f"创建LLMUsage表失败: {e}")
@staticmethod
def _calculate_cost(
prompt_tokens: int, completion_tokens: int, model_info: ModelInfo
) -> float:
"""计算API调用成本
使用模型的pri_in和pri_out价格计算输入和输出的成本
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
model_info: 模型信息
Returns:
float: 总成本(元)
"""
# 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * model_info.price_in
output_cost = (completion_tokens / 1000000) * model_info.price_out
return round(input_cost + output_cost, 6)
def create_usage(
self,
model_name: str,
task_name: str = "N/A",
request_type: ReqType = ReqType.CHAT,
user_id: str = "system",
endpoint: str = "/chat/completions",
) -> int | None:
"""
创建模型使用情况记录
Args:
model_name: 模型名
task_name: 任务名称
request_type: 请求类型默认为Chat
user_id: 用户ID默认为system
endpoint: API端点
Returns:
int | None: 返回记录ID失败返回None
"""
try:
usage_record = LLMUsage.create(
model_name=model_name,
user_id=user_id,
request_type=request_type.value,
endpoint=endpoint,
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost=0.0,
status=UsageCallStatus.PROCESSING.value,
timestamp=datetime.now(),
)
logger.trace(
f"创建了一条模型使用情况记录 - 模型: {model_name}, "
f"子任务: {task_name}, 类型: {request_type.value}, "
f"用户: {user_id}, 记录ID: {usage_record.id}"
)
return usage_record.id
except Exception as e:
logger.error(f"创建模型使用情况记录失败: {str(e)}")
return None
def update_usage(
self,
record_id: int | None,
model_info: ModelInfo,
usage_data: Tuple[int, int, int] | None = None,
stat: UsageCallStatus = UsageCallStatus.SUCCESS,
ext_msg: str | None = None,
):
"""
更新模型使用情况
Args:
record_id: 记录ID
model_info: 模型信息
usage_data: 使用情况数据(输入token数量, 输出token数量, 总token数量)
stat: 任务调用状态
ext_msg: 额外信息
"""
if not record_id:
logger.error("更新模型使用情况失败: record_id不能为空")
return
if usage_data and len(usage_data) != 3:
logger.error("更新模型使用情况失败: usage_data的长度不正确应该为3个元素")
return
# 提取使用情况数据
prompt_tokens = usage_data[0] if usage_data else 0
completion_tokens = usage_data[1] if usage_data else 0
total_tokens = usage_data[2] if usage_data else 0
try:
# 使用Peewee更新记录
update_query = LLMUsage.update(
status=stat.value,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(
prompt_tokens, completion_tokens, model_info
) if usage_data else 0.0,
).where(LLMUsage.id == record_id)
updated_count = update_query.execute()
if updated_count == 0:
logger.warning(f"记录ID {record_id} 不存在,无法更新")
return
logger.debug(
f"Token使用情况 - 模型: {model_info.name}, "
f"记录ID: {record_id}, "
f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")

152
src/llm_models/utils.py Normal file
View File

@@ -0,0 +1,152 @@
import base64
import io
from PIL import Image
from src.common.logger import get_logger
from .payload_content.message import Message, MessageBuilder
logger = get_logger("消息压缩工具")
def compress_messages(
messages: list[Message], img_target_size: int = 1 * 1024 * 1024
) -> list[Message]:
"""
压缩消息列表中的图片
:param messages: 消息列表
:param img_target_size: 图片目标大小默认1MB
:return: 压缩后的消息列表
"""
def reformat_static_image(image_data: bytes) -> bytes:
"""
将静态图片转换为JPEG格式
:param image_data: 图片数据
:return: 转换后的图片数据
"""
try:
image = Image.open(image_data)
if image.format and (
image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]
):
# 静态图像转换为JPEG格式
reformated_image_data = io.BytesIO()
image.save(
reformated_image_data, format="JPEG", quality=95, optimize=True
)
image_data = reformated_image_data.getvalue()
return image_data
except Exception as e:
logger.error(f"图片转换格式失败: {str(e)}")
return image_data
def rescale_image(
image_data: bytes, scale: float
) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
"""
缩放图片
:param image_data: 图片数据
:param scale: 缩放比例
:return: 缩放后的图片数据
"""
try:
image = Image.open(image_data)
# 原始尺寸
original_size = (image.width, image.height)
# 计算新的尺寸
new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
output_buffer = io.BytesIO()
if getattr(image, "is_animated", False):
# 动态图片,处理所有帧
frames = []
new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折
for frame_idx in range(getattr(image, "n_frames", 1)):
image.seek(frame_idx)
new_frame = image.copy()
new_frame = new_frame.resize(new_size, Image.Resampling.LANCZOS)
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format="GIF",
save_all=True,
append_images=frames[1:],
optimize=True,
duration=image.info.get("duration", 100),
loop=image.info.get("loop", 0),
)
else:
# 静态图片,直接缩放保存
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
resized_image.save(
output_buffer, format="JPEG", quality=95, optimize=True
)
return output_buffer.getvalue(), original_size, new_size
except Exception as e:
logger.error(f"图片缩放失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return image_data, None, None
def compress_base64_image(
base64_data: str, target_size: int = 1 * 1024 * 1024
) -> str:
original_b64_data_size = len(base64_data) # 计算原始数据大小
image_data = base64.b64decode(base64_data)
# 先尝试转换格式为JPEG
image_data = reformat_static_image(image_data)
base64_data = base64.b64encode(image_data).decode("utf-8")
if len(base64_data) <= target_size:
# 如果转换后小于目标大小,直接返回
logger.info(
f"成功将图片转为JPEG格式编码后大小: {len(base64_data) / 1024:.1f}KB"
)
return base64_data
# 如果转换后仍然大于目标大小,进行尺寸压缩
scale = min(1.0, target_size / len(base64_data))
image_data, original_size, new_size = rescale_image(image_data, scale)
base64_data = base64.b64encode(image_data).decode("utf-8")
if original_size and new_size:
logger.info(
f"压缩图片: {original_size[0]}x{original_size[1]} -> {new_size[0]}x{new_size[1]}\n"
f"压缩前大小: {original_b64_data_size / 1024:.1f}KB, 压缩后大小: {len(base64_data) / 1024:.1f}KB"
)
return base64_data
compressed_messages = []
for message in messages:
if isinstance(message.content, list):
# 检查content如有图片则压缩
message_builder = MessageBuilder()
for content_item in message.content:
if isinstance(content_item, tuple):
# 图片,进行压缩
message_builder.add_image_content(
content_item[0],
compress_base64_image(
content_item[1], target_size=img_target_size
),
)
else:
message_builder.add_text_content(content_item)
compressed_messages.append(message_builder.build())
else:
compressed_messages.append(message)
return compressed_messages

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
[inner]
version = "4.5.0"
version = "5.0.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请在修改后将version的值进行变更
@@ -227,120 +227,84 @@ show_prompt = false # 是否显示prompt
[model]
model_max_output_length = 1024 # 模型单次返回的最大token数
model_max_output_length = 800 # 模型单次返回的最大token数
#------------必填:组件模型------------
#------------模型任务配置------------
# 所有模型名称需要对应 model_config.toml 中配置的模型名称
[model.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
#默认temp 0.2 如果你使用的是老V3或者其他模型请自己修改temp参数
temp = 0.2 #模型的温度新V3建议0.1-0.3
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2 # 模型温度新V3建议0.1-0.3
max_tokens = 800 # 最大输出token数
[model.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
# 强烈建议使用免费的小模型
name = "Qwen/Qwen3-8B"
provider = "SILICONFLOW"
pri_in = 0
pri_out = 0
temp = 0.7
model_name = "qwen3-8b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考
[model.replyer_1] # 首要回复模型,还用于表达器和表达方式学习
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
#默认temp 0.2 如果你使用的是老V3或者其他模型请自己修改temp参数
temp = 0.2 #模型的温度新V3建议0.1-0.3
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2 # 模型温度新V3建议0.1-0.3
max_tokens = 800
[model.replyer_2] # 次要回复模型
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
#默认temp 0.2 如果你使用的是老V3或者其他模型请自己修改temp参数
temp = 0.2 #模型的温度新V3建议0.1-0.3
model_name = "siliconflow-deepseek-r1" # 对应 model_config.toml 中的模型名称
temperature = 0.7 # 模型温度
max_tokens = 800
[model.planner] #决策:负责决定麦麦该做什么的模型
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2
pri_out = 8
temp = 0.3
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.3
max_tokens = 800
[model.emotion] #负责麦麦的情绪变化
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2
pri_out = 8
temp = 0.3
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.3
max_tokens = 800
[model.memory] # 记忆模型
name = "Qwen/Qwen3-30B-A3B"
provider = "SILICONFLOW"
pri_in = 0.7
pri_out = 2.8
temp = 0.7
model_name = "qwen3-30b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考
[model.vlm] # 图像识别模型
name = "Pro/Qwen/Qwen2.5-VL-7B-Instruct"
provider = "SILICONFLOW"
pri_in = 0.35
pri_out = 0.35
model_name = "qwen2.5-vl-72b" # 对应 model_config.toml 中的模型名称
max_tokens = 800
[model.voice] # 语音识别模型
name = "FunAudioLLM/SenseVoiceSmall"
provider = "SILICONFLOW"
pri_in = 0
pri_out = 0
model_name = "sensevoice-small" # 对应 model_config.toml 中的模型名称
[model.tool_use] #工具调用模型,需要使用支持工具调用的模型
name = "Qwen/Qwen3-14B"
provider = "SILICONFLOW"
pri_in = 0.5
pri_out = 2
temp = 0.7
model_name = "qwen3-14b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考qwen3 only
#嵌入模型
[model.embedding]
name = "BAAI/bge-m3"
provider = "SILICONFLOW"
pri_in = 0
pri_out = 0
model_name = "bge-m3" # 对应 model_config.toml 中的模型名称
#------------LPMM知识库模型------------
[model.lpmm_entity_extract] # 实体提取模型
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2
pri_out = 8
temp = 0.2
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2
max_tokens = 800
[model.lpmm_rdf_build] # RDF构建模型
name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW"
pri_in = 2
pri_out = 8
temp = 0.2
model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称
temperature = 0.2
max_tokens = 800
[model.lpmm_qa] # 问答模型
name = "Qwen/Qwen3-30B-A3B"
provider = "SILICONFLOW"
pri_in = 0.7
pri_out = 2.8
temp = 0.7
model_name = "deepseek-r1-distill-qwen-32b" # 对应 model_config.toml 中的模型名称
temperature = 0.7
max_tokens = 800
enable_thinking = false # 是否启用思考
[maim_message]
auth_token = [] # 认证令牌用于API验证为空则不启用验证
# 以下项目若要使用需要打开use_custom并单独配置maim_message的服务器

View File

@@ -0,0 +1,220 @@
[inner]
version = "0.2.1"
# 配置文件版本号迭代规则同bot_config.toml
#
# === 多API Key支持 ===
# 本配置文件支持为每个API服务商配置多个API Key实现以下功能
# 1. 错误自动切换当某个API Key失败时自动切换到下一个可用的Key
# 2. 负载均衡在多个可用的API Key之间循环使用避免单个Key的频率限制
# 3. 向后兼容仍然支持单个key字段的配置方式
#
# 配置方式:
# - 多Key配置使用 api_keys = ["key1", "key2", "key3"] 数组格式
# - 单Key配置使用 key = "your-key" 字符串格式(向后兼容)
#
# 错误处理机制:
# - 401/403认证错误立即切换到下一个API Key
# - 429频率限制等待后重试如果持续失败则切换Key
# - 网络错误短暂等待后重试失败则切换Key
# - 其他错误:按照正常重试机制处理
#
# === 任务类型和模型能力配置 ===
# 为了提高任务分配的准确性和可维护性,现在支持明确配置模型的任务类型和能力:
#
# task_type推荐配置:
# - 明确指定模型主要用于什么任务
# - 可选值llm_normal, llm_reasoning, vision, embedding, speech
# - 如果不配置系统会根据capabilities或模型名称自动推断不推荐
#
# capabilities推荐配置:
# - 描述模型支持的所有能力
# - 可选值text, vision, embedding, speech, tool_calling, reasoning
# - 支持多个能力的组合,如:["text", "vision"]
#
# 配置优先级:
# 1. task_type最高优先级直接指定任务类型
# 2. capabilities中等优先级根据能力推断任务类型
# 3. 模型名称关键字(最低优先级,不推荐依赖)
#
# 向后兼容:
# - 仍然支持 model_flags 字段,但建议迁移到 capabilities
# - 未配置新字段时会自动回退到基于模型名称的推断
[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释)
#max_retry = 2 # 最大重试次数单个模型API调用失败最多重试的次数
#timeout = 10 # API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位
#retry_interval = 10 # 重试间隔如果API调用失败重试的间隔时间单位
#default_temperature = 0.7 # 默认的温度如果bot_config.toml中没有设置temperature参数默认使用这个值
#default_max_tokens = 1024 # 默认的最大输出token数如果bot_config.toml中没有设置max_tokens参数默认使用这个值
[[api_providers]] # API服务提供商可以配置多个
name = "DeepSeek" # API服务商名称可随意命名在models的api-provider中需使用这个命名
base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL
# 支持多个API Key实现自动切换和负载均衡
api_keys = [ # API Key列表多个key支持错误自动切换和负载均衡
"sk-your-first-key-here",
"sk-your-second-key-here",
"sk-your-third-key-here"
]
# 向后兼容如果只有一个key也可以使用单个key字段
#key = "******" # API Key 可选默认为None
client_type = "openai" # 请求客户端(可选,默认值为"openai"使用gimini等Google系模型时请配置为"gemini"
[[api_providers]] # 特殊Google的Gimini使用特殊API与OpenAI格式不兼容需要配置client为"gemini"
name = "Google"
base_url = "https://api.google.com/v1"
# Google API同样支持多key配置
api_keys = [
"your-google-api-key-1",
"your-google-api-key-2"
]
client_type = "gemini"
[[api_providers]]
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
# 单个key的示例向后兼容
key = "******"
#
#[[api_providers]]
#name = "LocalHost"
#base_url = "https://localhost:8888"
#key = "lm-studio"
[[models]] # 模型(可以配置多个)
# 模型标识符API服务商提供的模型标识符
model_identifier = "deepseek-chat"
# 模型名称可随意命名在bot_config.toml中需使用这个命名
#可选若无该字段则将自动使用model_identifier填充
name = "deepseek-v3"
# API服务商名称对应在api_providers中配置的服务商名称
api_provider = "DeepSeek"
# 任务类型(推荐配置,明确指定模型主要用于什么任务)
# 可选值llm_normal, llm_reasoning, vision, embedding, speech
# 如果不配置系统会根据capabilities或模型名称自动推断
task_type = "llm_normal"
# 模型能力列表(推荐配置,描述模型支持的能力)
# 可选值text, vision, embedding, speech, tool_calling, reasoning
capabilities = ["text", "tool_calling"]
# 输入价格用于API调用统计单位元/兆token可选若无该字段默认值为0
price_in = 2.0
# 输出价格用于API调用统计单位元/兆token可选若无该字段默认值为0
price_out = 8.0
# 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出)
#可选若无该字段默认值为false
#force_stream_mode = true
[[models]]
model_identifier = "deepseek-reasoner"
name = "deepseek-r1"
api_provider = "DeepSeek"
# 推理模型的配置示例
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "text", "tool_calling", "reasoning",]
price_in = 4.0
price_out = 16.0
[[models]]
model_identifier = "Pro/deepseek-ai/DeepSeek-V3"
name = "siliconflow-deepseek-v3"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 2.0
price_out = 8.0
[[models]]
model_identifier = "Pro/deepseek-ai/DeepSeek-R1"
name = "siliconflow-deepseek-r1"
api_provider = "SiliconFlow"
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
price_in = 4.0
price_out = 16.0
[[models]]
model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
name = "deepseek-r1-distill-qwen-32b"
api_provider = "SiliconFlow"
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
price_in = 4.0
price_out = 16.0
[[models]]
model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text"]
price_in = 0
price_out = 0
[[models]]
model_identifier = "Qwen/Qwen3-14B"
name = "qwen3-14b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 0.5
price_out = 2.0
[[models]]
model_identifier = "Qwen/Qwen3-30B-A3B"
name = "qwen3-30b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 0.7
price_out = 2.8
[[models]]
model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct"
name = "qwen2.5-vl-72b"
api_provider = "SiliconFlow"
# 视觉模型的配置示例
task_type = "vision"
capabilities = ["vision", "text"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "vision", "text",]
price_in = 4.13
price_out = 4.13
[[models]]
model_identifier = "FunAudioLLM/SenseVoiceSmall"
name = "sensevoice-small"
api_provider = "SiliconFlow"
# 语音模型的配置示例
task_type = "speech"
capabilities = ["speech"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "audio",]
price_in = 0
price_out = 0
[[models]]
model_identifier = "BAAI/bge-m3"
name = "bge-m3"
api_provider = "SiliconFlow"
# 嵌入模型的配置示例
task_type = "embedding"
capabilities = ["text", "embedding"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "text", "embedding",]
price_in = 0
price_out = 0
[task_model_usage]
llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0}
llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0}
embedding = "siliconflow-bge-m3"
#schedule = [
# "deepseek-v3",
# "deepseek-r1",
#]

View File

@@ -0,0 +1,220 @@
[inner]
version = "0.2.1"
# 配置文件版本号迭代规则同bot_config.toml
#
# === 多API Key支持 ===
# 本配置文件支持为每个API服务商配置多个API Key实现以下功能
# 1. 错误自动切换当某个API Key失败时自动切换到下一个可用的Key
# 2. 负载均衡在多个可用的API Key之间循环使用避免单个Key的频率限制
# 3. 向后兼容仍然支持单个key字段的配置方式
#
# 配置方式:
# - 多Key配置使用 api_keys = ["key1", "key2", "key3"] 数组格式
# - 单Key配置使用 key = "your-key" 字符串格式(向后兼容)
#
# 错误处理机制:
# - 401/403认证错误立即切换到下一个API Key
# - 429频率限制等待后重试如果持续失败则切换Key
# - 网络错误短暂等待后重试失败则切换Key
# - 其他错误:按照正常重试机制处理
#
# === 任务类型和模型能力配置 ===
# 为了提高任务分配的准确性和可维护性,现在支持明确配置模型的任务类型和能力:
#
# task_type推荐配置:
# - 明确指定模型主要用于什么任务
# - 可选值llm_normal, llm_reasoning, vision, embedding, speech
# - 如果不配置系统会根据capabilities或模型名称自动推断不推荐
#
# capabilities推荐配置:
# - 描述模型支持的所有能力
# - 可选值text, vision, embedding, speech, tool_calling, reasoning
# - 支持多个能力的组合,如:["text", "vision"]
#
# 配置优先级:
# 1. task_type最高优先级直接指定任务类型
# 2. capabilities中等优先级根据能力推断任务类型
# 3. 模型名称关键字(最低优先级,不推荐依赖)
#
# 向后兼容:
# - 仍然支持 model_flags 字段,但建议迁移到 capabilities
# - 未配置新字段时会自动回退到基于模型名称的推断
[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释)
#max_retry = 2 # 最大重试次数单个模型API调用失败最多重试的次数
#timeout = 10 # API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位
#retry_interval = 10 # 重试间隔如果API调用失败重试的间隔时间单位
#default_temperature = 0.7 # 默认的温度如果bot_config.toml中没有设置temperature参数默认使用这个值
#default_max_tokens = 1024 # 默认的最大输出token数如果bot_config.toml中没有设置max_tokens参数默认使用这个值
[[api_providers]] # API服务提供商可以配置多个
name = "DeepSeek" # API服务商名称可随意命名在models的api-provider中需使用这个命名
base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL
# 支持多个API Key实现自动切换和负载均衡
api_keys = [ # API Key列表多个key支持错误自动切换和负载均衡
"sk-your-first-key-here",
"sk-your-second-key-here",
"sk-your-third-key-here"
]
# 向后兼容如果只有一个key也可以使用单个key字段
#key = "******" # API Key 可选默认为None
client_type = "openai" # 请求客户端(可选,默认值为"openai"使用gimini等Google系模型时请配置为"gemini"
[[api_providers]] # 特殊Google的Gimini使用特殊API与OpenAI格式不兼容需要配置client为"gemini"
name = "Google"
base_url = "https://api.google.com/v1"
# Google API同样支持多key配置
api_keys = [
"your-google-api-key-1",
"your-google-api-key-2"
]
client_type = "gemini"
[[api_providers]]
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
# 单个key的示例向后兼容
key = "******"
#
#[[api_providers]]
#name = "LocalHost"
#base_url = "https://localhost:8888"
#key = "lm-studio"
[[models]] # 模型(可以配置多个)
# 模型标识符API服务商提供的模型标识符
model_identifier = "deepseek-chat"
# 模型名称可随意命名在bot_config.toml中需使用这个命名
#可选若无该字段则将自动使用model_identifier填充
name = "deepseek-v3"
# API服务商名称对应在api_providers中配置的服务商名称
api_provider = "DeepSeek"
# 任务类型(推荐配置,明确指定模型主要用于什么任务)
# 可选值llm_normal, llm_reasoning, vision, embedding, speech
# 如果不配置系统会根据capabilities或模型名称自动推断
task_type = "llm_normal"
# 模型能力列表(推荐配置,描述模型支持的能力)
# 可选值text, vision, embedding, speech, tool_calling, reasoning
capabilities = ["text", "tool_calling"]
# 输入价格用于API调用统计单位元/兆token可选若无该字段默认值为0
price_in = 2.0
# 输出价格用于API调用统计单位元/兆token可选若无该字段默认值为0
price_out = 8.0
# 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出)
#可选若无该字段默认值为false
#force_stream_mode = true
[[models]]
model_identifier = "deepseek-reasoner"
name = "deepseek-r1"
api_provider = "DeepSeek"
# 推理模型的配置示例
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "text", "tool_calling", "reasoning",]
price_in = 4.0
price_out = 16.0
[[models]]
model_identifier = "Pro/deepseek-ai/DeepSeek-V3"
name = "siliconflow-deepseek-v3"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 2.0
price_out = 8.0
[[models]]
model_identifier = "Pro/deepseek-ai/DeepSeek-R1"
name = "siliconflow-deepseek-r1"
api_provider = "SiliconFlow"
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
price_in = 4.0
price_out = 16.0
[[models]]
model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
name = "deepseek-r1-distill-qwen-32b"
api_provider = "SiliconFlow"
task_type = "llm_reasoning"
capabilities = ["text", "tool_calling", "reasoning"]
price_in = 4.0
price_out = 16.0
[[models]]
model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text"]
price_in = 0
price_out = 0
[[models]]
model_identifier = "Qwen/Qwen3-14B"
name = "qwen3-14b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 0.5
price_out = 2.0
[[models]]
model_identifier = "Qwen/Qwen3-30B-A3B"
name = "qwen3-30b"
api_provider = "SiliconFlow"
task_type = "llm_normal"
capabilities = ["text", "tool_calling"]
price_in = 0.7
price_out = 2.8
[[models]]
model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct"
name = "qwen2.5-vl-72b"
api_provider = "SiliconFlow"
# 视觉模型的配置示例
task_type = "vision"
capabilities = ["vision", "text"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "vision", "text",]
price_in = 4.13
price_out = 4.13
[[models]]
model_identifier = "FunAudioLLM/SenseVoiceSmall"
name = "sensevoice-small"
api_provider = "SiliconFlow"
# 语音模型的配置示例
task_type = "speech"
capabilities = ["speech"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "audio",]
price_in = 0
price_out = 0
[[models]]
model_identifier = "BAAI/bge-m3"
name = "bge-m3"
api_provider = "SiliconFlow"
# 嵌入模型的配置示例
task_type = "embedding"
capabilities = ["text", "embedding"]
# 保留向后兼容的model_flags字段已废弃建议使用capabilities
model_flags = [ "text", "embedding",]
price_in = 0
price_out = 0
[task_model_usage]
llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0}
llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0}
embedding = "siliconflow-bge-m3"
#schedule = [
# "deepseek-v3",
# "deepseek-r1",
#]

View File

@@ -1,23 +1,2 @@
HOST=127.0.0.1
PORT=8000
# 密钥和url
# 硅基流动
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
# DeepSeek官方
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
# 阿里百炼
BAILIAN_BASE_URL = https://dashscope.aliyuncs.com/compatible-mode/v1
# 火山引擎
HUOSHAN_BASE_URL =
# xxxxx平台
xxxxxxx_BASE_URL=https://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# 定义你要用的api的key(需要去对应网站申请哦)
DEEP_SEEK_KEY=
CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY=
BAILIAN_KEY =
HUOSHAN_KEY =
xxxxxxx_KEY=