From 3288051b42eb37e9e425dd486dbf1d2dadb28611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=A5=E6=B2=B3=E6=99=B4?= Date: Thu, 5 Jun 2025 15:52:28 +0900 Subject: [PATCH] temp fix https://github.com/crate/crate-python/issues/708 --- src/common/tcp_connector.py | 13 +++++++++++++ src/individuality/not_using/offline_llm.py | 3 ++- src/llm_models/utils_model.py | 18 +++++++++--------- 3 files changed, 24 insertions(+), 10 deletions(-) create mode 100644 src/common/tcp_connector.py diff --git a/src/common/tcp_connector.py b/src/common/tcp_connector.py new file mode 100644 index 000000000..0eba4997e --- /dev/null +++ b/src/common/tcp_connector.py @@ -0,0 +1,13 @@ +import ssl +import certifi +import aiohttp + +ssl_context = ssl.create_default_context(cafile=certifi.where()) +connector = None + + +async def get_tcp_connector(): + global connector + if connector is None: + connector = aiohttp.TCPConnector(ssl=ssl_context) + return connector diff --git a/src/individuality/not_using/offline_llm.py b/src/individuality/not_using/offline_llm.py index cc9560011..40ec0889d 100644 --- a/src/individuality/not_using/offline_llm.py +++ b/src/individuality/not_using/offline_llm.py @@ -6,6 +6,7 @@ from typing import Tuple, Union import aiohttp import requests from src.common.logger import get_module_logger +from src.common.tcp_connector import get_tcp_connector from rich.traceback import install install(extra_lines=3) @@ -94,7 +95,7 @@ class LLMRequestOff: max_retries = 3 base_wait_time = 15 - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: for retry in range(max_retries): try: async with session.post(api_url, headers=headers, json=data) as response: diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 24cc9731a..4022f9367 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -13,6 +13,7 @@ import os from src.common.database.database import db # 确保 db 被导入用于 create_tables from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from src.config.config import global_config +from src.common.tcp_connector import get_tcp_connector from rich.traceback import install install(extra_lines=3) @@ -244,7 +245,7 @@ class LLMRequest: if stream_mode: payload["stream"] = stream_mode - + if self.temp != 0.7: payload["temperature"] = self.temp @@ -257,13 +258,12 @@ class LLMRequest: if self.max_tokens: payload["max_tokens"] = self.max_tokens - + # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length + # payload["max_tokens"] = global_config.model.model_max_output_length # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - payload["max_completion_tokens"] = payload.pop("max_tokens") - + payload["max_completion_tokens"] = payload.pop("max_tokens") return { "policy": policy, @@ -312,7 +312,7 @@ class LLMRequest: # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 if request_content["stream_mode"]: headers["Accept"] = "text/event-stream" - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: async with session.post( request_content["api_url"], headers=headers, json=request_content["payload"] ) as response: @@ -653,7 +653,7 @@ class LLMRequest: ] else: messages = [{"role": "user", "content": prompt}] - + payload = { "model": self.model_name, "messages": messages, @@ -673,9 +673,9 @@ class LLMRequest: if self.max_tokens: payload["max_tokens"] = self.max_tokens - + # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length + # payload["max_tokens"] = global_config.model.model_max_output_length # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: payload["max_completion_tokens"] = payload.pop("max_tokens")