feat(client): 优化连接池配置以支持高并发embedding请求

refactor(request): 移除全局锁,改用信号量控制并发度
This commit is contained in:
Windpicker-owo
2025-11-09 20:16:47 +08:00
parent 58b746f217
commit ee44b02f93
2 changed files with 44 additions and 27 deletions

View File

@@ -433,11 +433,22 @@ class OpenaiClient(BaseClient):
f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash}, loop_id={current_loop_id})" f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash}, loop_id={current_loop_id})"
) )
# 🔧 优化增加连接池限制支持高并发embedding请求
# 默认httpx限制为100对于高频embedding场景不够用
import httpx
limits = httpx.Limits(
max_keepalive_connections=200, # 保持活跃连接数原100
max_connections=300, # 最大总连接数原100
keepalive_expiry=30.0, # 连接保活时间
)
client = AsyncOpenAI( client = AsyncOpenAI(
base_url=self.api_provider.base_url, base_url=self.api_provider.base_url,
api_key=self.api_provider.get_api_key(), api_key=self.api_provider.get_api_key(),
max_retries=0, max_retries=0,
timeout=self.api_provider.timeout, timeout=self.api_provider.timeout,
http_client=httpx.AsyncClient(limits=limits), # 🔧 自定义连接池配置
) )
# 存入全局缓存带事件循环ID # 存入全局缓存带事件循环ID

View File

@@ -802,7 +802,11 @@ class LLMRequest:
for model in self.model_for_task.model_list for model in self.model_for_task.model_list
} }
"""模型使用量记录""" """模型使用量记录"""
self._lock = asyncio.Lock() # 🔧 优化:移除全局锁,改用信号量控制并发度(允许多个请求并行)
# 默认允许50个并发请求可通过配置调整
max_concurrent = getattr(model_set, "max_concurrent_requests", 50)
self._semaphore = asyncio.Semaphore(max_concurrent)
self._stats_lock = asyncio.Lock() # 只保护统计数据的写入
# 初始化辅助类 # 初始化辅助类
self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage) self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
@@ -931,23 +935,24 @@ class LLMRequest:
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
raise_when_empty: bool = True, raise_when_empty: bool = True,
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]: ) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
async with self._lock: """
""" 执行单次文本生成请求的内部方法。
执行单次文本生成请求的内部方法。 这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期,
这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期, 包括工具构建、故障转移执行和用量记录。
包括工具构建、故障转移执行和用量记录。
Args: Args:
prompt (str): 用户的提示。 prompt (str): 用户的提示。
temperature (Optional[float]): 生成温度。 temperature (Optional[float]): 生成温度。
max_tokens (Optional[int]): 最大生成令牌数。 max_tokens (Optional[int]): 最大生成令牌数。
tools (Optional[List[Dict[str, Any]]]): 可用工具列表。 tools (Optional[List[Dict[str, Any]]]): 可用工具列表。
raise_when_empty (bool): 如果响应为空是否引发异常。 raise_when_empty (bool): 如果响应为空是否引发异常。
Returns: Returns:
Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
(响应内容, (推理过程, 模型名称, 工具调用)) (响应内容, (推理过程, 模型名称, 工具调用))
""" """
# 🔧 优化:使用信号量控制并发,允许多个请求并行执行
async with self._semaphore:
start_time = time.time() start_time = time.time()
tool_options = await self._build_tool_options(tools) tool_options = await self._build_tool_options(tools)
@@ -1006,20 +1011,21 @@ class LLMRequest:
endpoint (str): 请求的API端点 (e.g., "/chat/completions")。 endpoint (str): 请求的API端点 (e.g., "/chat/completions")。
""" """
if usage: if usage:
# 步骤1: 更新内存中的统计数据,用于负载均衡 # 步骤1: 更新内存中的统计数据,用于负载均衡(需要加锁保护)
stats = self.model_usage[model_info.name] async with self._stats_lock:
stats = self.model_usage[model_info.name]
# 计算新的平均延迟 # 计算新的平均延迟
new_request_count = stats.request_count + 1 new_request_count = stats.request_count + 1
new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count
self.model_usage[model_info.name] = stats._replace( self.model_usage[model_info.name] = stats._replace(
total_tokens=stats.total_tokens + usage.total_tokens, total_tokens=stats.total_tokens + usage.total_tokens,
avg_latency=new_avg_latency, avg_latency=new_avg_latency,
request_count=new_request_count, request_count=new_request_count,
) )
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库 # 步骤2: 创建一个后台任务,将用量数据异步写入数据库(无需等待)
asyncio.create_task( # noqa: RUF006 asyncio.create_task( # noqa: RUF006
llm_usage_recorder.record_usage_to_database( llm_usage_recorder.record_usage_to_database(
model_info=model_info, model_info=model_info,