fix: 优化LLMRequest类,初始化请求处理器并简化任务映射逻辑
This commit is contained in:
@@ -1,11 +1,10 @@
|
|||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Tuple, Union, Dict, Any
|
from typing import Tuple, Union
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
import base64
|
import base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
import copy # 添加copy模块用于深拷贝
|
|
||||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -135,6 +134,9 @@ class LLMRequest:
|
|||||||
# 确定使用哪个任务配置
|
# 确定使用哪个任务配置
|
||||||
task_name = self._determine_task_name(model)
|
task_name = self._determine_task_name(model)
|
||||||
|
|
||||||
|
# 初始化 request_handler
|
||||||
|
self.request_handler = None
|
||||||
|
|
||||||
# 尝试初始化新架构
|
# 尝试初始化新架构
|
||||||
if NEW_ARCHITECTURE_AVAILABLE and ModelManager_class is not None:
|
if NEW_ARCHITECTURE_AVAILABLE and ModelManager_class is not None:
|
||||||
try:
|
try:
|
||||||
@@ -231,12 +233,7 @@ class LLMRequest:
|
|||||||
return "speech"
|
return "speech"
|
||||||
else:
|
else:
|
||||||
# 根据request_type确定,映射到配置文件中定义的任务
|
# 根据request_type确定,映射到配置文件中定义的任务
|
||||||
if self.request_type in ["memory", "emotion"]:
|
return "llm_reasoning" if self.request_type == "reasoning" else "llm_normal"
|
||||||
return "llm_normal" # 映射到配置中的llm_normal任务
|
|
||||||
elif self.request_type in ["reasoning"]:
|
|
||||||
return "llm_reasoning" # 映射到配置中的llm_reasoning任务
|
|
||||||
else:
|
|
||||||
return "llm_normal" # 默认使用llm_normal任务
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_database():
|
def _init_database():
|
||||||
@@ -254,7 +251,7 @@ class LLMRequest:
|
|||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
user_id: str = "system",
|
user_id: str = "system",
|
||||||
request_type: str = None,
|
request_type: str | None = None,
|
||||||
endpoint: str = "/chat/completions",
|
endpoint: str = "/chat/completions",
|
||||||
):
|
):
|
||||||
"""记录模型使用情况到数据库
|
"""记录模型使用情况到数据库
|
||||||
@@ -314,10 +311,7 @@ class LLMRequest:
|
|||||||
"""CoT思维链提取"""
|
"""CoT思维链提取"""
|
||||||
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
||||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||||
if match:
|
reasoning = match.group(1).strip() if match else ""
|
||||||
reasoning = match.group(1).strip()
|
|
||||||
else:
|
|
||||||
reasoning = ""
|
|
||||||
return content, reasoning
|
return content, reasoning
|
||||||
|
|
||||||
# === 主要API方法 ===
|
# === 主要API方法 ===
|
||||||
@@ -333,6 +327,11 @@ class LLMRequest:
|
|||||||
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
|
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.request_handler is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"模型 {self.model_name} 请求处理器未初始化,无法处理图片请求"
|
||||||
|
)
|
||||||
|
|
||||||
if MessageBuilder is None:
|
if MessageBuilder is None:
|
||||||
raise RuntimeError("MessageBuilder不可用,请检查新架构配置")
|
raise RuntimeError("MessageBuilder不可用,请检查新架构配置")
|
||||||
|
|
||||||
@@ -346,7 +345,7 @@ class LLMRequest:
|
|||||||
messages = [message_builder.build()]
|
messages = [message_builder.build()]
|
||||||
|
|
||||||
# 使用新架构发送请求(只传递支持的参数)
|
# 使用新架构发送请求(只传递支持的参数)
|
||||||
response = await self.request_handler.get_response(
|
response = await self.request_handler.get_response( # type: ignore
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tool_options=None,
|
tool_options=None,
|
||||||
response_format=None
|
response_format=None
|
||||||
@@ -401,20 +400,22 @@ class LLMRequest:
|
|||||||
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
|
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.request_handler is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"模型 {self.model_name} 请求处理器未初始化,无法处理语音请求"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 构建语音识别请求参数
|
# 构建语音识别请求参数
|
||||||
# 注意:新架构中的语音识别可能使用不同的方法
|
# 注意:新架构中的语音识别可能使用不同的方法
|
||||||
# 这里先使用get_response方法,可能需要根据实际API调整
|
# 这里先使用get_response方法,可能需要根据实际API调整
|
||||||
response = await self.request_handler.get_response(
|
response = await self.request_handler.get_response( # type: ignore
|
||||||
messages=[], # 语音识别可能不需要消息
|
messages=[], # 语音识别可能不需要消息
|
||||||
tool_options=None
|
tool_options=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# 新架构返回的是 APIResponse 对象,直接提取文本内容
|
# 新架构返回的是 APIResponse 对象,直接提取文本内容
|
||||||
if response.content:
|
return (response.content,) if response.content else ("",)
|
||||||
return response.content
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"模型 {self.model_name} 语音识别失败: {str(e)}")
|
logger.error(f"模型 {self.model_name} 语音识别失败: {str(e)}")
|
||||||
@@ -438,6 +439,11 @@ class LLMRequest:
|
|||||||
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
|
f"模型 {self.model_name} 无法使用新架构,请检查 config/model_config.toml 中的 API 配置。"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.request_handler is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"模型 {self.model_name} 请求处理器未初始化,无法生成响应"
|
||||||
|
)
|
||||||
|
|
||||||
if MessageBuilder is None:
|
if MessageBuilder is None:
|
||||||
raise RuntimeError("MessageBuilder不可用,请检查新架构配置")
|
raise RuntimeError("MessageBuilder不可用,请检查新架构配置")
|
||||||
|
|
||||||
@@ -448,7 +454,7 @@ class LLMRequest:
|
|||||||
messages = [message_builder.build()]
|
messages = [message_builder.build()]
|
||||||
|
|
||||||
# 使用新架构发送请求(只传递支持的参数)
|
# 使用新架构发送请求(只传递支持的参数)
|
||||||
response = await self.request_handler.get_response(
|
response = await self.request_handler.get_response( # type: ignore
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tool_options=None,
|
tool_options=None,
|
||||||
response_format=None
|
response_format=None
|
||||||
@@ -504,7 +510,7 @@ class LLMRequest:
|
|||||||
Returns:
|
Returns:
|
||||||
list: embedding向量,如果失败则返回None
|
list: embedding向量,如果失败则返回None
|
||||||
"""
|
"""
|
||||||
if len(text) < 1:
|
if not text:
|
||||||
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
|
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -512,10 +518,14 @@ class LLMRequest:
|
|||||||
logger.warning(f"模型 {self.model_name} 无法使用新架构,embedding请求将被跳过")
|
logger.warning(f"模型 {self.model_name} 无法使用新架构,embedding请求将被跳过")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if self.request_handler is None:
|
||||||
|
logger.warning(f"模型 {self.model_name} 请求处理器未初始化,embedding请求将被跳过")
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 构建embedding请求参数
|
# 构建embedding请求参数
|
||||||
# 使用新架构的get_embedding方法
|
# 使用新架构的get_embedding方法
|
||||||
response = await self.request_handler.get_embedding(text)
|
response = await self.request_handler.get_embedding(text) # type: ignore
|
||||||
|
|
||||||
# 新架构返回的是 APIResponse 对象,直接提取embedding
|
# 新架构返回的是 APIResponse 对象,直接提取embedding
|
||||||
if response.embedding:
|
if response.embedding:
|
||||||
@@ -551,7 +561,7 @@ class LLMRequest:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
def compress_base64_image_by_scale(base64_data: str, target_size: int = int(0.8 * 1024 * 1024)) -> str:
|
||||||
"""压缩base64格式的图片到指定大小
|
"""压缩base64格式的图片到指定大小
|
||||||
Args:
|
Args:
|
||||||
base64_data: base64编码的图片数据
|
base64_data: base64编码的图片数据
|
||||||
@@ -589,7 +599,8 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10
|
|||||||
# 如果是GIF,处理所有帧
|
# 如果是GIF,处理所有帧
|
||||||
if getattr(img, "is_animated", False):
|
if getattr(img, "is_animated", False):
|
||||||
frames = []
|
frames = []
|
||||||
for frame_idx in range(img.n_frames):
|
n_frames = getattr(img, 'n_frames', 1)
|
||||||
|
for frame_idx in range(n_frames):
|
||||||
img.seek(frame_idx)
|
img.seek(frame_idx)
|
||||||
new_frame = img.copy()
|
new_frame = img.copy()
|
||||||
new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
|
new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
|
||||||
|
|||||||
Reference in New Issue
Block a user