增加对voice类型消息的支持

This commit is contained in:
Windpicker-owo
2025-07-17 14:50:19 +08:00
parent 8768b5d31b
commit 587aca4d18
5 changed files with 157 additions and 27 deletions

View File

@@ -9,6 +9,7 @@ from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_voice import get_voice_text
from .chat_stream import ChatStream
install(extra_lines=3)
@@ -106,6 +107,7 @@ class MessageRecv(Message):
self.has_emoji = False
self.is_picid = False
self.has_picid = False
self.is_voice = False
self.is_mentioned = None
self.is_command = False
@@ -156,6 +158,14 @@ class MessageRecv(Message):
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.has_picid = False
self.is_picid = False
self.is_emoji = False
self.is_voice == True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
@@ -233,6 +243,14 @@ class MessageRecvS4U(MessageRecv):
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.has_picid = False
self.is_picid = False
self.is_emoji = False
self.is_voice == True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
@@ -343,6 +361,10 @@ class MessageProcessBase(Message):
if isinstance(seg.data, str):
return await get_image_manager().get_emoji_description(seg.data)
return "[表情,网卡了加载不出来]"
elif seg.type == "voice":
if isinstance(seg.data, str):
return await get_voice_text(seg.data)
return "[发了一段语音,网卡了加载不出来]"
elif seg.type == "at":
return f"[@{seg.data}]"
elif seg.type == "reply":

View File

@@ -0,0 +1,46 @@
import base64
import os
import time
import hashlib
import uuid
from typing import Optional, Tuple
from PIL import Image
import io
import numpy as np
import asyncio
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
import traceback
install(extra_lines=3)
logger = get_logger("chat_voice")
async def get_voice_text(voice_base64: str) -> str:
"""获取音频文件描述"""
try:
# 计算图片哈希
# 确保base64字符串只包含ASCII字符
if isinstance(voice_base64, str):
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
voice_bytes = base64.b64decode(voice_base64)
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
text = await _llm.generate_response_for_voice(voice_bytes)
if text is None:
logger.warning("未能生成语音文本")
return "[语音(文本生成失败)]"
logger.debug(f"描述是{text}")
return f"[语音:{text}]"
except Exception as e:
traceback.print_exc()
logger.error(f"语音转文字失败: {str(e)}")
return "[语音]"

View File

@@ -630,6 +630,9 @@ class ModelConfig(ConfigBase):
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
voice: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""专注工具使用模型配置"""

View File

@@ -216,6 +216,8 @@ class LLMRequest:
prompt: str = None,
image_base64: str = None,
image_format: str = None,
file_bytes: str = None,
file_format: str = None,
payload: dict = None,
retry_policy: dict = None,
) -> Dict[str, Any]:
@@ -225,6 +227,8 @@ class LLMRequest:
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
file_bytes: 文件的二进制数据
file_format: 文件格式
payload: 请求体数据
retry_policy: 自定义重试策略
request_type: 请求类型
@@ -246,30 +250,33 @@ class LLMRequest:
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format)
elif file_bytes:
payload = await self._build_formdata_payload(file_bytes, file_format)
elif payload is None:
payload = await self._build_payload(prompt)
if stream_mode:
payload["stream"] = stream_mode
if not file_bytes:
if stream_mode:
payload["stream"] = stream_mode
if self.temp != 0.7:
payload["temperature"] = self.temp
if self.temp != 0.7:
payload["temperature"] = self.temp
# 添加enable_thinking参数如果不是默认值False
if not self.enable_thinking:
payload["enable_thinking"] = False
# 添加enable_thinking参数如果不是默认值False
if not self.enable_thinking:
payload["enable_thinking"] = False
if self.thinking_budget != 4096:
payload["thinking_budget"] = self.thinking_budget
if self.thinking_budget != 4096:
payload["thinking_budget"] = self.thinking_budget
if self.max_tokens:
payload["max_tokens"] = self.max_tokens
if self.max_tokens:
payload["max_tokens"] = self.max_tokens
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
# payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
# payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
return {
"policy": policy,
@@ -278,6 +285,8 @@ class LLMRequest:
"stream_mode": stream_mode,
"image_base64": image_base64, # 保留必要的exception处理所需的原始数据
"image_format": image_format,
"file_bytes": file_bytes,
"file_format": file_format,
"prompt": prompt,
}
@@ -287,6 +296,8 @@ class LLMRequest:
prompt: str = None,
image_base64: str = None,
image_format: str = None,
file_bytes: str = None,
file_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
@@ -299,6 +310,8 @@ class LLMRequest:
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
file_base64: 文件的二进制数据
file_format: 文件格式
payload: 请求体数据
retry_policy: 自定义重试策略
response_handler: 自定义响应处理器
@@ -307,25 +320,38 @@ class LLMRequest:
"""
# 获取请求配置
request_content = await self._prepare_request(
endpoint, prompt, image_base64, image_format, payload, retry_policy
endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy
)
if request_type is None:
request_type = self.request_type
for retry in range(request_content["policy"]["max_retries"]):
try:
# 使用上下文管理器处理会话
headers = await self._build_headers()
if file_bytes:
headers = await self._build_headers(is_formdata=True)
else:
headers = await self._build_headers(is_formdata=False)
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if request_content["stream_mode"]:
headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
async with session.post(
request_content["api_url"], headers=headers, json=request_content["payload"]
) as response:
handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint
)
return handled_result
if file_bytes:
#form-data数据上传方式不同
async with session.post(
request_content["api_url"], headers=headers, data=request_content["payload"]
) as response:
handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint
)
return handled_result
else:
async with session.post(
request_content["api_url"], headers=headers, json=request_content["payload"]
) as response:
handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint
)
return handled_result
except Exception as e:
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
retry += count_delta # 降级不计入重试次数
@@ -640,6 +666,23 @@ class LLMRequest:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params
async def _build_formdata_payload(self, file_bytes: str, file_format: str):
"""构建form-data请求体"""
# 非常丑陋的方法,先将文件写入本地,然后再读取,应该有更好的办法
with open(f"file.{file_format}","wb") as f:
f.write(file_bytes)
data = aiohttp.FormData()
data.add_field(
"file",open(f"file.{file_format}","rb"),
filename=f"file.{file_format}",
content_type='audio/wav'
)
data.add_field(
"model", self.model_name
)
return data
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体"""
# 复制一份参数,避免直接修改 self.params
@@ -725,7 +768,8 @@ class LLMRequest:
return content, reasoning_content, tool_calls
else:
return content, reasoning_content
elif "text" in result and result["text"]:
return result["text"]
return "没有返回结果", ""
@staticmethod
@@ -739,11 +783,15 @@ class LLMRequest:
reasoning = ""
return content, reasoning
async def _build_headers(self, no_key: bool = False) -> dict:
async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict:
"""构建请求头"""
if no_key:
if is_formdata:
return {"Authorization": "Bearer **********"}
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
else:
if is_formdata:
return {"Authorization": f"Bearer {self.api_key}"}
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# 防止小朋友们截图自己的key
@@ -761,6 +809,11 @@ class LLMRequest:
content, reasoning_content = response
return content, reasoning_content
async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
"""根据输入的语音文件生成模型的异步响应"""
response = await self._execute_request(endpoint="/audio/transcriptions",file_bytes=voice_bytes, file_format='wav')
return response
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应"""
# 构建请求体不硬编码max_tokens

View File

@@ -294,6 +294,12 @@ provider = "SILICONFLOW"
pri_in = 0.35
pri_out = 0.35
[model.voice] # 语音识别模型
name = "FunAudioLLM/SenseVoiceSmall"
provider = "SILICONFLOW"
pri_in = 0
pri_out = 0
[model.tool_use] #工具调用模型,需要使用支持工具调用的模型
name = "Qwen/Qwen3-14B"
provider = "SILICONFLOW"