fix: 增大了默认的maxtoken防止溢出,messagecq改异步get_image防止阻塞

This commit is contained in:
tcmofashi
2025-03-14 15:38:33 +08:00
parent e2c5d42634
commit d3fe02e467
7 changed files with 207 additions and 286 deletions

View File

@@ -1,48 +1,28 @@
import base64
import html
import time
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import ssl
import os
import requests
# 解析各种CQ码
# 包含CQ码类
import urllib3
import aiohttp
from loguru import logger
from nonebot import get_driver
from urllib3.util import create_urllib3_context
from ..models.utils_model import LLM_request
from .config import global_config
from .mapper import emojimapper
from .message_base import Seg
from .utils_user import get_user_nickname,get_groupname
from .utils_user import get_user_nickname, get_groupname
from .message_base import GroupInfo, UserInfo
driver = get_driver()
config = driver.config
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
ctx = create_urllib3_context()
ctx.load_default_certs()
ctx.set_ciphers("AES128-GCM-SHA256")
class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def __init__(self, ssl_context=None, **kwargs):
self.ssl_context = ssl_context
super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections,
maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
# 创建SSL上下文
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("AES128-GCM-SHA256")
@dataclass
@@ -70,14 +50,12 @@ class CQCode:
"""初始化LLM实例"""
pass
def translate(self):
async def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == "text":
self.translated_segments = Seg(
type="text", data=self.params.get("text", "")
)
self.translated_segments = Seg(type="text", data=self.params.get("text", ""))
elif self.type == "image":
base64_data = self.translate_image()
base64_data = await self.translate_image()
if base64_data:
if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data)
@@ -88,24 +66,20 @@ class CQCode:
elif self.type == "at":
if self.params.get("qq") == "all":
self.translated_segments = Seg(type="text", data="@[全体成员]")
else:
else:
user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg(
type="text", data=f"[@{user_nickname or '某人'}]"
)
self.translated_segments = Seg(type="text", data=f"[@{user_nickname or '某人'}]")
elif self.type == "reply":
reply_segments = self.translate_reply()
reply_segments = await self.translate_reply()
if reply_segments:
self.translated_segments = Seg(type="seglist", data=reply_segments)
else:
self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face":
face_id = self.params.get("id", "")
self.translated_segments = Seg(
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
self.translated_segments = Seg(type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]")
elif self.type == "forward":
forward_segments = self.translate_forward()
forward_segments = await self.translate_forward()
if forward_segments:
self.translated_segments = Seg(type="seglist", data=forward_segments)
else:
@@ -113,18 +87,8 @@ class CQCode:
else:
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
def get_img(self):
"""
headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8',
'Accept-Encoding': 'gzip, deflate, br',
'Connection': 'keep-alive',
'Cache-Control': 'no-cache',
'Pragma': 'no-cache'
}
"""
# 腾讯专用请求头配置
async def get_img(self) -> Optional[str]:
"""异步获取图片并转换为base64"""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
"Accept": "text/html, application/xhtml xml, */*",
@@ -133,61 +97,63 @@ class CQCode:
"Content-Type": "application/x-www-form-urlencoded",
"Cache-Control": "no-cache",
}
url = html.unescape(self.params["url"])
if not url.startswith(("http://", "https://")):
return None
# 创建专用会话
session = requests.session()
session.adapters.pop("https://", None)
session.mount("https://", TencentSSLAdapter(ctx))
max_retries = 3
for retry in range(max_retries):
try:
response = session.get(
url,
headers=headers,
timeout=15,
allow_redirects=True,
stream=True, # 流式传输避免大内存问题
)
logger.debug(f"获取图片中: {url}")
# 设置SSL上下文和创建连接器
conn = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=conn) as session:
async with session.get(
url,
headers=headers,
timeout=aiohttp.ClientTimeout(total=15),
allow_redirects=True,
) as response:
# 腾讯服务器特殊状态码处理
if response.status == 400 and "multimedia.nt.qq.com.cn" in url:
return None
# 腾讯服务器特殊状态码处理
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
return None
if response.status != 200:
raise aiohttp.ClientError(f"HTTP {response.status}")
if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
# 验证内容类型
content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}")
# 验证内容类型
content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}")
# 读取响应内容
content = await response.read()
logger.debug(f"获取图片成功: {url}")
# 转换为Base64
image_base64 = base64.b64encode(response.content).decode("utf-8")
self.image_base64 = image_base64
return image_base64
# 转换为Base64
image_base64 = base64.b64encode(content).decode("utf-8")
self.image_base64 = image_base64
return image_base64
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
except (aiohttp.ClientError, ValueError) as e:
if retry == max_retries - 1:
logger.error(f"最终请求失败: {str(e)}")
time.sleep(1.5**retry) # 指数退避
await asyncio.sleep(1.5**retry) # 指数退避
except Exception:
logger.exception("[未知错误]")
except Exception as e:
logger.exception(f"获取图片时发生未知错误: {str(e)}")
return None
return None
def translate_image(self) -> Optional[str]:
async def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串"""
if "url" not in self.params:
return None
return self.get_img()
return await self.get_img()
def translate_forward(self) -> Optional[List[Seg]]:
async def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表"""
try:
if "content" not in self.params:
@@ -217,15 +183,16 @@ class CQCode:
else:
if raw_message:
from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq',
user_info = UserInfo(
platform="qq",
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_info = GroupInfo(
platform="qq",
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
group_name=get_groupname(msg.get("group_id", 0)),
)
message_obj = MessageRecvCQ(
@@ -235,24 +202,23 @@ class CQCode:
plain_text=raw_message,
group_info=group_info,
)
content_seg = Seg(
type="seglist", data=[message_obj.message_segment]
)
await message_obj.initialize()
content_seg = Seg(type="seglist", data=[message_obj.message_segment])
else:
content_seg = Seg(type="text", data="[空消息]")
else:
if raw_message:
from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq',
user_info = UserInfo(
platform="qq",
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_info = GroupInfo(
platform="qq",
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
group_name=get_groupname(msg.get("group_id", 0)),
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
@@ -261,9 +227,8 @@ class CQCode:
plain_text=raw_message,
group_info=group_info,
)
content_seg = Seg(
type="seglist", data=[message_obj.message_segment]
)
await message_obj.initialize()
content_seg = Seg(type="seglist", data=[message_obj.message_segment])
else:
content_seg = Seg(type="text", data="[空消息]")
@@ -277,7 +242,7 @@ class CQCode:
logger.error(f"处理转发消息失败: {str(e)}")
return None
def translate_reply(self) -> Optional[List[Seg]]:
async def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
@@ -285,22 +250,19 @@ class CQCode:
return None
if self.reply_message.sender.user_id:
message_obj = MessageRecvCQ(
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname),
user_info=UserInfo(
user_id=self.reply_message.sender.user_id, user_nickname=self.reply_message.sender.nickname
),
message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message),
group_info=GroupInfo(group_id=self.reply_message.group_id),
)
await message_obj.initialize()
segments = []
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append(
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
segments.append(Seg(type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
else:
segments.append(
Seg(
@@ -318,16 +280,12 @@ class CQCode:
@staticmethod
def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符"""
return (
text.replace(",", ",")
.replace("[", "[")
.replace("]", "]")
.replace("&", "&")
)
return text.replace(",", ",").replace("[", "[").replace("]", "]").replace("&", "&")
class CQCode_tool:
@staticmethod
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
def cq_from_dict_to_class(cq_code: Dict, msg, reply: Optional[Dict] = None) -> CQCode:
"""
将CQ码字典转换为CQCode对象
@@ -353,11 +311,9 @@ class CQCode_tool:
params=params,
group_info=msg.message_info.group_info,
user_info=msg.message_info.user_info,
reply_message=reply
reply_message=reply,
)
# 进行翻译处理
instance.translate()
return instance
@staticmethod
@@ -383,12 +339,7 @@ class CQCode_tool:
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = (
abs_path.replace("&", "&")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
)
escaped_path = abs_path.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@@ -403,14 +354,11 @@ class CQCode_tool:
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@staticmethod
def create_image_cq_base64(base64_data: str) -> str:
"""
@@ -422,10 +370,7 @@ class CQCode_tool:
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"