refactor(llm_models): 重构并发请求逻辑以提高代码复用性

将并发请求的通用逻辑提取到一个新的 `execute_concurrently` 辅助函数中。此举简化了 `LLMRequest.get_response` 方法,使其更易于阅读和维护。

现在,`get_response` 方法调用 `execute_concurrently` 来处理并发执行,而不是在方法内部直接管理任务创建和结果收集。同时,改进了单个请求失败时的异常处理和重试逻辑,使其在并发和非并发模式下都更加健壮。

Co-authored-by: 雅诺狐 <foxcyber907@users.noreply.github.com>
This commit is contained in:
minecraft1024a
2025-08-17 12:12:12 +08:00
parent 5e2485dde0
commit 95bbcaff18

View File

@@ -5,7 +5,7 @@ import random
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
from src.common.logger import get_logger
from src.config.config import model_config
@@ -74,6 +74,50 @@ class RequestType(Enum):
AUDIO = "audio"
async def execute_concurrently(
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
concurrency_count: int,
*args,
**kwargs,
) -> Any:
"""
执行并发请求并从成功的结果中随机选择一个。
Args:
coro_callable (Callable): 要并发执行的协程函数。
concurrency_count (int): 并发执行的次数。
*args: 传递给协程函数的位置参数。
**kwargs: 传递给协程函数的关键字参数。
Returns:
Any: 其中一个成功执行的结果。
Raises:
RuntimeError: 如果所有并发请求都失败。
"""
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_results = [res for res in results if not isinstance(res, Exception)]
if successful_results:
selected = random.choice(successful_results)
logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
return selected
# 如果所有请求都失败了,记录所有异常并抛出第一个
for i, res in enumerate(results):
if isinstance(res, Exception):
logger.error(f"并发任务 {i+1}/{concurrency_count} 失败: {res}")
first_exception = next((res for res in results if isinstance(res, Exception)), None)
if first_exception:
raise first_exception
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
class LLMRequest:
"""LLM请求类"""
@@ -194,43 +238,31 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
start_time = time.time()
# 检查是否需要并发请求
concurrency_count = getattr(self.model_for_task, 'concurrency_count', 1)
concurrency_count = getattr(self.model_for_task, "concurrency_count", 1)
if concurrency_count <= 1:
# 单次请求,原有逻辑
# 单次请求
return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty)
# 并发请求
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
tasks = [
self._execute_single_request(prompt, temperature, max_tokens, tools, False)
for _ in range(concurrency_count)
]
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_results = []
for result in results:
if not isinstance(result, Exception):
successful_results.append(result)
if successful_results:
# 随机选择一个成功结果
selected = random.choice(successful_results) if len(successful_results) > 1 else successful_results[0]
logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
return selected
elif raise_when_empty:
raise RuntimeError(f"所有{concurrency_count}个并发请求都失败了")
else:
return "所有并发请求都失败了", ("", "unknown", None)
# 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False,
# 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理
return await execute_concurrently(
self._execute_single_request,
concurrency_count,
prompt,
temperature,
max_tokens,
tools,
raise_when_empty=False,
)
except Exception as e:
logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
if raise_when_empty:
raise e
return "并发请求异常", ("", "unknown", None)
return "所有并发请求都失败了", ("", "unknown", None)
async def _execute_single_request(
self,
@@ -264,33 +296,31 @@ class LLMRequest:
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
tool_options=tool_built,
temperature=temperature,
max_tokens=max_tokens,
tool_options=tool_built,
)
content = response.content
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
# 检测是否为空回复
is_empty_reply = not content or content.strip() == ""
if is_empty_reply and empty_retry_count < max_empty_retry:
empty_retry_count += 1
logger.warning(f"检测到空回复,正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
model_info, api_provider, client = self._select_model()
continue
# 记录使用情况
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
@@ -301,7 +331,7 @@ class LLMRequest:
request_type=self.request_type,
endpoint="/chat/completions",
)
# 处理空回复
if not content:
if raise_when_empty:
@@ -311,16 +341,28 @@ class LLMRequest:
logger.info(f"经过 {empty_retry_count} 次重试后成功生成回复")
return content, (reasoning_content, model_info.name, tool_calls)
except Exception as e:
if empty_retry_count == 0:
raise e
else:
logger.error(f"重试过程中出错: {e}")
logger.error(f"请求执行失败: {e}")
if raise_when_empty:
# 在非并发模式下,如果第一次尝试就失败,则直接抛出异常
if empty_retry_count == 0:
raise
# 如果在重试过程中失败,则继续重试
empty_retry_count += 1
if empty_retry_count <= max_empty_retry and empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue
if empty_retry_count <= max_empty_retry:
logger.warning(f"请求失败,将在 {empty_retry_interval} 秒后进行第 {empty_retry_count}/{max_empty_retry} 次重试...")
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue
else:
logger.error(f"经过 {max_empty_retry} 次重试后仍然失败")
raise RuntimeError(f"经过 {max_empty_retry} 次重试后仍然无法生成有效回复") from e
else:
# 在并发模式下,单个请求的失败不应中断整个并发流程,
# 而是将异常返回给调用者(即 execute_concurrently进行统一处理
raise # 重新抛出异常,由 execute_concurrently 中的 gather 捕获
# 重试失败
if raise_when_empty: