This commit is contained in:
13
src/common/tcp_connector.py
Normal file
13
src/common/tcp_connector.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import ssl
|
||||||
|
import certifi
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||||
|
connector = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tcp_connector():
|
||||||
|
global connector
|
||||||
|
if connector is None:
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
return connector
|
||||||
@@ -6,6 +6,7 @@ from typing import Tuple, Union
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
from src.common.tcp_connector import get_tcp_connector
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -94,7 +95,7 @@ class LLMRequestOff:
|
|||||||
max_retries = 3
|
max_retries = 3
|
||||||
base_wait_time = 15
|
base_wait_time = 15
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||||
for retry in range(max_retries):
|
for retry in range(max_retries):
|
||||||
try:
|
try:
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import os
|
|||||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.common.tcp_connector import get_tcp_connector
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -264,7 +265,6 @@ class LLMRequest:
|
|||||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"policy": policy,
|
"policy": policy,
|
||||||
"payload": payload,
|
"payload": payload,
|
||||||
@@ -312,7 +312,7 @@ class LLMRequest:
|
|||||||
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||||||
if request_content["stream_mode"]:
|
if request_content["stream_mode"]:
|
||||||
headers["Accept"] = "text/event-stream"
|
headers["Accept"] = "text/event-stream"
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
request_content["api_url"], headers=headers, json=request_content["payload"]
|
request_content["api_url"], headers=headers, json=request_content["payload"]
|
||||||
) as response:
|
) as response:
|
||||||
|
|||||||
Reference in New Issue
Block a user