fix: update CQCode and Message classes for async initialization and processing

This commit is contained in:
KawaiiYusora
2025-03-07 01:31:03 +08:00
parent 26f99664ee
commit e0e3ee4177
5 changed files with 112 additions and 103 deletions

View File

@@ -58,6 +58,7 @@ class ChatBot:
plain_text=event.get_plaintext(), plain_text=event.get_plaintext(),
reply_message=event.reply, reply_message=event.reply,
) )
await message.initialize()
# 过滤词 # 过滤词
for word in global_config.ban_words: for word in global_config.ban_words:

View File

@@ -10,11 +10,11 @@ from nonebot.adapters.onebot.v11 import Bot
from .config import global_config from .config import global_config
import time import time
import asyncio import asyncio
from .utils_image import storage_image,storage_emoji from .utils_image import storage_image, storage_emoji
from .utils_user import get_user_nickname from .utils_user import get_user_nickname
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
#解析各种CQ码 # 解析各种CQ码
#包含CQ码类 # 包含CQ码类
import urllib3 import urllib3
from urllib3.util import create_urllib3_context from urllib3.util import create_urllib3_context
from nonebot import get_driver from nonebot import get_driver
@@ -27,6 +27,7 @@ ctx = create_urllib3_context()
ctx.load_default_certs() ctx.load_default_certs()
ctx.set_ciphers("AES128-GCM-SHA256") ctx.set_ciphers("AES128-GCM-SHA256")
class TencentSSLAdapter(requests.adapters.HTTPAdapter): class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def __init__(self, ssl_context=None, **kwargs): def __init__(self, ssl_context=None, **kwargs):
self.ssl_context = ssl_context self.ssl_context = ssl_context
@@ -37,6 +38,7 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
num_pools=connections, maxsize=maxsize, num_pools=connections, maxsize=maxsize,
block=block, ssl_context=self.ssl_context) block=block, ssl_context=self.ssl_context)
@dataclass @dataclass
class CQCode: class CQCode:
""" """
@@ -80,13 +82,13 @@ class CQCode:
else: else:
self.translated_plain_text = f"@某人" self.translated_plain_text = f"@某人"
elif self.type == 'reply': elif self.type == 'reply':
self.translated_plain_text = self.translate_reply() self.translated_plain_text = await self.translate_reply()
elif self.type == 'face': elif self.type == 'face':
face_id = self.params.get('id', '') face_id = self.params.get('id', '')
# self.translated_plain_text = f"[表情{face_id}]" # self.translated_plain_text = f"[表情{face_id}]"
self.translated_plain_text = f"[表情]" self.translated_plain_text = f"[表情]"
elif self.type == 'forward': elif self.type == 'forward':
self.translated_plain_text = self.translate_forward() self.translated_plain_text = await self.translate_forward()
else: else:
self.translated_plain_text = f"[{self.type}]" self.translated_plain_text = f"[{self.type}]"
@@ -133,7 +135,7 @@ class CQCode:
# 腾讯服务器特殊状态码处理 # 腾讯服务器特殊状态码处理
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url: if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url:
return None return None
if response.status_code != 200: if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
@@ -157,7 +159,7 @@ class CQCode:
return None return None
return None return None
async def translate_emoji(self) -> str: async def translate_emoji(self) -> str:
"""处理表情包类型的CQ码""" """处理表情包类型的CQ码"""
if 'url' not in self.params: if 'url' not in self.params:
@@ -170,11 +172,10 @@ class CQCode:
return await self.get_emoji_description(base64_str) return await self.get_emoji_description(base64_str)
else: else:
return '[表情包]' return '[表情包]'
async def translate_image(self) -> str: async def translate_image(self) -> str:
"""处理图片类型的CQ码区分普通图片和表情包""" """处理图片类型的CQ码区分普通图片和表情包"""
#没有url直接返回默认文本 # 没有url直接返回默认文本
if 'url' not in self.params: if 'url' not in self.params:
return '[图片]' return '[图片]'
base64_str = self.get_img() base64_str = self.get_img()
@@ -206,13 +207,13 @@ class CQCode:
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[图片]" return "[图片]"
def translate_forward(self) -> str: async def translate_forward(self) -> str:
"""处理转发消息""" """处理转发消息"""
try: try:
if 'content' not in self.params: if 'content' not in self.params:
return '[转发消息]' return '[转发消息]'
# 解析content内容需要先反转义 # 解析content内容需要先反转义
content = self.unescape(self.params['content']) content = self.unescape(self.params['content'])
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}") # print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
@@ -223,17 +224,17 @@ class CQCode:
except ValueError as e: except ValueError as e:
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}")
return '[转发消息]' return '[转发消息]'
# 处理每条消息 # 处理每条消息
formatted_messages = [] formatted_messages = []
for msg in messages: for msg in messages:
sender = msg.get('sender', {}) sender = msg.get('sender', {})
nickname = sender.get('card') or sender.get('nickname', '未知用户') nickname = sender.get('card') or sender.get('nickname', '未知用户')
# 获取消息内容并使用Message类处理 # 获取消息内容并使用Message类处理
raw_message = msg.get('raw_message', '') raw_message = msg.get('raw_message', '')
message_array = msg.get('message', []) message_array = msg.get('message', [])
if message_array and isinstance(message_array, list): if message_array and isinstance(message_array, list):
# 检查是否包含嵌套的转发消息 # 检查是否包含嵌套的转发消息
for message_part in message_array: for message_part in message_array:
@@ -251,6 +252,7 @@ class CQCode:
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_id=msg.get('group_id', 0)
) )
await message_obj.initialize()
content = message_obj.processed_plain_text content = message_obj.processed_plain_text
else: else:
content = '[空消息]' content = '[空消息]'
@@ -265,23 +267,24 @@ class CQCode:
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_id=msg.get('group_id', 0)
) )
await message_obj.initialize()
content = message_obj.processed_plain_text content = message_obj.processed_plain_text
else: else:
content = '[空消息]' content = '[空消息]'
formatted_msg = f"{nickname}: {content}" formatted_msg = f"{nickname}: {content}"
formatted_messages.append(formatted_msg) formatted_messages.append(formatted_msg)
# 合并所有消息 # 合并所有消息
combined_messages = '\n'.join(formatted_messages) combined_messages = '\n'.join(formatted_messages)
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]" return f"[转发消息:\n{combined_messages}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
return '[转发消息]' return '[转发消息]'
def translate_reply(self) -> str: async def translate_reply(self) -> str:
"""处理回复类型的CQ码""" """处理回复类型的CQ码"""
# 创建Message对象 # 创建Message对象
@@ -289,7 +292,7 @@ class CQCode:
if self.reply_message == None: if self.reply_message == None:
# print(f"\033[1;31m[错误]\033[0m 回复消息为空") # print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]' return '[回复某人消息]'
if self.reply_message.sender.user_id: if self.reply_message.sender.user_id:
message_obj = Message( message_obj = Message(
user_id=self.reply_message.sender.user_id, user_id=self.reply_message.sender.user_id,
@@ -297,6 +300,7 @@ class CQCode:
raw_message=str(self.reply_message.message), raw_message=str(self.reply_message.message),
group_id=self.group_id group_id=self.group_id
) )
await message_obj.initialize()
if message_obj.user_id == global_config.BOT_QQ: if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]" return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
else: else:
@@ -310,9 +314,9 @@ class CQCode:
def unescape(text: str) -> str: def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符""" """反转义CQ码中的特殊字符"""
return text.replace(',', ',') \ return text.replace(',', ',') \
.replace('[', '[') \ .replace('[', '[') \
.replace(']', ']') \ .replace(']', ']') \
.replace('&', '&') .replace('&', '&')
@staticmethod @staticmethod
def create_emoji_cq(file_path: str) -> str: def create_emoji_cq(file_path: str) -> str:
@@ -327,12 +331,13 @@ class CQCode:
abs_path = os.path.abspath(file_path) abs_path = os.path.abspath(file_path)
# 转义特殊字符 # 转义特殊字符
escaped_path = abs_path.replace('&', '&') \ escaped_path = abs_path.replace('&', '&') \
.replace('[', '[') \ .replace('[', '[') \
.replace(']', ']') \ .replace(']', ']') \
.replace(',', ',') .replace(',', ',')
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
class CQCode_tool: class CQCode_tool:
@staticmethod @staticmethod
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
@@ -354,7 +359,7 @@ class CQCode_tool:
params['text'] = cq_code.get('data', {}).get('text', '') params['text'] = cq_code.get('data', {}).get('text', '')
else: else:
params = cq_code.get('data', {}) params = cq_code.get('data', {})
instance = CQCode( instance = CQCode(
type=cq_type, type=cq_type,
params=params, params=params,
@@ -362,11 +367,11 @@ class CQCode_tool:
user_id=0, user_id=0,
reply_message=reply reply_message=reply
) )
# 进行翻译处理 # 进行翻译处理
await instance.translate() await instance.translate()
return instance return instance
@staticmethod @staticmethod
def create_reply_cq(message_id: int) -> str: def create_reply_cq(message_id: int) -> str:
""" """
@@ -377,6 +382,6 @@ class CQCode_tool:
回复CQ码字符串 回复CQ码字符串
""" """
return f"[CQ:reply,id={message_id}]" return f"[CQ:reply,id={message_id}]"
cq_code_tool = CQCode_tool() cq_code_tool = CQCode_tool()

View File

@@ -27,56 +27,58 @@ class Message:
"""消息数据类""" """消息数据类"""
message_id: int = None message_id: int = None
time: float = None time: float = None
group_id: int = None group_id: int = None
group_name: str = None # 群名称 group_name: str = None # 群名称
user_id: int = None user_id: int = None
user_nickname: str = None # 用户昵称 user_nickname: str = None # 用户昵称
user_cardname: str=None # 用户群昵称 user_cardname: str = None # 用户群昵称
raw_message: str = None # 原始消息包含未解析的cq码 raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本 plain_text: str = None # 纯文本
reply_message: Dict = None # 存储 回复的 源消息
# 延迟初始化字段
_initialized: bool = False
message_segments: List[Dict] = None # 存储解析后的消息片段 message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本 detailed_plain_text: str = None # 用于存储详细可读文本
reply_message: Dict = None # 存储 回复的 源消息 # 状态标志
is_emoji: bool = False
is_emoji: bool = False # 是否是表情包 has_emoji: bool = False
has_emoji: bool = False # 是否包含表情包 translate_cq: bool = True
translate_cq: bool = True # 是否翻译cq码 async def initialize(self):
"""显式异步初始化方法(必须调用)"""
async def __post_init__(self): if self._initialized:
if self.time is None: return
self.time = int(time.time())
# 异步获取补充信息
if not self.group_name: self.group_name = self.group_name or get_groupname(self.group_id)
self.group_name = get_groupname(self.group_id) self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
if not self.user_nickname:
self.user_nickname = get_user_nickname(self.user_id) # 消息解析
if self.raw_message:
if not self.user_cardname: self.message_segments = await self.parse_message_segments(self.raw_message)
self.user_cardname=get_user_cardname(self.user_id) self.processed_plain_text = ' '.join(
seg.translated_plain_text
if not self.processed_plain_text: for seg in self.message_segments
if self.raw_message: )
self.message_segments = await self.parse_message_segments(str(self.raw_message))
self.processed_plain_text = ' '.join( # 构建详细文本
seg.translated_plain_text
for seg in self.message_segments
)
#将详细翻译为详细可读文本
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time)) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
try: name = (
name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
except: if self.user_cardname
name = self.user_nickname or f"用户{self.user_id}" else f"{self.user_nickname or f'用户{self.user_id}'}"
content = self.processed_plain_text )
self.detailed_plain_text = f"[{time_str}] {name}: {content}\n" self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True
async def parse_message_segments(self, message: str) -> List[CQCode]: async def parse_message_segments(self, message: str) -> List[CQCode]:
""" """

View File

@@ -131,7 +131,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return '' return ''
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录
Args: Args:
@@ -173,6 +173,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
processed_plain_text=msg_data.get("processed_text", ""), processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id group_id=group_id
) )
await msg.initialize()
message_objects.append(msg) message_objects.append(msg)
except KeyError: except KeyError:
print("[WARNING] 数据库中存在无效的消息") print("[WARNING] 数据库中存在无效的消息")

View File

@@ -65,7 +65,8 @@ class LLM_request:
} }
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
logger.info(f"发送请求到URL: {api_url}{self.model_name}") logger.info(f"发送请求到URL: {api_url}")
logger.info(f"使用模型: {self.model_name}")
# 构建请求体 # 构建请求体
if image_base64: if image_base64:
@@ -81,33 +82,32 @@ class LLM_request:
headers = await self._build_headers() headers = await self._build_headers()
async with session_method as session: async with session_method as session:
response = await session.post(api_url, headers=headers, json=payload) async with session.post(api_url, headers=headers, json=payload) as response:
# 处理需要重试的状态码
if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2 ** retry)
logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试")
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64)
elif response.status in [500, 503]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else:
logger.warning(f"请求限制(429),等待{wait_time}秒后重试...")
# 处理需要重试的状态码 await asyncio.sleep(wait_time)
if response.status in policy["retry_codes"]: continue
wait_time = policy["base_wait"] * (2 ** retry) elif response.status in policy["abort_codes"]:
logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试") logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
if response.status == 413: raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64)
elif response.status in [500, 503]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else:
logger.warning(f"请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time) response.raise_for_status()
continue result = await response.json()
elif response.status in policy["abort_codes"]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
response.raise_for_status() # 使用自定义处理器或默认处理
result = await response.json() return response_handler(result) if response_handler else self._default_response_handler(result)
# 使用自定义处理器或默认处理
return response_handler(result) if response_handler else self._default_response_handler(result)
except Exception as e: except Exception as e:
if retry < policy["max_retries"] - 1: if retry < policy["max_retries"] - 1:
@@ -116,7 +116,7 @@ class LLM_request:
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
else: else:
logger.critical(f"请求失败: {str(e)}") logger.critical(f"请求失败: {str(e)}")
logger.critical(f"请求头: {self._build_headers()} 请求体: {payload}") logger.critical(f"请求头: {await self._build_headers()} 请求体: {payload}")
raise RuntimeError(f"API请求失败: {str(e)}") raise RuntimeError(f"API请求失败: {str(e)}")
logger.error("达到最大重试次数,请求仍然失败") logger.error("达到最大重试次数,请求仍然失败")