fix: update CQCode and Message classes for async initialization and processing

This commit is contained in:
KawaiiYusora
2025-03-07 01:31:03 +08:00
parent 26f99664ee
commit e0e3ee4177
5 changed files with 112 additions and 103 deletions

View File

@@ -58,6 +58,7 @@ class ChatBot:
plain_text=event.get_plaintext(),
reply_message=event.reply,
)
await message.initialize()
# 过滤词
for word in global_config.ban_words:

View File

@@ -27,6 +27,7 @@ ctx = create_urllib3_context()
ctx.load_default_certs()
ctx.set_ciphers("AES128-GCM-SHA256")
class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def __init__(self, ssl_context=None, **kwargs):
self.ssl_context = ssl_context
@@ -37,6 +38,7 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
num_pools=connections, maxsize=maxsize,
block=block, ssl_context=self.ssl_context)
@dataclass
class CQCode:
"""
@@ -80,13 +82,13 @@ class CQCode:
else:
self.translated_plain_text = f"@某人"
elif self.type == 'reply':
self.translated_plain_text = self.translate_reply()
self.translated_plain_text = await self.translate_reply()
elif self.type == 'face':
face_id = self.params.get('id', '')
# self.translated_plain_text = f"[表情{face_id}]"
self.translated_plain_text = f"[表情]"
elif self.type == 'forward':
self.translated_plain_text = self.translate_forward()
self.translated_plain_text = await self.translate_forward()
else:
self.translated_plain_text = f"[{self.type}]"
@@ -171,7 +173,6 @@ class CQCode:
else:
return '[表情包]'
async def translate_image(self) -> str:
"""处理图片类型的CQ码区分普通图片和表情包"""
# 没有url直接返回默认文本
@@ -207,7 +208,7 @@ class CQCode:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[图片]"
def translate_forward(self) -> str:
async def translate_forward(self) -> str:
"""处理转发消息"""
try:
if 'content' not in self.params:
@@ -251,6 +252,7 @@ class CQCode:
plain_text=raw_message,
group_id=msg.get('group_id', 0)
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
@@ -265,6 +267,7 @@ class CQCode:
plain_text=raw_message,
group_id=msg.get('group_id', 0)
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
@@ -281,7 +284,7 @@ class CQCode:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
return '[转发消息]'
def translate_reply(self) -> str:
async def translate_reply(self) -> str:
"""处理回复类型的CQ码"""
# 创建Message对象
@@ -297,6 +300,7 @@ class CQCode:
raw_message=str(self.reply_message.message),
group_id=self.group_id
)
await message_obj.initialize()
if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
else:
@@ -333,6 +337,7 @@ class CQCode:
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
class CQCode_tool:
@staticmethod
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:

View File

@@ -38,45 +38,47 @@ class Message:
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
reply_message: Dict = None # 存储 回复的 源消息
# 延迟初始化字段
_initialized: bool = False
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
reply_message: Dict = None # 存储 回复的 源消息
# 状态标志
is_emoji: bool = False
has_emoji: bool = False
translate_cq: bool = True
is_emoji: bool = False # 是否是表情包
has_emoji: bool = False # 是否包含表情包
async def initialize(self):
"""显式异步初始化方法(必须调用)"""
if self._initialized:
return
translate_cq: bool = True # 是否翻译cq码
# 异步获取补充信息
self.group_name = self.group_name or get_groupname(self.group_id)
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
async def __post_init__(self):
if self.time is None:
self.time = int(time.time())
if not self.group_name:
self.group_name = get_groupname(self.group_id)
if not self.user_nickname:
self.user_nickname = get_user_nickname(self.user_id)
if not self.user_cardname:
self.user_cardname=get_user_cardname(self.user_id)
if not self.processed_plain_text:
# 消息解析
if self.raw_message:
self.message_segments = await self.parse_message_segments(str(self.raw_message))
self.message_segments = await self.parse_message_segments(self.raw_message)
self.processed_plain_text = ' '.join(
seg.translated_plain_text
for seg in self.message_segments
)
#将详细翻译为详细可读文本
# 构建详细文本
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
try:
name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
except:
name = self.user_nickname or f"用户{self.user_id}"
content = self.processed_plain_text
self.detailed_plain_text = f"[{time_str}] {name}: {content}\n"
name = (
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
if self.user_cardname
else f"{self.user_nickname or f'用户{self.user_id}'}"
)
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True
async def parse_message_segments(self, message: str) -> List[CQCode]:
"""

View File

@@ -131,7 +131,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return ''
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
@@ -173,6 +173,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
await msg.initialize()
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")

View File

@@ -65,7 +65,8 @@ class LLM_request:
}
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
logger.info(f"发送请求到URL: {api_url}{self.model_name}")
logger.info(f"发送请求到URL: {api_url}")
logger.info(f"使用模型: {self.model_name}")
# 构建请求体
if image_base64:
@@ -81,8 +82,7 @@ class LLM_request:
headers = await self._build_headers()
async with session_method as session:
response = await session.post(api_url, headers=headers, json=payload)
async with session.post(api_url, headers=headers, json=payload) as response:
# 处理需要重试的状态码
if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2 ** retry)
@@ -116,7 +116,7 @@ class LLM_request:
await asyncio.sleep(wait_time)
else:
logger.critical(f"请求失败: {str(e)}")
logger.critical(f"请求头: {self._build_headers()} 请求体: {payload}")
logger.critical(f"请求头: {await self._build_headers()} 请求体: {payload}")
raise RuntimeError(f"API请求失败: {str(e)}")
logger.error("达到最大重试次数,请求仍然失败")