refactor(llm_models): 重构并发请求逻辑以提高代码复用性
将并发请求的通用逻辑提取到一个新的 `execute_concurrently` 辅助函数中。此举简化了 `LLMRequest.get_response` 方法,使其更易于阅读和维护。 现在,`get_response` 方法调用 `execute_concurrently` 来处理并发执行,而不是在方法内部直接管理任务创建和结果收集。同时,改进了单个请求失败时的异常处理和重试逻辑,使其在并发和非并发模式下都更加健壮。 Co-authored-by: 雅诺狐 <foxcyber907@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import random
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from typing import Tuple, List, Dict, Optional, Callable, Any
|
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
@@ -74,6 +74,50 @@ class RequestType(Enum):
|
|||||||
AUDIO = "audio"
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_concurrently(
|
||||||
|
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
|
||||||
|
concurrency_count: int,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
执行并发请求并从成功的结果中随机选择一个。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coro_callable (Callable): 要并发执行的协程函数。
|
||||||
|
concurrency_count (int): 并发执行的次数。
|
||||||
|
*args: 传递给协程函数的位置参数。
|
||||||
|
**kwargs: 传递给协程函数的关键字参数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 其中一个成功执行的结果。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 如果所有并发请求都失败。
|
||||||
|
"""
|
||||||
|
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
|
||||||
|
tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
successful_results = [res for res in results if not isinstance(res, Exception)]
|
||||||
|
|
||||||
|
if successful_results:
|
||||||
|
selected = random.choice(successful_results)
|
||||||
|
logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
|
||||||
|
return selected
|
||||||
|
|
||||||
|
# 如果所有请求都失败了,记录所有异常并抛出第一个
|
||||||
|
for i, res in enumerate(results):
|
||||||
|
if isinstance(res, Exception):
|
||||||
|
logger.error(f"并发任务 {i+1}/{concurrency_count} 失败: {res}")
|
||||||
|
|
||||||
|
first_exception = next((res for res in results if isinstance(res, Exception)), None)
|
||||||
|
if first_exception:
|
||||||
|
raise first_exception
|
||||||
|
|
||||||
|
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
|
||||||
|
|
||||||
|
|
||||||
class LLMRequest:
|
class LLMRequest:
|
||||||
"""LLM请求类"""
|
"""LLM请求类"""
|
||||||
|
|
||||||
@@ -194,43 +238,31 @@ class LLMRequest:
|
|||||||
Returns:
|
Returns:
|
||||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# 检查是否需要并发请求
|
# 检查是否需要并发请求
|
||||||
concurrency_count = getattr(self.model_for_task, 'concurrency_count', 1)
|
concurrency_count = getattr(self.model_for_task, "concurrency_count", 1)
|
||||||
|
|
||||||
if concurrency_count <= 1:
|
if concurrency_count <= 1:
|
||||||
# 单次请求,原有逻辑
|
# 单次请求
|
||||||
return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty)
|
return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty)
|
||||||
|
|
||||||
# 并发请求
|
# 并发请求
|
||||||
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
|
|
||||||
tasks = [
|
|
||||||
self._execute_single_request(prompt, temperature, max_tokens, tools, False)
|
|
||||||
for _ in range(concurrency_count)
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
# 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False,
|
||||||
successful_results = []
|
# 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理
|
||||||
for result in results:
|
return await execute_concurrently(
|
||||||
if not isinstance(result, Exception):
|
self._execute_single_request,
|
||||||
successful_results.append(result)
|
concurrency_count,
|
||||||
|
prompt,
|
||||||
if successful_results:
|
temperature,
|
||||||
# 随机选择一个成功结果
|
max_tokens,
|
||||||
selected = random.choice(successful_results) if len(successful_results) > 1 else successful_results[0]
|
tools,
|
||||||
logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
|
raise_when_empty=False,
|
||||||
return selected
|
)
|
||||||
elif raise_when_empty:
|
|
||||||
raise RuntimeError(f"所有{concurrency_count}个并发请求都失败了")
|
|
||||||
else:
|
|
||||||
return "所有并发请求都失败了", ("", "unknown", None)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
|
||||||
if raise_when_empty:
|
if raise_when_empty:
|
||||||
raise e
|
raise e
|
||||||
return "并发请求异常", ("", "unknown", None)
|
return "所有并发请求都失败了", ("", "unknown", None)
|
||||||
|
|
||||||
async def _execute_single_request(
|
async def _execute_single_request(
|
||||||
self,
|
self,
|
||||||
@@ -264,15 +296,13 @@ class LLMRequest:
|
|||||||
request_type=RequestType.RESPONSE,
|
request_type=RequestType.RESPONSE,
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
message_list=messages,
|
message_list=messages,
|
||||||
|
tool_options=tool_built,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
tool_options=tool_built,
|
|
||||||
)
|
)
|
||||||
|
content = response.content or ""
|
||||||
content = response.content
|
|
||||||
reasoning_content = response.reasoning_content or ""
|
reasoning_content = response.reasoning_content or ""
|
||||||
tool_calls = response.tool_calls
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
# 从内容中提取<think>标签的推理内容(向后兼容)
|
# 从内容中提取<think>标签的推理内容(向后兼容)
|
||||||
if not reasoning_content and content:
|
if not reasoning_content and content:
|
||||||
content, extracted_reasoning = self._extract_reasoning(content)
|
content, extracted_reasoning = self._extract_reasoning(content)
|
||||||
@@ -313,14 +343,26 @@ class LLMRequest:
|
|||||||
return content, (reasoning_content, model_info.name, tool_calls)
|
return content, (reasoning_content, model_info.name, tool_calls)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"请求执行失败: {e}")
|
||||||
|
if raise_when_empty:
|
||||||
|
# 在非并发模式下,如果第一次尝试就失败,则直接抛出异常
|
||||||
if empty_retry_count == 0:
|
if empty_retry_count == 0:
|
||||||
raise e
|
raise
|
||||||
else:
|
|
||||||
logger.error(f"重试过程中出错: {e}")
|
# 如果在重试过程中失败,则继续重试
|
||||||
empty_retry_count += 1
|
empty_retry_count += 1
|
||||||
if empty_retry_count <= max_empty_retry and empty_retry_interval > 0:
|
if empty_retry_count <= max_empty_retry:
|
||||||
|
logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...")
|
||||||
|
if empty_retry_interval > 0:
|
||||||
await asyncio.sleep(empty_retry_interval)
|
await asyncio.sleep(empty_retry_interval)
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(f"经过 {max_empty_retry} 次重试后仍然失败")
|
||||||
|
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e
|
||||||
|
else:
|
||||||
|
# 在并发模式下,单个请求的失败不应中断整个并发流程,
|
||||||
|
# 而是将异常返回给调用者(即 execute_concurrently)进行统一处理
|
||||||
|
raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获
|
||||||
|
|
||||||
# 重试失败
|
# 重试失败
|
||||||
if raise_when_empty:
|
if raise_when_empty:
|
||||||
|
|||||||
Reference in New Issue
Block a user