ruff
This commit is contained in:
@@ -189,11 +189,11 @@ class ClientRegistry:
|
||||
bool: 事件循环是否变化
|
||||
"""
|
||||
current_loop_id = self._get_current_loop_id()
|
||||
|
||||
|
||||
# 如果没有缓存的循环ID,说明是首次创建
|
||||
if provider_name not in self._event_loop_cache:
|
||||
return False
|
||||
|
||||
|
||||
# 比较当前循环ID与缓存的循环ID
|
||||
cached_loop_id = self._event_loop_cache[provider_name]
|
||||
return current_loop_id != cached_loop_id
|
||||
@@ -208,7 +208,7 @@ class ClientRegistry:
|
||||
BaseClient: 注册的API客户端实例
|
||||
"""
|
||||
provider_name = api_provider.name
|
||||
|
||||
|
||||
# 如果强制创建新实例,直接创建不使用缓存
|
||||
if force_new:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
@@ -224,7 +224,7 @@ class ClientRegistry:
|
||||
# 事件循环已变化,需要重新创建实例
|
||||
logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例")
|
||||
self._loop_change_count += 1
|
||||
|
||||
|
||||
# 移除旧实例
|
||||
if provider_name in self.client_instance_cache:
|
||||
del self.client_instance_cache[provider_name]
|
||||
@@ -237,7 +237,7 @@ class ClientRegistry:
|
||||
self._event_loop_cache[provider_name] = self._get_current_loop_id()
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
|
||||
return self.client_instance_cache[provider_name]
|
||||
|
||||
def get_cache_stats(self) -> dict:
|
||||
|
||||
@@ -50,14 +50,16 @@ def _convert_messages_to_mcp(messages: list[Message]) -> list[dict[str, Any]]:
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple):
|
||||
# 图片内容
|
||||
content_parts.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": f"image/{item[0].lower()}",
|
||||
"data": item[1],
|
||||
},
|
||||
})
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": f"image/{item[0].lower()}",
|
||||
"data": item[1],
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
# 文本内容
|
||||
content_parts.append({"type": "text", "text": item})
|
||||
@@ -138,9 +140,7 @@ async def _parse_sse_stream(
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RespNotOkException(
|
||||
response.status, f"MCP SSE请求失败: {error_text}"
|
||||
)
|
||||
raise RespNotOkException(response.status, f"MCP SSE请求失败: {error_text}")
|
||||
|
||||
# 解析SSE流
|
||||
async for line in response.content:
|
||||
@@ -258,10 +258,7 @@ async def _parse_sse_stream(
|
||||
response.reasoning_content = reasoning_buffer.getvalue()
|
||||
|
||||
if tool_calls_buffer:
|
||||
response.tool_calls = [
|
||||
ToolCall(call_id, func_name, args)
|
||||
for call_id, func_name, args in tool_calls_buffer
|
||||
]
|
||||
response.tool_calls = [ToolCall(call_id, func_name, args) for call_id, func_name, args in tool_calls_buffer]
|
||||
|
||||
# 关闭缓冲区
|
||||
content_buffer.close()
|
||||
@@ -351,9 +348,7 @@ class MCPSSEClient(BaseClient):
|
||||
url = f"{self.api_provider.base_url}/v1/messages"
|
||||
|
||||
try:
|
||||
response, usage_record = await _parse_sse_stream(
|
||||
session, url, payload, headers, interrupt_flag
|
||||
)
|
||||
response, usage_record = await _parse_sse_stream(session, url, payload, headers, interrupt_flag)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP SSE请求失败: {e}")
|
||||
raise
|
||||
|
||||
@@ -378,7 +378,7 @@ class OpenaiClient(BaseClient):
|
||||
# 类级别的全局缓存:所有 OpenaiClient 实例共享
|
||||
_global_client_cache: dict[int, AsyncOpenAI] = {}
|
||||
"""全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例"""
|
||||
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
super().__init__(api_provider)
|
||||
self._config_hash = self._calculate_config_hash()
|
||||
@@ -396,33 +396,31 @@ class OpenaiClient(BaseClient):
|
||||
def _create_client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
获取或创建 OpenAI 客户端实例(全局缓存)
|
||||
|
||||
|
||||
多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout),
|
||||
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
|
||||
"""
|
||||
# 检查全局缓存
|
||||
if self._config_hash in self._global_client_cache:
|
||||
return self._global_client_cache[self._config_hash]
|
||||
|
||||
|
||||
# 创建新的 AsyncOpenAI 实例
|
||||
logger.debug(
|
||||
f"创建新的 AsyncOpenAI 客户端实例 "
|
||||
f"(base_url={self.api_provider.base_url}, "
|
||||
f"config_hash={self._config_hash})"
|
||||
f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash})"
|
||||
)
|
||||
|
||||
|
||||
client = AsyncOpenAI(
|
||||
base_url=self.api_provider.base_url,
|
||||
api_key=self.api_provider.get_api_key(),
|
||||
max_retries=0,
|
||||
timeout=self.api_provider.timeout,
|
||||
)
|
||||
|
||||
|
||||
# 存入全局缓存
|
||||
self._global_client_cache[self._config_hash] = client
|
||||
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_cache_stats(cls) -> dict:
|
||||
"""获取全局缓存统计信息"""
|
||||
|
||||
@@ -280,7 +280,9 @@ class _PromptProcessor:
|
||||
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
|
||||
"""
|
||||
|
||||
async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
|
||||
async def prepare_prompt(
|
||||
self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str
|
||||
) -> str:
|
||||
"""
|
||||
为请求准备最终的提示词。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user