增加对voice类型消息的支持

This commit is contained in:
Windpicker-owo
2025-07-17 14:50:19 +08:00
parent 8768b5d31b
commit 587aca4d18
5 changed files with 157 additions and 27 deletions

View File

@@ -216,6 +216,8 @@ class LLMRequest:
prompt: str = None,
image_base64: str = None,
image_format: str = None,
file_bytes: str = None,
file_format: str = None,
payload: dict = None,
retry_policy: dict = None,
) -> Dict[str, Any]:
@@ -225,6 +227,8 @@ class LLMRequest:
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
file_bytes: 文件的二进制数据
file_format: 文件格式
payload: 请求体数据
retry_policy: 自定义重试策略
request_type: 请求类型
@@ -246,30 +250,33 @@ class LLMRequest:
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format)
elif file_bytes:
payload = await self._build_formdata_payload(file_bytes, file_format)
elif payload is None:
payload = await self._build_payload(prompt)
if stream_mode:
payload["stream"] = stream_mode
if not file_bytes:
if stream_mode:
payload["stream"] = stream_mode
if self.temp != 0.7:
payload["temperature"] = self.temp
if self.temp != 0.7:
payload["temperature"] = self.temp
# 添加enable_thinking参数如果不是默认值False
if not self.enable_thinking:
payload["enable_thinking"] = False
# 添加enable_thinking参数如果不是默认值False
if not self.enable_thinking:
payload["enable_thinking"] = False
if self.thinking_budget != 4096:
payload["thinking_budget"] = self.thinking_budget
if self.thinking_budget != 4096:
payload["thinking_budget"] = self.thinking_budget
if self.max_tokens:
payload["max_tokens"] = self.max_tokens
if self.max_tokens:
payload["max_tokens"] = self.max_tokens
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
# payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
# payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
return {
"policy": policy,
@@ -278,6 +285,8 @@ class LLMRequest:
"stream_mode": stream_mode,
"image_base64": image_base64, # 保留必要的exception处理所需的原始数据
"image_format": image_format,
"file_bytes": file_bytes,
"file_format": file_format,
"prompt": prompt,
}
@@ -287,6 +296,8 @@ class LLMRequest:
prompt: str = None,
image_base64: str = None,
image_format: str = None,
file_bytes: str = None,
file_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
@@ -299,6 +310,8 @@ class LLMRequest:
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
file_base64: 文件的二进制数据
file_format: 文件格式
payload: 请求体数据
retry_policy: 自定义重试策略
response_handler: 自定义响应处理器
@@ -307,25 +320,38 @@ class LLMRequest:
"""
# 获取请求配置
request_content = await self._prepare_request(
endpoint, prompt, image_base64, image_format, payload, retry_policy
endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy
)
if request_type is None:
request_type = self.request_type
for retry in range(request_content["policy"]["max_retries"]):
try:
# 使用上下文管理器处理会话
headers = await self._build_headers()
if file_bytes:
headers = await self._build_headers(is_formdata=True)
else:
headers = await self._build_headers(is_formdata=False)
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if request_content["stream_mode"]:
headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
async with session.post(
request_content["api_url"], headers=headers, json=request_content["payload"]
) as response:
handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint
)
return handled_result
if file_bytes:
#form-data数据上传方式不同
async with session.post(
request_content["api_url"], headers=headers, data=request_content["payload"]
) as response:
handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint
)
return handled_result
else:
async with session.post(
request_content["api_url"], headers=headers, json=request_content["payload"]
) as response:
handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint
)
return handled_result
except Exception as e:
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
retry += count_delta # 降级不计入重试次数
@@ -640,6 +666,23 @@ class LLMRequest:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params
async def _build_formdata_payload(self, file_bytes: str, file_format: str):
"""构建form-data请求体"""
# 非常丑陋的方法,先将文件写入本地,然后再读取,应该有更好的办法
with open(f"file.{file_format}","wb") as f:
f.write(file_bytes)
data = aiohttp.FormData()
data.add_field(
"file",open(f"file.{file_format}","rb"),
filename=f"file.{file_format}",
content_type='audio/wav'
)
data.add_field(
"model", self.model_name
)
return data
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体"""
# 复制一份参数,避免直接修改 self.params
@@ -725,7 +768,8 @@ class LLMRequest:
return content, reasoning_content, tool_calls
else:
return content, reasoning_content
elif "text" in result and result["text"]:
return result["text"]
return "没有返回结果", ""
@staticmethod
@@ -739,11 +783,15 @@ class LLMRequest:
reasoning = ""
return content, reasoning
async def _build_headers(self, no_key: bool = False) -> dict:
async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict:
"""构建请求头"""
if no_key:
if is_formdata:
return {"Authorization": "Bearer **********"}
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
else:
if is_formdata:
return {"Authorization": f"Bearer {self.api_key}"}
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# 防止小朋友们截图自己的key
@@ -761,6 +809,11 @@ class LLMRequest:
content, reasoning_content = response
return content, reasoning_content
async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
"""根据输入的语音文件生成模型的异步响应"""
response = await self._execute_request(endpoint="/audio/transcriptions",file_bytes=voice_bytes, file_format='wav')
return response
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应"""
# 构建请求体不硬编码max_tokens