fix: update CQCode and Message classes for async initialization and processing
This commit is contained in:
@@ -58,6 +58,7 @@ class ChatBot:
|
||||
plain_text=event.get_plaintext(),
|
||||
reply_message=event.reply,
|
||||
)
|
||||
await message.initialize()
|
||||
|
||||
# 过滤词
|
||||
for word in global_config.ban_words:
|
||||
|
||||
@@ -10,11 +10,11 @@ from nonebot.adapters.onebot.v11 import Bot
|
||||
from .config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
from .utils_image import storage_image,storage_emoji
|
||||
from .utils_image import storage_image, storage_emoji
|
||||
from .utils_user import get_user_nickname
|
||||
from ..models.utils_model import LLM_request
|
||||
#解析各种CQ码
|
||||
#包含CQ码类
|
||||
# 解析各种CQ码
|
||||
# 包含CQ码类
|
||||
import urllib3
|
||||
from urllib3.util import create_urllib3_context
|
||||
from nonebot import get_driver
|
||||
@@ -27,6 +27,7 @@ ctx = create_urllib3_context()
|
||||
ctx.load_default_certs()
|
||||
ctx.set_ciphers("AES128-GCM-SHA256")
|
||||
|
||||
|
||||
class TencentSSLAdapter(requests.adapters.HTTPAdapter):
|
||||
def __init__(self, ssl_context=None, **kwargs):
|
||||
self.ssl_context = ssl_context
|
||||
@@ -37,6 +38,7 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
|
||||
num_pools=connections, maxsize=maxsize,
|
||||
block=block, ssl_context=self.ssl_context)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CQCode:
|
||||
"""
|
||||
@@ -80,13 +82,13 @@ class CQCode:
|
||||
else:
|
||||
self.translated_plain_text = f"@某人"
|
||||
elif self.type == 'reply':
|
||||
self.translated_plain_text = self.translate_reply()
|
||||
self.translated_plain_text = await self.translate_reply()
|
||||
elif self.type == 'face':
|
||||
face_id = self.params.get('id', '')
|
||||
# self.translated_plain_text = f"[表情{face_id}]"
|
||||
self.translated_plain_text = f"[表情]"
|
||||
elif self.type == 'forward':
|
||||
self.translated_plain_text = self.translate_forward()
|
||||
self.translated_plain_text = await self.translate_forward()
|
||||
else:
|
||||
self.translated_plain_text = f"[{self.type}]"
|
||||
|
||||
@@ -171,10 +173,9 @@ class CQCode:
|
||||
else:
|
||||
return '[表情包]'
|
||||
|
||||
|
||||
async def translate_image(self) -> str:
|
||||
"""处理图片类型的CQ码,区分普通图片和表情包"""
|
||||
#没有url,直接返回默认文本
|
||||
# 没有url,直接返回默认文本
|
||||
if 'url' not in self.params:
|
||||
return '[图片]'
|
||||
base64_str = self.get_img()
|
||||
@@ -207,7 +208,7 @@ class CQCode:
|
||||
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
||||
return "[图片]"
|
||||
|
||||
def translate_forward(self) -> str:
|
||||
async def translate_forward(self) -> str:
|
||||
"""处理转发消息"""
|
||||
try:
|
||||
if 'content' not in self.params:
|
||||
@@ -251,6 +252,7 @@ class CQCode:
|
||||
plain_text=raw_message,
|
||||
group_id=msg.get('group_id', 0)
|
||||
)
|
||||
await message_obj.initialize()
|
||||
content = message_obj.processed_plain_text
|
||||
else:
|
||||
content = '[空消息]'
|
||||
@@ -265,6 +267,7 @@ class CQCode:
|
||||
plain_text=raw_message,
|
||||
group_id=msg.get('group_id', 0)
|
||||
)
|
||||
await message_obj.initialize()
|
||||
content = message_obj.processed_plain_text
|
||||
else:
|
||||
content = '[空消息]'
|
||||
@@ -281,7 +284,7 @@ class CQCode:
|
||||
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
|
||||
return '[转发消息]'
|
||||
|
||||
def translate_reply(self) -> str:
|
||||
async def translate_reply(self) -> str:
|
||||
"""处理回复类型的CQ码"""
|
||||
|
||||
# 创建Message对象
|
||||
@@ -297,6 +300,7 @@ class CQCode:
|
||||
raw_message=str(self.reply_message.message),
|
||||
group_id=self.group_id
|
||||
)
|
||||
await message_obj.initialize()
|
||||
if message_obj.user_id == global_config.BOT_QQ:
|
||||
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
|
||||
else:
|
||||
@@ -333,6 +337,7 @@ class CQCode:
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
||||
|
||||
|
||||
class CQCode_tool:
|
||||
@staticmethod
|
||||
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
|
||||
|
||||
@@ -33,50 +33,52 @@ class Message:
|
||||
|
||||
user_id: int = None
|
||||
user_nickname: str = None # 用户昵称
|
||||
user_cardname: str=None # 用户群昵称
|
||||
user_cardname: str = None # 用户群昵称
|
||||
|
||||
raw_message: str = None # 原始消息,包含未解析的cq码
|
||||
plain_text: str = None # 纯文本
|
||||
|
||||
reply_message: Dict = None # 存储 回复的 源消息
|
||||
|
||||
# 延迟初始化字段
|
||||
_initialized: bool = False
|
||||
message_segments: List[Dict] = None # 存储解析后的消息片段
|
||||
processed_plain_text: str = None # 用于存储处理后的plain_text
|
||||
detailed_plain_text: str = None # 用于存储详细可读文本
|
||||
|
||||
reply_message: Dict = None # 存储 回复的 源消息
|
||||
# 状态标志
|
||||
is_emoji: bool = False
|
||||
has_emoji: bool = False
|
||||
translate_cq: bool = True
|
||||
|
||||
is_emoji: bool = False # 是否是表情包
|
||||
has_emoji: bool = False # 是否包含表情包
|
||||
async def initialize(self):
|
||||
"""显式异步初始化方法(必须调用)"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
translate_cq: bool = True # 是否翻译cq码
|
||||
# 异步获取补充信息
|
||||
self.group_name = self.group_name or get_groupname(self.group_id)
|
||||
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
|
||||
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
|
||||
|
||||
async def __post_init__(self):
|
||||
if self.time is None:
|
||||
self.time = int(time.time())
|
||||
|
||||
if not self.group_name:
|
||||
self.group_name = get_groupname(self.group_id)
|
||||
|
||||
if not self.user_nickname:
|
||||
self.user_nickname = get_user_nickname(self.user_id)
|
||||
|
||||
if not self.user_cardname:
|
||||
self.user_cardname=get_user_cardname(self.user_id)
|
||||
|
||||
if not self.processed_plain_text:
|
||||
# 消息解析
|
||||
if self.raw_message:
|
||||
self.message_segments = await self.parse_message_segments(str(self.raw_message))
|
||||
self.message_segments = await self.parse_message_segments(self.raw_message)
|
||||
self.processed_plain_text = ' '.join(
|
||||
seg.translated_plain_text
|
||||
for seg in self.message_segments
|
||||
)
|
||||
#将详细翻译为详细可读文本
|
||||
|
||||
# 构建详细文本
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
|
||||
try:
|
||||
name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
|
||||
except:
|
||||
name = self.user_nickname or f"用户{self.user_id}"
|
||||
content = self.processed_plain_text
|
||||
self.detailed_plain_text = f"[{time_str}] {name}: {content}\n"
|
||||
name = (
|
||||
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
|
||||
if self.user_cardname
|
||||
else f"{self.user_nickname or f'用户{self.user_id}'}"
|
||||
)
|
||||
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def parse_message_segments(self, message: str) -> List[CQCode]:
|
||||
"""
|
||||
|
||||
@@ -131,7 +131,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
return ''
|
||||
|
||||
|
||||
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
"""从数据库获取群组最近的消息记录
|
||||
|
||||
Args:
|
||||
@@ -173,6 +173,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
processed_plain_text=msg_data.get("processed_text", ""),
|
||||
group_id=group_id
|
||||
)
|
||||
await msg.initialize()
|
||||
message_objects.append(msg)
|
||||
except KeyError:
|
||||
print("[WARNING] 数据库中存在无效的消息")
|
||||
|
||||
@@ -65,7 +65,8 @@ class LLM_request:
|
||||
}
|
||||
|
||||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||
logger.info(f"发送请求到URL: {api_url}{self.model_name}")
|
||||
logger.info(f"发送请求到URL: {api_url}")
|
||||
logger.info(f"使用模型: {self.model_name}")
|
||||
|
||||
# 构建请求体
|
||||
if image_base64:
|
||||
@@ -81,8 +82,7 @@ class LLM_request:
|
||||
headers = await self._build_headers()
|
||||
|
||||
async with session_method as session:
|
||||
response = await session.post(api_url, headers=headers, json=payload)
|
||||
|
||||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||||
# 处理需要重试的状态码
|
||||
if response.status in policy["retry_codes"]:
|
||||
wait_time = policy["base_wait"] * (2 ** retry)
|
||||
@@ -116,7 +116,7 @@ class LLM_request:
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.critical(f"请求失败: {str(e)}")
|
||||
logger.critical(f"请求头: {self._build_headers()} 请求体: {payload}")
|
||||
logger.critical(f"请求头: {await self._build_headers()} 请求体: {payload}")
|
||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
|
||||
Reference in New Issue
Block a user