diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 89c15b388..3a2d43f0e 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -58,6 +58,7 @@ class ChatBot: plain_text=event.get_plaintext(), reply_message=event.reply, ) + await message.initialize() # 过滤词 for word in global_config.ban_words: diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 43d9d0862..0077f7295 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -10,11 +10,11 @@ from nonebot.adapters.onebot.v11 import Bot from .config import global_config import time import asyncio -from .utils_image import storage_image,storage_emoji +from .utils_image import storage_image, storage_emoji from .utils_user import get_user_nickname from ..models.utils_model import LLM_request -#解析各种CQ码 -#包含CQ码类 +# 解析各种CQ码 +# 包含CQ码类 import urllib3 from urllib3.util import create_urllib3_context from nonebot import get_driver @@ -27,6 +27,7 @@ ctx = create_urllib3_context() ctx.load_default_certs() ctx.set_ciphers("AES128-GCM-SHA256") + class TencentSSLAdapter(requests.adapters.HTTPAdapter): def __init__(self, ssl_context=None, **kwargs): self.ssl_context = ssl_context @@ -37,6 +38,7 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter): num_pools=connections, maxsize=maxsize, block=block, ssl_context=self.ssl_context) + @dataclass class CQCode: """ @@ -80,13 +82,13 @@ class CQCode: else: self.translated_plain_text = f"@某人" elif self.type == 'reply': - self.translated_plain_text = self.translate_reply() + self.translated_plain_text = await self.translate_reply() elif self.type == 'face': face_id = self.params.get('id', '') # self.translated_plain_text = f"[表情{face_id}]" self.translated_plain_text = f"[表情]" elif self.type == 'forward': - self.translated_plain_text = self.translate_forward() + self.translated_plain_text = await self.translate_forward() else: self.translated_plain_text = f"[{self.type}]" @@ -133,7 +135,7 @@ class CQCode: # 腾讯服务器特殊状态码处理 if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url: return None - + if response.status_code != 200: raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") @@ -157,7 +159,7 @@ class CQCode: return None return None - + async def translate_emoji(self) -> str: """处理表情包类型的CQ码""" if 'url' not in self.params: @@ -170,11 +172,10 @@ class CQCode: return await self.get_emoji_description(base64_str) else: return '[表情包]' - - + async def translate_image(self) -> str: """处理图片类型的CQ码,区分普通图片和表情包""" - #没有url,直接返回默认文本 + # 没有url,直接返回默认文本 if 'url' not in self.params: return '[图片]' base64_str = self.get_img() @@ -206,13 +207,13 @@ class CQCode: except Exception as e: print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") return "[图片]" - - def translate_forward(self) -> str: + + async def translate_forward(self) -> str: """处理转发消息""" try: if 'content' not in self.params: return '[转发消息]' - + # 解析content内容(需要先反转义) content = self.unescape(self.params['content']) # print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}") @@ -223,17 +224,17 @@ class CQCode: except ValueError as e: print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") return '[转发消息]' - + # 处理每条消息 formatted_messages = [] for msg in messages: sender = msg.get('sender', {}) nickname = sender.get('card') or sender.get('nickname', '未知用户') - + # 获取消息内容并使用Message类处理 raw_message = msg.get('raw_message', '') message_array = msg.get('message', []) - + if message_array and isinstance(message_array, list): # 检查是否包含嵌套的转发消息 for message_part in message_array: @@ -251,6 +252,7 @@ class CQCode: plain_text=raw_message, group_id=msg.get('group_id', 0) ) + await message_obj.initialize() content = message_obj.processed_plain_text else: content = '[空消息]' @@ -265,23 +267,24 @@ class CQCode: plain_text=raw_message, group_id=msg.get('group_id', 0) ) + await message_obj.initialize() content = message_obj.processed_plain_text else: content = '[空消息]' - + formatted_msg = f"{nickname}: {content}" formatted_messages.append(formatted_msg) - + # 合并所有消息 combined_messages = '\n'.join(formatted_messages) print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") return f"[转发消息:\n{combined_messages}]" - + except Exception as e: print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") return '[转发消息]' - def translate_reply(self) -> str: + async def translate_reply(self) -> str: """处理回复类型的CQ码""" # 创建Message对象 @@ -289,7 +292,7 @@ class CQCode: if self.reply_message == None: # print(f"\033[1;31m[错误]\033[0m 回复消息为空") return '[回复某人消息]' - + if self.reply_message.sender.user_id: message_obj = Message( user_id=self.reply_message.sender.user_id, @@ -297,6 +300,7 @@ class CQCode: raw_message=str(self.reply_message.message), group_id=self.group_id ) + await message_obj.initialize() if message_obj.user_id == global_config.BOT_QQ: return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]" else: @@ -310,9 +314,9 @@ class CQCode: def unescape(text: str) -> str: """反转义CQ码中的特殊字符""" return text.replace(',', ',') \ - .replace('[', '[') \ - .replace(']', ']') \ - .replace('&', '&') + .replace('[', '[') \ + .replace(']', ']') \ + .replace('&', '&') @staticmethod def create_emoji_cq(file_path: str) -> str: @@ -327,12 +331,13 @@ class CQCode: abs_path = os.path.abspath(file_path) # 转义特殊字符 escaped_path = abs_path.replace('&', '&') \ - .replace('[', '[') \ - .replace(']', ']') \ - .replace(',', ',') + .replace('[', '[') \ + .replace(']', ']') \ + .replace(',', ',') # 生成CQ码,设置sub_type=1表示这是表情包 return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" - + + class CQCode_tool: @staticmethod async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: @@ -354,7 +359,7 @@ class CQCode_tool: params['text'] = cq_code.get('data', {}).get('text', '') else: params = cq_code.get('data', {}) - + instance = CQCode( type=cq_type, params=params, @@ -362,11 +367,11 @@ class CQCode_tool: user_id=0, reply_message=reply ) - + # 进行翻译处理 await instance.translate() return instance - + @staticmethod def create_reply_cq(message_id: int) -> str: """ @@ -377,6 +382,6 @@ class CQCode_tool: 回复CQ码字符串 """ return f"[CQ:reply,id={message_id}]" - - + + cq_code_tool = CQCode_tool() diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 02f56b975..e1d36568c 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -27,56 +27,58 @@ class Message: """消息数据类""" message_id: int = None time: float = None - + group_id: int = None - group_name: str = None # 群名称 - + group_name: str = None # 群名称 + user_id: int = None user_nickname: str = None # 用户昵称 - user_cardname: str=None # 用户群昵称 - - raw_message: str = None # 原始消息,包含未解析的cq码 - plain_text: str = None # 纯文本 - + user_cardname: str = None # 用户群昵称 + + raw_message: str = None # 原始消息,包含未解析的cq码 + plain_text: str = None # 纯文本 + + reply_message: Dict = None # 存储 回复的 源消息 + + # 延迟初始化字段 + _initialized: bool = False message_segments: List[Dict] = None # 存储解析后的消息片段 processed_plain_text: str = None # 用于存储处理后的plain_text detailed_plain_text: str = None # 用于存储详细可读文本 - - reply_message: Dict = None # 存储 回复的 源消息 - - is_emoji: bool = False # 是否是表情包 - has_emoji: bool = False # 是否包含表情包 - - translate_cq: bool = True # 是否翻译cq码 - - async def __post_init__(self): - if self.time is None: - self.time = int(time.time()) - - if not self.group_name: - self.group_name = get_groupname(self.group_id) - - if not self.user_nickname: - self.user_nickname = get_user_nickname(self.user_id) - - if not self.user_cardname: - self.user_cardname=get_user_cardname(self.user_id) - - if not self.processed_plain_text: - if self.raw_message: - self.message_segments = await self.parse_message_segments(str(self.raw_message)) - self.processed_plain_text = ' '.join( - seg.translated_plain_text - for seg in self.message_segments - ) - #将详细翻译为详细可读文本 + + # 状态标志 + is_emoji: bool = False + has_emoji: bool = False + translate_cq: bool = True + + async def initialize(self): + """显式异步初始化方法(必须调用)""" + if self._initialized: + return + + # 异步获取补充信息 + self.group_name = self.group_name or get_groupname(self.group_id) + self.user_nickname = self.user_nickname or get_user_nickname(self.user_id) + self.user_cardname = self.user_cardname or get_user_cardname(self.user_id) + + # 消息解析 + if self.raw_message: + self.message_segments = await self.parse_message_segments(self.raw_message) + self.processed_plain_text = ' '.join( + seg.translated_plain_text + for seg in self.message_segments + ) + + # 构建详细文本 time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time)) - try: - name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" - except: - name = self.user_nickname or f"用户{self.user_id}" - content = self.processed_plain_text - self.detailed_plain_text = f"[{time_str}] {name}: {content}\n" + name = ( + f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" + if self.user_cardname + else f"{self.user_nickname or f'用户{self.user_id}'}" + ) + self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n" + + self._initialized = True async def parse_message_segments(self, message: str) -> List[CQCode]: """ diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 38aeefd21..42c91b93c 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -131,7 +131,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str): return '' -def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: +async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: """从数据库获取群组最近的消息记录 Args: @@ -173,6 +173,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: processed_plain_text=msg_data.get("processed_text", ""), group_id=group_id ) + await msg.initialize() message_objects.append(msg) except KeyError: print("[WARNING] 数据库中存在无效的消息") diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 3e4d7f1a2..8addf6a46 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -65,7 +65,8 @@ class LLM_request: } api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" - logger.info(f"发送请求到URL: {api_url}{self.model_name}") + logger.info(f"发送请求到URL: {api_url}") + logger.info(f"使用模型: {self.model_name}") # 构建请求体 if image_base64: @@ -81,33 +82,32 @@ class LLM_request: headers = await self._build_headers() async with session_method as session: - response = await session.post(api_url, headers=headers, json=payload) + async with session.post(api_url, headers=headers, json=payload) as response: + # 处理需要重试的状态码 + if response.status in policy["retry_codes"]: + wait_time = policy["base_wait"] * (2 ** retry) + logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试") + if response.status == 413: + logger.warning("请求体过大,尝试压缩...") + image_base64 = compress_base64_image_by_scale(image_base64) + payload = await self._build_payload(prompt, image_base64) + elif response.status in [500, 503]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError("服务器负载过高,模型恢复失败QAQ") + else: + logger.warning(f"请求限制(429),等待{wait_time}秒后重试...") - # 处理需要重试的状态码 - if response.status in policy["retry_codes"]: - wait_time = policy["base_wait"] * (2 ** retry) - logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试") - if response.status == 413: - logger.warning("请求体过大,尝试压缩...") - image_base64 = compress_base64_image_by_scale(image_base64) - payload = await self._build_payload(prompt, image_base64) - elif response.status in [500, 503]: - logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") - raise RuntimeError("服务器负载过高,模型恢复失败QAQ") - else: - logger.warning(f"请求限制(429),等待{wait_time}秒后重试...") + await asyncio.sleep(wait_time) + continue + elif response.status in policy["abort_codes"]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") - await asyncio.sleep(wait_time) - continue - elif response.status in policy["abort_codes"]: - logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") - raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") + response.raise_for_status() + result = await response.json() - response.raise_for_status() - result = await response.json() - - # 使用自定义处理器或默认处理 - return response_handler(result) if response_handler else self._default_response_handler(result) + # 使用自定义处理器或默认处理 + return response_handler(result) if response_handler else self._default_response_handler(result) except Exception as e: if retry < policy["max_retries"] - 1: @@ -116,7 +116,7 @@ class LLM_request: await asyncio.sleep(wait_time) else: logger.critical(f"请求失败: {str(e)}") - logger.critical(f"请求头: {self._build_headers()} 请求体: {payload}") + logger.critical(f"请求头: {await self._build_headers()} 请求体: {payload}") raise RuntimeError(f"API请求失败: {str(e)}") logger.error("达到最大重试次数,请求仍然失败")