Merge pull request #801 from UnCLAS-Prommer/dev

拆分_execute_request 第一步:拆分
This commit is contained in:
SengokuCola
2025-04-22 11:35:47 +08:00
committed by GitHub
4 changed files with 647 additions and 34 deletions

8
src/api/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
app = FastAPI()
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
app.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])

155
src/api/config_api.py Normal file
View File

@@ -0,0 +1,155 @@
from typing import Dict, List, Optional
import strawberry
# from packaging.version import Version, InvalidVersion
# from packaging.specifiers import SpecifierSet, InvalidSpecifier
# from ..config.config import global_config
# import os
from packaging.version import Version
@strawberry.type
class BotConfig:
"""机器人配置类"""
INNER_VERSION: Version
MAI_VERSION: str # 硬编码的版本信息
# bot
BOT_QQ: Optional[int]
BOT_NICKNAME: Optional[str]
BOT_ALIAS_NAMES: List[str] # 别名,可以通过这个叫它
# group
talk_allowed_groups: set
talk_frequency_down_groups: set
ban_user_id: set
# personality
personality_core: str # 建议20字以内谁再写3000字小作文敲谁脑袋
personality_sides: List[str]
# identity
identity_detail: List[str]
height: int # 身高 单位厘米
weight: int # 体重 单位千克
age: int # 年龄 单位岁
gender: str # 性别
appearance: str # 外貌特征
# schedule
ENABLE_SCHEDULE_GEN: bool # 是否启用日程生成
PROMPT_SCHEDULE_GEN: str
SCHEDULE_DOING_UPDATE_INTERVAL: int # 日程表更新间隔 单位秒
SCHEDULE_TEMPERATURE: float # 日程表温度建议0.5-1.0
TIME_ZONE: str # 时区
# message
MAX_CONTEXT_SIZE: int # 上下文最大消息数
emoji_chance: float # 发送表情包的基础概率
thinking_timeout: int # 思考时间
max_response_length: int # 最大回复长度
message_buffer: bool # 消息缓冲器
ban_words: set
ban_msgs_regex: set
# heartflow
# enable_heartflow: bool = False # 是否启用心流
sub_heart_flow_update_interval: int # 子心流更新频率,间隔 单位秒
sub_heart_flow_freeze_time: int # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒
sub_heart_flow_stop_time: int # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒
heart_flow_update_interval: int # 心流更新频率,间隔 单位秒
observation_context_size: int # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
compressed_length: int # 不能大于observation_context_size,心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5
compress_length_limit: int # 最多压缩份数,超过该数值的压缩上下文会被删除
# willing
willing_mode: str # 意愿模式
response_willing_amplifier: float # 回复意愿放大系数
response_interested_rate_amplifier: float # 回复兴趣度放大系数
down_frequency_rate: float # 降低回复频率的群组回复意愿降低系数
emoji_response_penalty: float # 表情包回复惩罚
mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复
at_bot_inevitable_reply: bool # @bot 必然回复
# response
response_mode: str # 回复策略
MODEL_R1_PROBABILITY: float # R1模型概率
MODEL_V3_PROBABILITY: float # V3模型概率
# MODEL_R1_DISTILL_PROBABILITY: float # R1蒸馏模型概率
# emoji
max_emoji_num: int # 表情包最大数量
max_reach_deletion: bool # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包
EMOJI_CHECK_INTERVAL: int # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int # 表情包注册间隔(分钟)
EMOJI_SAVE: bool # 偷表情包
EMOJI_CHECK: bool # 是否开启过滤
EMOJI_CHECK_PROMPT: str # 表情包过滤要求
# memory
build_memory_interval: int # 记忆构建间隔(秒)
memory_build_distribution: list # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
build_memory_sample_num: int # 记忆构建采样数量
build_memory_sample_length: int # 记忆构建采样长度
memory_compress_rate: float # 记忆压缩率
forget_memory_interval: int # 记忆遗忘间隔(秒)
memory_forget_time: int # 记忆遗忘时间(小时)
memory_forget_percentage: float # 记忆遗忘比例
memory_ban_words: list # 添加新的配置项默认值
# mood
mood_update_interval: float # 情绪更新间隔 单位秒
mood_decay_rate: float # 情绪衰减率
mood_intensity_factor: float # 情绪强度因子
# keywords
keywords_reaction_rules: list # 关键词回复规则
# chinese_typo
chinese_typo_enable: bool # 是否启用中文错别字生成器
chinese_typo_error_rate: float # 单字替换概率
chinese_typo_min_freq: int # 最小字频阈值
chinese_typo_tone_error_rate: float # 声调错误概率
chinese_typo_word_replace_rate: float # 整词替换概率
# response_splitter
enable_response_splitter: bool # 是否启用回复分割器
response_max_length: int # 回复允许的最大长度
response_max_sentence_num: int # 回复允许的最大句子数
# remote
remote_enable: bool # 是否启用远程控制
# experimental
enable_friend_chat: bool # 是否启用好友聊天
# enable_think_flow: bool # 是否启用思考流程
enable_pfc_chatting: bool # 是否启用PFC聊天
# 模型配置
llm_reasoning: Dict[str, str] # LLM推理
# llm_reasoning_minor: Dict[str, str]
llm_normal: Dict[str, str] # LLM普通
llm_topic_judge: Dict[str, str] # LLM话题判断
llm_summary_by_topic: Dict[str, str] # LLM话题总结
llm_emotion_judge: Dict[str, str] # LLM情感判断
embedding: Dict[str, str] # 嵌入
vlm: Dict[str, str] # VLM
moderation: Dict[str, str] # 审核
# 实验性
llm_observation: Dict[str, str] # LLM观察
llm_sub_heartflow: Dict[str, str] # LLM子心流
llm_heartflow: Dict[str, str] # LLM心流
api_urls: Dict[str, str] # API URLs
@strawberry.type
class EnvConfig:
pass
@strawberry.field
def get_env(self) -> str:
return "env"

View File

@@ -151,7 +151,7 @@ class ReplyGenerator:
return content return content
except Exception as e: except Exception as e:
logger.error(f"生成回复时出错: {e}") logger.error(f"生成回复时出错: {str(e)}")
return "抱歉,我现在有点混乱,让我重新思考一下..." return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]: async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:

View File

@@ -2,9 +2,11 @@ import asyncio
import json import json
import re import re
from datetime import datetime from datetime import datetime
from typing import Tuple, Union from typing import Tuple, Union, Dict, Any
import aiohttp import aiohttp
from aiohttp.client import ClientResponse
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
import base64 import base64
from PIL import Image from PIL import Image
@@ -16,19 +18,72 @@ from ...config.config import global_config
logger = get_module_logger("model_utils") logger = get_module_logger("model_utils")
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return "请求体过大,请尝试压缩图片或减少输入内容。"
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
def __init__(self, message: str, response: ClientResponse):
super().__init__(message)
self.message = message
self.response = response
def __str__(self):
return self.message
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class LLMRequest: class LLMRequest:
# 定义需要转换的模型列表,作为类变量避免重复 # 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [ MODELS_NEEDING_TRANSFORMATION = [
"o3-mini", "o1",
"o1-mini",
"o1-preview",
"o1-2024-12-17", "o1-2024-12-17",
"o1-preview-2024-09-12", "o1-mini",
"o3-mini-2025-01-31",
"o1-mini-2024-09-12", "o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"o1-pro",
"o1-pro-2025-03-19",
"o3",
"o3-2025-04-16",
"o3-mini",
"o3-mini-2025-01-31o4-mini",
"o4-mini-2025-04-16",
] ]
def __init__(self, model, **kwargs): def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值 # 将大写的配置键转换为小写并从config中获取实际值
try: try:
self.api_key = os.environ[model["key"]] self.api_key = os.environ[model["key"]]
@@ -37,7 +92,7 @@ class LLMRequest:
logger.error(f"原始 model dict 信息:{model}") logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
self.model_name = model["name"] self.model_name: str = model["name"]
self.params = kwargs self.params = kwargs
self.stream = model.get("stream", False) self.stream = model.get("stream", False)
@@ -123,6 +178,7 @@ class LLMRequest:
output_cost = (completion_tokens / 1000000) * self.pri_out output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6) return round(input_cost + output_cost, 6)
'''
async def _execute_request( async def _execute_request(
self, self,
endpoint: str, endpoint: str,
@@ -509,6 +565,404 @@ class LLMRequest:
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败") logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API请求仍然失败") raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API请求仍然失败")
'''
async def _prepare_request(
self,
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
) -> Dict[str, Any]:
"""配置请求参数
Args:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据
retry_policy: 自定义重试策略
request_type: 请求类型
"""
# 合并重试策略
default_retry = {
"max_retries": 3,
"base_wait": 10,
"retry_codes": [429, 413, 500, 503],
"abort_codes": [400, 401, 402, 403],
}
policy = {**default_retry, **(retry_policy or {})}
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
stream_mode = self.stream
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format)
elif payload is None:
payload = await self._build_payload(prompt)
if stream_mode:
payload["stream"] = stream_mode
return {
"policy": policy,
"payload": payload,
"api_url": api_url,
"stream_mode": stream_mode,
"image_base64": image_base64, # 保留必要的exception处理所需的原始数据
"image_format": image_format,
"prompt": prompt,
}
async def _execute_request(
self,
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
user_id: str = "system",
request_type: str = None,
):
"""统一请求执行入口
Args:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据
retry_policy: 自定义重试策略
response_handler: 自定义响应处理器
user_id: 用户ID
request_type: 请求类型
"""
# 获取请求配置
request_content = await self._prepare_request(
endpoint, prompt, image_base64, image_format, payload, retry_policy
)
if request_type is None:
request_type = self.request_type
for retry in range(request_content["policy"]["max_retries"]):
try:
# 使用上下文管理器处理会话
headers = await self._build_headers()
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if request_content["stream_mode"]:
headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession() as session:
async with session.post(
request_content["api_url"], headers=headers, json=request_content["payload"]
) as response:
handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint
)
return handled_result
except Exception as e:
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
retry += count_delta # 降级不计入重试次数
if handled_payload:
# 如果降级成功,重新构建请求体
request_content["payload"] = handled_payload
continue
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API请求仍然失败")
async def _handle_response(
self,
response: ClientResponse,
request_content: Dict[str, Any],
retry_count: int,
response_handler: callable,
user_id,
request_type,
endpoint,
) -> Union[Dict[str, Any], None]:
policy = request_content["policy"]
stream_mode = request_content["stream_mode"]
if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
await self._handle_error_response(response, retry_count, policy)
return
response.raise_for_status()
result = {}
if stream_mode:
# 将流式输出转化为非流式输出
result = await self._handle_stream_output(response)
else:
result = await response.json()
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, request_type, endpoint)
)
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
flag_delta_content_finished = False
accumulated_content = ""
usage = None # 初始化usage变量避免未定义错误
reasoning_content = ""
content = ""
async for line_bytes in response.content:
try:
line = line_bytes.decode("utf-8").strip()
if not line:
continue
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
if flag_delta_content_finished:
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage # 获取token用量
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
# 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason")
if delta.get("reasoning_content", None):
reasoning_content += delta["reasoning_content"]
if finish_reason == "stop":
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage
break
# 部分平台在文本输出结束前不会返回token用量此时需要再获取一次chunk
flag_delta_content_finished = True
except Exception as e:
logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
except Exception as e:
if isinstance(e, GeneratorExit):
log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
else:
log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
logger.warning(log_content)
# 确保资源被正确清理
try:
await response.release()
except Exception as cleanup_error:
logger.error(f"清理资源时发生错误: {cleanup_error}")
# 返回已经累积的内容
content = accumulated_content
if not content:
content = accumulated_content
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
result = {
"choices": [
{
"message": {
"content": content,
"reasoning_content": reasoning_content,
# 流式输出可能没有工具调用此处不需要添加tool_calls字段
}
}
],
"usage": usage,
}
return result
async def _handle_error_response(
self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]
) -> Union[Dict[str, any]]:
if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2**retry_count)
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
raise PayLoadTooLargeError("请求体过大")
elif response.status in [500, 503]:
logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else:
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
raise RuntimeError("请求限制(429)")
elif response.status in policy["abort_codes"]:
if response.status != 403:
raise RequestAbortException("请求出现错误,中断处理", response)
else:
raise PermissionDeniedException("模型禁止访问")
async def _handle_exception(
self, exception, retry_count: int, request_content: Dict[str, Any]
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
policy = request_content["policy"]
payload = request_content["payload"]
wait_time = policy["base_wait"] * (2**retry_count)
if retry_count < policy["max_retries"] - 1:
keep_request = True
if isinstance(exception, RequestAbortException):
response = exception.response
logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
# 尝试获取并记录服务器返回的详细错误信息
try:
error_json = await response.json()
if error_json and isinstance(error_json, list) and len(error_json) > 0:
# 处理多个错误的情况
for error_item in error_json:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj: dict = error_item["error"]
error_code = error_obj.get("code")
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
)
elif isinstance(error_json, dict) and "error" in error_json:
# 处理单个错误对象的情况
error_obj = error_json.get("error", {})
error_code = error_obj.get("code")
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}")
else:
# 记录原始错误响应内容
logger.error(f"服务器错误响应: {error_json}")
except Exception as e:
logger.warning(f"无法解析服务器错误响应: {str(e)}")
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
elif isinstance(exception, PermissionDeniedException):
# 只针对硅基流动的V3和R1进行降级处理
if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
global_config.llm_normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
if payload and "model" in payload:
payload["model"] = self.model_name
await asyncio.sleep(wait_time)
return payload, -1
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}")
elif isinstance(exception, PayLoadTooLargeError):
if keep_request:
image_base64 = request_content["image_base64"]
compressed_image_base64 = compress_base64_image_by_scale(image_base64)
new_payload = await self._build_payload(
request_content["prompt"], compressed_image_base64, request_content["image_format"]
)
return new_payload, 0
else:
return None, 0
elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError):
if keep_request:
logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}")
await asyncio.sleep(wait_time)
return None, 0
else:
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
raise RuntimeError(f"网络请求失败: {str(exception)}")
elif isinstance(exception, aiohttp.ClientResponseError):
# 处理aiohttp抛出的除了policy中的status的响应错误
if keep_request:
logger.error(
f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}"
)
try:
error_text = await exception.response.text()
error_json = json.loads(error_text)
if isinstance(error_json, list) and len(error_json) > 0:
# 处理多个错误的情况
for error_item in error_json:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj = error_item["error"]
logger.error(
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
elif isinstance(error_json, dict) and "error" in error_json:
error_obj = error_json.get("error", {})
logger.error(
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
else:
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err:
logger.warning(
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
)
except Exception as parse_err:
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time)
return None, 0
else:
logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
)
# 安全地检查和记录请求详情
handled_payload = await self._safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
)
else:
if keep_request:
logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}")
await asyncio.sleep(wait_time)
return None, 0
else:
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
# 安全地检查和记录请求详情
handled_payload = await self._safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
async def _safely_record(self, request_content: Dict[str, Any], payload: Dict[str, Any]):
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload
async def _transform_parameters(self, params: dict) -> dict: async def _transform_parameters(self, params: dict) -> dict:
""" """
@@ -532,9 +986,7 @@ class LLMRequest:
# 复制一份参数,避免直接修改 self.params # 复制一份参数,避免直接修改 self.params
params_copy = await self._transform_parameters(self.params) params_copy = await self._transform_parameters(self.params)
if image_base64: if image_base64:
payload = { messages = [
"model": self.model_name,
"messages": [
{ {
"role": "user", "role": "user",
"content": [ "content": [
@@ -545,17 +997,16 @@ class LLMRequest:
}, },
], ],
} }
], ]
"max_tokens": global_config.max_response_length,
**params_copy,
}
else: else:
messages = [{"role": "user", "content": prompt}]
payload = { payload = {
"model": self.model_name, "model": self.model_name,
"messages": [{"role": "user", "content": prompt}], "messages": messages,
"max_tokens": global_config.max_response_length,
**params_copy, **params_copy,
} }
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
payload["max_tokens"] = global_config.max_response_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens") payload["max_completion_tokens"] = payload.pop("max_tokens")
@@ -648,11 +1099,10 @@ class LLMRequest:
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应""" """异步方式根据输入的提示生成模型的响应"""
# 构建请求体 # 构建请求体不硬编码max_tokens
data = { data = {
"model": self.model_name, "model": self.model_name,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
**self.params, **self.params,
**kwargs, **kwargs,
} }