refactor(llm): 重构 LLM 请求处理,引入通用故障转移执行器
之前的代码里,处理文本、图片、语音的请求方法都各自为战,写了一大堆重复的故障转移逻辑,简直乱糟糟的,看得我头疼。 为了解决这个问题,我进行了一次大扫除: - 引入了一个通用的 `_execute_with_failover` 执行器,把所有“模型失败就换下一个”的脏活累活都统一管理起来了。 - 重构了所有相关的请求方法(文本、图片、语音、嵌入),让它们变得更清爽,只专注于自己的核心任务。 - 升级了 `_model_scheduler`,现在它会智能地根据实时负载给模型排队,谁最闲谁先上。那个笨笨的 `_select_model` 就被我光荣地裁掉了。 这次重构之后,代码的可维护性和健壮性都好多了,再加新功能也方便啦。哼哼,快夸我!
This commit is contained in:
committed by
Windpicker-owo
parent
f333611bfa
commit
b91d1e6bf5
@@ -764,8 +764,7 @@ class LLMRequest:
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||||
"""
|
"""
|
||||||
为图像生成响应。
|
为图像生成响应(已集成故障转移)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (str): 提示词
|
prompt (str): 提示词
|
||||||
image_base64 (str): 图像的Base64编码字符串
|
image_base64 (str): 图像的Base64编码字符串
|
||||||
@@ -774,49 +773,79 @@ class LLMRequest:
|
|||||||
Returns:
|
Returns:
|
||||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
|
|
||||||
selection_result = self._model_selector.select_best_available_model(set(), "response")
|
|
||||||
if not selection_result:
|
|
||||||
raise RuntimeError("无法为图像响应选择可用模型。")
|
|
||||||
model_info, api_provider, client = selection_result
|
|
||||||
|
|
||||||
normalized_format = _normalize_image_format(image_format)
|
normalized_format = _normalize_image_format(image_format)
|
||||||
message = MessageBuilder().add_text_content(prompt).add_image_content(
|
|
||||||
image_base64=image_base64,
|
|
||||||
image_format=normalized_format,
|
|
||||||
support_formats=client.get_support_image_formats(),
|
|
||||||
).build()
|
|
||||||
|
|
||||||
response = await self._executor.execute_request(
|
async def request_logic(
|
||||||
api_provider, client, RequestType.RESPONSE, model_info,
|
model_info: ModelInfo, api_provider: APIProvider, client: BaseClient
|
||||||
message_list=[message],
|
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||||
temperature=temperature,
|
start_time = time.time()
|
||||||
max_tokens=max_tokens,
|
message_builder = MessageBuilder()
|
||||||
)
|
message_builder.add_text_content(prompt)
|
||||||
|
message_builder.add_image_content(
|
||||||
|
image_base64=image_base64,
|
||||||
|
image_format=normalized_format,
|
||||||
|
support_formats=client.get_support_image_formats(),
|
||||||
|
)
|
||||||
|
messages = [message_builder.build()]
|
||||||
|
|
||||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
response = await self._execute_request(
|
||||||
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
|
api_provider=api_provider,
|
||||||
reasoning = response.reasoning_content or reasoning
|
client=client,
|
||||||
|
request_type=RequestType.RESPONSE,
|
||||||
|
model_info=model_info,
|
||||||
|
message_list=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return content, (reasoning, model_info.name, response.tool_calls)
|
content = response.content or ""
|
||||||
|
reasoning_content = response.reasoning_content or ""
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
if not reasoning_content and content:
|
||||||
|
content, extracted_reasoning = self._extract_reasoning(content)
|
||||||
|
reasoning_content = extracted_reasoning
|
||||||
|
if usage := response.usage:
|
||||||
|
await llm_usage_recorder.record_usage_to_database(
|
||||||
|
model_info=model_info,
|
||||||
|
model_usage=usage,
|
||||||
|
user_id="system",
|
||||||
|
time_cost=time.time() - start_time,
|
||||||
|
request_type=self.request_type,
|
||||||
|
endpoint="/chat/completions",
|
||||||
|
)
|
||||||
|
return content, (reasoning_content, model_info.name, tool_calls)
|
||||||
|
|
||||||
|
result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常
|
||||||
|
raise RuntimeError("图片响应生成失败,所有模型均尝试失败。")
|
||||||
|
|
||||||
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
|
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
为语音生成响应(语音转文字)。
|
为语音生成响应(已集成故障转移)
|
||||||
使用故障转移策略来确保即使主模型失败也能获得结果。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
voice_base64 (str): 语音的Base64编码字符串。
|
voice_base64 (str): 语音的Base64编码字符串。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None。
|
(Optional[str]): 生成的文本描述或None
|
||||||
"""
|
"""
|
||||||
response, _ = await self._strategy.execute_with_failover(
|
|
||||||
RequestType.AUDIO, audio_base64=voice_base64
|
async def request_logic(model_info: ModelInfo, api_provider: APIProvider, client: BaseClient) -> Optional[str]:
|
||||||
)
|
"""定义单次请求的具体逻辑"""
|
||||||
return response.content or None
|
response = await self._execute_request(
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.AUDIO,
|
||||||
|
model_info=model_info,
|
||||||
|
audio_base64=voice_base64,
|
||||||
|
)
|
||||||
|
return response.content or None
|
||||||
|
|
||||||
|
# 对于语音识别,如果所有模型都失败,我们可能不希望程序崩溃,而是返回None
|
||||||
|
result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=False)
|
||||||
|
return result
|
||||||
|
|
||||||
async def generate_response_async(
|
async def generate_response_async(
|
||||||
self,
|
self,
|
||||||
@@ -856,7 +885,76 @@ class LLMRequest:
|
|||||||
raise e
|
raise e
|
||||||
return "所有并发请求都失败了", ("", "unknown", None)
|
return "所有并发请求都失败了", ("", "unknown", None)
|
||||||
|
|
||||||
async def _execute_single_text_request(
|
async def _execute_with_failover(
|
||||||
|
self,
|
||||||
|
request_callable: Callable[[ModelInfo, APIProvider, BaseClient], Coroutine[Any, Any, Any]],
|
||||||
|
raise_on_failure: bool = True,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
通用的故障转移执行器。
|
||||||
|
|
||||||
|
它会使用智能模型调度器按最优顺序尝试模型,直到请求成功或所有模型都失败。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_callable: 一个接收 (model_info, api_provider, client) 并返回协程的函数,
|
||||||
|
用于执行实际的请求逻辑。
|
||||||
|
raise_on_failure: 如果所有模型都失败,是否抛出异常。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
请求成功时的返回结果。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 如果所有模型都失败且 raise_on_failure 为 True。
|
||||||
|
"""
|
||||||
|
failed_models = set()
|
||||||
|
last_exception: Optional[Exception] = None
|
||||||
|
|
||||||
|
# model_scheduler 现在会动态排序,所以我们只需要在循环中处理失败的模型
|
||||||
|
while True:
|
||||||
|
model_scheduler = self._model_scheduler(failed_models)
|
||||||
|
try:
|
||||||
|
model_info, api_provider, client = next(model_scheduler)
|
||||||
|
except StopIteration:
|
||||||
|
# 没有更多可用模型了
|
||||||
|
break
|
||||||
|
|
||||||
|
model_name = model_info.name
|
||||||
|
logger.debug(f"正在尝试使用模型: {model_name} (剩余可用: {len(self.model_for_task.model_list) - len(failed_models)})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行传入的请求函数
|
||||||
|
result = await request_callable(model_info, api_provider, client)
|
||||||
|
logger.debug(f"模型 '{model_name}' 成功生成回复。")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except RespNotOkException as e:
|
||||||
|
# 对于某些致命的HTTP错误(如认证失败),我们可能希望立即失败或标记该模型为永久失败
|
||||||
|
if e.status_code in [401, 403]:
|
||||||
|
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将永久禁用此模型在此次请求中。")
|
||||||
|
else:
|
||||||
|
logger.warning(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code},将尝试下一个模型。")
|
||||||
|
failed_models.add(model_name)
|
||||||
|
last_exception = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获其他所有异常(包括超时、解析错误、运行时错误等)
|
||||||
|
logger.error(f"使用模型 '{model_name}' 时发生异常: {e},将尝试下一个模型。")
|
||||||
|
failed_models.add(model_name)
|
||||||
|
last_exception = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 所有模型都尝试失败
|
||||||
|
logger.error("所有可用模型都已尝试失败。")
|
||||||
|
if raise_on_failure:
|
||||||
|
if last_exception:
|
||||||
|
raise RuntimeError("所有模型都请求失败") from last_exception
|
||||||
|
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
|
||||||
|
|
||||||
|
# 根据需要返回一个默认的错误结果
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _execute_single_request(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
@@ -865,92 +963,283 @@ class LLMRequest:
|
|||||||
raise_when_empty: bool = True,
|
raise_when_empty: bool = True,
|
||||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||||
"""
|
"""
|
||||||
执行单次文本生成请求的内部方法。
|
使用通用的故障转移执行器来执行单次文本生成请求。
|
||||||
这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期,
|
|
||||||
包括工具构建、故障转移执行和用量记录。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): 用户的提示。
|
|
||||||
temperature (Optional[float]): 生成温度。
|
|
||||||
max_tokens (Optional[int]): 最大生成令牌数。
|
|
||||||
tools (Optional[List[Dict[str, Any]]]): 可用工具列表。
|
|
||||||
raise_when_empty (bool): 如果响应为空是否引发异常。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
|
||||||
(响应内容, (推理过程, 模型名称, 工具调用))
|
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
|
||||||
tool_options = self._build_tool_options(tools)
|
|
||||||
|
|
||||||
response, model_info = await self._strategy.execute_with_failover(
|
async def request_logic(
|
||||||
RequestType.RESPONSE,
|
model_info: ModelInfo, api_provider: APIProvider, client: BaseClient
|
||||||
raise_when_empty=raise_when_empty,
|
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||||
prompt=prompt, # 传递原始prompt,由strategy处理
|
"""定义单次请求的具体逻辑"""
|
||||||
tool_options=tool_options,
|
start_time = time.time()
|
||||||
temperature=self.model_for_task.temperature if temperature is None else temperature,
|
model_name = model_info.name
|
||||||
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
|
||||||
|
# 检查是否启用反截断
|
||||||
|
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
|
||||||
|
processed_prompt = prompt
|
||||||
|
if use_anti_truncation:
|
||||||
|
processed_prompt += self.anti_truncation_instruction
|
||||||
|
logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。")
|
||||||
|
|
||||||
|
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
|
||||||
|
|
||||||
|
message_builder = MessageBuilder()
|
||||||
|
message_builder.add_text_content(processed_prompt)
|
||||||
|
messages = [message_builder.build()]
|
||||||
|
tool_built = self._build_tool_options(tools)
|
||||||
|
|
||||||
|
# 针对当前模型的空回复/截断重试逻辑
|
||||||
|
empty_retry_count = 0
|
||||||
|
max_empty_retry = api_provider.max_retry
|
||||||
|
empty_retry_interval = api_provider.retry_interval
|
||||||
|
|
||||||
|
is_empty_reply = False
|
||||||
|
is_truncated = False
|
||||||
|
|
||||||
|
while empty_retry_count <= max_empty_retry:
|
||||||
|
response = await self._execute_request(
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.RESPONSE,
|
||||||
|
model_info=model_info,
|
||||||
|
message_list=messages,
|
||||||
|
tool_options=tool_built,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.content or ""
|
||||||
|
reasoning_content = response.reasoning_content or ""
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
|
if not reasoning_content and content:
|
||||||
|
content, extracted_reasoning = self._extract_reasoning(content)
|
||||||
|
reasoning_content = extracted_reasoning
|
||||||
|
|
||||||
|
is_empty_reply = not tool_calls and (not content or content.strip() == "")
|
||||||
|
is_truncated = False
|
||||||
|
if use_anti_truncation:
|
||||||
|
if content.endswith(self.end_marker):
|
||||||
|
content = content[: -len(self.end_marker)].strip()
|
||||||
|
else:
|
||||||
|
is_truncated = True
|
||||||
|
|
||||||
|
if not is_empty_reply and not is_truncated:
|
||||||
|
# 成功获取响应
|
||||||
|
if usage := response.usage:
|
||||||
|
llm_usage_recorder.record_usage_to_database(
|
||||||
|
model_info=model_info,
|
||||||
|
model_usage=usage,
|
||||||
|
time_cost=time.time() - start_time,
|
||||||
|
user_id="system",
|
||||||
|
request_type=self.request_type,
|
||||||
|
endpoint="/chat/completions",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content and not tool_calls:
|
||||||
|
if raise_when_empty:
|
||||||
|
raise RuntimeError("生成空回复")
|
||||||
|
content = "生成的响应为空"
|
||||||
|
|
||||||
|
return content, (reasoning_content, model_name, tool_calls)
|
||||||
|
|
||||||
|
# 如果代码执行到这里,说明是空回复或截断,需要重试
|
||||||
|
empty_retry_count += 1
|
||||||
|
if empty_retry_count <= max_empty_retry:
|
||||||
|
reason = "空回复" if is_empty_reply else "截断"
|
||||||
|
logger.warning(
|
||||||
|
f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..."
|
||||||
|
)
|
||||||
|
if empty_retry_interval > 0:
|
||||||
|
await asyncio.sleep(empty_retry_interval)
|
||||||
|
continue # 继续使用当前模型重试
|
||||||
|
|
||||||
|
# 如果循环结束,说明重试次数已用尽
|
||||||
|
reason = "空回复" if is_empty_reply else "截断"
|
||||||
|
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。")
|
||||||
|
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
|
||||||
|
|
||||||
|
# 调用通用的故障转移执行器
|
||||||
|
result = await self._execute_with_failover(
|
||||||
|
request_callable=request_logic, raise_on_failure=raise_when_empty
|
||||||
)
|
)
|
||||||
|
|
||||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
if not response.content and not response.tool_calls:
|
# 如果所有模型都失败了,并且不抛出异常,返回一个默认的错误信息
|
||||||
if raise_when_empty:
|
return "所有模型都请求失败", ("", "unknown", None)
|
||||||
raise RuntimeError("所选模型生成了空回复。")
|
|
||||||
response.content = "生成的响应为空"
|
|
||||||
|
|
||||||
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
|
|
||||||
|
|
||||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||||
"""
|
"""获取嵌入向量(已集成故障转移)
|
||||||
获取嵌入向量。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding_input (str): 获取嵌入的目标
|
embedding_input (str): 获取嵌入的目标
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
|
||||||
response, model_info = await self._strategy.execute_with_failover(
|
async def request_logic(
|
||||||
RequestType.EMBEDDING,
|
model_info: ModelInfo, api_provider: APIProvider, client: BaseClient
|
||||||
embedding_input=embedding_input
|
) -> Tuple[List[float], str]:
|
||||||
|
"""定义单次请求的具体逻辑"""
|
||||||
|
start_time = time.time()
|
||||||
|
response = await self._execute_request(
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.EMBEDDING,
|
||||||
|
model_info=model_info,
|
||||||
|
embedding_input=embedding_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = response.embedding
|
||||||
|
if not embedding:
|
||||||
|
raise RuntimeError(f"模型 '{model_info.name}'未能返回 embedding。")
|
||||||
|
|
||||||
|
if usage := response.usage:
|
||||||
|
await llm_usage_recorder.record_usage_to_database(
|
||||||
|
model_info=model_info,
|
||||||
|
time_cost=time.time() - start_time,
|
||||||
|
model_usage=usage,
|
||||||
|
user_id="system",
|
||||||
|
request_type=self.request_type,
|
||||||
|
endpoint="/embeddings",
|
||||||
|
)
|
||||||
|
|
||||||
|
return embedding, model_info.name
|
||||||
|
|
||||||
|
result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常
|
||||||
|
raise RuntimeError("获取 embedding 失败,所有模型均尝试失败。")
|
||||||
|
|
||||||
|
def _model_scheduler(
|
||||||
|
self, failed_models: set | None = None
|
||||||
|
) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
|
||||||
|
"""
|
||||||
|
一个智能模型调度器,根据实时负载动态排序并提供模型,同时跳过已失败的模型。
|
||||||
|
"""
|
||||||
|
# sourcery skip: class-extract-method
|
||||||
|
if failed_models is None:
|
||||||
|
failed_models = set()
|
||||||
|
|
||||||
|
# 1. 筛选出所有未失败的可用模型
|
||||||
|
available_models = [name for name in self.model_for_task.model_list if name not in failed_models]
|
||||||
|
|
||||||
|
# 2. 根据负载均衡算法对可用模型进行排序
|
||||||
|
# key: total_tokens + penalty * 300 + usage_penalty * 1000
|
||||||
|
sorted_models = sorted(
|
||||||
|
available_models,
|
||||||
|
key=lambda name: self.model_usage[name][0]
|
||||||
|
+ self.model_usage[name][1] * 300
|
||||||
|
+ self.model_usage[name][2] * 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
if not sorted_models:
|
||||||
|
logger.warning("所有模型都已失败或不可用,调度器无法提供任何模型。")
|
||||||
|
return
|
||||||
|
|
||||||
if not response.embedding:
|
logger.debug(f"模型调度顺序: {', '.join(sorted_models)}")
|
||||||
raise RuntimeError("获取embedding失败")
|
|
||||||
|
|
||||||
return response.embedding, model_info.name
|
# 3. 按最优顺序 yield 模型信息
|
||||||
|
for model_name in sorted_models:
|
||||||
|
model_info = model_config.get_model_info(model_name)
|
||||||
|
api_provider = model_config.get_provider(model_info.api_provider)
|
||||||
|
force_new_client = self.request_type == "embedding"
|
||||||
|
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||||
|
yield model_info, api_provider, client
|
||||||
|
|
||||||
def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
async def _execute_request(
|
||||||
|
self,
|
||||||
|
api_provider: APIProvider,
|
||||||
|
client: BaseClient,
|
||||||
|
request_type: RequestType,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
message_list: List[Message] | None = None,
|
||||||
|
tool_options: list[ToolOption] | None = None,
|
||||||
|
response_format: RespFormat | None = None,
|
||||||
|
stream_response_handler: Optional[Callable] = None,
|
||||||
|
async_response_parser: Optional[Callable] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
embedding_input: str = "",
|
||||||
|
audio_base64: str = "",
|
||||||
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
记录模型使用情况。
|
实际执行请求的方法
|
||||||
|
|
||||||
此方法首先在内存中更新模型的累计token使用量,然后创建一个异步任务,
|
包含了重试和异常处理逻辑
|
||||||
将详细的用量数据(包括模型信息、token数、耗时等)写入数据库。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_info (ModelInfo): 使用的模型信息。
|
|
||||||
usage (Optional[UsageRecord]): API返回的用量记录。
|
|
||||||
time_cost (float): 本次请求的总耗时。
|
|
||||||
endpoint (str): 请求的API端点 (e.g., "/chat/completions")。
|
|
||||||
"""
|
"""
|
||||||
if usage:
|
retry_remain = api_provider.max_retry
|
||||||
# 步骤1: 更新内存中的token计数,用于负载均衡
|
compressed_messages: Optional[List[Message]] = None
|
||||||
|
|
||||||
|
# 增加使用惩罚值,标记该模型正在被尝试
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while retry_remain > 0:
|
||||||
|
try:
|
||||||
|
if request_type == RequestType.RESPONSE:
|
||||||
|
assert message_list is not None, "message_list cannot be None for response requests"
|
||||||
|
return await client.get_response(
|
||||||
|
model_info=model_info,
|
||||||
|
message_list=(compressed_messages or message_list),
|
||||||
|
tool_options=tool_options,
|
||||||
|
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
||||||
|
temperature=self.model_for_task.temperature if temperature is None else temperature,
|
||||||
|
response_format=response_format,
|
||||||
|
stream_response_handler=stream_response_handler,
|
||||||
|
async_response_parser=async_response_parser,
|
||||||
|
extra_params=model_info.extra_params,
|
||||||
|
)
|
||||||
|
elif request_type == RequestType.EMBEDDING:
|
||||||
|
assert embedding_input, "embedding_input cannot be empty for embedding requests"
|
||||||
|
return await client.get_embedding(
|
||||||
|
model_info=model_info,
|
||||||
|
embedding_input=embedding_input,
|
||||||
|
extra_params=model_info.extra_params,
|
||||||
|
)
|
||||||
|
elif request_type == RequestType.AUDIO:
|
||||||
|
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
|
||||||
|
return await client.get_audio_transcriptions(
|
||||||
|
model_info=model_info,
|
||||||
|
audio_base64=audio_base64,
|
||||||
|
extra_params=model_info.extra_params,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"请求失败: {str(e)}")
|
||||||
|
# 处理异常
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
|
||||||
|
|
||||||
|
wait_interval, compressed_messages = self._default_exception_handler(
|
||||||
|
e,
|
||||||
|
self.task_name,
|
||||||
|
model_info=model_info,
|
||||||
|
api_provider=api_provider,
|
||||||
|
remain_try=retry_remain,
|
||||||
|
retry_interval=api_provider.retry_interval,
|
||||||
|
messages=(message_list, compressed_messages is not None) if message_list else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if wait_interval == -1:
|
||||||
|
retry_remain = 0 # 不再重试
|
||||||
|
elif wait_interval > 0:
|
||||||
|
logger.info(f"等待 {wait_interval} 秒后重试...")
|
||||||
|
await asyncio.sleep(wait_interval)
|
||||||
|
finally:
|
||||||
|
# 放在finally防止死循环
|
||||||
|
retry_remain -= 1
|
||||||
|
|
||||||
|
# 当请求完全结束(无论是成功还是所有重试都失败),都将在此处处理
|
||||||
|
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||||
|
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||||
|
finally:
|
||||||
|
# 无论请求成功或失败,最终都将使用惩罚值减回去
|
||||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
|
||||||
|
|
||||||
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库
|
|
||||||
asyncio.create_task(llm_usage_recorder.record_usage_to_database(
|
|
||||||
model_info=model_info,
|
|
||||||
model_usage=usage,
|
|
||||||
user_id="system", # 此处可根据业务需求修改
|
|
||||||
time_cost=time_cost,
|
|
||||||
request_type=self.task_name,
|
|
||||||
endpoint=endpoint,
|
|
||||||
))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||||
|
|||||||
Reference in New Issue
Block a user