@@ -15,13 +15,8 @@ async def get_voice_text(voice_base64: str) -> str:
|
|||||||
logger.warning("语音识别未启用,无法处理语音消息")
|
logger.warning("语音识别未启用,无法处理语音消息")
|
||||||
return "[语音]"
|
return "[语音]"
|
||||||
try:
|
try:
|
||||||
# 解码base64音频数据
|
|
||||||
# 确保base64字符串只包含ASCII字符
|
|
||||||
if isinstance(voice_base64, str):
|
|
||||||
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
|
|
||||||
voice_bytes = base64.b64decode(voice_base64)
|
|
||||||
_llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="voice")
|
_llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="voice")
|
||||||
text = await _llm.generate_response_for_voice(voice_bytes)
|
text = await _llm.generate_response_for_voice(voice_base64)
|
||||||
if text is None:
|
if text is None:
|
||||||
logger.warning("未能生成语音文本")
|
logger.warning("未能生成语音文本")
|
||||||
return "[语音(文本生成失败)]"
|
return "[语音(文本生成失败)]"
|
||||||
|
|||||||
@@ -114,6 +114,21 @@ class BaseClient:
|
|||||||
"""
|
"""
|
||||||
raise RuntimeError("This method should be overridden in subclasses")
|
raise RuntimeError("This method should be overridden in subclasses")
|
||||||
|
|
||||||
|
async def get_audio_transcriptions(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
audio_base64: str,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取音频转录
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param audio_base64: base64编码的音频数据
|
||||||
|
:extra_params: 附加的请求参数
|
||||||
|
:return: 音频转录响应
|
||||||
|
"""
|
||||||
|
raise RuntimeError("This method should be overridden in subclasses")
|
||||||
|
|
||||||
|
|
||||||
class ClientRegistry:
|
class ClientRegistry:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import base64
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Callable, Any, Coroutine, Optional
|
from typing import Callable, Any, Coroutine, Optional
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
@@ -532,3 +533,38 @@ class OpenaiClient(BaseClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def get_audio_transcriptions(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
audio_base64: str,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取音频转录
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param audio_base64: base64编码的音频数据
|
||||||
|
:extra_params: 附加的请求参数
|
||||||
|
:return: 音频转录响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw_response = await self.client.audio.transcriptions.create(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
|
||||||
|
extra_body=extra_params
|
||||||
|
)
|
||||||
|
except APIConnectionError as e:
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
except APIStatusError as e:
|
||||||
|
# 重封装APIError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.status_code) from e
|
||||||
|
response = APIResponse()
|
||||||
|
# 解析转录响应
|
||||||
|
if hasattr(raw_response, "text"):
|
||||||
|
response.content = raw_response.text
|
||||||
|
else:
|
||||||
|
raise RespParseException(
|
||||||
|
raw_response,
|
||||||
|
"响应解析失败,缺失转录文本。",
|
||||||
|
)
|
||||||
|
return response
|
||||||
@@ -38,7 +38,7 @@ class RequestType(Enum):
|
|||||||
|
|
||||||
RESPONSE = "response"
|
RESPONSE = "response"
|
||||||
EMBEDDING = "embedding"
|
EMBEDDING = "embedding"
|
||||||
|
AUDIO = "audio"
|
||||||
|
|
||||||
class LLMRequest:
|
class LLMRequest:
|
||||||
"""LLM请求类"""
|
"""LLM请求类"""
|
||||||
@@ -106,8 +106,27 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
return content, (reasoning_content, model_info.name, tool_calls)
|
return content, (reasoning_content, model_info.name, tool_calls)
|
||||||
|
|
||||||
async def generate_response_for_voice(self):
|
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
|
||||||
pass
|
"""
|
||||||
|
为语音生成响应
|
||||||
|
Args:
|
||||||
|
voice_base64 (str): 语音的Base64编码字符串
|
||||||
|
Returns:
|
||||||
|
(Optional[str]): 生成的文本描述或None
|
||||||
|
"""
|
||||||
|
# 模型选择
|
||||||
|
model_info, api_provider, client = self._select_model()
|
||||||
|
|
||||||
|
# 请求并处理返回值
|
||||||
|
response = await self._execute_request(
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.AUDIO,
|
||||||
|
model_info=model_info,
|
||||||
|
audio_base64=voice_base64,
|
||||||
|
)
|
||||||
|
return response.content or None
|
||||||
|
|
||||||
|
|
||||||
async def generate_response_async(
|
async def generate_response_async(
|
||||||
self,
|
self,
|
||||||
@@ -225,6 +244,7 @@ class LLMRequest:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
embedding_input: str = "",
|
embedding_input: str = "",
|
||||||
|
audio_base64: str = ""
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
实际执行请求的方法
|
实际执行请求的方法
|
||||||
@@ -255,6 +275,13 @@ class LLMRequest:
|
|||||||
embedding_input=embedding_input,
|
embedding_input=embedding_input,
|
||||||
extra_params=model_info.extra_params,
|
extra_params=model_info.extra_params,
|
||||||
)
|
)
|
||||||
|
elif request_type == RequestType.AUDIO:
|
||||||
|
assert message_list is not None, "message_list cannot be None for audio requests"
|
||||||
|
return await client.get_audio_transcriptions(
|
||||||
|
model_info=model_info,
|
||||||
|
audio_base64=audio_base64,
|
||||||
|
extra_params=model_info.extra_params,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"请求失败: {str(e)}")
|
logger.debug(f"请求失败: {str(e)}")
|
||||||
# 处理异常
|
# 处理异常
|
||||||
|
|||||||
Reference in New Issue
Block a user