Merge pull request #801 from UnCLAS-Prommer/dev
拆分_execute_request 第一步:拆分
This commit is contained in:
8
src/api/__init__.py
Normal file
8
src/api/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from strawberry.fastapi import GraphQLRouter
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
|
||||||
|
|
||||||
|
app.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
|
||||||
155
src/api/config_api.py
Normal file
155
src/api/config_api.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
import strawberry
|
||||||
|
|
||||||
|
# from packaging.version import Version, InvalidVersion
|
||||||
|
# from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||||
|
# from ..config.config import global_config
|
||||||
|
# import os
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.type
|
||||||
|
class BotConfig:
|
||||||
|
"""机器人配置类"""
|
||||||
|
|
||||||
|
INNER_VERSION: Version
|
||||||
|
MAI_VERSION: str # 硬编码的版本信息
|
||||||
|
|
||||||
|
# bot
|
||||||
|
BOT_QQ: Optional[int]
|
||||||
|
BOT_NICKNAME: Optional[str]
|
||||||
|
BOT_ALIAS_NAMES: List[str] # 别名,可以通过这个叫它
|
||||||
|
|
||||||
|
# group
|
||||||
|
talk_allowed_groups: set
|
||||||
|
talk_frequency_down_groups: set
|
||||||
|
ban_user_id: set
|
||||||
|
|
||||||
|
# personality
|
||||||
|
personality_core: str # 建议20字以内,谁再写3000字小作文敲谁脑袋
|
||||||
|
personality_sides: List[str]
|
||||||
|
# identity
|
||||||
|
identity_detail: List[str]
|
||||||
|
height: int # 身高 单位厘米
|
||||||
|
weight: int # 体重 单位千克
|
||||||
|
age: int # 年龄 单位岁
|
||||||
|
gender: str # 性别
|
||||||
|
appearance: str # 外貌特征
|
||||||
|
|
||||||
|
# schedule
|
||||||
|
ENABLE_SCHEDULE_GEN: bool # 是否启用日程生成
|
||||||
|
PROMPT_SCHEDULE_GEN: str
|
||||||
|
SCHEDULE_DOING_UPDATE_INTERVAL: int # 日程表更新间隔 单位秒
|
||||||
|
SCHEDULE_TEMPERATURE: float # 日程表温度,建议0.5-1.0
|
||||||
|
TIME_ZONE: str # 时区
|
||||||
|
|
||||||
|
# message
|
||||||
|
MAX_CONTEXT_SIZE: int # 上下文最大消息数
|
||||||
|
emoji_chance: float # 发送表情包的基础概率
|
||||||
|
thinking_timeout: int # 思考时间
|
||||||
|
max_response_length: int # 最大回复长度
|
||||||
|
message_buffer: bool # 消息缓冲器
|
||||||
|
|
||||||
|
ban_words: set
|
||||||
|
ban_msgs_regex: set
|
||||||
|
# heartflow
|
||||||
|
# enable_heartflow: bool = False # 是否启用心流
|
||||||
|
sub_heart_flow_update_interval: int # 子心流更新频率,间隔 单位秒
|
||||||
|
sub_heart_flow_freeze_time: int # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒
|
||||||
|
sub_heart_flow_stop_time: int # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒
|
||||||
|
heart_flow_update_interval: int # 心流更新频率,间隔 单位秒
|
||||||
|
observation_context_size: int # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
|
||||||
|
compressed_length: int # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5
|
||||||
|
compress_length_limit: int # 最多压缩份数,超过该数值的压缩上下文会被删除
|
||||||
|
|
||||||
|
# willing
|
||||||
|
willing_mode: str # 意愿模式
|
||||||
|
response_willing_amplifier: float # 回复意愿放大系数
|
||||||
|
response_interested_rate_amplifier: float # 回复兴趣度放大系数
|
||||||
|
down_frequency_rate: float # 降低回复频率的群组回复意愿降低系数
|
||||||
|
emoji_response_penalty: float # 表情包回复惩罚
|
||||||
|
mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复
|
||||||
|
at_bot_inevitable_reply: bool # @bot 必然回复
|
||||||
|
|
||||||
|
# response
|
||||||
|
response_mode: str # 回复策略
|
||||||
|
MODEL_R1_PROBABILITY: float # R1模型概率
|
||||||
|
MODEL_V3_PROBABILITY: float # V3模型概率
|
||||||
|
# MODEL_R1_DISTILL_PROBABILITY: float # R1蒸馏模型概率
|
||||||
|
|
||||||
|
# emoji
|
||||||
|
max_emoji_num: int # 表情包最大数量
|
||||||
|
max_reach_deletion: bool # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包
|
||||||
|
EMOJI_CHECK_INTERVAL: int # 表情包检查间隔(分钟)
|
||||||
|
EMOJI_REGISTER_INTERVAL: int # 表情包注册间隔(分钟)
|
||||||
|
EMOJI_SAVE: bool # 偷表情包
|
||||||
|
EMOJI_CHECK: bool # 是否开启过滤
|
||||||
|
EMOJI_CHECK_PROMPT: str # 表情包过滤要求
|
||||||
|
|
||||||
|
# memory
|
||||||
|
build_memory_interval: int # 记忆构建间隔(秒)
|
||||||
|
memory_build_distribution: list # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
||||||
|
build_memory_sample_num: int # 记忆构建采样数量
|
||||||
|
build_memory_sample_length: int # 记忆构建采样长度
|
||||||
|
memory_compress_rate: float # 记忆压缩率
|
||||||
|
|
||||||
|
forget_memory_interval: int # 记忆遗忘间隔(秒)
|
||||||
|
memory_forget_time: int # 记忆遗忘时间(小时)
|
||||||
|
memory_forget_percentage: float # 记忆遗忘比例
|
||||||
|
|
||||||
|
memory_ban_words: list # 添加新的配置项默认值
|
||||||
|
|
||||||
|
# mood
|
||||||
|
mood_update_interval: float # 情绪更新间隔 单位秒
|
||||||
|
mood_decay_rate: float # 情绪衰减率
|
||||||
|
mood_intensity_factor: float # 情绪强度因子
|
||||||
|
|
||||||
|
# keywords
|
||||||
|
keywords_reaction_rules: list # 关键词回复规则
|
||||||
|
|
||||||
|
# chinese_typo
|
||||||
|
chinese_typo_enable: bool # 是否启用中文错别字生成器
|
||||||
|
chinese_typo_error_rate: float # 单字替换概率
|
||||||
|
chinese_typo_min_freq: int # 最小字频阈值
|
||||||
|
chinese_typo_tone_error_rate: float # 声调错误概率
|
||||||
|
chinese_typo_word_replace_rate: float # 整词替换概率
|
||||||
|
|
||||||
|
# response_splitter
|
||||||
|
enable_response_splitter: bool # 是否启用回复分割器
|
||||||
|
response_max_length: int # 回复允许的最大长度
|
||||||
|
response_max_sentence_num: int # 回复允许的最大句子数
|
||||||
|
|
||||||
|
# remote
|
||||||
|
remote_enable: bool # 是否启用远程控制
|
||||||
|
|
||||||
|
# experimental
|
||||||
|
enable_friend_chat: bool # 是否启用好友聊天
|
||||||
|
# enable_think_flow: bool # 是否启用思考流程
|
||||||
|
enable_pfc_chatting: bool # 是否启用PFC聊天
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
llm_reasoning: Dict[str, str] # LLM推理
|
||||||
|
# llm_reasoning_minor: Dict[str, str]
|
||||||
|
llm_normal: Dict[str, str] # LLM普通
|
||||||
|
llm_topic_judge: Dict[str, str] # LLM话题判断
|
||||||
|
llm_summary_by_topic: Dict[str, str] # LLM话题总结
|
||||||
|
llm_emotion_judge: Dict[str, str] # LLM情感判断
|
||||||
|
embedding: Dict[str, str] # 嵌入
|
||||||
|
vlm: Dict[str, str] # VLM
|
||||||
|
moderation: Dict[str, str] # 审核
|
||||||
|
|
||||||
|
# 实验性
|
||||||
|
llm_observation: Dict[str, str] # LLM观察
|
||||||
|
llm_sub_heartflow: Dict[str, str] # LLM子心流
|
||||||
|
llm_heartflow: Dict[str, str] # LLM心流
|
||||||
|
|
||||||
|
api_urls: Dict[str, str] # API URLs
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.type
|
||||||
|
class EnvConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@strawberry.field
|
||||||
|
def get_env(self) -> str:
|
||||||
|
return "env"
|
||||||
@@ -151,7 +151,7 @@ class ReplyGenerator:
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成回复时出错: {e}")
|
logger.error(f"生成回复时出错: {str(e)}")
|
||||||
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
||||||
|
|
||||||
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
|
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union, Dict, Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
from aiohttp.client import ClientResponse
|
||||||
|
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
import base64
|
import base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -16,19 +18,72 @@ from ...config.config import global_config
|
|||||||
logger = get_module_logger("model_utils")
|
logger = get_module_logger("model_utils")
|
||||||
|
|
||||||
|
|
||||||
|
class PayLoadTooLargeError(Exception):
|
||||||
|
"""自定义异常类,用于处理请求体过大错误"""
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "请求体过大,请尝试压缩图片或减少输入内容。"
|
||||||
|
|
||||||
|
|
||||||
|
class RequestAbortException(Exception):
|
||||||
|
"""自定义异常类,用于处理请求中断异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, response: ClientResponse):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.response = response
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionDeniedException(Exception):
|
||||||
|
"""自定义异常类,用于处理访问拒绝的异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
|
# 常见Error Code Mapping
|
||||||
|
error_code_mapping = {
|
||||||
|
400: "参数不正确",
|
||||||
|
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
|
||||||
|
402: "账号余额不足",
|
||||||
|
403: "需要实名,或余额不足",
|
||||||
|
404: "Not Found",
|
||||||
|
429: "请求过于频繁,请稍后再试",
|
||||||
|
500: "服务器内部故障",
|
||||||
|
503: "服务器负载过高",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LLMRequest:
|
class LLMRequest:
|
||||||
# 定义需要转换的模型列表,作为类变量避免重复
|
# 定义需要转换的模型列表,作为类变量避免重复
|
||||||
MODELS_NEEDING_TRANSFORMATION = [
|
MODELS_NEEDING_TRANSFORMATION = [
|
||||||
"o3-mini",
|
"o1",
|
||||||
"o1-mini",
|
|
||||||
"o1-preview",
|
|
||||||
"o1-2024-12-17",
|
"o1-2024-12-17",
|
||||||
"o1-preview-2024-09-12",
|
"o1-mini",
|
||||||
"o3-mini-2025-01-31",
|
|
||||||
"o1-mini-2024-09-12",
|
"o1-mini-2024-09-12",
|
||||||
|
"o1-preview",
|
||||||
|
"o1-preview-2024-09-12",
|
||||||
|
"o1-pro",
|
||||||
|
"o1-pro-2025-03-19",
|
||||||
|
"o3",
|
||||||
|
"o3-2025-04-16",
|
||||||
|
"o3-mini",
|
||||||
|
"o3-mini-2025-01-31o4-mini",
|
||||||
|
"o4-mini-2025-04-16",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, model, **kwargs):
|
def __init__(self, model: dict, **kwargs):
|
||||||
# 将大写的配置键转换为小写并从config中获取实际值
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
try:
|
try:
|
||||||
self.api_key = os.environ[model["key"]]
|
self.api_key = os.environ[model["key"]]
|
||||||
@@ -37,7 +92,7 @@ class LLMRequest:
|
|||||||
logger.error(f"原始 model dict 信息:{model}")
|
logger.error(f"原始 model dict 信息:{model}")
|
||||||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||||
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
|
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
|
||||||
self.model_name = model["name"]
|
self.model_name: str = model["name"]
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
|
|
||||||
self.stream = model.get("stream", False)
|
self.stream = model.get("stream", False)
|
||||||
@@ -123,6 +178,7 @@ class LLMRequest:
|
|||||||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||||||
return round(input_cost + output_cost, 6)
|
return round(input_cost + output_cost, 6)
|
||||||
|
|
||||||
|
'''
|
||||||
async def _execute_request(
|
async def _execute_request(
|
||||||
self,
|
self,
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
@@ -509,6 +565,404 @@ class LLMRequest:
|
|||||||
|
|
||||||
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
|
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
|
||||||
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
|
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
|
||||||
|
'''
|
||||||
|
|
||||||
|
async def _prepare_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
prompt: str = None,
|
||||||
|
image_base64: str = None,
|
||||||
|
image_format: str = None,
|
||||||
|
payload: dict = None,
|
||||||
|
retry_policy: dict = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""配置请求参数
|
||||||
|
Args:
|
||||||
|
endpoint: API端点路径 (如 "chat/completions")
|
||||||
|
prompt: prompt文本
|
||||||
|
image_base64: 图片的base64编码
|
||||||
|
image_format: 图片格式
|
||||||
|
payload: 请求体数据
|
||||||
|
retry_policy: 自定义重试策略
|
||||||
|
request_type: 请求类型
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 合并重试策略
|
||||||
|
default_retry = {
|
||||||
|
"max_retries": 3,
|
||||||
|
"base_wait": 10,
|
||||||
|
"retry_codes": [429, 413, 500, 503],
|
||||||
|
"abort_codes": [400, 401, 402, 403],
|
||||||
|
}
|
||||||
|
policy = {**default_retry, **(retry_policy or {})}
|
||||||
|
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||||
|
|
||||||
|
stream_mode = self.stream
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
if image_base64:
|
||||||
|
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||||
|
elif payload is None:
|
||||||
|
payload = await self._build_payload(prompt)
|
||||||
|
|
||||||
|
if stream_mode:
|
||||||
|
payload["stream"] = stream_mode
|
||||||
|
|
||||||
|
return {
|
||||||
|
"policy": policy,
|
||||||
|
"payload": payload,
|
||||||
|
"api_url": api_url,
|
||||||
|
"stream_mode": stream_mode,
|
||||||
|
"image_base64": image_base64, # 保留必要的exception处理所需的原始数据
|
||||||
|
"image_format": image_format,
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
prompt: str = None,
|
||||||
|
image_base64: str = None,
|
||||||
|
image_format: str = None,
|
||||||
|
payload: dict = None,
|
||||||
|
retry_policy: dict = None,
|
||||||
|
response_handler: callable = None,
|
||||||
|
user_id: str = "system",
|
||||||
|
request_type: str = None,
|
||||||
|
):
|
||||||
|
"""统一请求执行入口
|
||||||
|
Args:
|
||||||
|
endpoint: API端点路径 (如 "chat/completions")
|
||||||
|
prompt: prompt文本
|
||||||
|
image_base64: 图片的base64编码
|
||||||
|
image_format: 图片格式
|
||||||
|
payload: 请求体数据
|
||||||
|
retry_policy: 自定义重试策略
|
||||||
|
response_handler: 自定义响应处理器
|
||||||
|
user_id: 用户ID
|
||||||
|
request_type: 请求类型
|
||||||
|
"""
|
||||||
|
# 获取请求配置
|
||||||
|
request_content = await self._prepare_request(
|
||||||
|
endpoint, prompt, image_base64, image_format, payload, retry_policy
|
||||||
|
)
|
||||||
|
if request_type is None:
|
||||||
|
request_type = self.request_type
|
||||||
|
for retry in range(request_content["policy"]["max_retries"]):
|
||||||
|
try:
|
||||||
|
# 使用上下文管理器处理会话
|
||||||
|
headers = await self._build_headers()
|
||||||
|
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||||||
|
if request_content["stream_mode"]:
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
request_content["api_url"], headers=headers, json=request_content["payload"]
|
||||||
|
) as response:
|
||||||
|
handled_result = await self._handle_response(
|
||||||
|
response, request_content, retry, response_handler, user_id, request_type, endpoint
|
||||||
|
)
|
||||||
|
return handled_result
|
||||||
|
except Exception as e:
|
||||||
|
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
|
||||||
|
retry += count_delta # 降级不计入重试次数
|
||||||
|
if handled_payload:
|
||||||
|
# 如果降级成功,重新构建请求体
|
||||||
|
request_content["payload"] = handled_payload
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
|
||||||
|
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
|
||||||
|
|
||||||
|
async def _handle_response(
|
||||||
|
self,
|
||||||
|
response: ClientResponse,
|
||||||
|
request_content: Dict[str, Any],
|
||||||
|
retry_count: int,
|
||||||
|
response_handler: callable,
|
||||||
|
user_id,
|
||||||
|
request_type,
|
||||||
|
endpoint,
|
||||||
|
) -> Union[Dict[str, Any], None]:
|
||||||
|
policy = request_content["policy"]
|
||||||
|
stream_mode = request_content["stream_mode"]
|
||||||
|
if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
|
||||||
|
await self._handle_error_response(response, retry_count, policy)
|
||||||
|
return
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
result = {}
|
||||||
|
if stream_mode:
|
||||||
|
# 将流式输出转化为非流式输出
|
||||||
|
result = await self._handle_stream_output(response)
|
||||||
|
else:
|
||||||
|
result = await response.json()
|
||||||
|
return (
|
||||||
|
response_handler(result)
|
||||||
|
if response_handler
|
||||||
|
else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
|
||||||
|
flag_delta_content_finished = False
|
||||||
|
accumulated_content = ""
|
||||||
|
usage = None # 初始化usage变量,避免未定义错误
|
||||||
|
reasoning_content = ""
|
||||||
|
content = ""
|
||||||
|
async for line_bytes in response.content:
|
||||||
|
try:
|
||||||
|
line = line_bytes.decode("utf-8").strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.startswith("data:"):
|
||||||
|
data_str = line[5:].strip()
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
if flag_delta_content_finished:
|
||||||
|
chunk_usage = chunk.get("usage", None)
|
||||||
|
if chunk_usage:
|
||||||
|
usage = chunk_usage # 获取token用量
|
||||||
|
else:
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
delta_content = delta.get("content")
|
||||||
|
if delta_content is None:
|
||||||
|
delta_content = ""
|
||||||
|
accumulated_content += delta_content
|
||||||
|
# 检测流式输出文本是否结束
|
||||||
|
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||||
|
if delta.get("reasoning_content", None):
|
||||||
|
reasoning_content += delta["reasoning_content"]
|
||||||
|
if finish_reason == "stop":
|
||||||
|
chunk_usage = chunk.get("usage", None)
|
||||||
|
if chunk_usage:
|
||||||
|
usage = chunk_usage
|
||||||
|
break
|
||||||
|
# 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
|
||||||
|
flag_delta_content_finished = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, GeneratorExit):
|
||||||
|
log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
|
||||||
|
else:
|
||||||
|
log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
|
||||||
|
logger.warning(log_content)
|
||||||
|
# 确保资源被正确清理
|
||||||
|
try:
|
||||||
|
await response.release()
|
||||||
|
except Exception as cleanup_error:
|
||||||
|
logger.error(f"清理资源时发生错误: {cleanup_error}")
|
||||||
|
# 返回已经累积的内容
|
||||||
|
content = accumulated_content
|
||||||
|
if not content:
|
||||||
|
content = accumulated_content
|
||||||
|
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||||
|
if think_match:
|
||||||
|
reasoning_content = think_match.group(1).strip()
|
||||||
|
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
||||||
|
result = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": content,
|
||||||
|
"reasoning_content": reasoning_content,
|
||||||
|
# 流式输出可能没有工具调用,此处不需要添加tool_calls字段
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": usage,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _handle_error_response(
|
||||||
|
self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]
|
||||||
|
) -> Union[Dict[str, any]]:
|
||||||
|
if response.status in policy["retry_codes"]:
|
||||||
|
wait_time = policy["base_wait"] * (2**retry_count)
|
||||||
|
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
|
||||||
|
if response.status == 413:
|
||||||
|
logger.warning("请求体过大,尝试压缩...")
|
||||||
|
raise PayLoadTooLargeError("请求体过大")
|
||||||
|
elif response.status in [500, 503]:
|
||||||
|
logger.error(
|
||||||
|
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||||
|
)
|
||||||
|
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
||||||
|
else:
|
||||||
|
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
raise RuntimeError("请求限制(429)")
|
||||||
|
elif response.status in policy["abort_codes"]:
|
||||||
|
if response.status != 403:
|
||||||
|
raise RequestAbortException("请求出现错误,中断处理", response)
|
||||||
|
else:
|
||||||
|
raise PermissionDeniedException("模型禁止访问")
|
||||||
|
|
||||||
|
async def _handle_exception(
|
||||||
|
self, exception, retry_count: int, request_content: Dict[str, Any]
|
||||||
|
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
|
||||||
|
policy = request_content["policy"]
|
||||||
|
payload = request_content["payload"]
|
||||||
|
wait_time = policy["base_wait"] * (2**retry_count)
|
||||||
|
if retry_count < policy["max_retries"] - 1:
|
||||||
|
keep_request = True
|
||||||
|
if isinstance(exception, RequestAbortException):
|
||||||
|
response = exception.response
|
||||||
|
logger.error(
|
||||||
|
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||||
|
)
|
||||||
|
# 尝试获取并记录服务器返回的详细错误信息
|
||||||
|
try:
|
||||||
|
error_json = await response.json()
|
||||||
|
if error_json and isinstance(error_json, list) and len(error_json) > 0:
|
||||||
|
# 处理多个错误的情况
|
||||||
|
for error_item in error_json:
|
||||||
|
if "error" in error_item and isinstance(error_item["error"], dict):
|
||||||
|
error_obj: dict = error_item["error"]
|
||||||
|
error_code = error_obj.get("code")
|
||||||
|
error_message = error_obj.get("message")
|
||||||
|
error_status = error_obj.get("status")
|
||||||
|
logger.error(
|
||||||
|
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
|
||||||
|
)
|
||||||
|
elif isinstance(error_json, dict) and "error" in error_json:
|
||||||
|
# 处理单个错误对象的情况
|
||||||
|
error_obj = error_json.get("error", {})
|
||||||
|
error_code = error_obj.get("code")
|
||||||
|
error_message = error_obj.get("message")
|
||||||
|
error_status = error_obj.get("status")
|
||||||
|
logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}")
|
||||||
|
else:
|
||||||
|
# 记录原始错误响应内容
|
||||||
|
logger.error(f"服务器错误响应: {error_json}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法解析服务器错误响应: {str(e)}")
|
||||||
|
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
|
||||||
|
|
||||||
|
elif isinstance(exception, PermissionDeniedException):
|
||||||
|
# 只针对硅基流动的V3和R1进行降级处理
|
||||||
|
if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
|
||||||
|
old_model_name = self.model_name
|
||||||
|
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
|
||||||
|
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
|
||||||
|
|
||||||
|
# 对全局配置进行更新
|
||||||
|
if global_config.llm_normal.get("name") == old_model_name:
|
||||||
|
global_config.llm_normal["name"] = self.model_name
|
||||||
|
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
|
||||||
|
if global_config.llm_reasoning.get("name") == old_model_name:
|
||||||
|
global_config.llm_reasoning["name"] = self.model_name
|
||||||
|
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
|
||||||
|
|
||||||
|
if payload and "model" in payload:
|
||||||
|
payload["model"] = self.model_name
|
||||||
|
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
return payload, -1
|
||||||
|
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}")
|
||||||
|
|
||||||
|
elif isinstance(exception, PayLoadTooLargeError):
|
||||||
|
if keep_request:
|
||||||
|
image_base64 = request_content["image_base64"]
|
||||||
|
compressed_image_base64 = compress_base64_image_by_scale(image_base64)
|
||||||
|
new_payload = await self._build_payload(
|
||||||
|
request_content["prompt"], compressed_image_base64, request_content["image_format"]
|
||||||
|
)
|
||||||
|
return new_payload, 0
|
||||||
|
else:
|
||||||
|
return None, 0
|
||||||
|
|
||||||
|
elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError):
|
||||||
|
if keep_request:
|
||||||
|
logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
return None, 0
|
||||||
|
else:
|
||||||
|
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
|
||||||
|
raise RuntimeError(f"网络请求失败: {str(exception)}")
|
||||||
|
|
||||||
|
elif isinstance(exception, aiohttp.ClientResponseError):
|
||||||
|
# 处理aiohttp抛出的,除了policy中的status的响应错误
|
||||||
|
if keep_request:
|
||||||
|
logger.error(
|
||||||
|
f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
error_text = await exception.response.text()
|
||||||
|
error_json = json.loads(error_text)
|
||||||
|
if isinstance(error_json, list) and len(error_json) > 0:
|
||||||
|
# 处理多个错误的情况
|
||||||
|
for error_item in error_json:
|
||||||
|
if "error" in error_item and isinstance(error_item["error"], dict):
|
||||||
|
error_obj = error_item["error"]
|
||||||
|
logger.error(
|
||||||
|
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
|
||||||
|
f"状态={error_obj.get('status')}, "
|
||||||
|
f"消息={error_obj.get('message')}"
|
||||||
|
)
|
||||||
|
elif isinstance(error_json, dict) and "error" in error_json:
|
||||||
|
error_obj = error_json.get("error", {})
|
||||||
|
logger.error(
|
||||||
|
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
|
||||||
|
f"状态={error_obj.get('status')}, "
|
||||||
|
f"消息={error_obj.get('message')}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
|
||||||
|
except (json.JSONDecodeError, TypeError) as json_err:
|
||||||
|
logger.warning(
|
||||||
|
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
|
||||||
|
)
|
||||||
|
except Exception as parse_err:
|
||||||
|
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
|
||||||
|
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
return None, 0
|
||||||
|
else:
|
||||||
|
logger.critical(
|
||||||
|
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
|
||||||
|
)
|
||||||
|
# 安全地检查和记录请求详情
|
||||||
|
handled_payload = await self._safely_record(request_content, payload)
|
||||||
|
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if keep_request:
|
||||||
|
logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
return None, 0
|
||||||
|
else:
|
||||||
|
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
|
||||||
|
# 安全地检查和记录请求详情
|
||||||
|
handled_payload = await self._safely_record(request_content, payload)
|
||||||
|
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
|
||||||
|
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
|
||||||
|
|
||||||
|
async def _safely_record(self, request_content: Dict[str, Any], payload: Dict[str, Any]):
|
||||||
|
image_base64: str = request_content.get("image_base64")
|
||||||
|
image_format: str = request_content.get("image_format")
|
||||||
|
if (
|
||||||
|
image_base64
|
||||||
|
and payload
|
||||||
|
and isinstance(payload, dict)
|
||||||
|
and "messages" in payload
|
||||||
|
and len(payload["messages"]) > 0
|
||||||
|
):
|
||||||
|
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||||
|
content = payload["messages"][0]["content"]
|
||||||
|
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||||
|
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||||
|
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||||
|
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||||
|
)
|
||||||
|
# if isinstance(content, str) and len(content) > 100:
|
||||||
|
# payload["messages"][0]["content"] = content[:100]
|
||||||
|
return payload
|
||||||
|
|
||||||
async def _transform_parameters(self, params: dict) -> dict:
|
async def _transform_parameters(self, params: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -532,30 +986,27 @@ class LLMRequest:
|
|||||||
# 复制一份参数,避免直接修改 self.params
|
# 复制一份参数,避免直接修改 self.params
|
||||||
params_copy = await self._transform_parameters(self.params)
|
params_copy = await self._transform_parameters(self.params)
|
||||||
if image_base64:
|
if image_base64:
|
||||||
payload = {
|
messages = [
|
||||||
"model": self.model_name,
|
{
|
||||||
"messages": [
|
"role": "user",
|
||||||
{
|
"content": [
|
||||||
"role": "user",
|
{"type": "text", "text": prompt},
|
||||||
"content": [
|
{
|
||||||
{"type": "text", "text": prompt},
|
"type": "image_url",
|
||||||
{
|
"image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
|
||||||
"type": "image_url",
|
},
|
||||||
"image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
|
],
|
||||||
},
|
}
|
||||||
],
|
]
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": global_config.max_response_length,
|
|
||||||
**params_copy,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
payload = {
|
messages = [{"role": "user", "content": prompt}]
|
||||||
"model": self.model_name,
|
payload = {
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"model": self.model_name,
|
||||||
"max_tokens": global_config.max_response_length,
|
"messages": messages,
|
||||||
**params_copy,
|
**params_copy,
|
||||||
}
|
}
|
||||||
|
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||||||
|
payload["max_tokens"] = global_config.max_response_length
|
||||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
@@ -648,11 +1099,10 @@ class LLMRequest:
|
|||||||
|
|
||||||
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||||
"""异步方式根据输入的提示生成模型的响应"""
|
"""异步方式根据输入的提示生成模型的响应"""
|
||||||
# 构建请求体
|
# 构建请求体,不硬编码max_tokens
|
||||||
data = {
|
data = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"max_tokens": global_config.max_response_length,
|
|
||||||
**self.params,
|
**self.params,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user