diff --git a/bot.py b/bot.py index 72ea65d29..b8f154cd3 100644 --- a/bot.py +++ b/bot.py @@ -74,36 +74,6 @@ def easter_egg(): print(rainbow_text) -def scan_provider(env_config: dict): - provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 - env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 - for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 - if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) - if provider_name not in provider: - provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key - if key.endswith("_BASE_URL"): - provider[provider_name]["url"] = env_config[key] - elif key.endswith("_KEY"): - provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key - for provider_name, config in provider.items(): - if config["url"] is None or config["key"] is None: - logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") - async def graceful_shutdown(): try: @@ -229,9 +199,6 @@ def raw_main(): easter_egg() - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) - # 返回MainSystem实例 return MainSystem() diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 51313d4e1..dd691e484 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -981,8 +981,9 @@ class DefaultReplyer: with Timer("LLM生成", {}): # 内部计时器,可选保留 # 加权随机选择一个模型配置 selected_model_config = self._select_weighted_model_config() + model_display_name = selected_model_config.get('model_name') or selected_model_config.get('name', 'N/A') logger.info( - f"使用模型生成回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})" + f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})" ) express_model = LLMRequest( diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py new file mode 100644 index 000000000..b68bf1ae7 --- /dev/null +++ b/src/config/api_ada_configs.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass, field +from typing import List, Dict, Union +import threading +import time + +from packaging.version import Version + +NEWEST_VER = "0.1.1" # 当前支持的最新版本 + +@dataclass +class APIProvider: + name: str = "" # API提供商名称 + base_url: str = "" # API基础URL + api_key: str = field(repr=False, default="") # API密钥(向后兼容) + api_keys: List[str] = field(repr=False, default_factory=list) # API密钥列表(新格式) + client_type: str = "openai" # 客户端类型(如openai/google等,默认为openai) + + # 多API Key管理相关属性 + _current_key_index: int = field(default=0, init=False, repr=False) # 当前使用的key索引 + _key_failure_count: Dict[int, int] = field(default_factory=dict, init=False, repr=False) # 每个key的失败次数 + _key_last_failure_time: Dict[int, float] = field(default_factory=dict, init=False, repr=False) # 每个key最后失败时间 + _lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) # 线程锁 + + def __post_init__(self): + """初始化后处理,确保API keys列表正确""" + # 向后兼容:如果只设置了api_key,将其添加到api_keys列表 + if self.api_key and not self.api_keys: + self.api_keys = [self.api_key] + # 如果api_keys不为空但api_key为空,设置api_key为第一个 + elif self.api_keys and not self.api_key: + self.api_key = self.api_keys[0] + + # 初始化失败计数器 + for i in range(len(self.api_keys)): + self._key_failure_count[i] = 0 + self._key_last_failure_time[i] = 0 + + def get_current_api_key(self) -> str: + """获取当前应该使用的API Key""" + with self._lock: + if not self.api_keys: + return "" + + # 确保索引在有效范围内 + if self._current_key_index >= len(self.api_keys): + self._current_key_index = 0 + + return self.api_keys[self._current_key_index] + + def get_next_api_key(self) -> Union[str, None]: + """获取下一个可用的API Key(负载均衡)""" + with self._lock: + if not self.api_keys: + return None + + # 如果只有一个key,直接返回 + if len(self.api_keys) == 1: + return self.api_keys[0] + + # 轮询到下一个key + self._current_key_index = (self._current_key_index + 1) % len(self.api_keys) + return self.api_keys[self._current_key_index] + + def mark_key_failed(self, api_key: str) -> Union[str, None]: + """标记某个API Key失败,返回下一个可用的key""" + with self._lock: + if not self.api_keys or api_key not in self.api_keys: + return None + + key_index = self.api_keys.index(api_key) + self._key_failure_count[key_index] += 1 + self._key_last_failure_time[key_index] = time.time() + + # 寻找下一个可用的key + current_time = time.time() + for _ in range(len(self.api_keys)): + self._current_key_index = (self._current_key_index + 1) % len(self.api_keys) + next_key_index = self._current_key_index + + # 检查该key是否最近失败过(5分钟内失败超过3次则暂时跳过) + if (self._key_failure_count[next_key_index] <= 3 or + current_time - self._key_last_failure_time[next_key_index] > 300): # 5分钟后重试 + return self.api_keys[next_key_index] + + # 如果所有key都不可用,返回当前key(让上层处理) + return api_key + + def reset_key_failures(self, api_key: str | None = None): + """重置失败计数(成功调用后调用)""" + with self._lock: + if api_key and api_key in self.api_keys: + key_index = self.api_keys.index(api_key) + self._key_failure_count[key_index] = 0 + self._key_last_failure_time[key_index] = 0 + else: + # 重置所有key的失败计数 + for i in range(len(self.api_keys)): + self._key_failure_count[i] = 0 + self._key_last_failure_time[i] = 0 + + def get_api_key_stats(self) -> Dict[str, Dict[str, Union[int, float]]]: + """获取API Key使用统计""" + with self._lock: + stats = {} + for i, key in enumerate(self.api_keys): + # 只显示key的前8位和后4位,中间用*代替 + masked_key = f"{key[:8]}***{key[-4:]}" if len(key) > 12 else "***" + stats[masked_key] = { + "failure_count": self._key_failure_count.get(i, 0), + "last_failure_time": self._key_last_failure_time.get(i, 0), + "is_current": i == self._current_key_index + } + return stats + + +@dataclass +class ModelInfo: + model_identifier: str = "" # 模型标识符(用于URL调用) + name: str = "" # 模型名称(用于模块调用) + api_provider: str = "" # API提供商(如OpenAI、Azure等) + + # 以下用于模型计费 + price_in: float = 0.0 # 每M token输入价格 + price_out: float = 0.0 # 每M token输出价格 + + force_stream_mode: bool = False # 是否强制使用流式输出模式 + + # 新增:任务类型和能力字段 + task_type: str = "" # 任务类型:llm_normal, llm_reasoning, vision, embedding, speech + capabilities: List[str] = field(default_factory=list) # 模型能力:text, vision, embedding, speech, tool_calling, reasoning + + +@dataclass +class RequestConfig: + max_retry: int = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) + timeout: int = ( + 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) + ) + retry_interval: int = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) + default_temperature: float = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) + default_max_tokens: int = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) + + +@dataclass +class ModelUsageArgConfigItem: + """模型使用的配置类 + 该类用于加载和存储子任务模型使用的配置 + """ + + name: str = "" # 模型名称 + temperature: float | None = None # 温度 + max_tokens: int | None = None # 最大token数 + max_retry: int | None = None # 调用失败时的最大重试次数 + + +@dataclass +class ModelUsageArgConfig: + """子任务使用模型的配置类 + 该类用于加载和存储子任务使用的模型配置 + """ + + name: str = "" # 任务名称 + usage: List[ModelUsageArgConfigItem] = field( + default_factory=lambda: [] + ) # 任务使用的模型列表 + + + +@dataclass +class ModuleConfig: + INNER_VERSION: Version | None = None # 配置文件版本 + + req_conf: RequestConfig = field(default_factory=lambda: RequestConfig()) # 请求配置 + api_providers: Dict[str, APIProvider] = field( + default_factory=lambda: {} + ) # API提供商列表 + models: Dict[str, ModelInfo] = field(default_factory=lambda: {}) # 模型列表 + task_model_arg_map: Dict[str, ModelUsageArgConfig] = field( + default_factory=lambda: {} + ) \ No newline at end of file diff --git a/src/config/auto_update.py b/src/config/auto_update.py deleted file mode 100644 index e6471e808..000000000 --- a/src/config/auto_update.py +++ /dev/null @@ -1,162 +0,0 @@ -import shutil -import tomlkit -from tomlkit.items import Table, KeyType -from pathlib import Path -from datetime import datetime - - -def get_key_comment(toml_table, key): - # 获取key的注释(如果有) - if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): - return toml_table.trivia.comment - if hasattr(toml_table, "value") and isinstance(toml_table.value, dict): - item = toml_table.value.get(key) - if item is not None and hasattr(item, "trivia"): - return item.trivia.comment - if hasattr(toml_table, "keys"): - for k in toml_table.keys(): - if isinstance(k, KeyType) and k.key == key: - return k.trivia.comment - return None - - -def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None): - # 递归比较两个dict,找出新增和删减项,收集注释 - if path is None: - path = [] - if logs is None: - logs = [] - if new_comments is None: - new_comments = {} - if old_comments is None: - old_comments = {} - # 新增项 - for key in new: - if key == "version": - continue - if key not in old: - comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") - elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): - compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs) - # 删减项 - for key in old: - if key == "version": - continue - if key not in new: - comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") - return logs - - -def update_config(): - print("开始更新配置文件...") - # 获取根目录路径 - root_dir = Path(__file__).parent.parent.parent.parent - template_dir = root_dir / "template" - config_dir = root_dir / "config" - old_config_dir = config_dir / "old" - - # 创建old目录(如果不存在) - old_config_dir.mkdir(exist_ok=True) - - # 定义文件路径 - template_path = template_dir / "bot_config_template.toml" - old_config_path = config_dir / "bot_config.toml" - new_config_path = config_dir / "bot_config.toml" - - # 读取旧配置文件 - old_config = {} - if old_config_path.exists(): - print(f"发现旧配置文件: {old_config_path}") - with open(old_config_path, "r", encoding="utf-8") as f: - old_config = tomlkit.load(f) - - # 生成带时间戳的新文件名 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml" - - # 移动旧配置文件到old目录 - shutil.move(old_config_path, old_backup_path) - print(f"已备份旧配置文件到: {old_backup_path}") - - # 复制模板文件到配置目录 - print(f"从模板文件创建新配置: {template_path}") - shutil.copy2(template_path, new_config_path) - - # 读取新配置文件 - with open(new_config_path, "r", encoding="utf-8") as f: - new_config = tomlkit.load(f) - - # 检查version是否相同 - if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") # type: ignore - new_version = new_config["inner"].get("version") # type: ignore - if old_version and new_version and old_version == new_version: - print(f"检测到版本号相同 (v{old_version}),跳过更新") - # 如果version相同,恢复旧配置文件并返回 - shutil.move(old_backup_path, old_config_path) # type: ignore - return - else: - print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") - - # 输出新增和删减项及注释 - if old_config: - print("配置项变动如下:") - logs = compare_dicts(new_config, old_config) - if logs: - for log in logs: - print(log) - else: - print("无新增或删减项") - - # 递归更新配置 - def update_dict(target, source): - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, Table)): - update_dict(target[key], value) - else: - try: - # 对数组类型进行特殊处理 - if isinstance(value, list): - # 如果是空数组,确保它保持为空数组 - if not value: - target[key] = tomlkit.array() - else: - # 特殊处理正则表达式数组和包含正则表达式的结构 - if key == "ban_msgs_regex": - # 直接使用原始值,不进行额外处理 - target[key] = value - elif key == "regex_rules": - # 对于regex_rules,需要特殊处理其中的regex字段 - target[key] = value - else: - # 检查是否包含正则表达式相关的字典项 - contains_regex = False - if value and isinstance(value[0], dict) and "regex" in value[0]: - contains_regex = True - - target[key] = value if contains_regex else tomlkit.array(str(value)) - else: - # 其他类型使用item方法创建新值 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - - # 将旧配置的值更新到新配置中 - print("开始合并新旧配置...") - update_dict(new_config, old_config) - - # 保存更新后的配置(保留注释和格式) - with open(new_config_path, "w", encoding="utf-8") as f: - f.write(tomlkit.dumps(new_config)) - print("配置文件更新完成") - - -if __name__ == "__main__": - update_config() diff --git a/src/config/config.py b/src/config/config.py index 805a17d48..bbbf30cd3 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -7,6 +7,10 @@ from tomlkit import TOMLDocument from tomlkit.items import Table, KeyType from dataclasses import field, dataclass from rich.traceback import install +from packaging import version +from packaging.specifiers import SpecifierSet +from packaging.version import Version, InvalidVersion +from typing import Any, Dict, List from src.common.logger import get_logger from src.config.config_base import ConfigBase @@ -36,6 +40,17 @@ from src.config.official_configs import ( CustomPromptConfig, ) +from .api_ada_configs import ( + ModelUsageArgConfigItem, + ModelUsageArgConfig, + APIProvider, + ModelInfo, + NEWEST_VER, + ModuleConfig, +) + + + install(extra_lines=3) @@ -52,6 +67,273 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") MMC_VERSION = "0.9.1" + + +def _get_config_version(toml: Dict) -> Version: + """提取配置文件的 SpecifierSet 版本数据 + Args: + toml[dict]: 输入的配置文件字典 + Returns: + Version + """ + + if "inner" in toml and "version" in toml["inner"]: + config_version: str = toml["inner"]["version"] + else: + config_version = "0.0.0" # 默认版本 + + try: + ver = version.parse(config_version) + except InvalidVersion as e: + logger.error( + "配置文件中 inner段 的 version 键是错误的版本描述\n" + f"请检查配置文件,当前 version 键: {config_version}\n" + f"错误信息: {e}" + ) + raise InvalidVersion( + "配置文件中 inner段 的 version 键是错误的版本描述\n" + ) from e + + return ver + + +def _request_conf(parent: Dict, config: ModuleConfig): + request_conf_config = parent.get("request_conf") + config.req_conf.max_retry = request_conf_config.get( + "max_retry", config.req_conf.max_retry + ) + config.req_conf.timeout = request_conf_config.get( + "timeout", config.req_conf.timeout + ) + config.req_conf.retry_interval = request_conf_config.get( + "retry_interval", config.req_conf.retry_interval + ) + config.req_conf.default_temperature = request_conf_config.get( + "default_temperature", config.req_conf.default_temperature + ) + config.req_conf.default_max_tokens = request_conf_config.get( + "default_max_tokens", config.req_conf.default_max_tokens + ) + + +def _api_providers(parent: Dict, config: ModuleConfig): + api_providers_config = parent.get("api_providers") + for provider in api_providers_config: + name = provider.get("name", None) + base_url = provider.get("base_url", None) + api_key = provider.get("api_key", None) + api_keys = provider.get("api_keys", []) # 新增:支持多个API Key + client_type = provider.get("client_type", "openai") + + if name in config.api_providers: # 查重 + logger.error(f"重复的API提供商名称: {name},请检查配置文件。") + raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。") + + if name and base_url: + # 处理API Key配置:支持单个api_key或多个api_keys + if api_keys: + # 使用新格式:api_keys列表 + logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key") + elif api_key: + # 向后兼容:使用单个api_key + api_keys = [api_key] + logger.debug(f"API提供商 '{name}' 使用单个API Key(向后兼容模式)") + else: + logger.warning(f"API提供商 '{name}' 没有配置API Key,某些功能可能不可用") + + config.api_providers[name] = APIProvider( + name=name, + base_url=base_url, + api_key=api_key, # 保留向后兼容 + api_keys=api_keys, # 新格式 + client_type=client_type, + ) + else: + logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") + raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") + + +def _models(parent: Dict, config: ModuleConfig): + models_config = parent.get("models") + for model in models_config: + model_identifier = model.get("model_identifier", None) + name = model.get("name", model_identifier) + api_provider = model.get("api_provider", None) + price_in = model.get("price_in", 0.0) + price_out = model.get("price_out", 0.0) + force_stream_mode = model.get("force_stream_mode", False) + task_type = model.get("task_type", "") + capabilities = model.get("capabilities", []) + + if name in config.models: # 查重 + logger.error(f"重复的模型名称: {name},请检查配置文件。") + raise KeyError(f"重复的模型名称: {name},请检查配置文件。") + + if model_identifier and api_provider: + # 检查API提供商是否存在 + if api_provider not in config.api_providers: + logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。") + raise ValueError( + f"未声明的API提供商 '{api_provider}' ,请检查配置文件。" + ) + config.models[name] = ModelInfo( + name=name, + model_identifier=model_identifier, + api_provider=api_provider, + price_in=price_in, + price_out=price_out, + force_stream_mode=force_stream_mode, + task_type=task_type, + capabilities=capabilities, + ) + else: + logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。") + raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。") + + +def _task_model_usage(parent: Dict, config: ModuleConfig): + model_usage_configs = parent.get("task_model_usage") + config.task_model_arg_map = {} + for task_name, item in model_usage_configs.items(): + if task_name in config.task_model_arg_map: + logger.error(f"子任务 {task_name} 已存在,请检查配置文件。") + raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。") + + usage = [] + if isinstance(item, Dict): + if "model" in item: + usage.append( + ModelUsageArgConfigItem( + name=item["model"], + temperature=item.get("temperature", None), + max_tokens=item.get("max_tokens", None), + max_retry=item.get("max_retry", None), + ) + ) + else: + logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。") + raise ValueError( + f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" + ) + elif isinstance(item, List): + for model in item: + if isinstance(model, Dict): + usage.append( + ModelUsageArgConfigItem( + name=model["model"], + temperature=model.get("temperature", None), + max_tokens=model.get("max_tokens", None), + max_retry=model.get("max_retry", None), + ) + ) + elif isinstance(model, str): + usage.append( + ModelUsageArgConfigItem( + name=model, + temperature=None, + max_tokens=None, + max_retry=None, + ) + ) + else: + logger.error( + f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" + ) + raise ValueError( + f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" + ) + elif isinstance(item, str): + usage.append( + ModelUsageArgConfigItem( + name=item, + temperature=None, + max_tokens=None, + max_retry=None, + ) + ) + + config.task_model_arg_map[task_name] = ModelUsageArgConfig( + name=task_name, + usage=usage, + ) + + +def api_ada_load_config(config_path: str) -> ModuleConfig: + """从TOML配置文件加载配置""" + config = ModuleConfig() + + include_configs: Dict[str, Dict[str, Any]] = { + "request_conf": { + "func": _request_conf, + "support": ">=0.0.0", + "necessary": False, + }, + "api_providers": {"func": _api_providers, "support": ">=0.0.0"}, + "models": {"func": _models, "support": ">=0.0.0"}, + "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"}, + } + + if os.path.exists(config_path): + with open(config_path, "rb") as f: + try: + toml_dict = tomlkit.load(f) + except tomlkit.TOMLDecodeError as e: + logger.critical( + f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}" + ) + exit(1) + + # 获取配置文件版本 + config.INNER_VERSION = _get_config_version(toml_dict) + + # 检查版本 + if config.INNER_VERSION > Version(NEWEST_VER): + logger.warning( + f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。" + ) + + # 解析配置文件 + # 如果在配置中找到了需要的项,调用对应项的闭包函数处理 + for key in include_configs: + if key in toml_dict: + group_specifier_set: SpecifierSet = SpecifierSet( + include_configs[key]["support"] + ) + + # 检查配置文件版本是否在支持范围内 + if config.INNER_VERSION in group_specifier_set: + # 如果版本在支持范围内,检查是否存在通知 + if "notice" in include_configs[key]: + logger.warning(include_configs[key]["notice"]) + # 调用闭包函数处理配置 + (include_configs[key]["func"])(toml_dict, config) + else: + # 如果版本不在支持范围内,崩溃并提示用户 + logger.error( + f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n" + f"当前程序仅支持以下版本范围: {group_specifier_set}" + ) + raise InvalidVersion( + f"当前程序仅支持以下版本范围: {group_specifier_set}" + ) + + # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理 + elif ( + "necessary" in include_configs[key] + and include_configs[key].get("necessary") is False + ): + # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 + if key == "keywords_reaction": + pass + else: + # 如果用户根本没有需要的配置项,提示缺少配置 + logger.error(f"配置文件中缺少必需的字段: '{key}'") + raise KeyError(f"配置文件中缺少必需的字段: '{key}'") + + logger.info(f"成功加载配置文件: {config_path}") + + return config + def get_key_comment(toml_table, key): # 获取key的注释(如果有) if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): @@ -133,37 +415,74 @@ def compare_default_values(new, old, path=None, logs=None, changes=None): return logs, changes -def update_config(): +def _get_version_from_toml(toml_path): + """从TOML文件中获取版本号""" + if not os.path.exists(toml_path): + return None + with open(toml_path, "r", encoding="utf-8") as f: + doc = tomlkit.load(f) + if "inner" in doc and "version" in doc["inner"]: # type: ignore + return doc["inner"]["version"] # type: ignore + return None + + +def _version_tuple(v): + """将版本字符串转换为元组以便比较""" + if v is None: + return (0,) + return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) + + +def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): + """ + 将source字典的值更新到target字典中(如果target中存在相同的键) + """ + for key, value in source.items(): + # 跳过version字段的更新 + if key == "version": + continue + if key in target: + target_value = target[key] + if isinstance(value, dict) and isinstance(target_value, (dict, Table)): + _update_dict(target_value, value) + else: + try: + # 对数组类型进行特殊处理 + if isinstance(value, list): + # 如果是空数组,确保它保持为空数组 + target[key] = tomlkit.array(str(value)) if value else tomlkit.array() + else: + # 其他类型使用item方法创建新值 + target[key] = tomlkit.item(value) + except (TypeError, ValueError): + # 如果转换失败,直接赋值 + target[key] = value + + +def _update_config_generic(config_name: str, template_name: str, should_quit_on_new: bool = True): + """ + 通用的配置文件更新函数 + + Args: + config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config' + template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template' + should_quit_on_new: 创建新配置文件后是否退出程序 + """ # 获取根目录路径 old_config_dir = os.path.join(CONFIG_DIR, "old") compare_dir = os.path.join(TEMPLATE_DIR, "compare") # 定义文件路径 - template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml") - old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") - new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") - compare_path = os.path.join(compare_dir, "bot_config_template.toml") + template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml") + old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml") + new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml") + compare_path = os.path.join(compare_dir, f"{template_name}.toml") # 创建compare目录(如果不存在) os.makedirs(compare_dir, exist_ok=True) - # 处理compare下的模板文件 - def get_version_from_toml(toml_path): - if not os.path.exists(toml_path): - return None - with open(toml_path, "r", encoding="utf-8") as f: - doc = tomlkit.load(f) - if "inner" in doc and "version" in doc["inner"]: # type: ignore - return doc["inner"]["version"] # type: ignore - return None - - template_version = get_version_from_toml(template_path) - compare_version = get_version_from_toml(compare_path) - - def version_tuple(v): - if v is None: - return (0,) - return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) + template_version = _get_version_from_toml(template_path) + compare_version = _get_version_from_toml(compare_path) # 先读取 compare 下的模板(如果有),用于默认值变动检测 if os.path.exists(compare_path): @@ -183,7 +502,7 @@ def update_config(): old_config = tomlkit.load(f) logs, changes = compare_default_values(new_config, compare_config) if logs: - logger.info("检测到模板默认值变动如下:") + logger.info(f"检测到{config_name}模板默认值变动如下:") for log in logs: logger.info(log) # 检查旧配置是否等于旧默认值,如果是则更新为新默认值 @@ -192,10 +511,10 @@ def update_config(): if old_value == old_default: set_value_by_path(old_config, path, new_default) logger.info( - f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" + f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" ) else: - logger.info("未检测到模板默认值变动") + logger.info(f"未检测到{config_name}模板默认值变动") # 保存旧配置的变更(后续合并逻辑会用到 old_config) else: old_config = None @@ -203,22 +522,25 @@ def update_config(): # 检查 compare 下没有模板,或新模板版本更高,则复制 if not os.path.exists(compare_path): shutil.copy2(template_path, compare_path) - logger.info(f"已将模板文件复制到: {compare_path}") + logger.info(f"已将{config_name}模板文件复制到: {compare_path}") else: - if version_tuple(template_version) > version_tuple(compare_version): + if _version_tuple(template_version) > _version_tuple(compare_version): shutil.copy2(template_path, compare_path) - logger.info(f"模板版本较新,已替换compare下的模板: {compare_path}") + logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}") else: - logger.debug(f"compare下的模板版本不低于当前模板,无需替换: {compare_path}") + logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}") # 检查配置文件是否存在 if not os.path.exists(old_config_path): - logger.info("配置文件不存在,从模板创建新配置") + logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置") os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 shutil.copy2(template_path, old_config_path) # 复制模板文件 - logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") - # 如果是新创建的配置文件,直接返回 - quit() + logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}") + # 如果是新创建的配置文件,根据参数决定是否退出 + if should_quit_on_new: + quit() + else: + return # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次) if old_config is None: @@ -226,38 +548,36 @@ def update_config(): old_config = tomlkit.load(f) # new_config 已经读取 - # 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用 - # 检查version是否相同 if old_config and "inner" in old_config and "inner" in new_config: old_version = old_config["inner"].get("version") # type: ignore new_version = new_config["inner"].get("version") # type: ignore if old_version and new_version and old_version == new_version: - logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") + logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新") return else: logger.info( - f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" + f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" ) else: - logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") + logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新") # 创建old目录(如果不存在) os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = os.path.join(old_config_dir, f"bot_config_{timestamp}.toml") + old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml") # 移动旧配置文件到old目录 shutil.move(old_config_path, old_backup_path) - logger.info(f"已备份旧配置文件到: {old_backup_path}") + logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}") # 复制模板文件到配置目录 shutil.copy2(template_path, new_config_path) - logger.info(f"已创建新配置文件: {new_config_path}") + logger.info(f"已创建新{config_name}配置文件: {new_config_path}") # 输出新增和删减项及注释 if old_config: - logger.info("配置项变动如下:\n----------------------------------------") + logger.info(f"{config_name}配置项变动如下:\n----------------------------------------") logs = compare_dicts(new_config, old_config) if logs: for log in logs: @@ -265,40 +585,24 @@ def update_config(): else: logger.info("无新增或删减项") - def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): - """ - 将source字典的值更新到target字典中(如果target中存在相同的键) - """ - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - target_value = target[key] - if isinstance(value, dict) and isinstance(target_value, (dict, Table)): - update_dict(target_value, value) - else: - try: - # 对数组类型进行特殊处理 - if isinstance(value, list): - # 如果是空数组,确保它保持为空数组 - target[key] = tomlkit.array(str(value)) if value else tomlkit.array() - else: - # 其他类型使用item方法创建新值 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - # 将旧配置的值更新到新配置中 - logger.info("开始合并新旧配置...") - update_dict(new_config, old_config) + logger.info(f"开始合并{config_name}新旧配置...") + _update_dict(new_config, old_config) # 保存更新后的配置(保留注释和格式) with open(new_config_path, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(new_config)) - logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") - quit() + logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + + +def update_config(): + """更新bot_config.toml配置文件""" + _update_config_generic("bot_config", "bot_config_template", should_quit_on_new=True) + + +def update_model_config(): + """更新model_config.toml配置文件""" + _update_config_generic("model_config", "model_config_template", should_quit_on_new=False) @dataclass @@ -360,7 +664,9 @@ def get_config_dir() -> str: # 获取配置文件路径 logger.info(f"MaiCore当前版本: {MMC_VERSION}") update_config() +update_model_config() logger.info("正在品鉴配置文件...") global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml")) -logger.info("非常的新鲜,非常的美味!") +model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml")) +logger.info("非常的新鲜,非常的美味!") \ No newline at end of file diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 2c9f847c4..08acf97c6 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from typing import Any, Literal, Optional from src.config.config_base import ConfigBase +from packaging.version import Version """ 须知: @@ -598,7 +599,6 @@ class LPMMKnowledgeConfig(ConfigBase): embedding_dimension: int = 1024 """嵌入向量维度,应该与模型的输出维度一致""" - @dataclass class ModelConfig(ConfigBase): """模型配置类""" diff --git a/src/llm_models/LICENSE b/src/llm_models/LICENSE new file mode 100644 index 000000000..8b3236ed5 --- /dev/null +++ b/src/llm_models/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Mai.To.The.Gate + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/llm_models/__init__.py b/src/llm_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llm_models/exceptions.py b/src/llm_models/exceptions.py new file mode 100644 index 000000000..0ced8dd14 --- /dev/null +++ b/src/llm_models/exceptions.py @@ -0,0 +1,69 @@ +from typing import Any + + +# 常见Error Code Mapping (以OpenAI API为例) +error_code_mapping = { + 400: "参数不正确", + 401: "API-Key错误,认证失败,请检查/config/model_list.toml中的配置是否正确", + 402: "账号余额不足", + 403: "模型拒绝访问,可能需要实名或余额不足", + 404: "Not Found", + 413: "请求体过大,请尝试压缩图片或减少输入内容", + 429: "请求过于频繁,请稍后再试", + 500: "服务器内部故障", + 503: "服务器负载过高", +} + + +class NetworkConnectionError(Exception): + """连接异常,常见于网络问题或服务器不可用""" + + def __init__(self): + super().__init__() + + def __str__(self): + return "连接异常,请检查网络连接状态或URL是否正确" + + +class ReqAbortException(Exception): + """请求异常退出,常见于请求被中断或取消""" + + def __init__(self, message: str | None = None): + super().__init__(message) + self.message = message + + def __str__(self): + return self.message or "请求因未知原因异常终止" + + +class RespNotOkException(Exception): + """请求响应异常,见于请求未能成功响应(非 '200 OK')""" + + def __init__(self, status_code: int, message: str | None = None): + super().__init__(message) + self.status_code = status_code + self.message = message + + def __str__(self): + if self.status_code in error_code_mapping: + return error_code_mapping[self.status_code] + elif self.message: + return self.message + else: + return f"未知的异常响应代码:{self.status_code}" + + +class RespParseException(Exception): + """响应解析错误,常见于响应格式不正确或解析方法不匹配""" + + def __init__(self, ext_info: Any, message: str | None = None): + super().__init__(message) + self.ext_info = ext_info + self.message = message + + def __str__(self): + return ( + self.message + if self.message + else "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法" + ) diff --git a/src/llm_models/model_client/__init__.py b/src/llm_models/model_client/__init__.py new file mode 100644 index 000000000..7e57c82d6 --- /dev/null +++ b/src/llm_models/model_client/__init__.py @@ -0,0 +1,380 @@ +import asyncio +from typing import Callable, Any + +from openai import AsyncStream +from openai.types.chat import ChatCompletionChunk, ChatCompletion + +from .base_client import BaseClient, APIResponse +from src.config.api_ada_configs import ( + ModelInfo, + ModelUsageArgConfigItem, + RequestConfig, + ModuleConfig, +) +from ..exceptions import ( + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from ..payload_content.message import Message +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption +from ..utils import compress_messages +from src.common.logger import get_logger + +logger = get_logger("模型客户端") + + +def _check_retry( + remain_try: int, + retry_interval: int, + can_retry_msg: str, + cannot_retry_msg: str, + can_retry_callable: Callable | None = None, + **kwargs, +) -> tuple[int, Any | None]: + """ + 辅助函数:检查是否可以重试 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param can_retry_msg: 可以重试时的提示信息 + :param cannot_retry_msg: 不可以重试时的提示信息 + :return: (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + if remain_try > 0: + # 还有重试机会 + logger.warning(f"{can_retry_msg}") + if can_retry_callable is not None: + return retry_interval, can_retry_callable(**kwargs) + else: + return retry_interval, None + else: + # 达到最大重试次数 + logger.warning(f"{cannot_retry_msg}") + return -1, None # 不再重试请求该模型 + + +def _handle_resp_not_ok( + e: RespNotOkException, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, +): + """ + 处理响应错误异常 + :param e: 异常对象 + :param task_name: 任务名称 + :param model_name: 模型名称 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param messages: (消息列表, 是否已压缩过) + :return: (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + # 响应错误 + if e.status_code in [401, 403]: + # API Key认证错误 - 让多API Key机制处理,给一次重试机会 + if remain_try > 0: + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"API Key认证失败(错误代码-{e.status_code}),多API Key机制会自动切换" + ) + return 0, None # 立即重试,让底层客户端切换API Key + else: + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"所有API Key都认证失败,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None # 不再重试请求该模型 + elif e.status_code in [400, 402, 404]: + # 其他客户端错误(不应该重试) + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"请求失败,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None # 不再重试请求该模型 + elif e.status_code == 413: + if messages and not messages[1]: + # 消息列表不为空且未压缩,尝试压缩消息 + return _check_retry( + remain_try, + 0, + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,尝试压缩消息后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,压缩消息后仍然过大,放弃请求" + ), + can_retry_callable=compress_messages, + messages=messages[0], + ) + # 没有消息可压缩 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求体过大,无法压缩消息,放弃请求。" + ) + return -1, None + elif e.status_code == 429: + # 请求过于频繁 - 让多API Key机制处理,适当延迟后重试 + return _check_retry( + remain_try, + min(retry_interval, 5), # 限制最大延迟为5秒,让API Key切换更快生效 + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"请求过于频繁,多API Key机制会自动切换,{min(retry_interval, 5)}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "请求过于频繁,所有API Key都被限制,放弃请求" + ), + ) + elif e.status_code >= 500: + # 服务器错误 + return _check_retry( + remain_try, + retry_interval, + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"服务器错误,将于{retry_interval}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + "服务器错误,超过最大重试次数,请稍后再试" + ), + ) + else: + # 未知错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"未知错误,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None + + +def default_exception_handler( + e: Exception, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, +) -> tuple[int, list[Message] | None]: + """ + 默认异常处理函数 + :param e: 异常对象 + :param task_name: 任务名称 + :param model_name: 模型名称 + :param remain_try: 剩余尝试次数 + :param retry_interval: 重试间隔 + :param messages: (消息列表, 是否已压缩过) + :return (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + + if isinstance(e, NetworkConnectionError): # 网络连接错误 + # 网络错误可能是某个API Key的端点问题,给多API Key机制一次快速重试机会 + return _check_retry( + remain_try, + min(retry_interval, 3), # 网络错误时减少等待时间,让API Key切换更快 + can_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"连接异常,多API Key机制会尝试其他Key,{min(retry_interval, 3)}秒后重试" + ), + cannot_retry_msg=( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确" + ), + ) + elif isinstance(e, ReqAbortException): + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}'\n请求被中断,详细信息-{str(e.message)}" + ) + return -1, None # 不再重试请求该模型 + elif isinstance(e, RespNotOkException): + return _handle_resp_not_ok( + e, + task_name, + model_name, + remain_try, + retry_interval, + messages, + ) + elif isinstance(e, RespParseException): + # 响应解析错误 + logger.error( + f"任务-'{task_name}' 模型-'{model_name}'\n" + f"响应解析错误,错误信息-{e.message}\n" + ) + logger.debug(f"附加内容:\n{str(e.ext_info)}") + return -1, None # 不再重试请求该模型 + else: + logger.error( + f"任务-'{task_name}' 模型-'{model_name}'\n未知异常,错误信息-{str(e)}" + ) + return -1, None # 不再重试请求该模型 + + +class ModelRequestHandler: + """ + 模型请求处理器 + """ + + def __init__( + self, + task_name: str, + config: ModuleConfig, + api_client_map: dict[str, BaseClient], + ): + self.task_name: str = task_name + """任务名称""" + + self.client_map: dict[str, BaseClient] = {} + """API客户端列表""" + + self.configs: list[tuple[ModelInfo, ModelUsageArgConfigItem]] = [] + """模型参数配置""" + + self.req_conf: RequestConfig = config.req_conf + """请求配置""" + + # 获取模型与使用配置 + for model_usage in config.task_model_arg_map[task_name].usage: + if model_usage.name not in config.models: + logger.error(f"Model '{model_usage.name}' not found in ModelManager") + raise KeyError(f"Model '{model_usage.name}' not found in ModelManager") + model_info = config.models[model_usage.name] + + if model_info.api_provider not in self.client_map: + # 缓存API客户端 + self.client_map[model_info.api_provider] = api_client_map[ + model_info.api_provider + ] + + self.configs.append((model_info, model_usage)) # 添加模型与使用配置 + + async def get_response( + self, + messages: list[Message], + tool_options: list[ToolOption] | None = None, + response_format: RespFormat | None = None, # 暂不启用 + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], APIResponse + ] + | None = None, + async_response_parser: Callable[[ChatCompletion], APIResponse] | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param messages: 消息列表 + :param tool_options: 工具选项列表 + :param response_format: 响应格式 + :param stream_response_handler: 流式响应处理函数(可选) + :param async_response_parser: 响应解析函数(可选) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: APIResponse + """ + # 遍历可用模型,若获取响应失败,则使用下一个模型继续请求 + for config_item in self.configs: + client = self.client_map[config_item[0].api_provider] + model_info: ModelInfo = config_item[0] + model_usage_config: ModelUsageArgConfigItem = config_item[1] + + remain_try = ( + model_usage_config.max_retry or self.req_conf.max_retry + ) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1 + + compressed_messages = None + retry_interval = self.req_conf.retry_interval + while remain_try > 0: + try: + return await client.get_response( + model_info, + message_list=(compressed_messages or messages), + tool_options=tool_options, + max_tokens=model_usage_config.max_tokens + or self.req_conf.default_max_tokens, + temperature=model_usage_config.temperature + or self.req_conf.default_temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + interrupt_flag=interrupt_flag, + ) + except Exception as e: + logger.debug(e) + remain_try -= 1 # 剩余尝试次数减1 + + # 处理异常 + handle_res = default_exception_handler( + e, + self.task_name, + model_info.name, + remain_try, + retry_interval=self.req_conf.retry_interval, + messages=(messages, compressed_messages is not None), + ) + + if handle_res[0] == -1: + # 等待间隔为-1,表示不再请求该模型 + remain_try = 0 + elif handle_res[0] != 0: + # 等待间隔不为0,表示需要等待 + await asyncio.sleep(handle_res[0]) + retry_interval *= 2 + + if handle_res[1] is not None: + # 压缩消息 + compressed_messages = handle_res[1] + + logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用") + raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败 + + async def get_embedding( + self, + embedding_input: str, + ) -> APIResponse: + """ + 获取嵌入向量 + :param embedding_input: 嵌入输入 + :return: APIResponse + """ + for config in self.configs: + client = self.client_map[config[0].api_provider] + model_info: ModelInfo = config[0] + model_usage_config: ModelUsageArgConfigItem = config[1] + remain_try = ( + model_usage_config.max_retry or self.req_conf.max_retry + ) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1 + + while remain_try: + try: + return await client.get_embedding( + model_info=model_info, + embedding_input=embedding_input, + ) + except Exception as e: + logger.debug(e) + remain_try -= 1 # 剩余尝试次数减1 + + # 处理异常 + handle_res = default_exception_handler( + e, + self.task_name, + model_info.name, + remain_try, + retry_interval=self.req_conf.retry_interval, + ) + + if handle_res[0] == -1: + # 等待间隔为-1,表示不再请求该模型 + remain_try = 0 + elif handle_res[0] != 0: + # 等待间隔不为0,表示需要等待 + await asyncio.sleep(handle_res[0]) + + logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用") + raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败 diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py new file mode 100644 index 000000000..50a379d34 --- /dev/null +++ b/src/llm_models/model_client/base_client.py @@ -0,0 +1,116 @@ +import asyncio +from dataclasses import dataclass +from typing import Callable, Any + +from openai import AsyncStream +from openai.types.chat import ChatCompletionChunk, ChatCompletion + +from src.config.api_ada_configs import ModelInfo, APIProvider +from ..payload_content.message import Message +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption, ToolCall + + +@dataclass +class UsageRecord: + """ + 使用记录类 + """ + + model_name: str + """模型名称""" + + provider_name: str + """提供商名称""" + + prompt_tokens: int + """提示token数""" + + completion_tokens: int + """完成token数""" + + total_tokens: int + """总token数""" + + +@dataclass +class APIResponse: + """ + API响应类 + """ + + content: str | None = None + """响应内容""" + + reasoning_content: str | None = None + """推理内容""" + + tool_calls: list[ToolCall] | None = None + """工具调用 [(工具名称, 工具参数), ...]""" + + embedding: list[float] | None = None + """嵌入向量""" + + usage: UsageRecord | None = None + """使用情况 (prompt_tokens, completion_tokens, total_tokens)""" + + raw_data: Any = None + """响应原始数据""" + + +class BaseClient: + """ + 基础客户端 + """ + + api_provider: APIProvider + + def __init__(self, api_provider: APIProvider): + self.api_provider = api_provider + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + tuple[APIResponse, tuple[int, int, int]], + ] + | None = None, + async_response_parser: Callable[ + [ChatCompletion], tuple[APIResponse, tuple[int, int, int]] + ] + | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param model_info: 模型信息 + :param message_list: 对话体 + :param tool_options: 工具选项(可选,默认为None) + :param max_tokens: 最大token数(可选,默认为1024) + :param temperature: 温度(可选,默认为0.7) + :param response_format: 响应格式(可选,默认为 NotGiven ) + :param stream_response_handler: 流式响应处理函数(可选) + :param async_response_parser: 响应解析函数(可选) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: (响应文本, 推理文本, 工具调用, 其他数据) + """ + raise RuntimeError("This method should be overridden in subclasses") + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + raise RuntimeError("This method should be overridden in subclasses") diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py new file mode 100644 index 000000000..a2c715a21 --- /dev/null +++ b/src/llm_models/model_client/gemini_client.py @@ -0,0 +1,573 @@ +import asyncio +import io +from collections.abc import Iterable +from typing import Callable, Iterator, TypeVar, AsyncIterator + +from google import genai +from google.genai import types +from google.genai.types import FunctionDeclaration, GenerateContentResponse +from google.genai.errors import ( + ClientError, + ServerError, + UnknownFunctionCallArgumentError, + UnsupportedFunctionError, + FunctionInvocationError, +) + +from .base_client import APIResponse, UsageRecord +from src.config.api_ada_configs import ModelInfo, APIProvider +from . import BaseClient +from src.common.logger import get_logger + +from ..exceptions import ( + RespParseException, + NetworkConnectionError, + RespNotOkException, + ReqAbortException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat, RespFormatType +from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall + +logger = get_logger("Gemini客户端") +T = TypeVar("T") + + +def _convert_messages( + messages: list[Message], +) -> tuple[list[types.Content], list[str] | None]: + """ + 转换消息格式 - 将消息转换为Gemini API所需的格式 + :param messages: 消息列表 + :return: 转换后的消息列表(和可能存在的system消息) + """ + + def _convert_message_item(message: Message) -> types.Content: + """ + 转换单个消息格式,除了system和tool类型的消息 + :param message: 消息对象 + :return: 转换后的消息字典 + """ + + # 将openai格式的角色重命名为gemini格式的角色 + if message.role == RoleType.Assistant: + role = "model" + elif message.role == RoleType.User: + role = "user" + + # 添加Content + content: types.Part | list + if isinstance(message.content, str): + content = types.Part.from_text(message.content) + elif isinstance(message.content, list): + content = [] + for item in message.content: + if isinstance(item, tuple): + content.append( + types.Part.from_bytes( + data=item[1], mime_type=f"image/{item[0].lower()}" + ) + ) + elif isinstance(item, str): + content.append(types.Part.from_text(item)) + else: + raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + + return types.Content(role=role, content=content) + + temp_list: list[types.Content] = [] + system_instructions: list[str] = [] + for message in messages: + if message.role == RoleType.System: + if isinstance(message.content, str): + system_instructions.append(message.content) + else: + raise RuntimeError("你tm怎么往system里面塞图片base64?") + elif message.role == RoleType.Tool: + if not message.tool_call_id: + raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") + else: + temp_list.append(_convert_message_item(message)) + if system_instructions: + # 如果有system消息,就把它加上去 + ret: tuple = (temp_list, system_instructions) + else: + # 如果没有system消息,就直接返回 + ret: tuple = (temp_list, None) + + return ret + + +def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]: + """ + 转换工具选项格式 - 将工具选项转换为Gemini API所需的格式 + :param tool_options: 工具选项列表 + :return: 转换后的工具对象列表 + """ + + def _convert_tool_param(tool_option_param: ToolParam) -> dict: + """ + 转换单个工具参数格式 + :param tool_option_param: 工具参数对象 + :return: 转换后的工具参数字典 + """ + return { + "type": tool_option_param.param_type.value, + "description": tool_option_param.description, + } + + def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration: + """ + 转换单个工具项格式 + :param tool_option: 工具选项对象 + :return: 转换后的Gemini工具选项对象 + """ + ret = { + "name": tool_option.name, + "description": tool_option.description, + } + if tool_option.params: + ret["parameters"] = { + "type": "object", + "properties": { + param.name: _convert_tool_param(param) + for param in tool_option.params + }, + "required": [ + param.name for param in tool_option.params if param.required + ], + } + ret1 = types.FunctionDeclaration(**ret) + return ret1 + + return [_convert_tool_option_item(tool_option) for tool_option in tool_options] + + +def _process_delta( + delta: GenerateContentResponse, + fc_delta_buffer: io.StringIO, + tool_calls_buffer: list[tuple[str, str, dict]], +): + if not hasattr(delta, "candidates") or len(delta.candidates) == 0: + raise RespParseException(delta, "响应解析失败,缺失candidates字段") + + if delta.text: + fc_delta_buffer.write(delta.text) + + if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的 + for call in delta.function_calls: + try: + if not isinstance( + call.args, dict + ): # gemini返回的function call参数就是dict格式的了 + raise RespParseException( + delta, "响应解析失败,工具调用参数无法解析为字典类型" + ) + tool_calls_buffer.append( + ( + call.id, + call.name, + call.args, + ) + ) + except Exception as e: + raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e + + +def _build_stream_api_resp( + _fc_delta_buffer: io.StringIO, + _tool_calls_buffer: list[tuple[str, str, dict]], +) -> APIResponse: + resp = APIResponse() + + if _fc_delta_buffer.tell() > 0: + # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 + resp.content = _fc_delta_buffer.getvalue() + _fc_delta_buffer.close() + if len(_tool_calls_buffer) > 0: + # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 + resp.tool_calls = [] + for call_id, function_name, arguments_buffer in _tool_calls_buffer: + if arguments_buffer is not None: + arguments = arguments_buffer + if not isinstance(arguments, dict): + raise RespParseException( + None, + "响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n" + f"{arguments_buffer}", + ) + else: + arguments = None + + resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + + return resp + + +async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]: + """ + 将迭代器转换为异步迭代器 + :param iterable: 迭代器对象 + :return: 异步迭代器对象 + """ + for item in iterable: + await asyncio.sleep(0) + yield item + + +async def _default_stream_response_handler( + resp_stream: Iterator[GenerateContentResponse], + interrupt_flag: asyncio.Event | None, +) -> tuple[APIResponse, tuple[int, int, int]]: + """ + 流式响应处理函数 - 处理Gemini API的流式响应 + :param resp_stream: 流式响应对象,是一个神秘的iterator,我完全不知道这个玩意能不能跑,不过遍历一遍之后它就空了,如果跑不了一点的话可以考虑改成别的东西 + :return: APIResponse对象 + """ + _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 + _tool_calls_buffer: list[ + tuple[str, str, dict] + ] = [] # 工具调用缓冲区,用于存储接收到的工具调用 + _usage_record = None # 使用情况记录 + + def _insure_buffer_closed(): + if _fc_delta_buffer and not _fc_delta_buffer.closed: + _fc_delta_buffer.close() + + async for chunk in _to_async_iterable(resp_stream): + # 检查是否有中断量 + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量被设置,则抛出ReqAbortException + raise ReqAbortException("请求被外部信号中断") + + _process_delta( + chunk, + _fc_delta_buffer, + _tool_calls_buffer, + ) + + if chunk.usage_metadata: + # 如果有使用情况,则将其存储在APIResponse对象中 + _usage_record = ( + chunk.usage_metadata.prompt_token_count, + chunk.usage_metadata.candidates_token_count + + chunk.usage_metadata.thoughts_token_count, + chunk.usage_metadata.total_token_count, + ) + try: + return _build_stream_api_resp( + _fc_delta_buffer, + _tool_calls_buffer, + ), _usage_record + except Exception: + # 确保缓冲区被关闭 + _insure_buffer_closed() + raise + + +def _default_normal_response_parser( + resp: GenerateContentResponse, +) -> tuple[APIResponse, tuple[int, int, int]]: + """ + 解析对话补全响应 - 将Gemini API响应解析为APIResponse对象 + :param resp: 响应对象 + :return: APIResponse对象 + """ + api_response = APIResponse() + + if not hasattr(resp, "candidates") or len(resp.candidates) == 0: + raise RespParseException(resp, "响应解析失败,缺失candidates字段") + + if resp.text: + api_response.content = resp.text + + if resp.function_calls: + api_response.tool_calls = [] + for call in resp.function_calls: + try: + if not isinstance(call.args, dict): + raise RespParseException( + resp, "响应解析失败,工具调用参数无法解析为字典类型" + ) + api_response.tool_calls.append(ToolCall(call.id, call.name, call.args)) + except Exception as e: + raise RespParseException( + resp, "响应解析失败,无法解析工具调用参数" + ) from e + + if resp.usage_metadata: + _usage_record = ( + resp.usage_metadata.prompt_token_count, + resp.usage_metadata.candidates_token_count + + resp.usage_metadata.thoughts_token_count, + resp.usage_metadata.total_token_count, + ) + else: + _usage_record = None + + api_response.raw_data = resp + + return api_response, _usage_record + + +class GeminiClient(BaseClient): + def __init__(self, api_provider: APIProvider): + super().__init__(api_provider) + # 不再在初始化时创建固定的client,而是在请求时动态创建 + self._clients_cache = {} # API Key -> genai.Client 的缓存 + + def _get_client(self, api_key: str = None) -> genai.Client: + """获取或创建对应API Key的客户端""" + if api_key is None: + api_key = self.api_provider.get_current_api_key() + + if not api_key: + raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key") + + # 使用缓存避免重复创建客户端 + if api_key not in self._clients_cache: + self._clients_cache[api_key] = genai.Client(api_key=api_key) + + return self._clients_cache[api_key] + + async def _execute_with_fallback(self, func, *args, **kwargs): + """执行请求并在失败时切换API Key""" + current_api_key = self.api_provider.get_current_api_key() + max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1 + + for attempt in range(max_attempts): + try: + client = self._get_client(current_api_key) + result = await func(client, *args, **kwargs) + # 成功时重置失败计数 + self.api_provider.reset_key_failures(current_api_key) + return result + + except (ClientError, ServerError) as e: + # 记录失败并尝试下一个API Key + logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}") + + if attempt < max_attempts - 1: # 还有重试机会 + next_api_key = self.api_provider.mark_key_failed(current_api_key) + if next_api_key and next_api_key != current_api_key: + current_api_key = next_api_key + logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}") + continue + + # 所有API Key都失败了,重新抛出异常 + raise RespNotOkException(e.status_code, e.message) from e + + except Exception as e: + # 其他异常直接抛出 + raise e + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + thinking_budget: int = 0, + response_format: RespFormat | None = None, + stream_response_handler: Callable[ + [Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse + ] + | None = None, + async_response_parser: Callable[[GenerateContentResponse], APIResponse] + | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param model_info: 模型信息 + :param message_list: 对话体 + :param tool_options: 工具选项(可选,默认为None) + :param max_tokens: 最大token数(可选,默认为1024) + :param temperature: 温度(可选,默认为0.7) + :param thinking_budget: 思考预算(可选,默认为0) + :param response_format: 响应格式(默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的,暂不支持其它相应格式输入) + :param stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) + :param async_response_parser: 响应解析函数(可选,默认为default_response_parser) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: (响应文本, 推理文本, 工具调用, 其他数据) + """ + return await self._execute_with_fallback( + self._get_response_internal, + model_info, + message_list, + tool_options, + max_tokens, + temperature, + thinking_budget, + response_format, + stream_response_handler, + async_response_parser, + interrupt_flag, + ) + + async def _get_response_internal( + self, + client: genai.Client, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + thinking_budget: int = 0, + response_format: RespFormat | None = None, + stream_response_handler: Callable[ + [Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse + ] + | None = None, + async_response_parser: Callable[[GenerateContentResponse], APIResponse] + | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """内部方法:执行实际的API调用""" + if stream_response_handler is None: + stream_response_handler = _default_stream_response_handler + + if async_response_parser is None: + async_response_parser = _default_normal_response_parser + + # 将messages构造为Gemini API所需的格式 + messages = _convert_messages(message_list) + # 将tool_options转换为Gemini API所需的格式 + tools = _convert_tool_options(tool_options) if tool_options else None + # 将response_format转换为Gemini API所需的格式 + generation_config_dict = { + "max_output_tokens": max_tokens, + "temperature": temperature, + "response_modalities": ["TEXT"], # 暂时只支持文本输出 + } + if "2.5" in model_info.model_identifier.lower(): + # 我偷个懒,在这里识别一下2.5然后开摆,反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容 + generation_config_dict["thinking_config"] = types.ThinkingConfig( + thinking_budget=thinking_budget, include_thoughts=False + ) + if tools: + generation_config_dict["tools"] = types.Tool(tools) + if messages[1]: + # 如果有system消息,则将其添加到配置中 + generation_config_dict["system_instructions"] = messages[1] + if response_format and response_format.format_type == RespFormatType.TEXT: + generation_config_dict["response_mime_type"] = "text/plain" + elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA): + generation_config_dict["response_mime_type"] = "application/json" + generation_config_dict["response_schema"] = response_format.to_dict() + + generation_config = types.GenerateContentConfig(**generation_config_dict) + + try: + if model_info.force_stream_mode: + req_task = asyncio.create_task( + client.aio.models.generate_content_stream( + model=model_info.model_identifier, + contents=messages[0], + config=generation_config, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 + resp, usage_record = await stream_response_handler( + req_task.result(), interrupt_flag + ) + else: + req_task = asyncio.create_task( + client.aio.models.generate_content( + model=model_info.model_identifier, + contents=messages[0], + config=generation_config, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 + + resp, usage_record = async_response_parser(req_task.result()) + except (ClientError, ServerError) as e: + # 重封装ClientError和ServerError为RespNotOkException + raise RespNotOkException(e.status_code, e.message) from e + except ( + UnknownFunctionCallArgumentError, + UnsupportedFunctionError, + FunctionInvocationError, + ) as e: + raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from e + except Exception as e: + raise NetworkConnectionError() from e + + if usage_record: + resp.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + return resp + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + return await self._execute_with_fallback( + self._get_embedding_internal, + model_info, + embedding_input, + ) + + async def _get_embedding_internal( + self, + client: genai.Client, + model_info: ModelInfo, + embedding_input: str, + ) -> APIResponse: + """内部方法:执行实际的嵌入API调用""" + try: + raw_response: types.EmbedContentResponse = ( + await client.aio.models.embed_content( + model=model_info.model_identifier, + contents=embedding_input, + config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), + ) + ) + except (ClientError, ServerError) as e: + # 重封装ClientError和ServerError为RespNotOkException + raise RespNotOkException(e.status_code) from e + except Exception as e: + raise NetworkConnectionError() from e + + response = APIResponse() + + # 解析嵌入响应和使用情况 + if hasattr(raw_response, "embeddings"): + response.embedding = raw_response.embeddings[0].values + else: + raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段") + + response.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=len(embedding_input), + completion_tokens=0, + total_tokens=len(embedding_input), + ) + + return response diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py new file mode 100644 index 000000000..a70458ffe --- /dev/null +++ b/src/llm_models/model_client/openai_client.py @@ -0,0 +1,647 @@ +import asyncio +import io +import json +import re +from collections.abc import Iterable +from typing import Callable, Any + +from openai import ( + AsyncOpenAI, + APIConnectionError, + APIStatusError, + NOT_GIVEN, + AsyncStream, +) +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionToolParam, +) +from openai.types.chat.chat_completion_chunk import ChoiceDelta + +from .base_client import APIResponse, UsageRecord +from src.config.api_ada_configs import ModelInfo, APIProvider +from . import BaseClient +from src.common.logger import get_logger + +from ..exceptions import ( + RespParseException, + NetworkConnectionError, + RespNotOkException, + ReqAbortException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall + +logger = get_logger("OpenAI客户端") + + +def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]: + """ + 转换消息格式 - 将消息转换为OpenAI API所需的格式 + :param messages: 消息列表 + :return: 转换后的消息列表 + """ + + def _convert_message_item(message: Message) -> ChatCompletionMessageParam: + """ + 转换单个消息格式 + :param message: 消息对象 + :return: 转换后的消息字典 + """ + + # 添加Content + content: str | list[dict[str, Any]] + if isinstance(message.content, str): + content = message.content + elif isinstance(message.content, list): + content = [] + for item in message.content: + if isinstance(item, tuple): + content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/{item[0].lower()};base64,{item[1]}" + }, + } + ) + elif isinstance(item, str): + content.append({"type": "text", "text": item}) + else: + raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + + ret = { + "role": message.role.value, + "content": content, + } + + # 添加工具调用ID + if message.role == RoleType.Tool: + if not message.tool_call_id: + raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") + ret["tool_call_id"] = message.tool_call_id + + return ret + + return [_convert_message_item(message) for message in messages] + + +def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]: + """ + 转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式 + :param tool_options: 工具选项列表 + :return: 转换后的工具选项列表 + """ + + def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, str]: + """ + 转换单个工具参数格式 + :param tool_option_param: 工具参数对象 + :return: 转换后的工具参数字典 + """ + return { + "type": tool_option_param.param_type.value, + "description": tool_option_param.description, + } + + def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]: + """ + 转换单个工具项格式 + :param tool_option: 工具选项对象 + :return: 转换后的工具选项字典 + """ + ret: dict[str, Any] = { + "name": tool_option.name, + "description": tool_option.description, + } + if tool_option.params: + ret["parameters"] = { + "type": "object", + "properties": { + param.name: _convert_tool_param(param) + for param in tool_option.params + }, + "required": [ + param.name for param in tool_option.params if param.required + ], + } + return ret + + return [ + { + "type": "function", + "function": _convert_tool_option_item(tool_option), + } + for tool_option in tool_options + ] + + +def _process_delta( + delta: ChoiceDelta, + has_rc_attr_flag: bool, + in_rc_flag: bool, + rc_delta_buffer: io.StringIO, + fc_delta_buffer: io.StringIO, + tool_calls_buffer: list[tuple[str, str, io.StringIO]], +) -> bool: + # 接收content + if has_rc_attr_flag: + # 有独立的推理内容块,则无需考虑content内容的判读 + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + # 如果有推理内容,则将其写入推理内容缓冲区 + assert isinstance(delta.reasoning_content, str) + rc_delta_buffer.write(delta.reasoning_content) + elif delta.content: + # 如果有正式内容,则将其写入正式内容缓冲区 + fc_delta_buffer.write(delta.content) + elif hasattr(delta, "content") and delta.content is not None: + # 没有独立的推理内容块,但有正式内容 + if in_rc_flag: + # 当前在推理内容块中 + if delta.content == "": + # 如果当前内容是,则将其视为推理内容的结束标记,退出推理内容块 + in_rc_flag = False + else: + # 其他情况视为推理内容,加入推理内容缓冲区 + rc_delta_buffer.write(delta.content) + elif delta.content == "" and not fc_delta_buffer.getvalue(): + # 如果当前内容是,且正式内容缓冲区为空,说明为输出的首个token + # 则将其视为推理内容的开始标记,进入推理内容块 + in_rc_flag = True + else: + # 其他情况视为正式内容,加入正式内容缓冲区 + fc_delta_buffer.write(delta.content) + # 接收tool_calls + if hasattr(delta, "tool_calls") and delta.tool_calls: + tool_call_delta = delta.tool_calls[0] + + if tool_call_delta.index >= len(tool_calls_buffer): + # 调用索引号大于等于缓冲区长度,说明是新的工具调用 + tool_calls_buffer.append( + ( + tool_call_delta.id, + tool_call_delta.function.name, + io.StringIO(), + ) + ) + + if tool_call_delta.function.arguments: + # 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中 + tool_calls_buffer[tool_call_delta.index][2].write( + tool_call_delta.function.arguments + ) + + return in_rc_flag + + +def _build_stream_api_resp( + _fc_delta_buffer: io.StringIO, + _rc_delta_buffer: io.StringIO, + _tool_calls_buffer: list[tuple[str, str, io.StringIO]], +) -> APIResponse: + resp = APIResponse() + + if _rc_delta_buffer.tell() > 0: + # 如果推理内容缓冲区不为空,则将其写入APIResponse对象 + resp.reasoning_content = _rc_delta_buffer.getvalue() + _rc_delta_buffer.close() + if _fc_delta_buffer.tell() > 0: + # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 + resp.content = _fc_delta_buffer.getvalue() + _fc_delta_buffer.close() + if _tool_calls_buffer: + # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 + resp.tool_calls = [] + for call_id, function_name, arguments_buffer in _tool_calls_buffer: + if arguments_buffer.tell() > 0: + # 如果参数串缓冲区不为空,则解析为JSON对象 + raw_arg_data = arguments_buffer.getvalue() + arguments_buffer.close() + try: + arguments = json.loads(raw_arg_data) + if not isinstance(arguments, dict): + raise RespParseException( + None, + "响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n" + f"{raw_arg_data}", + ) + except json.JSONDecodeError as e: + raise RespParseException( + None, + "响应解析失败,无法解析工具调用参数。工具调用参数原始响应:" + f"{raw_arg_data}", + ) from e + else: + arguments_buffer.close() + arguments = None + + resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + + return resp + + +async def _default_stream_response_handler( + resp_stream: AsyncStream[ChatCompletionChunk], + interrupt_flag: asyncio.Event | None, +) -> tuple[APIResponse, tuple[int, int, int]]: + """ + 流式响应处理函数 - 处理OpenAI API的流式响应 + :param resp_stream: 流式响应对象 + :return: APIResponse对象 + """ + + _has_rc_attr_flag = False # 标记是否有独立的推理内容块 + _in_rc_flag = False # 标记是否在推理内容块中 + _rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容 + _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 + _tool_calls_buffer: list[ + tuple[str, str, io.StringIO] + ] = [] # 工具调用缓冲区,用于存储接收到的工具调用 + _usage_record = None # 使用情况记录 + + def _insure_buffer_closed(): + # 确保缓冲区被关闭 + if _rc_delta_buffer and not _rc_delta_buffer.closed: + _rc_delta_buffer.close() + if _fc_delta_buffer and not _fc_delta_buffer.closed: + _fc_delta_buffer.close() + for _, _, buffer in _tool_calls_buffer: + if buffer and not buffer.closed: + buffer.close() + + async for event in resp_stream: + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量被设置,则抛出ReqAbortException + _insure_buffer_closed() + raise ReqAbortException("请求被外部信号中断") + + delta = event.choices[0].delta # 获取当前块的delta内容 + + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + # 标记:有独立的推理内容块 + _has_rc_attr_flag = True + + _in_rc_flag = _process_delta( + delta, + _has_rc_attr_flag, + _in_rc_flag, + _rc_delta_buffer, + _fc_delta_buffer, + _tool_calls_buffer, + ) + + if event.usage: + # 如果有使用情况,则将其存储在APIResponse对象中 + _usage_record = ( + event.usage.prompt_tokens, + event.usage.completion_tokens, + event.usage.total_tokens, + ) + + try: + return _build_stream_api_resp( + _fc_delta_buffer, + _rc_delta_buffer, + _tool_calls_buffer, + ), _usage_record + except Exception: + # 确保缓冲区被关闭 + _insure_buffer_closed() + raise + + +pattern = re.compile( + r"(?P.*?)(?P.*)|(?P.*)|(?P.+)", + re.DOTALL, +) +"""用于解析推理内容的正则表达式""" + + +def _default_normal_response_parser( + resp: ChatCompletion, +) -> tuple[APIResponse, tuple[int, int, int]]: + """ + 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象 + :param resp: 响应对象 + :return: APIResponse对象 + """ + api_response = APIResponse() + + if not hasattr(resp, "choices") or len(resp.choices) == 0: + raise RespParseException(resp, "响应解析失败,缺失choices字段") + message_part = resp.choices[0].message + + if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: + # 有有效的推理字段 + api_response.content = message_part.content + api_response.reasoning_content = message_part.reasoning_content + elif message_part.content: + # 提取推理和内容 + match = pattern.match(message_part.content) + if not match: + raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容") + if match.group("think") is not None: + result = match.group("think").strip(), match.group("content").strip() + elif match.group("think_unclosed") is not None: + result = match.group("think_unclosed").strip(), None + else: + result = None, match.group("content_only").strip() + api_response.reasoning_content, api_response.content = result + + # 提取工具调用 + if message_part.tool_calls: + api_response.tool_calls = [] + for call in message_part.tool_calls: + try: + arguments = json.loads(call.function.arguments) + if not isinstance(arguments, dict): + raise RespParseException( + resp, "响应解析失败,工具调用参数无法解析为字典类型" + ) + api_response.tool_calls.append( + ToolCall(call.id, call.function.name, arguments) + ) + except json.JSONDecodeError as e: + raise RespParseException( + resp, "响应解析失败,无法解析工具调用参数" + ) from e + + # 提取Usage信息 + if resp.usage: + _usage_record = ( + resp.usage.prompt_tokens, + resp.usage.completion_tokens, + resp.usage.total_tokens, + ) + else: + _usage_record = None + + # 将原始响应存储在原始数据中 + api_response.raw_data = resp + + return api_response, _usage_record + + +class OpenaiClient(BaseClient): + def __init__(self, api_provider: APIProvider): + super().__init__(api_provider) + # 不再在初始化时创建固定的client,而是在请求时动态创建 + self._clients_cache = {} # API Key -> AsyncOpenAI client 的缓存 + + def _get_client(self, api_key: str = None) -> AsyncOpenAI: + """获取或创建对应API Key的客户端""" + if api_key is None: + api_key = self.api_provider.get_current_api_key() + + if not api_key: + raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key") + + # 使用缓存避免重复创建客户端 + if api_key not in self._clients_cache: + self._clients_cache[api_key] = AsyncOpenAI( + base_url=self.api_provider.base_url, + api_key=api_key, + max_retries=0, + ) + + return self._clients_cache[api_key] + + async def _execute_with_fallback(self, func, *args, **kwargs): + """执行请求并在失败时切换API Key""" + current_api_key = self.api_provider.get_current_api_key() + max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1 + + for attempt in range(max_attempts): + try: + client = self._get_client(current_api_key) + result = await func(client, *args, **kwargs) + # 成功时重置失败计数 + self.api_provider.reset_key_failures(current_api_key) + return result + + except (APIStatusError, APIConnectionError) as e: + # 记录失败并尝试下一个API Key + logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}") + + if attempt < max_attempts - 1: # 还有重试机会 + next_api_key = self.api_provider.mark_key_failed(current_api_key) + if next_api_key and next_api_key != current_api_key: + current_api_key = next_api_key + logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}") + continue + + # 所有API Key都失败了,重新抛出异常 + if isinstance(e, APIStatusError): + raise RespNotOkException(e.status_code, e.message) from e + elif isinstance(e, APIConnectionError): + raise NetworkConnectionError(str(e)) from e + + except Exception as e: + # 其他异常直接抛出 + raise e + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + tuple[APIResponse, tuple[int, int, int]], + ] + | None = None, + async_response_parser: Callable[ + [ChatCompletion], tuple[APIResponse, tuple[int, int, int]] + ] + | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param model_info: 模型信息 + :param message_list: 对话体 + :param tool_options: 工具选项(可选,默认为None) + :param max_tokens: 最大token数(可选,默认为1024) + :param temperature: 温度(可选,默认为0.7) + :param response_format: 响应格式(可选,默认为 NotGiven ) + :param stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) + :param async_response_parser: 响应解析函数(可选,默认为default_response_parser) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: (响应文本, 推理文本, 工具调用, 其他数据) + """ + return await self._execute_with_fallback( + self._get_response_internal, + model_info, + message_list, + tool_options, + max_tokens, + temperature, + response_format, + stream_response_handler, + async_response_parser, + interrupt_flag, + ) + + async def _get_response_internal( + self, + client: AsyncOpenAI, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + tuple[APIResponse, tuple[int, int, int]], + ] + | None = None, + async_response_parser: Callable[ + [ChatCompletion], tuple[APIResponse, tuple[int, int, int]] + ] + | None = None, + interrupt_flag: asyncio.Event | None = None, + ) -> APIResponse: + """内部方法:执行实际的API调用""" + if stream_response_handler is None: + stream_response_handler = _default_stream_response_handler + + if async_response_parser is None: + async_response_parser = _default_normal_response_parser + + # 将messages构造为OpenAI API所需的格式 + messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) + # 将tool_options转换为OpenAI API所需的格式 + tools: Iterable[ChatCompletionToolParam] = ( + _convert_tool_options(tool_options) if tool_options else NOT_GIVEN + ) + + try: + if model_info.force_stream_mode: + req_task = asyncio.create_task( + client.chat.completions.create( + model=model_info.model_identifier, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=True, + response_format=response_format.to_dict() + if response_format + else NOT_GIVEN, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 + + resp, usage_record = await stream_response_handler( + req_task.result(), interrupt_flag + ) + else: + # 发送请求并获取响应 + req_task = asyncio.create_task( + client.chat.completions.create( + model=model_info.model_identifier, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + response_format=response_format.to_dict() + if response_format + else NOT_GIVEN, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 + + resp, usage_record = async_response_parser(req_task.result()) + except APIConnectionError as e: + # 重封装APIConnectionError为NetworkConnectionError + raise NetworkConnectionError() from e + except APIStatusError as e: + # 重封装APIError为RespNotOkException + raise RespNotOkException(e.status_code, e.message) from e + + if usage_record: + resp.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + return resp + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + return await self._execute_with_fallback( + self._get_embedding_internal, + model_info, + embedding_input, + ) + + async def _get_embedding_internal( + self, + client: AsyncOpenAI, + model_info: ModelInfo, + embedding_input: str, + ) -> APIResponse: + """内部方法:执行实际的嵌入API调用""" + try: + raw_response = await client.embeddings.create( + model=model_info.model_identifier, + input=embedding_input, + ) + except APIConnectionError as e: + raise NetworkConnectionError() from e + except APIStatusError as e: + # 重封装APIError为RespNotOkException + raise RespNotOkException(e.status_code) from e + + response = APIResponse() + + # 解析嵌入响应 + if len(raw_response.data) > 0: + response.embedding = raw_response.data[0].embedding + else: + raise RespParseException( + raw_response, + "响应解析失败,缺失嵌入数据。", + ) + + # 解析使用情况 + if hasattr(raw_response, "usage"): + response.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=raw_response.usage.prompt_tokens, + completion_tokens=raw_response.usage.completion_tokens, + total_tokens=raw_response.usage.total_tokens, + ) + + return response diff --git a/src/llm_models/model_manager.py b/src/llm_models/model_manager.py new file mode 100644 index 000000000..36d63c72e --- /dev/null +++ b/src/llm_models/model_manager.py @@ -0,0 +1,92 @@ +import importlib +from typing import Dict + +from src.config.config import model_config +from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig +from src.common.logger import get_logger + +from .model_client import ModelRequestHandler, BaseClient + +logger = get_logger("模型管理器") + +class ModelManager: + # TODO: 添加读写锁,防止异步刷新配置时发生数据竞争 + + def __init__( + self, + config: ModuleConfig, + ): + self.config: ModuleConfig = config + """配置信息""" + + self.api_client_map: Dict[str, BaseClient] = {} + """API客户端映射表""" + + self._request_handler_cache: Dict[str, ModelRequestHandler] = {} + """ModelRequestHandler缓存,避免重复创建""" + + for provider_name, api_provider in self.config.api_providers.items(): + # 初始化API客户端 + try: + # 根据配置动态加载实现 + client_module = importlib.import_module( + f".model_client.{api_provider.client_type}_client", __package__ + ) + client_class = getattr( + client_module, f"{api_provider.client_type.capitalize()}Client" + ) + if not issubclass(client_class, BaseClient): + raise TypeError( + f"'{client_class.__name__}' is not a subclass of 'BaseClient'" + ) + self.api_client_map[api_provider.name] = client_class( + api_provider + ) # 实例化,放入api_client_map + except ImportError as e: + logger.error(f"Failed to import client module: {e}") + raise ImportError( + f"Failed to import client module for '{provider_name}': {e}" + ) from e + + def __getitem__(self, task_name: str) -> ModelRequestHandler: + """ + 获取任务所需的模型客户端(封装) + 使用缓存机制避免重复创建ModelRequestHandler + :param task_name: 任务名称 + :return: 模型客户端 + """ + if task_name not in self.config.task_model_arg_map: + raise KeyError(f"'{task_name}' not registered in ModelManager") + + # 检查缓存中是否已存在 + if task_name in self._request_handler_cache: + logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}") + return self._request_handler_cache[task_name] + + # 创建新的ModelRequestHandler并缓存 + logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}") + handler = ModelRequestHandler( + task_name=task_name, + config=self.config, + api_client_map=self.api_client_map, + ) + self._request_handler_cache[task_name] = handler + return handler + + def __setitem__(self, task_name: str, value: ModelUsageArgConfig): + """ + 注册任务的模型使用配置 + :param task_name: 任务名称 + :param value: 模型使用配置 + """ + self.config.task_model_arg_map[task_name] = value + + def __contains__(self, task_name: str): + """ + 判断任务是否已注册 + :param task_name: 任务名称 + :return: 是否在模型列表中 + """ + return task_name in self.config.task_model_arg_map + + diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py new file mode 100644 index 000000000..26202ca11 --- /dev/null +++ b/src/llm_models/payload_content/message.py @@ -0,0 +1,104 @@ +from enum import Enum + + +# 设计这系列类的目的是为未来可能的扩展做准备 + + +class RoleType(Enum): + System = "system" + User = "user" + Assistant = "assistant" + Tool = "tool" + + +SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] + + +class Message: + def __init__( + self, + role: RoleType, + content: str | list[tuple[str, str] | str], + tool_call_id: str | None = None, + ): + """ + 初始化消息对象 + (不应直接修改Message类,而应使用MessageBuilder类来构建对象) + """ + self.role: RoleType = role + self.content: str | list[tuple[str, str] | str] = content + self.tool_call_id: str | None = tool_call_id + + +class MessageBuilder: + def __init__(self): + self.__role: RoleType = RoleType.User + self.__content: list[tuple[str, str] | str] = [] + self.__tool_call_id: str | None = None + + def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder": + """ + 设置角色(默认为User) + :param role: 角色 + :return: MessageBuilder对象 + """ + self.__role = role + return self + + def add_text_content(self, text: str) -> "MessageBuilder": + """ + 添加文本内容 + :param text: 文本内容 + :return: MessageBuilder对象 + """ + self.__content.append(text) + return self + + def add_image_content( + self, image_format: str, image_base64: str + ) -> "MessageBuilder": + """ + 添加图片内容 + :param image_format: 图片格式 + :param image_base64: 图片的base64编码 + :return: MessageBuilder对象 + """ + if image_format.lower() not in SUPPORTED_IMAGE_FORMATS: + raise ValueError("不受支持的图片格式") + if not image_base64: + raise ValueError("图片的base64编码不能为空") + self.__content.append((image_format, image_base64)) + return self + + def add_tool_call(self, tool_call_id: str) -> "MessageBuilder": + """ + 添加工具调用指令(调用时请确保已设置为Tool角色) + :param tool_call_id: 工具调用指令的id + :return: MessageBuilder对象 + """ + if self.__role != RoleType.Tool: + raise ValueError("仅当角色为Tool时才能添加工具调用ID") + if not tool_call_id: + raise ValueError("工具调用ID不能为空") + self.__tool_call_id = tool_call_id + return self + + def build(self) -> Message: + """ + 构建消息对象 + :return: Message对象 + """ + if len(self.__content) == 0: + raise ValueError("内容不能为空") + if self.__role == RoleType.Tool and self.__tool_call_id is None: + raise ValueError("Tool角色的工具调用ID不能为空") + + return Message( + role=self.__role, + content=( + self.__content[0] + if (len(self.__content) == 1 and isinstance(self.__content[0], str)) + else self.__content + ), + tool_call_id=self.__tool_call_id, + ) diff --git a/src/llm_models/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py new file mode 100644 index 000000000..ab2e2edf4 --- /dev/null +++ b/src/llm_models/payload_content/resp_format.py @@ -0,0 +1,223 @@ +from enum import Enum +from typing import Optional, Any + +from pydantic import BaseModel +from typing_extensions import TypedDict, Required + + +class RespFormatType(Enum): + TEXT = "text" # 文本 + JSON_OBJ = "json_object" # JSON + JSON_SCHEMA = "json_schema" # JSON Schema + + +class JsonSchema(TypedDict, total=False): + name: Required[str] + """ + The name of the response format. + + Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length + of 64. + """ + + description: Optional[str] + """ + A description of what the response format is for, used by the model to determine + how to respond in the format. + """ + + schema: dict[str, object] + """ + The schema for the response format, described as a JSON Schema object. Learn how + to build JSON schemas [here](https://json-schema.org/). + """ + + strict: Optional[bool] + """ + Whether to enable strict schema adherence when generating the output. If set to + true, the model will always follow the exact schema defined in the `schema` + field. Only a subset of JSON Schema is supported when `strict` is `true`. To + learn more, read the + [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + """ + + +def _json_schema_type_check(instance) -> str | None: + if "name" not in instance: + return "schema必须包含'name'字段" + elif not isinstance(instance["name"], str) or instance["name"].strip() == "": + return "schema的'name'字段必须是非空字符串" + if "description" in instance and ( + not isinstance(instance["description"], str) + or instance["description"].strip() == "" + ): + return "schema的'description'字段只能填入非空字符串" + if "schema" not in instance: + return "schema必须包含'schema'字段" + elif not isinstance(instance["schema"], dict): + return "schema的'schema'字段必须是字典,详见https://json-schema.org/" + if "strict" in instance and not isinstance(instance["strict"], bool): + return "schema的'strict'字段只能填入布尔值" + + return None + + +def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]: + """ + 递归移除JSON Schema中的title字段 + """ + if isinstance(schema, list): + # 如果当前Schema是列表,则对所有dict/list子元素递归调用 + for idx, item in enumerate(schema): + if isinstance(item, (dict, list)): + schema[idx] = _remove_title(item) + elif isinstance(schema, dict): + # 是字典,移除title字段,并对所有dict/list子元素递归调用 + if "title" in schema: + del schema["title"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_title(value) + + return schema + + +def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: + """ + 链接JSON Schema中的definitions字段 + """ + + def link_definitions_recursive( + path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any] + ) -> dict[str, Any]: + """ + 递归链接JSON Schema中的definitions字段 + :param path: 当前路径 + :param sub_schema: 子Schema + :param defs: Schema定义集 + :return: + """ + if isinstance(sub_schema, list): + # 如果当前Schema是列表,则遍历每个元素 + for i in range(len(sub_schema)): + if isinstance(sub_schema[i], dict): + sub_schema[i] = link_definitions_recursive( + f"{path}/{str(i)}", sub_schema[i], defs + ) + else: + # 否则为字典 + if "$defs" in sub_schema: + # 如果当前Schema有$def字段,则将其添加到defs中 + key_prefix = f"{path}/$defs/" + for key, value in sub_schema["$defs"].items(): + def_key = key_prefix + key + if def_key not in defs: + defs[def_key] = value + del sub_schema["$defs"] + if "$ref" in sub_schema: + # 如果当前Schema有$ref字段,则将其替换为defs中的定义 + def_key = sub_schema["$ref"] + if def_key in defs: + sub_schema = defs[def_key] + else: + raise ValueError(f"Schema中引用的定义'{def_key}'不存在") + # 遍历键值对 + for key, value in sub_schema.items(): + if isinstance(value, (dict, list)): + # 如果当前值是字典或列表,则递归调用 + sub_schema[key] = link_definitions_recursive( + f"{path}/{key}", value, defs + ) + + return sub_schema + + return link_definitions_recursive("#", schema, {}) + + +def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]: + """ + 递归移除JSON Schema中的$defs字段 + """ + if isinstance(schema, list): + # 如果当前Schema是列表,则对所有dict/list子元素递归调用 + for idx, item in enumerate(schema): + if isinstance(item, (dict, list)): + schema[idx] = _remove_title(item) + elif isinstance(schema, dict): + # 是字典,移除title字段,并对所有dict/list子元素递归调用 + if "$defs" in schema: + del schema["$defs"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_title(value) + + return schema + + +class RespFormat: + """ + 响应格式 + """ + + @staticmethod + def _generate_schema_from_model(schema): + json_schema = { + "name": schema.__name__, + "schema": _remove_defs( + _link_definitions(_remove_title(schema.model_json_schema())) + ), + "strict": False, + } + if schema.__doc__: + json_schema["description"] = schema.__doc__ + return json_schema + + def __init__( + self, + format_type: RespFormatType = RespFormatType.TEXT, + schema: type | JsonSchema | None = None, + ): + """ + 响应格式 + :param format_type: 响应格式类型(默认为文本) + :param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效) + """ + self.format_type: RespFormatType = format_type + + if format_type == RespFormatType.JSON_SCHEMA: + if schema is None: + raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空") + if isinstance(schema, dict): + if check_msg := _json_schema_type_check(schema): + raise ValueError(f"schema格式不正确,{check_msg}") + + self.schema = schema + elif issubclass(schema, BaseModel): + try: + json_schema = self._generate_schema_from_model(schema) + + self.schema = json_schema + except Exception as e: + raise ValueError( + f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n" + f"{schema.__name__}:\n" + ) from e + else: + raise ValueError("schema必须是BaseModel的子类或JsonSchema") + else: + self.schema = None + + def to_dict(self): + """ + 将响应格式转换为字典 + :return: 字典 + """ + if self.schema: + return { + "format_type": self.format_type.value, + "schema": self.schema, + } + else: + return { + "format_type": self.format_type.value, + } diff --git a/src/llm_models/payload_content/tool_option.py b/src/llm_models/payload_content/tool_option.py new file mode 100644 index 000000000..8a9bbdb31 --- /dev/null +++ b/src/llm_models/payload_content/tool_option.py @@ -0,0 +1,155 @@ +from enum import Enum + + +class ToolParamType(Enum): + """ + 工具调用参数类型 + """ + + String = "string" # 字符串 + Int = "integer" # 整型 + Float = "float" # 浮点型 + Boolean = "bool" # 布尔型 + + +class ToolParam: + """ + 工具调用参数 + """ + + def __init__( + self, name: str, param_type: ToolParamType, description: str, required: bool + ): + """ + 初始化工具调用参数 + (不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象) + :param name: 参数名称 + :param param_type: 参数类型 + :param description: 参数描述 + :param required: 是否必填 + """ + self.name: str = name + self.param_type: ToolParamType = param_type + self.description: str = description + self.required: bool = required + + +class ToolOption: + """ + 工具调用项 + """ + + def __init__( + self, + name: str, + description: str, + params: list[ToolParam] | None = None, + ): + """ + 初始化工具调用项 + (不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象) + :param name: 工具名称 + :param description: 工具描述 + :param params: 工具参数列表 + """ + self.name: str = name + self.description: str = description + self.params: list[ToolParam] | None = params + + +class ToolOptionBuilder: + """ + 工具调用项构建器 + """ + + def __init__(self): + self.__name: str = "" + self.__description: str = "" + self.__params: list[ToolParam] = [] + + def set_name(self, name: str) -> "ToolOptionBuilder": + """ + 设置工具名称 + :param name: 工具名称 + :return: ToolBuilder实例 + """ + if not name: + raise ValueError("工具名称不能为空") + self.__name = name + return self + + def set_description(self, description: str) -> "ToolOptionBuilder": + """ + 设置工具描述 + :param description: 工具描述 + :return: ToolBuilder实例 + """ + if not description: + raise ValueError("工具描述不能为空") + self.__description = description + return self + + def add_param( + self, + name: str, + param_type: ToolParamType, + description: str, + required: bool = False, + ) -> "ToolOptionBuilder": + """ + 添加工具参数 + :param name: 参数名称 + :param param_type: 参数类型 + :param description: 参数描述 + :param required: 是否必填(默认为False) + :return: ToolBuilder实例 + """ + if not name or not description: + raise ValueError("参数名称/描述不能为空") + + self.__params.append( + ToolParam( + name=name, + param_type=param_type, + description=description, + required=required, + ) + ) + + return self + + def build(self): + """ + 构建工具调用项 + :return: 工具调用项 + """ + if self.__name == "" or self.__description == "": + raise ValueError("工具名称/描述不能为空") + + return ToolOption( + name=self.__name, + description=self.__description, + params=None if len(self.__params) == 0 else self.__params, + ) + + +class ToolCall: + """ + 来自模型反馈的工具调用 + """ + + def __init__( + self, + call_id: str, + func_name: str, + args: dict | None = None, + ): + """ + 初始化工具调用 + :param call_id: 工具调用ID + :param func_name: 要调用的函数名称 + :param args: 工具调用参数 + """ + self.call_id: str = call_id + self.func_name: str = func_name + self.args: dict | None = args diff --git a/src/llm_models/usage_statistic.py b/src/llm_models/usage_statistic.py new file mode 100644 index 000000000..176c4b7b1 --- /dev/null +++ b/src/llm_models/usage_statistic.py @@ -0,0 +1,172 @@ +from datetime import datetime +from enum import Enum +from typing import Tuple + +from src.common.logger import get_logger +from src.config.api_ada_configs import ModelInfo +from src.common.database.database_model import LLMUsage + +logger = get_logger("模型使用统计") + + +class ReqType(Enum): + """ + 请求类型 + """ + + CHAT = "chat" # 对话请求 + EMBEDDING = "embedding" # 嵌入请求 + + +class UsageCallStatus(Enum): + """ + 任务调用状态 + """ + + PROCESSING = "processing" # 处理中 + SUCCESS = "success" # 成功 + FAILURE = "failure" # 失败 + CANCELED = "canceled" # 取消 + + +class ModelUsageStatistic: + """ + 模型使用统计类 - 使用SQLite+Peewee + """ + + def __init__(self): + """ + 初始化统计类 + 由于使用Peewee ORM,不需要传入数据库实例 + """ + # 确保表已经创建 + try: + from src.common.database.database import db + db.create_tables([LLMUsage], safe=True) + except Exception as e: + logger.error(f"创建LLMUsage表失败: {e}") + + @staticmethod + def _calculate_cost( + prompt_tokens: int, completion_tokens: int, model_info: ModelInfo + ) -> float: + """计算API调用成本 + 使用模型的pri_in和pri_out价格计算输入和输出的成本 + + Args: + prompt_tokens: 输入token数量 + completion_tokens: 输出token数量 + model_info: 模型信息 + + Returns: + float: 总成本(元) + """ + # 使用模型的pri_in和pri_out计算成本 + input_cost = (prompt_tokens / 1000000) * model_info.price_in + output_cost = (completion_tokens / 1000000) * model_info.price_out + return round(input_cost + output_cost, 6) + + def create_usage( + self, + model_name: str, + task_name: str = "N/A", + request_type: ReqType = ReqType.CHAT, + user_id: str = "system", + endpoint: str = "/chat/completions", + ) -> int | None: + """ + 创建模型使用情况记录 + + Args: + model_name: 模型名 + task_name: 任务名称 + request_type: 请求类型,默认为Chat + user_id: 用户ID,默认为system + endpoint: API端点 + + Returns: + int | None: 返回记录ID,失败返回None + """ + try: + usage_record = LLMUsage.create( + model_name=model_name, + user_id=user_id, + request_type=request_type.value, + endpoint=endpoint, + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost=0.0, + status=UsageCallStatus.PROCESSING.value, + timestamp=datetime.now(), + ) + + logger.trace( + f"创建了一条模型使用情况记录 - 模型: {model_name}, " + f"子任务: {task_name}, 类型: {request_type.value}, " + f"用户: {user_id}, 记录ID: {usage_record.id}" + ) + + return usage_record.id + except Exception as e: + logger.error(f"创建模型使用情况记录失败: {str(e)}") + return None + + def update_usage( + self, + record_id: int | None, + model_info: ModelInfo, + usage_data: Tuple[int, int, int] | None = None, + stat: UsageCallStatus = UsageCallStatus.SUCCESS, + ext_msg: str | None = None, + ): + """ + 更新模型使用情况 + + Args: + record_id: 记录ID + model_info: 模型信息 + usage_data: 使用情况数据(输入token数量, 输出token数量, 总token数量) + stat: 任务调用状态 + ext_msg: 额外信息 + """ + if not record_id: + logger.error("更新模型使用情况失败: record_id不能为空") + return + + if usage_data and len(usage_data) != 3: + logger.error("更新模型使用情况失败: usage_data的长度不正确,应该为3个元素") + return + + # 提取使用情况数据 + prompt_tokens = usage_data[0] if usage_data else 0 + completion_tokens = usage_data[1] if usage_data else 0 + total_tokens = usage_data[2] if usage_data else 0 + + try: + # 使用Peewee更新记录 + update_query = LLMUsage.update( + status=stat.value, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=self._calculate_cost( + prompt_tokens, completion_tokens, model_info + ) if usage_data else 0.0, + ).where(LLMUsage.id == record_id) + + updated_count = update_query.execute() + + if updated_count == 0: + logger.warning(f"记录ID {record_id} 不存在,无法更新") + return + + logger.debug( + f"Token使用情况 - 模型: {model_info.name}, " + f"记录ID: {record_id}, " + f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, " + f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " + f"总计: {total_tokens}" + ) + except Exception as e: + logger.error(f"记录token使用情况失败: {str(e)}") diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py new file mode 100644 index 000000000..352df5a43 --- /dev/null +++ b/src/llm_models/utils.py @@ -0,0 +1,152 @@ +import base64 +import io + +from PIL import Image + +from src.common.logger import get_logger +from .payload_content.message import Message, MessageBuilder + +logger = get_logger("消息压缩工具") + + +def compress_messages( + messages: list[Message], img_target_size: int = 1 * 1024 * 1024 +) -> list[Message]: + """ + 压缩消息列表中的图片 + :param messages: 消息列表 + :param img_target_size: 图片目标大小,默认1MB + :return: 压缩后的消息列表 + """ + + def reformat_static_image(image_data: bytes) -> bytes: + """ + 将静态图片转换为JPEG格式 + :param image_data: 图片数据 + :return: 转换后的图片数据 + """ + try: + image = Image.open(image_data) + + if image.format and ( + image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"] + ): + # 静态图像,转换为JPEG格式 + reformated_image_data = io.BytesIO() + image.save( + reformated_image_data, format="JPEG", quality=95, optimize=True + ) + image_data = reformated_image_data.getvalue() + + return image_data + except Exception as e: + logger.error(f"图片转换格式失败: {str(e)}") + return image_data + + def rescale_image( + image_data: bytes, scale: float + ) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]: + """ + 缩放图片 + :param image_data: 图片数据 + :param scale: 缩放比例 + :return: 缩放后的图片数据 + """ + try: + image = Image.open(image_data) + + # 原始尺寸 + original_size = (image.width, image.height) + + # 计算新的尺寸 + new_size = (int(original_size[0] * scale), int(original_size[1] * scale)) + + output_buffer = io.BytesIO() + + if getattr(image, "is_animated", False): + # 动态图片,处理所有帧 + frames = [] + new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折 + for frame_idx in range(getattr(image, "n_frames", 1)): + image.seek(frame_idx) + new_frame = image.copy() + new_frame = new_frame.resize(new_size, Image.Resampling.LANCZOS) + frames.append(new_frame) + + # 保存到缓冲区 + frames[0].save( + output_buffer, + format="GIF", + save_all=True, + append_images=frames[1:], + optimize=True, + duration=image.info.get("duration", 100), + loop=image.info.get("loop", 0), + ) + else: + # 静态图片,直接缩放保存 + resized_image = image.resize(new_size, Image.Resampling.LANCZOS) + resized_image.save( + output_buffer, format="JPEG", quality=95, optimize=True + ) + + return output_buffer.getvalue(), original_size, new_size + + except Exception as e: + logger.error(f"图片缩放失败: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return image_data, None, None + + def compress_base64_image( + base64_data: str, target_size: int = 1 * 1024 * 1024 + ) -> str: + original_b64_data_size = len(base64_data) # 计算原始数据大小 + + image_data = base64.b64decode(base64_data) + + # 先尝试转换格式为JPEG + image_data = reformat_static_image(image_data) + base64_data = base64.b64encode(image_data).decode("utf-8") + if len(base64_data) <= target_size: + # 如果转换后小于目标大小,直接返回 + logger.info( + f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB" + ) + return base64_data + + # 如果转换后仍然大于目标大小,进行尺寸压缩 + scale = min(1.0, target_size / len(base64_data)) + image_data, original_size, new_size = rescale_image(image_data, scale) + base64_data = base64.b64encode(image_data).decode("utf-8") + + if original_size and new_size: + logger.info( + f"压缩图片: {original_size[0]}x{original_size[1]} -> {new_size[0]}x{new_size[1]}\n" + f"压缩前大小: {original_b64_data_size / 1024:.1f}KB, 压缩后大小: {len(base64_data) / 1024:.1f}KB" + ) + + return base64_data + + compressed_messages = [] + for message in messages: + if isinstance(message.content, list): + # 检查content,如有图片则压缩 + message_builder = MessageBuilder() + for content_item in message.content: + if isinstance(content_item, tuple): + # 图片,进行压缩 + message_builder.add_image_content( + content_item[0], + compress_base64_image( + content_item[1], target_size=img_target_size + ), + ) + else: + message_builder.add_text_content(content_item) + compressed_messages.append(message_builder.build()) + else: + compressed_messages.append(message) + + return compressed_messages diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 98d93db13..805a47343 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,26 +1,55 @@ -import asyncio -import json import re from datetime import datetime -from typing import Tuple, Union, Dict, Any, Callable -import aiohttp -from aiohttp.client import ClientResponse +from typing import Tuple, Union from src.common.logger import get_logger import base64 from PIL import Image import io -import os -import copy # 添加copy模块用于深拷贝 from src.common.database.database import db # 确保 db 被导入用于 create_tables from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from src.config.config import global_config -from src.common.tcp_connector import get_tcp_connector from rich.traceback import install install(extra_lines=3) logger = get_logger("model_utils") +# 导入具体的异常类型用于精确的异常处理 +try: + from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException + SPECIFIC_EXCEPTIONS_AVAILABLE = True +except ImportError: + logger.warning("无法导入具体异常类型,将使用通用异常处理") + NetworkConnectionError = Exception + ReqAbortException = Exception + RespNotOkException = Exception + RespParseException = Exception + SPECIFIC_EXCEPTIONS_AVAILABLE = False + +# 新架构导入 - 使用延迟导入以支持fallback模式 +try: + from .model_manager import ModelManager + from .model_client import ModelRequestHandler + from .payload_content.message import MessageBuilder + + # 不在模块级别初始化ModelManager,延迟到实际使用时 + ModelManager_class = ModelManager + model_manager = None # 延迟初始化 + + # 添加请求处理器缓存,避免重复创建 + _request_handler_cache = {} # 格式: {(model_name, task_name): ModelRequestHandler} + + NEW_ARCHITECTURE_AVAILABLE = True + logger.info("新架构模块导入成功") +except Exception as e: + logger.warning(f"新架构不可用,将使用fallback模式: {str(e)}") + ModelManager_class = None + model_manager = None + ModelRequestHandler = None + MessageBuilder = None + _request_handler_cache = {} + NEW_ARCHITECTURE_AVAILABLE = False + class PayLoadTooLargeError(Exception): """自定义异常类,用于处理请求体过大错误""" @@ -36,10 +65,9 @@ class PayLoadTooLargeError(Exception): class RequestAbortException(Exception): """自定义异常类,用于处理请求中断异常""" - def __init__(self, message: str, response: ClientResponse): + def __init__(self, message: str): super().__init__(message) self.message = message - self.response = response def __str__(self): return self.message @@ -59,7 +87,7 @@ class PermissionDeniedException(Exception): # 常见Error Code Mapping error_code_mapping = { 400: "参数不正确", - 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", + 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", 402: "账号余额不足", 403: "需要实名,或余额不足", 404: "Not Found", @@ -69,32 +97,14 @@ error_code_mapping = { } -async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]): - """安全地记录请求体,用于调试日志,不会修改原始payload对象""" - # 创建payload的深拷贝,避免修改原始对象 - safe_payload = copy.deepcopy(payload) - - image_base64: str = request_content.get("image_base64") - image_format: str = request_content.get("image_format") - if ( - image_base64 - and safe_payload - and isinstance(safe_payload, dict) - and "messages" in safe_payload - and len(safe_payload["messages"]) > 0 - ): - if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]: - content = safe_payload["messages"][0]["content"] - if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]: - # 只修改拷贝的对象,用于安全的日志记录 - safe_payload["messages"][0]["content"][1]["image_url"]["url"] = ( - f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," - f"{image_base64[:10]}...{image_base64[-10:]}" - ) - return safe_payload class LLMRequest: + """ + 重构后的LLM请求类,基于新的model_manager和model_client架构 + 保持向后兼容的API接口 + """ + # 定义需要转换的模型列表,作为类变量避免重复 MODELS_NEEDING_TRANSFORMATION = [ "o1", @@ -114,42 +124,105 @@ class LLMRequest: ] def __init__(self, model: dict, **kwargs): - # 将大写的配置键转换为小写并从config中获取实际值 - logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('name', 'Unknown')}") - logger.debug(f"🔍 [模型初始化] 模型配置: {model}") + """ + 初始化LLM请求实例 + Args: + model: 模型配置字典,兼容旧格式和新格式 + **kwargs: 额外参数 + """ + logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('model_name', model.get('name', 'Unknown'))}") + logger.debug(f"🔍 [模型初始化] 输入的模型配置: {model}") logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}") - try: - # print(f"model['provider']: {model['provider']}") - self.api_key = os.environ[f"{model['provider']}_KEY"] - self.base_url = os.environ[f"{model['provider']}_BASE_URL"] - logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL") - except AttributeError as e: - logger.error(f"原始 model dict 信息:{model}") - logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") - raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e - except KeyError: - logger.warning( - f"找不到{model['provider']}_KEY或{model['provider']}_BASE_URL环境变量,请检查配置文件或环境变量设置。" - ) - self.model_name: str = model["name"] - self.params = kwargs - - # 记录配置文件中声明了哪些参数(不管值是什么) - self.has_enable_thinking = "enable_thinking" in model - self.has_thinking_budget = "thinking_budget" in model + # 兼容新旧模型配置格式 + # 新格式使用 model_name,旧格式使用 name + self.model_name: str = model.get("model_name", model.get("name", "")) + # 如果传入的配置不完整,自动从全局配置中获取完整配置 + if not all(key in model for key in ["task_type", "capabilities"]): + logger.debug("🔍 [模型初始化] 检测到不完整的模型配置,尝试获取完整配置") + if (full_model_config := self._get_full_model_config(self.model_name)): + logger.debug("🔍 [模型初始化] 成功获取完整模型配置,合并配置信息") + # 合并配置:运行时参数优先,但添加缺失的配置字段 + model = {**full_model_config, **model} + logger.debug(f"🔍 [模型初始化] 合并后的模型配置: {model}") + else: + logger.warning(f"⚠️ [模型初始化] 无法获取模型 {self.model_name} 的完整配置,使用原始配置") + + # 在新架构中,provider信息从model_config.toml自动获取,不需要在这里设置 + self.provider = model.get("provider", "") # 保留兼容性,但在新架构中不使用 + + # 从全局配置中获取任务配置 + self.request_type = kwargs.pop("request_type", "default") + + # 确定使用哪个任务配置 + task_name = self._determine_task_name(model) + + # 初始化 request_handler + self.request_handler = None + + # 尝试初始化新架构 + if NEW_ARCHITECTURE_AVAILABLE and ModelManager_class is not None: + try: + # 延迟初始化ModelManager + global model_manager, _request_handler_cache + if model_manager is None: + from src.config.config import model_config + model_manager = ModelManager_class(model_config) + logger.debug("🔍 [模型初始化] ModelManager延迟初始化成功") + + # 构建缓存键 + cache_key = (self.model_name, task_name) + + # 检查是否已有缓存的请求处理器 + if cache_key in _request_handler_cache: + self.request_handler = _request_handler_cache[cache_key] + logger.debug(f"🚀 [性能优化] 从LLMRequest缓存获取请求处理器: {cache_key}") + else: + # 使用新架构获取模型请求处理器 + self.request_handler = model_manager[task_name] + _request_handler_cache[cache_key] = self.request_handler + logger.debug(f"🔧 [性能优化] 创建并缓存LLMRequest请求处理器: {cache_key}") + + logger.debug(f"🔍 [模型初始化] 成功获取模型请求处理器,任务: {task_name}") + self.use_new_architecture = True + except Exception as e: + logger.warning(f"无法使用新架构,任务 {task_name} 初始化失败: {e}") + logger.warning("回退到兼容模式,某些功能可能受限") + self.request_handler = None + self.use_new_architecture = False + else: + logger.warning("新架构不可用,使用兼容模式") + logger.warning("回退到兼容模式,某些功能可能受限") + self.request_handler = None + self.use_new_architecture = False + + # 保存原始参数用于向后兼容 + self.params = kwargs + + # 兼容性属性,从模型配置中提取 + # 新格式和旧格式都支持 self.enable_thinking = model.get("enable_thinking", False) - self.temp = model.get("temp", 0.7) + self.temp = model.get("temperature", model.get("temp", 0.7)) # 新格式用temperature,旧格式用temp self.thinking_budget = model.get("thinking_budget", 4096) self.stream = model.get("stream", False) self.pri_in = model.get("pri_in", 0) self.pri_out = model.get("pri_out", 0) self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) - # print(f"max_tokens: {self.max_tokens}") + + # 记录配置文件中声明了哪些参数(不管值是什么) + self.has_enable_thinking = "enable_thinking" in model + self.has_thinking_budget = "thinking_budget" in model + self.pri_out = model.get("pri_out", 0) + self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) + + # 记录配置文件中声明了哪些参数(不管值是什么) + self.has_enable_thinking = "enable_thinking" in model + self.has_thinking_budget = "thinking_budget" in model logger.debug("🔍 [模型初始化] 模型参数设置完成:") logger.debug(f" - model_name: {self.model_name}") + logger.debug(f" - provider: {self.provider}") logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}") logger.debug(f" - enable_thinking: {self.enable_thinking}") logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}") @@ -157,15 +230,146 @@ class LLMRequest: logger.debug(f" - temp: {self.temp}") logger.debug(f" - stream: {self.stream}") logger.debug(f" - max_tokens: {self.max_tokens}") - logger.debug(f" - base_url: {self.base_url}") + logger.debug(f" - use_new_architecture: {self.use_new_architecture}") # 获取数据库实例 self._init_database() - - # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default" - self.request_type = kwargs.pop("request_type", "default") + logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}") + def _determine_task_name(self, model: dict) -> str: + """ + 根据模型配置确定任务名称 + 优先使用配置文件中明确定义的任务类型,避免基于模型名称的脆弱推断 + + Args: + model: 模型配置字典 + Returns: + 任务名称 + """ + # 调试信息:打印模型配置字典的所有键 + logger.debug(f"🔍 [任务确定] 模型配置字典的所有键: {list(model.keys())}") + logger.debug(f"🔍 [任务确定] 模型配置字典内容: {model}") + + # 获取模型名称 + model_name = model.get("model_name", model.get("name", "")) + + # 方法1: 优先使用配置文件中明确定义的 task_type 字段 + if "task_type" in model: + task_type = model["task_type"] + logger.debug(f"🎯 [任务确定] 使用配置中的 task_type: {task_type}") + return task_type + + # 方法2: 使用 capabilities 字段来推断主要任务类型 + if "capabilities" in model: + capabilities = model["capabilities"] + if isinstance(capabilities, list): + # 按优先级顺序检查能力 + if "vision" in capabilities: + logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: vision") + return "vision" + elif "embedding" in capabilities: + logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: embedding") + return "embedding" + elif "speech" in capabilities: + logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 推断为: speech") + return "speech" + elif "text" in capabilities: + # 如果只有文本能力,则根据request_type细分 + task = "llm_reasoning" if self.request_type == "reasoning" else "llm_normal" + logger.debug(f"🎯 [任务确定] 从 capabilities {capabilities} 和 request_type {self.request_type} 推断为: {task}") + return task + + # 方法3: 向后兼容 - 基于模型名称的关键字推断(不推荐但保留兼容性) + logger.warning(f"⚠️ [任务确定] 配置中未找到 task_type 或 capabilities,回退到基于模型名称的推断: {model_name}") + logger.warning("⚠️ [建议] 请在 model_config.toml 中为模型添加明确的 task_type 或 capabilities 字段") + + # 保留原有的关键字匹配逻辑作为fallback + if any(keyword in model_name.lower() for keyword in ["vlm", "vision", "gpt-4o", "claude", "vl-"]): + logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: vision") + return "vision" + elif any(keyword in model_name.lower() for keyword in ["embed", "text-embedding", "bge-"]): + logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: embedding") + return "embedding" + elif any(keyword in model_name.lower() for keyword in ["whisper", "speech", "voice"]): + logger.debug(f"🎯 [任务确定] 从模型名称 {model_name} 推断为: speech") + return "speech" + else: + # 根据request_type确定,映射到配置文件中定义的任务 + task = "llm_reasoning" if self.request_type == "reasoning" else "llm_normal" + logger.debug(f"🎯 [任务确定] 从 request_type {self.request_type} 推断为: {task}") + return task + + def _get_full_model_config(self, model_name: str) -> dict | None: + """ + 根据模型名称从全局配置中获取完整的模型配置 + 现在直接使用已解析的ModelInfo对象,不再读取TOML文件 + + Args: + model_name: 模型名称 + Returns: + 完整的模型配置字典,如果找不到则返回None + """ + try: + from src.config.config import model_config + return self._get_model_config_from_parsed(model_name, model_config) + + except Exception as e: + logger.warning(f"⚠️ [配置查找] 获取模型配置时出错: {str(e)}") + return None + + def _get_model_config_from_parsed(self, model_name: str, model_config) -> dict | None: + """ + 从已解析的配置对象中获取模型配置 + 使用扩展后的ModelInfo类,包含task_type和capabilities字段 + """ + try: + # 直接通过模型名称查找 + if model_name in model_config.models: + model_info = model_config.models[model_name] + logger.debug(f"🔍 [配置查找] 找到模型 {model_name} 的配置对象: {model_info}") + + # 将ModelInfo对象转换为字典 + model_dict = { + "model_identifier": model_info.model_identifier, + "name": model_info.name, + "api_provider": model_info.api_provider, + "price_in": model_info.price_in, + "price_out": model_info.price_out, + "force_stream_mode": model_info.force_stream_mode, + "task_type": model_info.task_type, + "capabilities": model_info.capabilities, + } + + logger.debug(f"🔍 [配置查找] 转换后的模型配置字典: {model_dict}") + return model_dict + + # 如果直接查找失败,尝试通过model_identifier查找 + for name, model_info in model_config.models.items(): + if (model_info.model_identifier == model_name or + hasattr(model_info, 'model_name') and model_info.model_name == model_name): + + logger.debug(f"🔍 [配置查找] 通过标识符找到模型 {model_name} (配置名称: {name})") + # 同样转换为字典 + model_dict = { + "model_identifier": model_info.model_identifier, + "name": model_info.name, + "api_provider": model_info.api_provider, + "price_in": model_info.price_in, + "price_out": model_info.price_out, + "force_stream_mode": model_info.force_stream_mode, + "task_type": model_info.task_type, + "capabilities": model_info.capabilities, + } + + return model_dict + + return None + + except Exception as e: + logger.warning(f"⚠️ [配置查找] 从已解析配置获取模型配置时出错: {str(e)}") + return None + @staticmethod def _init_database(): """初始化数据库集合""" @@ -182,7 +386,7 @@ class LLMRequest: completion_tokens: int, total_tokens: int, user_id: str = "system", - request_type: str = None, + request_type: str | None = None, endpoint: str = "/chat/completions", ): """记录模型使用情况到数据库 @@ -237,726 +441,253 @@ class LLMRequest: output_cost = (completion_tokens / 1000000) * self.pri_out return round(input_cost + output_cost, 6) - async def _prepare_request( - self, - endpoint: str, - prompt: str = None, - image_base64: str = None, - image_format: str = None, - file_bytes: bytes = None, - file_format: str = None, - payload: dict = None, - retry_policy: dict = None, - ) -> Dict[str, Any]: - """配置请求参数 - Args: - endpoint: API端点路径 (如 "chat/completions") - prompt: prompt文本 - image_base64: 图片的base64编码 - image_format: 图片格式 - file_bytes: 文件的二进制数据 - file_format: 文件格式 - payload: 请求体数据 - retry_policy: 自定义重试策略 - request_type: 请求类型 - """ - - # 合并重试策略 - default_retry = { - "max_retries": 3, - "base_wait": 10, - "retry_codes": [429, 413, 500, 503], - "abort_codes": [400, 401, 402, 403], - } - policy = {**default_retry, **(retry_policy or {})} - - api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" - - stream_mode = self.stream - - # 构建请求体 - if image_base64: - payload = await self._build_payload(prompt, image_base64, image_format) - elif file_bytes: - payload = await self._build_formdata_payload(file_bytes, file_format) - elif payload is None: - payload = await self._build_payload(prompt) - - if not file_bytes: - if stream_mode: - payload["stream"] = stream_mode - - if self.temp != 0.7: - payload["temperature"] = self.temp - - # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false) - if self.has_enable_thinking: - payload["enable_thinking"] = self.enable_thinking - - # 添加thinking_budget参数(只有配置文件中声明了才添加) - if self.has_thinking_budget: - payload["thinking_budget"] = self.thinking_budget - - if self.max_tokens: - payload["max_tokens"] = self.max_tokens - - # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length - # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - payload["max_completion_tokens"] = payload.pop("max_tokens") - - return { - "policy": policy, - "payload": payload, - "api_url": api_url, - "stream_mode": stream_mode, - "image_base64": image_base64, # 保留必要的exception处理所需的原始数据 - "image_format": image_format, - "file_bytes": file_bytes, - "file_format": file_format, - "prompt": prompt, - } - - async def _execute_request( - self, - endpoint: str, - prompt: str = None, - image_base64: str = None, - image_format: str = None, - file_bytes: bytes = None, - file_format: str = None, - payload: dict = None, - retry_policy: dict = None, - response_handler: Callable = None, - user_id: str = "system", - request_type: str = None, - ): - """统一请求执行入口 - Args: - endpoint: API端点路径 (如 "chat/completions") - prompt: prompt文本 - image_base64: 图片的base64编码 - image_format: 图片格式 - file_bytes: 文件的二进制数据 - file_format: 文件格式 - payload: 请求体数据 - retry_policy: 自定义重试策略 - response_handler: 自定义响应处理器 - user_id: 用户ID - request_type: 请求类型 - """ - # 获取请求配置 - request_content = await self._prepare_request( - endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy - ) - if request_type is None: - request_type = self.request_type - for retry in range(request_content["policy"]["max_retries"]): - try: - # 使用上下文管理器处理会话 - if file_bytes: - headers = await self._build_headers(is_formdata=True) - else: - headers = await self._build_headers(is_formdata=False) - # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 - if request_content["stream_mode"]: - headers["Accept"] = "text/event-stream" - - # 添加请求发送前的调试信息 - logger.debug(f"🔍 [请求调试] 模型 {self.model_name} 准备发送请求") - logger.debug(f"🔍 [请求调试] API URL: {request_content['api_url']}") - logger.debug(f"🔍 [请求调试] 请求头: {await self._build_headers(no_key=True, is_formdata=file_bytes is not None)}") - - if not file_bytes: - # 安全地记录请求体(隐藏敏感信息) - safe_payload = await _safely_record(request_content, request_content["payload"]) - logger.debug(f"🔍 [请求调试] 请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}") - else: - logger.debug(f"🔍 [请求调试] 文件上传请求,文件格式: {request_content['file_format']}") - - async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: - post_kwargs = {"headers": headers} - # form-data数据上传方式不同 - if file_bytes: - post_kwargs["data"] = request_content["payload"] - else: - post_kwargs["json"] = request_content["payload"] - - async with session.post(request_content["api_url"], **post_kwargs) as response: - handled_result = await self._handle_response( - response, request_content, retry, response_handler, user_id, request_type, endpoint - ) - return handled_result - - except Exception as e: - handled_payload, count_delta = await self._handle_exception(e, retry, request_content) - retry += count_delta # 降级不计入重试次数 - if handled_payload: - # 如果降级成功,重新构建请求体 - request_content["payload"] = handled_payload - continue - - logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败") - raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败") - - async def _handle_response( - self, - response: ClientResponse, - request_content: Dict[str, Any], - retry_count: int, - response_handler: Callable, - user_id, - request_type, - endpoint, - ): - policy = request_content["policy"] - stream_mode = request_content["stream_mode"] - if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]: - await self._handle_error_response(response, retry_count, policy) - return None - - response.raise_for_status() - result = {} - if stream_mode: - # 将流式输出转化为非流式输出 - result = await self._handle_stream_output(response) - else: - result = await response.json() - return ( - response_handler(result) - if response_handler - else self._default_response_handler(result, user_id, request_type, endpoint) - ) - - async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]: - flag_delta_content_finished = False - accumulated_content = "" - usage = None # 初始化usage变量,避免未定义错误 - reasoning_content = "" - content = "" - tool_calls = None # 初始化工具调用变量 - - async for line_bytes in response.content: - try: - line = line_bytes.decode("utf-8").strip() - if not line: - continue - if line.startswith("data:"): - data_str = line[5:].strip() - if data_str == "[DONE]": - break - try: - chunk = json.loads(data_str) - if flag_delta_content_finished: - chunk_usage = chunk.get("usage", None) - if chunk_usage: - usage = chunk_usage # 获取token用量 - else: - delta = chunk["choices"][0]["delta"] - delta_content = delta.get("content") - if delta_content is None: - delta_content = "" - accumulated_content += delta_content - - # 提取工具调用信息 - if "tool_calls" in delta: - if tool_calls is None: - tool_calls = delta["tool_calls"] - else: - # 合并工具调用信息 - tool_calls.extend(delta["tool_calls"]) - - # 检测流式输出文本是否结束 - finish_reason = chunk["choices"][0].get("finish_reason") - if delta.get("reasoning_content", None): - reasoning_content += delta["reasoning_content"] - if finish_reason == "stop" or finish_reason == "tool_calls": - chunk_usage = chunk.get("usage", None) - if chunk_usage: - usage = chunk_usage - break - # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk - flag_delta_content_finished = True - except Exception as e: - logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}") - except Exception as e: - if isinstance(e, GeneratorExit): - log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..." - else: - log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}" - logger.warning(log_content) - # 确保资源被正确清理 - try: - await response.release() - except Exception as cleanup_error: - logger.error(f"清理资源时发生错误: {cleanup_error}") - # 返回已经累积的内容 - content = accumulated_content - if not content: - content = accumulated_content - think_match = re.search(r"(.*?)", content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() - - # 构建消息对象 - message = { - "content": content, - "reasoning_content": reasoning_content, - } - - # 如果有工具调用,添加到消息中 - if tool_calls: - message["tool_calls"] = tool_calls - - result = { - "choices": [{"message": message}], - "usage": usage, - } - return result - - async def _handle_error_response(self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]): - if response.status in policy["retry_codes"]: - wait_time = policy["base_wait"] * (2**retry_count) - logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试") - if response.status == 413: - logger.warning("请求体过大,尝试压缩...") - raise PayLoadTooLargeError("请求体过大") - elif response.status in [500, 503]: - logger.error( - f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}" - ) - raise RuntimeError("服务器负载过高,模型回复失败QAQ") - else: - logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...") - raise RuntimeError("请求限制(429)") - elif response.status in policy["abort_codes"]: - # 特别处理400错误,添加详细调试信息 - if response.status == 400: - logger.error(f"🔍 [调试信息] 模型 {self.model_name} 参数错误 (400) - 开始详细诊断") - logger.error(f"🔍 [调试信息] 模型名称: {self.model_name}") - logger.error(f"🔍 [调试信息] API地址: {self.base_url}") - logger.error("🔍 [调试信息] 模型配置参数:") - logger.error(f" - enable_thinking: {self.enable_thinking}") - logger.error(f" - temp: {self.temp}") - logger.error(f" - thinking_budget: {self.thinking_budget}") - logger.error(f" - stream: {self.stream}") - logger.error(f" - max_tokens: {self.max_tokens}") - logger.error(f" - pri_in: {self.pri_in}") - logger.error(f" - pri_out: {self.pri_out}") - logger.error(f"🔍 [调试信息] 原始params: {self.params}") - - # 尝试获取服务器返回的详细错误信息 - try: - error_text = await response.text() - logger.error(f"🔍 [调试信息] 服务器返回的原始错误内容: {error_text}") - - try: - error_json = json.loads(error_text) - logger.error(f"🔍 [调试信息] 解析后的错误JSON: {json.dumps(error_json, indent=2, ensure_ascii=False)}") - except json.JSONDecodeError: - logger.error("🔍 [调试信息] 错误响应不是有效的JSON格式") - except Exception as e: - logger.error(f"🔍 [调试信息] 无法读取错误响应内容: {str(e)}") - - raise RequestAbortException("参数错误,请检查调试信息", response) - elif response.status != 403: - raise RequestAbortException("请求出现错误,中断处理", response) - else: - raise PermissionDeniedException("模型禁止访问") - - async def _handle_exception( - self, exception, retry_count: int, request_content: Dict[str, Any] - ) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]: - policy = request_content["policy"] - payload = request_content["payload"] - wait_time = policy["base_wait"] * (2**retry_count) - keep_request = False - if retry_count < policy["max_retries"] - 1: - keep_request = True - if isinstance(exception, RequestAbortException): - response = exception.response - logger.error( - f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}" - ) - - # 如果是400错误,额外输出请求体信息用于调试 - if response.status == 400: - logger.error("🔍 [异常调试] 400错误 - 请求体调试信息:") - try: - safe_payload = await _safely_record(request_content, payload) - logger.error(f"🔍 [异常调试] 发送的请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}") - except Exception as debug_error: - logger.error(f"🔍 [异常调试] 无法安全记录请求体: {str(debug_error)}") - logger.error(f"🔍 [异常调试] 原始payload类型: {type(payload)}") - if isinstance(payload, dict): - logger.error(f"🔍 [异常调试] 原始payload键: {list(payload.keys())}") - - # print(request_content) - # print(response) - # 尝试获取并记录服务器返回的详细错误信息 - try: - error_json = await response.json() - if error_json and isinstance(error_json, list) and len(error_json) > 0: - # 处理多个错误的情况 - for error_item in error_json: - if "error" in error_item and isinstance(error_item["error"], dict): - error_obj: dict = error_item["error"] - error_code = error_obj.get("code") - error_message = error_obj.get("message") - error_status = error_obj.get("status") - logger.error( - f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}" - ) - elif isinstance(error_json, dict) and "error" in error_json: - # 处理单个错误对象的情况 - error_obj = error_json.get("error", {}) - error_code = error_obj.get("code") - error_message = error_obj.get("message") - error_status = error_obj.get("status") - logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}") - else: - # 记录原始错误响应内容 - logger.error(f"服务器错误响应: {error_json}") - except Exception as e: - logger.warning(f"无法解析服务器错误响应: {str(e)}") - raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") - - elif isinstance(exception, PermissionDeniedException): - # 只针对硅基流动的V3和R1进行降级处理 - if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/": - old_model_name = self.model_name - self.model_name = self.model_name[4:] # 移除"Pro/"前缀 - logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}") - - # 对全局配置进行更新 - if global_config.model.replyer_2.get("name") == old_model_name: - global_config.model.replyer_2["name"] = self.model_name - logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}") - if global_config.model.replyer_1.get("name") == old_model_name: - global_config.model.replyer_1["name"] = self.model_name - logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") - - if payload and "model" in payload: - payload["model"] = self.model_name - - await asyncio.sleep(wait_time) - return payload, -1 - raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}") - - elif isinstance(exception, PayLoadTooLargeError): - if keep_request: - image_base64 = request_content["image_base64"] - compressed_image_base64 = compress_base64_image_by_scale(image_base64) - new_payload = await self._build_payload( - request_content["prompt"], compressed_image_base64, request_content["image_format"] - ) - return new_payload, 0 - else: - return None, 0 - - elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError): - if keep_request: - logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}") - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}") - raise RuntimeError(f"网络请求失败: {str(exception)}") - - elif isinstance(exception, aiohttp.ClientResponseError): - # 处理aiohttp抛出的,除了policy中的status的响应错误 - if keep_request: - logger.error( - f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}" - ) - try: - error_text = await exception.response.text() - error_json = json.loads(error_text) - if isinstance(error_json, list) and len(error_json) > 0: - # 处理多个错误的情况 - for error_item in error_json: - if "error" in error_item and isinstance(error_item["error"], dict): - error_obj = error_item["error"] - logger.error( - f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, " - f"状态={error_obj.get('status')}, " - f"消息={error_obj.get('message')}" - ) - elif isinstance(error_json, dict) and "error" in error_json: - error_obj = error_json.get("error", {}) - logger.error( - f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, " - f"状态={error_obj.get('status')}, " - f"消息={error_obj.get('message')}" - ) - else: - logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}") - except (json.JSONDecodeError, TypeError) as json_err: - logger.warning( - f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}" - ) - except Exception as parse_err: - logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}") - - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical( - f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}" - ) - # 安全地检查和记录请求详情 - handled_payload = await _safely_record(request_content, payload) - logger.critical( - f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}" - ) - raise RuntimeError( - f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}" - ) - - else: - if keep_request: - logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}") - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}") - # 安全地检查和记录请求详情 - handled_payload = await _safely_record(request_content, payload) - logger.critical( - f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}" - ) - raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") - - async def _transform_parameters(self, params: dict) -> dict: - """ - 根据模型名称转换参数: - - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数, - 并将 'max_tokens' 重命名为 'max_completion_tokens' - """ - # 复制一份参数,避免直接修改原始数据 - new_params = dict(params) - - logger.debug(f"🔍 [参数转换] 模型 {self.model_name} 开始参数转换") - logger.debug(f"🔍 [参数转换] 是否为CoT模型: {self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION}") - logger.debug(f"🔍 [参数转换] CoT模型列表: {self.MODELS_NEEDING_TRANSFORMATION}") - - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION: - logger.debug("🔍 [参数转换] 检测到CoT模型,开始参数转换") - # 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度 - if "temperature" in new_params and new_params["temperature"] == 0.7: - removed_temp = new_params.pop("temperature") - logger.debug(f"🔍 [参数转换] 移除默认temperature参数: {removed_temp}") - # 如果存在 'max_tokens',则重命名为 'max_completion_tokens' - if "max_tokens" in new_params: - old_value = new_params["max_tokens"] - new_params["max_completion_tokens"] = new_params.pop("max_tokens") - logger.debug(f"🔍 [参数转换] 参数重命名: max_tokens({old_value}) -> max_completion_tokens({new_params['max_completion_tokens']})") - else: - logger.debug("🔍 [参数转换] 非CoT模型,无需参数转换") - - logger.debug(f"🔍 [参数转换] 转换前参数: {params}") - logger.debug(f"🔍 [参数转换] 转换后参数: {new_params}") - return new_params - - async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData: - """构建form-data请求体""" - # 目前只适配了音频文件 - # 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑 - data = aiohttp.FormData() - content_type_list = { - "wav": "audio/wav", - "mp3": "audio/mpeg", - "ogg": "audio/ogg", - "flac": "audio/flac", - "aac": "audio/aac", - } - - content_type = content_type_list.get(file_format) - if not content_type: - logger.warning(f"暂不支持的文件类型: {file_format}") - - data.add_field( - "file", - io.BytesIO(file_bytes), - filename=f"file.{file_format}", - content_type=f"{content_type}", # 根据实际文件类型设置 - ) - data.add_field("model", self.model_name) - return data - - async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: - """构建请求体""" - # 复制一份参数,避免直接修改 self.params - logger.debug(f"🔍 [参数构建] 模型 {self.model_name} 开始构建请求体") - logger.debug(f"🔍 [参数构建] 原始self.params: {self.params}") - - params_copy = await self._transform_parameters(self.params) - logger.debug(f"🔍 [参数构建] 转换后的params_copy: {params_copy}") - - if image_base64: - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}, - }, - ], - } - ] - else: - messages = [{"role": "user", "content": prompt}] - - payload = { - "model": self.model_name, - "messages": messages, - **params_copy, - } - - logger.debug(f"🔍 [参数构建] 基础payload构建完成: {list(payload.keys())}") - - # 添加temp参数(如果不是默认值0.7) - if self.temp != 0.7: - payload["temperature"] = self.temp - logger.debug(f"🔍 [参数构建] 添加temperature参数: {self.temp}") - - # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false) - if self.has_enable_thinking: - payload["enable_thinking"] = self.enable_thinking - logger.debug(f"🔍 [参数构建] 添加enable_thinking参数: {self.enable_thinking}") - - # 添加thinking_budget参数(只有配置文件中声明了才添加) - if self.has_thinking_budget: - payload["thinking_budget"] = self.thinking_budget - logger.debug(f"🔍 [参数构建] 添加thinking_budget参数: {self.thinking_budget}") - - if self.max_tokens: - payload["max_tokens"] = self.max_tokens - logger.debug(f"🔍 [参数构建] 添加max_tokens参数: {self.max_tokens}") - - # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length - # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - old_value = payload["max_tokens"] - payload["max_completion_tokens"] = payload.pop("max_tokens") - logger.debug(f"🔍 [参数构建] CoT模型参数转换: max_tokens({old_value}) -> max_completion_tokens({payload['max_completion_tokens']})") - - logger.debug(f"🔍 [参数构建] 最终payload键列表: {list(payload.keys())}") - return payload - - def _default_response_handler( - self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions" - ): - """默认响应解析""" - if "choices" in result and result["choices"]: - message = result["choices"][0]["message"] - content = message.get("content", "") - content, reasoning = self._extract_reasoning(content) - reasoning_content = message.get("model_extra", {}).get("reasoning_content", "") - if not reasoning_content: - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - reasoning_content = reasoning - - # 提取工具调用信息 - tool_calls = message.get("tool_calls", None) - - # 记录token使用情况 - usage = result.get("usage", {}) - if usage: - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - total_tokens = usage.get("total_tokens", 0) - self._record_usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - user_id=user_id, - request_type=request_type if request_type is not None else self.request_type, - endpoint=endpoint, - ) - - # 只有当tool_calls存在且不为空时才返回 - if tool_calls: - logger.debug(f"检测到工具调用: {tool_calls}") - return content, reasoning_content, tool_calls - else: - return content, reasoning_content - elif "text" in result and result["text"]: - return result["text"] - return "没有返回结果", "" - @staticmethod def _extract_reasoning(content: str) -> Tuple[str, str]: """CoT思维链提取""" match = re.search(r"(?:)?(.*?)", content, re.DOTALL) content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() - if match: - reasoning = match.group(1).strip() - else: - reasoning = "" + reasoning = match[1].strip() if match else "" return content, reasoning - async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict: - """构建请求头""" - if no_key: - if is_formdata: - return {"Authorization": "Bearer **********"} - return {"Authorization": "Bearer **********", "Content-Type": "application/json"} + def _handle_model_exception(self, e: Exception, operation: str) -> None: + """ + 统一的模型异常处理方法 + 根据异常类型提供更精确的错误信息和处理策略 + + Args: + e: 捕获的异常 + operation: 操作类型(用于日志记录) + """ + operation_desc = { + "image": "图片响应生成", + "voice": "语音识别", + "text": "文本响应生成", + "embedding": "向量嵌入获取" + } + + op_name = operation_desc.get(operation, operation) + + if SPECIFIC_EXCEPTIONS_AVAILABLE: + # 使用具体异常类型进行精确处理 + if isinstance(e, NetworkConnectionError): + logger.error(f"模型 {self.model_name} {op_name}失败: 网络连接错误") + raise RuntimeError("网络连接异常,请检查网络连接状态或API服务器地址是否正确") from e + + elif isinstance(e, ReqAbortException): + logger.error(f"模型 {self.model_name} {op_name}失败: 请求被中断") + raise RuntimeError("请求被中断或取消,请稍后重试") from e + + elif isinstance(e, RespNotOkException): + logger.error(f"模型 {self.model_name} {op_name}失败: HTTP响应错误 {e.status_code}") + # 重新抛出原始异常,保留详细的状态码信息 + raise e + + elif isinstance(e, RespParseException): + logger.error(f"模型 {self.model_name} {op_name}失败: 响应解析错误") + raise RuntimeError("API响应格式异常,请检查模型配置或联系管理员") from e + + else: + # 未知异常,使用通用处理 + logger.error(f"模型 {self.model_name} {op_name}失败: 未知错误 {type(e).__name__}: {str(e)}") + self._handle_generic_exception(e, op_name) else: - if is_formdata: - return {"Authorization": f"Bearer {self.api_key}"} - return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - # 防止小朋友们截图自己的key + # 如果无法导入具体异常,使用通用处理 + logger.error(f"模型 {self.model_name} {op_name}失败: {str(e)}") + self._handle_generic_exception(e, op_name) + + def _handle_generic_exception(self, e: Exception, operation: str) -> None: + """ + 通用异常处理(向后兼容的错误字符串匹配) + + Args: + e: 捕获的异常 + operation: 操作描述 + """ + error_str = str(e) + + # 基于错误消息内容的分类处理 + if "401" in error_str or "API key" in error_str or "认证" in error_str: + raise RuntimeError("API key 错误,认证失败,请检查 config/model_config.toml 中的 API key 配置是否正确") from e + elif "429" in error_str or "频繁" in error_str or "rate limit" in error_str: + raise RuntimeError("请求过于频繁,请稍后再试") from e + elif "500" in error_str or "503" in error_str or "服务器" in error_str: + raise RuntimeError("服务器负载过高,模型回复失败QAQ") from e + elif "413" in error_str or "payload" in error_str.lower() or "过大" in error_str: + raise RuntimeError("请求体过大,请尝试压缩图片或减少输入内容") from e + elif "timeout" in error_str.lower() or "超时" in error_str: + raise RuntimeError("请求超时,请检查网络连接或稍后重试") from e + else: + raise RuntimeError(f"模型 {self.model_name} {operation}失败: {str(e)}") from e + + # === 主要API方法 === + # 这些方法提供与新架构的桥接 async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple: - """根据输入的提示和图片生成模型的异步响应""" - - response = await self._execute_request( - endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format - ) - # 根据返回值的长度决定怎么处理 - if len(response) == 3: - content, reasoning_content, tool_calls = response - return content, reasoning_content, tool_calls - else: - content, reasoning_content = response - return content, reasoning_content + """ + 根据输入的提示和图片生成模型的异步响应 + 使用新架构的模型请求处理器 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + if self.request_handler is None: + raise RuntimeError( + f"模型 {self.model_name} 请求处理器未初始化,无法处理图片请求" + ) + + if MessageBuilder is None: + raise RuntimeError("MessageBuilder不可用,请检查新架构配置") + + try: + # 构建包含图片的消息 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt).add_image_content( + image_format=image_format, + image_base64=image_base64 + ) + messages = [message_builder.build()] + + # 使用新架构发送请求(只传递支持的参数) + response = await self.request_handler.get_response( # type: ignore + messages=messages, + tool_options=None, + response_format=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取内容 + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions" + ) + + # 返回格式兼容旧版本 + if tool_calls: + return content, reasoning_content, tool_calls + else: + return content, reasoning_content + + except Exception as e: + self._handle_model_exception(e, "image") + # 这行代码永远不会执行,因为_handle_model_exception总是抛出异常 + # 但是为了满足类型检查的要求,我们添加一个不可达的返回语句 + return "", "" # pragma: no cover async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple: - """根据输入的语音文件生成模型的异步响应""" - response = await self._execute_request( - endpoint="/audio/transcriptions", file_bytes=voice_bytes, file_format="wav" - ) - return response + """ + 根据输入的语音文件生成模型的异步响应 + 使用新架构的模型请求处理器 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + if self.request_handler is None: + raise RuntimeError( + f"模型 {self.model_name} 请求处理器未初始化,无法处理语音请求" + ) + + try: + # 构建语音识别请求参数 + # 注意:新架构中的语音识别可能使用不同的方法 + # 这里先使用get_response方法,可能需要根据实际API调整 + response = await self.request_handler.get_response( # type: ignore + messages=[], # 语音识别可能不需要消息 + tool_options=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取文本内容 + return (response.content,) if response.content else ("",) + + except Exception as e: + self._handle_model_exception(e, "voice") + # 不可达的返回语句,仅用于满足类型检查 + return ("",) # pragma: no cover async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: - """异步方式根据输入的提示生成模型的响应""" - # 构建请求体,不硬编码max_tokens - data = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - **self.params, - **kwargs, - } - - response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt) - # 原样返回响应,不做处理 - - if len(response) == 3: - content, reasoning_content, tool_calls = response - return content, (reasoning_content, self.model_name, tool_calls) - else: - content, reasoning_content = response - return content, (reasoning_content, self.model_name) + """ + 异步方式根据输入的提示生成模型的响应 + 使用新架构的模型请求处理器,如无法使用则抛出错误 + """ + if not self.use_new_architecture: + raise RuntimeError( + f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。" + ) + + if self.request_handler is None: + raise RuntimeError( + f"模型 {self.model_name} 请求处理器未初始化,无法生成响应" + ) + + if MessageBuilder is None: + raise RuntimeError("MessageBuilder不可用,请检查新架构配置") + + try: + # 构建消息 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + messages = [message_builder.build()] + + # 使用新架构发送请求(只传递支持的参数) + response = await self.request_handler.get_response( # type: ignore + messages=messages, + tool_options=None, + response_format=None + ) + + # 新架构返回的是 APIResponse 对象,直接提取内容 + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions" + ) + + # 返回格式兼容旧版本 + if tool_calls: + return content, (reasoning_content, self.model_name, tool_calls) + else: + return content, (reasoning_content, self.model_name) + + except Exception as e: + self._handle_model_exception(e, "text") + # 不可达的返回语句,仅用于满足类型检查 + return "", ("", self.model_name) # pragma: no cover async def get_embedding(self, text: str) -> Union[list, None]: - """异步方法:获取文本的embedding向量 + """ + 异步方法:获取文本的embedding向量 + 使用新架构的模型请求处理器 Args: text: 需要获取embedding的文本 @@ -964,45 +695,55 @@ class LLMRequest: Returns: list: embedding向量,如果失败则返回None """ - - if len(text) < 1: + if not text: logger.debug("该消息没有长度,不再发送获取embedding向量的请求") return None - def embedding_handler(result): - """处理响应""" - if "data" in result and len(result["data"]) > 0: - # 提取 token 使用信息 - usage = result.get("usage", {}) - if usage: - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - total_tokens = usage.get("total_tokens", 0) - # 记录 token 使用情况 - self._record_usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - user_id="system", # 可以根据需要修改 user_id - # request_type="embedding", # 请求类型为 embedding - request_type=self.request_type, # 请求类型为 text - endpoint="/embeddings", # API 端点 - ) - return result["data"][0].get("embedding", None) - return result["data"][0].get("embedding", None) + if not self.use_new_architecture: + logger.warning(f"模型 {self.model_name} 无法使用新架构,embedding请求将被跳过") return None - embedding = await self._execute_request( - endpoint="/embeddings", - prompt=text, - payload={"model": self.model_name, "input": text, "encoding_format": "float"}, - retry_policy={"max_retries": 2, "base_wait": 6}, - response_handler=embedding_handler, - ) - return embedding + if self.request_handler is None: + logger.warning(f"模型 {self.model_name} 请求处理器未初始化,embedding请求将被跳过") + return None + + try: + # 构建embedding请求参数 + # 使用新架构的get_embedding方法 + response = await self.request_handler.get_embedding(text) # type: ignore + + # 新架构返回的是 APIResponse 对象,直接提取embedding + if response.embedding: + embedding = response.embedding + + # 记录token使用情况 + if response.usage: + self._record_usage( + prompt_tokens=response.usage.prompt_tokens or 0, + completion_tokens=response.usage.completion_tokens or 0, + total_tokens=response.usage.total_tokens or 0, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings" + ) + + return embedding + else: + logger.warning(f"模型 {self.model_name} 返回的embedding响应为空") + return None + + except Exception as e: + # 对于embedding请求,我们记录错误但不抛出异常,而是返回None + # 这是为了保持与原有行为的兼容性 + try: + self._handle_model_exception(e, "embedding") + except RuntimeError: + # 捕获_handle_model_exception抛出的RuntimeError,转换为警告日志 + logger.warning(f"模型 {self.model_name} embedding请求失败,返回None: {str(e)}") + return None -def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: +def compress_base64_image_by_scale(base64_data: str, target_size: int = int(0.8 * 1024 * 1024)) -> str: """压缩base64格式的图片到指定大小 Args: base64_data: base64编码的图片数据 @@ -1040,7 +781,8 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 # 如果是GIF,处理所有帧 if getattr(img, "is_animated", False): frames = [] - for frame_idx in range(img.n_frames): + n_frames = getattr(img, 'n_frames', 1) + for frame_idx in range(n_frames): img.seek(frame_idx) new_frame = img.copy() new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折 diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 39857d669..fa9466c6d 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "4.5.0" +version = "5.0.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -227,120 +227,84 @@ show_prompt = false # 是否显示prompt [model] -model_max_output_length = 1024 # 模型单次返回的最大token数 +model_max_output_length = 800 # 模型单次返回的最大token数 -#------------必填:组件模型------------ +#------------模型任务配置------------ +# 所有模型名称需要对应 model_config.toml 中配置的模型名称 [model.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 #模型的输入价格(非必填,可以记录消耗) -pri_out = 8 #模型的输出价格(非必填,可以记录消耗) -#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数 -temp = 0.2 #模型的温度,新V3建议0.1-0.3 +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.2 # 模型温度,新V3建议0.1-0.3 +max_tokens = 800 # 最大输出token数 [model.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型 -# 强烈建议使用免费的小模型 -name = "Qwen/Qwen3-8B" -provider = "SILICONFLOW" -pri_in = 0 -pri_out = 0 -temp = 0.7 +model_name = "qwen3-8b" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 +max_tokens = 800 enable_thinking = false # 是否启用思考 [model.replyer_1] # 首要回复模型,还用于表达器和表达方式学习 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 #模型的输入价格(非必填,可以记录消耗) -pri_out = 8 #模型的输出价格(非必填,可以记录消耗) -#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数 -temp = 0.2 #模型的温度,新V3建议0.1-0.3 +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.2 # 模型温度,新V3建议0.1-0.3 +max_tokens = 800 [model.replyer_2] # 次要回复模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 #模型的输入价格(非必填,可以记录消耗) -pri_out = 8 #模型的输出价格(非必填,可以记录消耗) -#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数 -temp = 0.2 #模型的温度,新V3建议0.1-0.3 +model_name = "siliconflow-deepseek-r1" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 # 模型温度 +max_tokens = 800 [model.planner] #决策:负责决定麦麦该做什么的模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 -pri_out = 8 -temp = 0.3 +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.3 +max_tokens = 800 [model.emotion] #负责麦麦的情绪变化 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 -pri_out = 8 -temp = 0.3 - +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.3 +max_tokens = 800 [model.memory] # 记忆模型 -name = "Qwen/Qwen3-30B-A3B" -provider = "SILICONFLOW" -pri_in = 0.7 -pri_out = 2.8 -temp = 0.7 +model_name = "qwen3-30b" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 +max_tokens = 800 enable_thinking = false # 是否启用思考 [model.vlm] # 图像识别模型 -name = "Pro/Qwen/Qwen2.5-VL-7B-Instruct" -provider = "SILICONFLOW" -pri_in = 0.35 -pri_out = 0.35 +model_name = "qwen2.5-vl-72b" # 对应 model_config.toml 中的模型名称 +max_tokens = 800 [model.voice] # 语音识别模型 -name = "FunAudioLLM/SenseVoiceSmall" -provider = "SILICONFLOW" -pri_in = 0 -pri_out = 0 +model_name = "sensevoice-small" # 对应 model_config.toml 中的模型名称 [model.tool_use] #工具调用模型,需要使用支持工具调用的模型 -name = "Qwen/Qwen3-14B" -provider = "SILICONFLOW" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 +model_name = "qwen3-14b" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 +max_tokens = 800 enable_thinking = false # 是否启用思考(qwen3 only) #嵌入模型 [model.embedding] -name = "BAAI/bge-m3" -provider = "SILICONFLOW" -pri_in = 0 -pri_out = 0 - +model_name = "bge-m3" # 对应 model_config.toml 中的模型名称 #------------LPMM知识库模型------------ [model.lpmm_entity_extract] # 实体提取模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 -pri_out = 8 -temp = 0.2 - +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.2 +max_tokens = 800 [model.lpmm_rdf_build] # RDF构建模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 -pri_out = 8 -temp = 0.2 - +model_name = "siliconflow-deepseek-v3" # 对应 model_config.toml 中的模型名称 +temperature = 0.2 +max_tokens = 800 [model.lpmm_qa] # 问答模型 -name = "Qwen/Qwen3-30B-A3B" -provider = "SILICONFLOW" -pri_in = 0.7 -pri_out = 2.8 -temp = 0.7 +model_name = "deepseek-r1-distill-qwen-32b" # 对应 model_config.toml 中的模型名称 +temperature = 0.7 +max_tokens = 800 enable_thinking = false # 是否启用思考 + [maim_message] auth_token = [] # 认证令牌,用于API验证,为空则不启用验证 # 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器 diff --git a/template/compare/model_config_template.toml b/template/compare/model_config_template.toml new file mode 100644 index 000000000..8ab187626 --- /dev/null +++ b/template/compare/model_config_template.toml @@ -0,0 +1,220 @@ +[inner] +version = "0.2.1" + +# 配置文件版本号迭代规则同bot_config.toml +# +# === 多API Key支持 === +# 本配置文件支持为每个API服务商配置多个API Key,实现以下功能: +# 1. 错误自动切换:当某个API Key失败时,自动切换到下一个可用的Key +# 2. 负载均衡:在多个可用的API Key之间循环使用,避免单个Key的频率限制 +# 3. 向后兼容:仍然支持单个key字段的配置方式 +# +# 配置方式: +# - 多Key配置:使用 api_keys = ["key1", "key2", "key3"] 数组格式 +# - 单Key配置:使用 key = "your-key" 字符串格式(向后兼容) +# +# 错误处理机制: +# - 401/403认证错误:立即切换到下一个API Key +# - 429频率限制:等待后重试,如果持续失败则切换Key +# - 网络错误:短暂等待后重试,失败则切换Key +# - 其他错误:按照正常重试机制处理 +# +# === 任务类型和模型能力配置 === +# 为了提高任务分配的准确性和可维护性,现在支持明确配置模型的任务类型和能力: +# +# task_type(推荐配置): +# - 明确指定模型主要用于什么任务 +# - 可选值:llm_normal, llm_reasoning, vision, embedding, speech +# - 如果不配置,系统会根据capabilities或模型名称自动推断(不推荐) +# +# capabilities(推荐配置): +# - 描述模型支持的所有能力 +# - 可选值:text, vision, embedding, speech, tool_calling, reasoning +# - 支持多个能力的组合,如:["text", "vision"] +# +# 配置优先级: +# 1. task_type(最高优先级,直接指定任务类型) +# 2. capabilities(中等优先级,根据能力推断任务类型) +# 3. 模型名称关键字(最低优先级,不推荐依赖) +# +# 向后兼容: +# - 仍然支持 model_flags 字段,但建议迁移到 capabilities +# - 未配置新字段时会自动回退到基于模型名称的推断 + +[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释) +#max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) +#timeout = 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) +#retry_interval = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) +#default_temperature = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) +#default_max_tokens = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) + + +[[api_providers]] # API服务提供商(可以配置多个) +name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) +base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL +# 支持多个API Key,实现自动切换和负载均衡 +api_keys = [ # API Key列表(多个key支持错误自动切换和负载均衡) + "sk-your-first-key-here", + "sk-your-second-key-here", + "sk-your-third-key-here" +] +# 向后兼容:如果只有一个key,也可以使用单个key字段 +#key = "******" # API Key (可选,默认为None) +client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini") + +[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"gemini" +name = "Google" +base_url = "https://api.google.com/v1" +# Google API同样支持多key配置 +api_keys = [ + "your-google-api-key-1", + "your-google-api-key-2" +] +client_type = "gemini" + +[[api_providers]] +name = "SiliconFlow" +base_url = "https://api.siliconflow.cn/v1" +# 单个key的示例(向后兼容) +key = "******" +# +#[[api_providers]] +#name = "LocalHost" +#base_url = "https://localhost:8888" +#key = "lm-studio" + + +[[models]] # 模型(可以配置多个) +# 模型标识符(API服务商提供的模型标识符) +model_identifier = "deepseek-chat" +# 模型名称(可随意命名,在bot_config.toml中需使用这个命名) +#(可选,若无该字段,则将自动使用model_identifier填充) +name = "deepseek-v3" +# API服务商名称(对应在api_providers中配置的服务商名称) +api_provider = "DeepSeek" +# 任务类型(推荐配置,明确指定模型主要用于什么任务) +# 可选值:llm_normal, llm_reasoning, vision, embedding, speech +# 如果不配置,系统会根据capabilities或模型名称自动推断 +task_type = "llm_normal" +# 模型能力列表(推荐配置,描述模型支持的能力) +# 可选值:text, vision, embedding, speech, tool_calling, reasoning +capabilities = ["text", "tool_calling"] +# 输入价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0) +price_in = 2.0 +# 输出价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0) +price_out = 8.0 +# 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出) +#(可选,若无该字段,默认值为false) +#force_stream_mode = true + +[[models]] +model_identifier = "deepseek-reasoner" +name = "deepseek-r1" +api_provider = "DeepSeek" +# 推理模型的配置示例 +task_type = "llm_reasoning" +capabilities = ["text", "tool_calling", "reasoning"] +# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) +model_flags = [ "text", "tool_calling", "reasoning",] +price_in = 4.0 +price_out = 16.0 + +[[models]] +model_identifier = "Pro/deepseek-ai/DeepSeek-V3" +name = "siliconflow-deepseek-v3" +api_provider = "SiliconFlow" +task_type = "llm_normal" +capabilities = ["text", "tool_calling"] +price_in = 2.0 +price_out = 8.0 + +[[models]] +model_identifier = "Pro/deepseek-ai/DeepSeek-R1" +name = "siliconflow-deepseek-r1" +api_provider = "SiliconFlow" +task_type = "llm_reasoning" +capabilities = ["text", "tool_calling", "reasoning"] +price_in = 4.0 +price_out = 16.0 + +[[models]] +model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +name = "deepseek-r1-distill-qwen-32b" +api_provider = "SiliconFlow" +task_type = "llm_reasoning" +capabilities = ["text", "tool_calling", "reasoning"] +price_in = 4.0 +price_out = 16.0 + +[[models]] +model_identifier = "Qwen/Qwen3-8B" +name = "qwen3-8b" +api_provider = "SiliconFlow" +task_type = "llm_normal" +capabilities = ["text"] +price_in = 0 +price_out = 0 + +[[models]] +model_identifier = "Qwen/Qwen3-14B" +name = "qwen3-14b" +api_provider = "SiliconFlow" +task_type = "llm_normal" +capabilities = ["text", "tool_calling"] +price_in = 0.5 +price_out = 2.0 + +[[models]] +model_identifier = "Qwen/Qwen3-30B-A3B" +name = "qwen3-30b" +api_provider = "SiliconFlow" +task_type = "llm_normal" +capabilities = ["text", "tool_calling"] +price_in = 0.7 +price_out = 2.8 + +[[models]] +model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct" +name = "qwen2.5-vl-72b" +api_provider = "SiliconFlow" +# 视觉模型的配置示例 +task_type = "vision" +capabilities = ["vision", "text"] +# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) +model_flags = [ "vision", "text",] +price_in = 4.13 +price_out = 4.13 + +[[models]] +model_identifier = "FunAudioLLM/SenseVoiceSmall" +name = "sensevoice-small" +api_provider = "SiliconFlow" +# 语音模型的配置示例 +task_type = "speech" +capabilities = ["speech"] +# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) +model_flags = [ "audio",] +price_in = 0 +price_out = 0 + +[[models]] +model_identifier = "BAAI/bge-m3" +name = "bge-m3" +api_provider = "SiliconFlow" +# 嵌入模型的配置示例 +task_type = "embedding" +capabilities = ["text", "embedding"] +# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) +model_flags = [ "text", "embedding",] +price_in = 0 +price_out = 0 + + +[task_model_usage] +llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0} +llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0} +embedding = "siliconflow-bge-m3" +#schedule = [ +# "deepseek-v3", +# "deepseek-r1", +#] \ No newline at end of file diff --git a/template/model_config_template.toml b/template/model_config_template.toml new file mode 100644 index 000000000..8ab187626 --- /dev/null +++ b/template/model_config_template.toml @@ -0,0 +1,220 @@ +[inner] +version = "0.2.1" + +# 配置文件版本号迭代规则同bot_config.toml +# +# === 多API Key支持 === +# 本配置文件支持为每个API服务商配置多个API Key,实现以下功能: +# 1. 错误自动切换:当某个API Key失败时,自动切换到下一个可用的Key +# 2. 负载均衡:在多个可用的API Key之间循环使用,避免单个Key的频率限制 +# 3. 向后兼容:仍然支持单个key字段的配置方式 +# +# 配置方式: +# - 多Key配置:使用 api_keys = ["key1", "key2", "key3"] 数组格式 +# - 单Key配置:使用 key = "your-key" 字符串格式(向后兼容) +# +# 错误处理机制: +# - 401/403认证错误:立即切换到下一个API Key +# - 429频率限制:等待后重试,如果持续失败则切换Key +# - 网络错误:短暂等待后重试,失败则切换Key +# - 其他错误:按照正常重试机制处理 +# +# === 任务类型和模型能力配置 === +# 为了提高任务分配的准确性和可维护性,现在支持明确配置模型的任务类型和能力: +# +# task_type(推荐配置): +# - 明确指定模型主要用于什么任务 +# - 可选值:llm_normal, llm_reasoning, vision, embedding, speech +# - 如果不配置,系统会根据capabilities或模型名称自动推断(不推荐) +# +# capabilities(推荐配置): +# - 描述模型支持的所有能力 +# - 可选值:text, vision, embedding, speech, tool_calling, reasoning +# - 支持多个能力的组合,如:["text", "vision"] +# +# 配置优先级: +# 1. task_type(最高优先级,直接指定任务类型) +# 2. capabilities(中等优先级,根据能力推断任务类型) +# 3. 模型名称关键字(最低优先级,不推荐依赖) +# +# 向后兼容: +# - 仍然支持 model_flags 字段,但建议迁移到 capabilities +# - 未配置新字段时会自动回退到基于模型名称的推断 + +[request_conf] # 请求配置(此配置项数值均为默认值,如想修改,请取消对应条目的注释) +#max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) +#timeout = 10 # API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒) +#retry_interval = 10 # 重试间隔(如果API调用失败,重试的间隔时间,单位:秒) +#default_temperature = 0.7 # 默认的温度(如果bot_config.toml中没有设置temperature参数,默认使用这个值) +#default_max_tokens = 1024 # 默认的最大输出token数(如果bot_config.toml中没有设置max_tokens参数,默认使用这个值) + + +[[api_providers]] # API服务提供商(可以配置多个) +name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) +base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL +# 支持多个API Key,实现自动切换和负载均衡 +api_keys = [ # API Key列表(多个key支持错误自动切换和负载均衡) + "sk-your-first-key-here", + "sk-your-second-key-here", + "sk-your-third-key-here" +] +# 向后兼容:如果只有一个key,也可以使用单个key字段 +#key = "******" # API Key (可选,默认为None) +client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini") + +[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"gemini" +name = "Google" +base_url = "https://api.google.com/v1" +# Google API同样支持多key配置 +api_keys = [ + "your-google-api-key-1", + "your-google-api-key-2" +] +client_type = "gemini" + +[[api_providers]] +name = "SiliconFlow" +base_url = "https://api.siliconflow.cn/v1" +# 单个key的示例(向后兼容) +key = "******" +# +#[[api_providers]] +#name = "LocalHost" +#base_url = "https://localhost:8888" +#key = "lm-studio" + + +[[models]] # 模型(可以配置多个) +# 模型标识符(API服务商提供的模型标识符) +model_identifier = "deepseek-chat" +# 模型名称(可随意命名,在bot_config.toml中需使用这个命名) +#(可选,若无该字段,则将自动使用model_identifier填充) +name = "deepseek-v3" +# API服务商名称(对应在api_providers中配置的服务商名称) +api_provider = "DeepSeek" +# 任务类型(推荐配置,明确指定模型主要用于什么任务) +# 可选值:llm_normal, llm_reasoning, vision, embedding, speech +# 如果不配置,系统会根据capabilities或模型名称自动推断 +task_type = "llm_normal" +# 模型能力列表(推荐配置,描述模型支持的能力) +# 可选值:text, vision, embedding, speech, tool_calling, reasoning +capabilities = ["text", "tool_calling"] +# 输入价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0) +price_in = 2.0 +# 输出价格(用于API调用统计,单位:元/兆token)(可选,若无该字段,默认值为0) +price_out = 8.0 +# 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出) +#(可选,若无该字段,默认值为false) +#force_stream_mode = true + +[[models]] +model_identifier = "deepseek-reasoner" +name = "deepseek-r1" +api_provider = "DeepSeek" +# 推理模型的配置示例 +task_type = "llm_reasoning" +capabilities = ["text", "tool_calling", "reasoning"] +# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) +model_flags = [ "text", "tool_calling", "reasoning",] +price_in = 4.0 +price_out = 16.0 + +[[models]] +model_identifier = "Pro/deepseek-ai/DeepSeek-V3" +name = "siliconflow-deepseek-v3" +api_provider = "SiliconFlow" +task_type = "llm_normal" +capabilities = ["text", "tool_calling"] +price_in = 2.0 +price_out = 8.0 + +[[models]] +model_identifier = "Pro/deepseek-ai/DeepSeek-R1" +name = "siliconflow-deepseek-r1" +api_provider = "SiliconFlow" +task_type = "llm_reasoning" +capabilities = ["text", "tool_calling", "reasoning"] +price_in = 4.0 +price_out = 16.0 + +[[models]] +model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +name = "deepseek-r1-distill-qwen-32b" +api_provider = "SiliconFlow" +task_type = "llm_reasoning" +capabilities = ["text", "tool_calling", "reasoning"] +price_in = 4.0 +price_out = 16.0 + +[[models]] +model_identifier = "Qwen/Qwen3-8B" +name = "qwen3-8b" +api_provider = "SiliconFlow" +task_type = "llm_normal" +capabilities = ["text"] +price_in = 0 +price_out = 0 + +[[models]] +model_identifier = "Qwen/Qwen3-14B" +name = "qwen3-14b" +api_provider = "SiliconFlow" +task_type = "llm_normal" +capabilities = ["text", "tool_calling"] +price_in = 0.5 +price_out = 2.0 + +[[models]] +model_identifier = "Qwen/Qwen3-30B-A3B" +name = "qwen3-30b" +api_provider = "SiliconFlow" +task_type = "llm_normal" +capabilities = ["text", "tool_calling"] +price_in = 0.7 +price_out = 2.8 + +[[models]] +model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct" +name = "qwen2.5-vl-72b" +api_provider = "SiliconFlow" +# 视觉模型的配置示例 +task_type = "vision" +capabilities = ["vision", "text"] +# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) +model_flags = [ "vision", "text",] +price_in = 4.13 +price_out = 4.13 + +[[models]] +model_identifier = "FunAudioLLM/SenseVoiceSmall" +name = "sensevoice-small" +api_provider = "SiliconFlow" +# 语音模型的配置示例 +task_type = "speech" +capabilities = ["speech"] +# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) +model_flags = [ "audio",] +price_in = 0 +price_out = 0 + +[[models]] +model_identifier = "BAAI/bge-m3" +name = "bge-m3" +api_provider = "SiliconFlow" +# 嵌入模型的配置示例 +task_type = "embedding" +capabilities = ["text", "embedding"] +# 保留向后兼容的model_flags字段(已废弃,建议使用capabilities) +model_flags = [ "text", "embedding",] +price_in = 0 +price_out = 0 + + +[task_model_usage] +llm_reasoning = {model="deepseek-r1", temperature=0.8, max_tokens=1024, max_retry=0} +llm_normal = {model="deepseek-r1", max_tokens=1024, max_retry=0} +embedding = "siliconflow-bge-m3" +#schedule = [ +# "deepseek-v3", +# "deepseek-r1", +#] \ No newline at end of file diff --git a/template/template.env b/template/template.env index 4718203d7..d9b6e2bd1 100644 --- a/template/template.env +++ b/template/template.env @@ -1,23 +1,2 @@ HOST=127.0.0.1 -PORT=8000 - -# 密钥和url - -# 硅基流动 -SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ -# DeepSeek官方 -DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 -# 阿里百炼 -BAILIAN_BASE_URL = https://dashscope.aliyuncs.com/compatible-mode/v1 -# 火山引擎 -HUOSHAN_BASE_URL = -# xxxxx平台 -xxxxxxx_BASE_URL=https://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - -# 定义你要用的api的key(需要去对应网站申请哦) -DEEP_SEEK_KEY= -CHAT_ANY_WHERE_KEY= -SILICONFLOW_KEY= -BAILIAN_KEY = -HUOSHAN_KEY = -xxxxxxx_KEY= +PORT=8000 \ No newline at end of file