This commit is contained in:
sunbiz1024
2025-10-06 09:38:16 +08:00
parent 28afc09d31
commit 8f4f7d19af
66 changed files with 487 additions and 497 deletions

View File

@@ -189,11 +189,11 @@ class ClientRegistry:
bool: 事件循环是否变化
"""
current_loop_id = self._get_current_loop_id()
# 如果没有缓存的循环ID说明是首次创建
if provider_name not in self._event_loop_cache:
return False
# 比较当前循环ID与缓存的循环ID
cached_loop_id = self._event_loop_cache[provider_name]
return current_loop_id != cached_loop_id
@@ -208,7 +208,7 @@ class ClientRegistry:
BaseClient: 注册的API客户端实例
"""
provider_name = api_provider.name
# 如果强制创建新实例,直接创建不使用缓存
if force_new:
if client_class := self.client_registry.get(api_provider.client_type):
@@ -224,7 +224,7 @@ class ClientRegistry:
# 事件循环已变化,需要重新创建实例
logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例")
self._loop_change_count += 1
# 移除旧实例
if provider_name in self.client_instance_cache:
del self.client_instance_cache[provider_name]
@@ -237,7 +237,7 @@ class ClientRegistry:
self._event_loop_cache[provider_name] = self._get_current_loop_id()
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[provider_name]
def get_cache_stats(self) -> dict:

View File

@@ -50,14 +50,16 @@ def _convert_messages_to_mcp(messages: list[Message]) -> list[dict[str, Any]]:
for item in message.content:
if isinstance(item, tuple):
# 图片内容
content_parts.append({
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{item[0].lower()}",
"data": item[1],
},
})
content_parts.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{item[0].lower()}",
"data": item[1],
},
}
)
elif isinstance(item, str):
# 文本内容
content_parts.append({"type": "text", "text": item})
@@ -138,9 +140,7 @@ async def _parse_sse_stream(
async with session.post(url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
raise RespNotOkException(
response.status, f"MCP SSE请求失败: {error_text}"
)
raise RespNotOkException(response.status, f"MCP SSE请求失败: {error_text}")
# 解析SSE流
async for line in response.content:
@@ -258,10 +258,7 @@ async def _parse_sse_stream(
response.reasoning_content = reasoning_buffer.getvalue()
if tool_calls_buffer:
response.tool_calls = [
ToolCall(call_id, func_name, args)
for call_id, func_name, args in tool_calls_buffer
]
response.tool_calls = [ToolCall(call_id, func_name, args) for call_id, func_name, args in tool_calls_buffer]
# 关闭缓冲区
content_buffer.close()
@@ -351,9 +348,7 @@ class MCPSSEClient(BaseClient):
url = f"{self.api_provider.base_url}/v1/messages"
try:
response, usage_record = await _parse_sse_stream(
session, url, payload, headers, interrupt_flag
)
response, usage_record = await _parse_sse_stream(session, url, payload, headers, interrupt_flag)
except Exception as e:
logger.error(f"MCP SSE请求失败: {e}")
raise

View File

@@ -378,7 +378,7 @@ class OpenaiClient(BaseClient):
# 类级别的全局缓存:所有 OpenaiClient 实例共享
_global_client_cache: dict[int, AsyncOpenAI] = {}
"""全局 AsyncOpenAI 客户端缓存config_hash -> AsyncOpenAI 实例"""
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self._config_hash = self._calculate_config_hash()
@@ -396,33 +396,31 @@ class OpenaiClient(BaseClient):
def _create_client(self) -> AsyncOpenAI:
"""
获取或创建 OpenAI 客户端实例(全局缓存)
多个 OpenaiClient 实例如果配置相同base_url + api_key + timeout
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
"""
# 检查全局缓存
if self._config_hash in self._global_client_cache:
return self._global_client_cache[self._config_hash]
# 创建新的 AsyncOpenAI 实例
logger.debug(
f"创建新的 AsyncOpenAI 客户端实例 "
f"(base_url={self.api_provider.base_url}, "
f"config_hash={self._config_hash})"
f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash})"
)
client = AsyncOpenAI(
base_url=self.api_provider.base_url,
api_key=self.api_provider.get_api_key(),
max_retries=0,
timeout=self.api_provider.timeout,
)
# 存入全局缓存
self._global_client_cache[self._config_hash] = client
return client
@classmethod
def get_cache_stats(cls) -> dict:
"""获取全局缓存统计信息"""

View File

@@ -280,7 +280,9 @@ class _PromptProcessor:
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
"""
async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
async def prepare_prompt(
self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str
) -> str:
"""
为请求准备最终的提示词。