refactor(llm_models): 重构并发请求逻辑以提高代码复用性
将并发请求的通用逻辑提取到一个新的 `execute_concurrently` 辅助函数中。此举简化了 `LLMRequest.get_response` 方法,使其更易于阅读和维护。 现在,`get_response` 方法调用 `execute_concurrently` 来处理并发执行,而不是在方法内部直接管理任务创建和结果收集。同时,改进了单个请求失败时的异常处理和重试逻辑,使其在并发和非并发模式下都更加健壮。 Co-authored-by: 雅诺狐 <foxcyber907@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import random
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
@@ -74,6 +74,50 @@ class RequestType(Enum):
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
async def execute_concurrently(
|
||||
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
|
||||
concurrency_count: int,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
执行并发请求并从成功的结果中随机选择一个。
|
||||
|
||||
Args:
|
||||
coro_callable (Callable): 要并发执行的协程函数。
|
||||
concurrency_count (int): 并发执行的次数。
|
||||
*args: 传递给协程函数的位置参数。
|
||||
**kwargs: 传递给协程函数的关键字参数。
|
||||
|
||||
Returns:
|
||||
Any: 其中一个成功执行的结果。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果所有并发请求都失败。
|
||||
"""
|
||||
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
|
||||
tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
successful_results = [res for res in results if not isinstance(res, Exception)]
|
||||
|
||||
if successful_results:
|
||||
selected = random.choice(successful_results)
|
||||
logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
|
||||
return selected
|
||||
|
||||
# 如果所有请求都失败了,记录所有异常并抛出第一个
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"并发任务 {i+1}/{concurrency_count} 失败: {res}")
|
||||
|
||||
first_exception = next((res for res in results if isinstance(res, Exception)), None)
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
|
||||
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""LLM请求类"""
|
||||
|
||||
@@ -194,43 +238,31 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 检查是否需要并发请求
|
||||
concurrency_count = getattr(self.model_for_task, 'concurrency_count', 1)
|
||||
|
||||
concurrency_count = getattr(self.model_for_task, "concurrency_count", 1)
|
||||
|
||||
if concurrency_count <= 1:
|
||||
# 单次请求,原有逻辑
|
||||
# 单次请求
|
||||
return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty)
|
||||
|
||||
|
||||
# 并发请求
|
||||
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
|
||||
tasks = [
|
||||
self._execute_single_request(prompt, temperature, max_tokens, tools, False)
|
||||
for _ in range(concurrency_count)
|
||||
]
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
successful_results = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
successful_results.append(result)
|
||||
|
||||
if successful_results:
|
||||
# 随机选择一个成功结果
|
||||
selected = random.choice(successful_results) if len(successful_results) > 1 else successful_results[0]
|
||||
logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
|
||||
return selected
|
||||
elif raise_when_empty:
|
||||
raise RuntimeError(f"所有{concurrency_count}个并发请求都失败了")
|
||||
else:
|
||||
return "所有并发请求都失败了", ("", "unknown", None)
|
||||
|
||||
# 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False,
|
||||
# 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理
|
||||
return await execute_concurrently(
|
||||
self._execute_single_request,
|
||||
concurrency_count,
|
||||
prompt,
|
||||
temperature,
|
||||
max_tokens,
|
||||
tools,
|
||||
raise_when_empty=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
|
||||
if raise_when_empty:
|
||||
raise e
|
||||
return "并发请求异常", ("", "unknown", None)
|
||||
return "所有并发请求都失败了", ("", "unknown", None)
|
||||
|
||||
async def _execute_single_request(
|
||||
self,
|
||||
@@ -264,33 +296,31 @@ class LLMRequest:
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
tool_options=tool_built,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
|
||||
content = response.content
|
||||
content = response.content or ""
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
# 从内容中提取<think>标签的推理内容(向后兼容)
|
||||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
|
||||
|
||||
# 检测是否为空回复
|
||||
is_empty_reply = not content or content.strip() == ""
|
||||
|
||||
|
||||
if is_empty_reply and empty_retry_count < max_empty_retry:
|
||||
empty_retry_count += 1
|
||||
logger.warning(f"检测到空回复,正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
|
||||
|
||||
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
|
||||
|
||||
model_info, api_provider, client = self._select_model()
|
||||
continue
|
||||
|
||||
|
||||
# 记录使用情况
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
@@ -301,7 +331,7 @@ class LLMRequest:
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
)
|
||||
|
||||
|
||||
# 处理空回复
|
||||
if not content:
|
||||
if raise_when_empty:
|
||||
@@ -311,16 +341,28 @@ class LLMRequest:
|
||||
logger.info(f"经过 {empty_retry_count} 次重试后成功生成回复")
|
||||
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if empty_retry_count == 0:
|
||||
raise e
|
||||
else:
|
||||
logger.error(f"重试过程中出错: {e}")
|
||||
logger.error(f"请求执行失败: {e}")
|
||||
if raise_when_empty:
|
||||
# 在非并发模式下,如果第一次尝试就失败,则直接抛出异常
|
||||
if empty_retry_count == 0:
|
||||
raise
|
||||
|
||||
# 如果在重试过程中失败,则继续重试
|
||||
empty_retry_count += 1
|
||||
if empty_retry_count <= max_empty_retry and empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
continue
|
||||
if empty_retry_count <= max_empty_retry:
|
||||
logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...")
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"经过 {max_empty_retry} 次重试后仍然失败")
|
||||
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e
|
||||
else:
|
||||
# 在并发模式下,单个请求的失败不应中断整个并发流程,
|
||||
# 而是将异常返回给调用者(即 execute_concurrently)进行统一处理
|
||||
raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获
|
||||
|
||||
# 重试失败
|
||||
if raise_when_empty:
|
||||
|
||||
Reference in New Issue
Block a user