feat(napcat_adapter): 添加请求处理程序、发送处理程序、视频处理程序以及实用函数
- 实现了request_handler.py来处理对核心的请求。 - 创建了send_handler.py文件,用于处理并向Napcat发送消息。 - 添加了video_handler.py文件,用于从QQ消息中下载和处理视频文件。 - 开发了utils.py,用于缓存和实现与Napcat操作相关的实用函数。 - 为群组、成员和自身信息引入了带有生存时间(TTL)设置的缓存机制。 - 新模块中增强了错误处理和日志记录功能。
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -345,4 +345,4 @@ package.json
|
||||
src/chat/planner_actions/新建 文本文档.txt
|
||||
/backup
|
||||
mofox_bot_statistics.html
|
||||
src/plugins/built_in/NEW_napcat_adapter/src/handlers/napcat_cache.json
|
||||
src/plugins/built_in/napcat_adapter/src/handlers/napcat_cache.json
|
||||
|
||||
@@ -92,6 +92,8 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st
|
||||
|
||||
# 构造消息数据字典(基于 TypedDict 风格)
|
||||
message_time = message_info.get("time", time.time())
|
||||
if isinstance(message_time,int):
|
||||
message_time = float(message_time / 1000)
|
||||
message_id = message_info.get("message_id", "")
|
||||
|
||||
# 处理 is_mentioned
|
||||
@@ -215,15 +217,9 @@ async def _process_single_segment(
|
||||
|
||||
try:
|
||||
if seg_type == "text":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
return str(seg_data) if seg_data else ""
|
||||
|
||||
elif seg_type == "at":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
state["is_at"] = True
|
||||
# 处理at消息,格式为"@<昵称:QQ号>"
|
||||
if isinstance(seg_data, str):
|
||||
@@ -242,8 +238,6 @@ async def _process_single_segment(
|
||||
if isinstance(seg_data, str):
|
||||
state["has_picid"] = True
|
||||
state["is_picid"] = True
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
image_manager = get_image_manager()
|
||||
_, processed_text = await image_manager.process_image(seg_data)
|
||||
return processed_text
|
||||
@@ -252,18 +246,12 @@ async def _process_single_segment(
|
||||
elif seg_type == "emoji":
|
||||
state["has_emoji"] = True
|
||||
state["is_emoji"] = True
|
||||
state["is_picid"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = False
|
||||
if isinstance(seg_data, str):
|
||||
return await get_image_manager().get_emoji_description(seg_data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
|
||||
elif seg_type == "voice":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = True
|
||||
state["is_video"] = False
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
user_info = message_info.get("user_info", {})
|
||||
@@ -284,18 +272,11 @@ async def _process_single_segment(
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
|
||||
elif seg_type == "mention_bot":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = False
|
||||
if isinstance(seg_data, (int, float)):
|
||||
state["is_mentioned"] = float(seg_data)
|
||||
return ""
|
||||
|
||||
elif seg_type == "priority_info":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
if isinstance(seg_data, dict):
|
||||
# 处理优先级信息
|
||||
state["priority_mode"] = "priority"
|
||||
@@ -310,9 +291,6 @@ async def _process_single_segment(
|
||||
return "[收到一个文件]"
|
||||
|
||||
elif seg_type == "video":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = True
|
||||
logger.info(f"接收到视频消息,数据类型: {type(seg_data)}")
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,230 +0,0 @@
|
||||
"""消息处理器 - 将 Napcat OneBot 消息转换为 MessageEnvelope"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from mofox_wire import MessageBuilder
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
from mofox_wire import (
|
||||
MessageEnvelope,
|
||||
SegPayload,
|
||||
MessageInfoPayload,
|
||||
UserInfoPayload,
|
||||
GroupInfoPayload,
|
||||
)
|
||||
|
||||
from ...event_models import ACCEPT_FORMAT, QQ_FACE
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_image_base64,
|
||||
get_self_info,
|
||||
get_member_info,
|
||||
get_message_detail,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.message_handler")
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
"""处理来自 Napcat 的消息事件"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_raw_message(self, raw: Dict[str, Any]):
|
||||
"""
|
||||
处理原始消息并转换为 MessageEnvelope
|
||||
|
||||
Args:
|
||||
raw: OneBot 原始消息数据
|
||||
|
||||
Returns:
|
||||
MessageEnvelope (dict)
|
||||
"""
|
||||
|
||||
message_type = raw.get("message_type")
|
||||
message_id = str(raw.get("message_id", ""))
|
||||
message_time = time.time()
|
||||
|
||||
msg_builder = MessageBuilder()
|
||||
|
||||
# 构造用户信息
|
||||
sender_info = raw.get("sender", {})
|
||||
|
||||
(
|
||||
msg_builder.direction("incoming")
|
||||
.message_id(message_id)
|
||||
.timestamp_ms(int(message_time * 1000))
|
||||
.from_user(
|
||||
user_id=str(sender_info.get("user_id", "")),
|
||||
platform="qq",
|
||||
nickname=sender_info.get("nickname", ""),
|
||||
cardname=sender_info.get("card", ""),
|
||||
user_avatar=sender_info.get("avatar", ""),
|
||||
)
|
||||
)
|
||||
|
||||
# 构造群组信息(如果是群消息)
|
||||
if message_type == "group":
|
||||
group_id = raw.get("group_id")
|
||||
if group_id:
|
||||
fetched_group_info = await get_group_info(group_id)
|
||||
(
|
||||
msg_builder.from_group(
|
||||
group_id=str(group_id),
|
||||
platform="qq",
|
||||
name=(
|
||||
fetched_group_info.get("group_name", "")
|
||||
if fetched_group_info
|
||||
else raw.get("group_name", "")
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# 解析消息段
|
||||
message_segments = raw.get("message", [])
|
||||
seg_list: List[SegPayload] = []
|
||||
|
||||
for segment in message_segments:
|
||||
seg_message = await self.handle_single_segment(segment, raw)
|
||||
if seg_message:
|
||||
seg_list.append(seg_message)
|
||||
|
||||
msg_builder.format_info(
|
||||
content_format=[seg["type"] for seg in seg_list],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
)
|
||||
|
||||
msg_builder.seg_list(seg_list)
|
||||
|
||||
return msg_builder.build()
|
||||
|
||||
async def handle_single_segment(
|
||||
self, segment: dict, raw_message: dict, in_reply: bool = False
|
||||
) -> SegPayload | None:
|
||||
"""
|
||||
处理单一消息段并转换为 MessageEnvelope
|
||||
|
||||
Args:
|
||||
segment: 单一原始消息段
|
||||
raw_message: 完整的原始消息数据
|
||||
|
||||
Returns:
|
||||
SegPayload | List[SegPayload] | None
|
||||
"""
|
||||
seg_type = segment.get("type")
|
||||
seg_data: dict = segment.get("data", {})
|
||||
match seg_type:
|
||||
case "text":
|
||||
return {"type": "text", "data": seg_data.get("text", "")}
|
||||
case "image":
|
||||
image_sub_type = seg_data.get("sub_type")
|
||||
try:
|
||||
image_base64 = await get_image_base64(seg_data.get("url", ""))
|
||||
except Exception as e:
|
||||
logger.error(f"图片消息处理失败: {str(e)}")
|
||||
return None
|
||||
if image_sub_type == 0:
|
||||
"""这部分认为是图片"""
|
||||
return {"type": "image", "data": image_base64}
|
||||
elif image_sub_type not in [4, 9]:
|
||||
"""这部分认为是表情包"""
|
||||
return {"type": "emoji", "data": image_base64}
|
||||
else:
|
||||
logger.warning(f"不支持的图片子类型:{image_sub_type}")
|
||||
return None
|
||||
case "face":
|
||||
message_data: dict = segment.get("data", {})
|
||||
face_raw_id: str = str(message_data.get("id"))
|
||||
if face_raw_id in QQ_FACE:
|
||||
face_content: str = QQ_FACE.get(face_raw_id, "[未知表情]")
|
||||
return {"type": "text", "data": face_content}
|
||||
else:
|
||||
logger.warning(f"不支持的表情:{face_raw_id}")
|
||||
return None
|
||||
case "at":
|
||||
if seg_data:
|
||||
qq_id = seg_data.get("qq")
|
||||
self_id = raw_message.get("self_id")
|
||||
group_id = raw_message.get("group_id")
|
||||
if str(self_id) == str(qq_id):
|
||||
logger.debug("机器人被at")
|
||||
self_info = await get_self_info()
|
||||
if self_info:
|
||||
# 返回包含昵称和用户ID的at格式,便于后续处理
|
||||
return {
|
||||
"type": "at",
|
||||
"data": f"{self_info.get('nickname')}:{self_info.get('user_id')}",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
if qq_id and group_id:
|
||||
member_info = await get_member_info(
|
||||
group_id=group_id, user_id=qq_id
|
||||
)
|
||||
if member_info:
|
||||
# 返回包含昵称和用户ID的at格式,便于后续处理
|
||||
return {
|
||||
"type": "at",
|
||||
"data": f"{member_info.get('nickname')}:{member_info.get('user_id')}",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
case "emoji":
|
||||
seg_data = segment.get("id", "")
|
||||
case "reply":
|
||||
if not in_reply:
|
||||
message_id = None
|
||||
if seg_data:
|
||||
message_id = seg_data.get("id")
|
||||
else:
|
||||
return None
|
||||
message_detail = await get_message_detail(message_id)
|
||||
if not message_detail:
|
||||
logger.warning("获取被引用的消息详情失败")
|
||||
return None
|
||||
reply_message = await self.handle_single_segment(
|
||||
message_detail, raw_message, in_reply=True
|
||||
)
|
||||
if reply_message is None:
|
||||
reply_message = [
|
||||
{"type": "text", "data": "[无法获取被引用的消息]"}
|
||||
]
|
||||
sender_info: dict = message_detail.get("sender", {})
|
||||
sender_nickname: str = sender_info.get("nickname", "")
|
||||
sender_id = sender_info.get("user_id")
|
||||
if not sender_nickname:
|
||||
logger.warning("无法获取被引用的人的昵称,返回默认值")
|
||||
return {
|
||||
"type": "text",
|
||||
"data": f"[回复<未知用户>:{reply_message}],说:",
|
||||
}
|
||||
|
||||
else:
|
||||
if sender_id:
|
||||
return {
|
||||
"type": "text",
|
||||
"data": f"[回复<{sender_nickname}({sender_id})>:{reply_message}],说:",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": "text",
|
||||
"data": f"[回复<{sender_nickname}>:{reply_message}],说:",
|
||||
}
|
||||
|
||||
case "voice":
|
||||
seg_data = segment.get("url", "")
|
||||
case _:
|
||||
logger.warning(f"Unsupported segment type: {seg_type}")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""元事件处理器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.meta_event_handler")
|
||||
|
||||
|
||||
class MetaEventHandler:
|
||||
"""处理 Napcat 元事件(心跳、生命周期)"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_meta_event(self, raw: Dict[str, Any]):
|
||||
"""处理元事件"""
|
||||
# 简化版本:返回一个空的 MessageEnvelope
|
||||
pass
|
||||
@@ -1,77 +0,0 @@
|
||||
"""发送处理器 - 将 MessageEnvelope 转换并发送到 Napcat"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.send_handler")
|
||||
|
||||
|
||||
class SendHandler:
|
||||
"""处理向 Napcat 发送消息"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_message(self, envelope) -> None:
|
||||
"""
|
||||
处理发送消息
|
||||
|
||||
将 MessageEnvelope 转换为 OneBot API 调用
|
||||
"""
|
||||
message_info = envelope.get("message_info", {})
|
||||
message_segment = envelope.get("message_segment", {})
|
||||
|
||||
# 获取群组和用户信息
|
||||
group_info = message_info.get("group_info")
|
||||
user_info = message_info.get("user_info")
|
||||
|
||||
# 构造消息内容
|
||||
message = self._convert_seg_to_onebot(message_segment)
|
||||
|
||||
# 发送消息
|
||||
if group_info:
|
||||
# 发送群消息
|
||||
group_id = group_info.get("group_id")
|
||||
if group_id:
|
||||
await self.adapter.send_napcat_api("send_group_msg", {
|
||||
"group_id": int(group_id),
|
||||
"message": message,
|
||||
})
|
||||
elif user_info:
|
||||
# 发送私聊消息
|
||||
user_id = user_info.get("user_id")
|
||||
if user_id:
|
||||
await self.adapter.send_napcat_api("send_private_msg", {
|
||||
"user_id": int(user_id),
|
||||
"message": message,
|
||||
})
|
||||
|
||||
def _convert_seg_to_onebot(self, seg: Dict[str, Any]) -> list:
|
||||
"""将 SegPayload 转换为 OneBot 消息格式"""
|
||||
seg_type = seg.get("type", "")
|
||||
seg_data = seg.get("data", "")
|
||||
|
||||
if seg_type == "text":
|
||||
return [{"type": "text", "data": {"text": seg_data}}]
|
||||
elif seg_type == "image":
|
||||
return [{"type": "image", "data": {"file": f"base64://{seg_data}"}}]
|
||||
elif seg_type == "seglist":
|
||||
# 递归处理列表
|
||||
result = []
|
||||
for sub_seg in seg_data:
|
||||
result.extend(self._convert_seg_to_onebot(sub_seg))
|
||||
return result
|
||||
else:
|
||||
# 默认作为文本
|
||||
return [{"type": "text", "data": {"text": str(seg_data)}}]
|
||||
@@ -1,350 +0,0 @@
|
||||
"""
|
||||
按聊天流分配消费者的消息路由系统
|
||||
|
||||
核心思想:
|
||||
- 为每个活跃的聊天流(stream_id)创建独立的消息队列和消费者协程
|
||||
- 同一聊天流的消息由同一个 worker 处理,保证顺序性
|
||||
- 不同聊天流的消息并发处理,提高吞吐量
|
||||
- 动态管理流的生命周期,自动清理不活跃的流
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("stream_router")
|
||||
|
||||
|
||||
class StreamConsumer:
|
||||
"""单个聊天流的消息消费者
|
||||
|
||||
维护独立的消息队列和处理协程
|
||||
"""
|
||||
|
||||
def __init__(self, stream_id: str, queue_maxsize: int = 100):
|
||||
self.stream_id = stream_id
|
||||
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
||||
self.worker_task: Optional[asyncio.Task] = None
|
||||
self.last_active_time = time.time()
|
||||
self.is_running = False
|
||||
|
||||
# 性能统计
|
||||
self.stats = {
|
||||
"total_messages": 0,
|
||||
"total_processing_time": 0.0,
|
||||
"queue_overflow_count": 0,
|
||||
}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动消费者"""
|
||||
if not self.is_running:
|
||||
self.is_running = True
|
||||
self.worker_task = asyncio.create_task(self._process_loop())
|
||||
logger.debug(f"Stream Consumer 启动: {self.stream_id}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止消费者"""
|
||||
self.is_running = False
|
||||
if self.worker_task:
|
||||
self.worker_task.cancel()
|
||||
try:
|
||||
await self.worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.debug(f"Stream Consumer 停止: {self.stream_id}")
|
||||
|
||||
async def enqueue(self, message: dict) -> None:
|
||||
"""将消息加入队列"""
|
||||
self.last_active_time = time.time()
|
||||
|
||||
try:
|
||||
# 使用 put_nowait 避免阻塞路由器
|
||||
self.queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
self.stats["queue_overflow_count"] += 1
|
||||
logger.warning(
|
||||
f"Stream {self.stream_id} 队列已满 "
|
||||
f"({self.queue.qsize()}/{self.queue.maxsize}),"
|
||||
)
|
||||
|
||||
try:
|
||||
self.queue.get_nowait()
|
||||
self.queue.put_nowait(message)
|
||||
logger.debug(f"Stream {self.stream_id} 丢弃最旧消息,添加新消息")
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
|
||||
async def _process_loop(self) -> None:
|
||||
"""消息处理循环"""
|
||||
# 延迟导入,避免循环依赖
|
||||
from .recv_handler.message_handler import message_handler
|
||||
from .recv_handler.meta_event_handler import meta_event_handler
|
||||
from .recv_handler.notice_handler import notice_handler
|
||||
|
||||
logger.info(f"Stream {self.stream_id} 处理循环启动")
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
try:
|
||||
# 等待消息,1秒超时
|
||||
message = await asyncio.wait_for(
|
||||
self.queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 处理消息
|
||||
post_type = message.get("post_type")
|
||||
if post_type == "message":
|
||||
await message_handler.handle_raw_message(message)
|
||||
elif post_type == "meta_event":
|
||||
await meta_event_handler.handle_meta_event(message)
|
||||
elif post_type == "notice":
|
||||
await notice_handler.handle_notice(message)
|
||||
else:
|
||||
logger.warning(f"未知的 post_type: {post_type}")
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_messages"] += 1
|
||||
self.stats["total_processing_time"] += processing_time
|
||||
self.last_active_time = time.time()
|
||||
self.queue.task_done()
|
||||
|
||||
# 性能监控(每100条消息输出一次)
|
||||
if self.stats["total_messages"] % 100 == 0:
|
||||
avg_time = self.stats["total_processing_time"] / self.stats["total_messages"]
|
||||
logger.info(
|
||||
f"Stream {self.stream_id[:30]}... 统计: "
|
||||
f"消息数={self.stats['total_messages']}, "
|
||||
f"平均耗时={avg_time:.3f}秒, "
|
||||
f"队列长度={self.queue.qsize()}"
|
||||
)
|
||||
|
||||
# 动态延迟:队列空时短暂休眠
|
||||
if self.queue.qsize() == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时是正常的,继续循环
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Stream {self.stream_id} 处理循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Stream {self.stream_id} 处理消息时出错: {e}", exc_info=True)
|
||||
# 继续处理下一条消息
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
finally:
|
||||
logger.info(f"Stream {self.stream_id} 处理循环结束")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取性能统计"""
|
||||
avg_time = (
|
||||
self.stats["total_processing_time"] / self.stats["total_messages"]
|
||||
if self.stats["total_messages"] > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"queue_size": self.queue.qsize(),
|
||||
"total_messages": self.stats["total_messages"],
|
||||
"avg_processing_time": avg_time,
|
||||
"queue_overflow_count": self.stats["queue_overflow_count"],
|
||||
"last_active_time": self.last_active_time,
|
||||
}
|
||||
|
||||
|
||||
class StreamRouter:
|
||||
"""流路由器
|
||||
|
||||
负责将消息路由到对应的聊天流队列
|
||||
动态管理聊天流的生命周期
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_streams: int = 500,
|
||||
stream_timeout: int = 600,
|
||||
stream_queue_size: int = 100,
|
||||
cleanup_interval: int = 60,
|
||||
):
|
||||
self.streams: Dict[str, StreamConsumer] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
self.max_streams = max_streams
|
||||
self.stream_timeout = stream_timeout
|
||||
self.stream_queue_size = stream_queue_size
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
self.is_running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动路由器"""
|
||||
if not self.is_running:
|
||||
self.is_running = True
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info(
|
||||
f"StreamRouter 已启动 - "
|
||||
f"最大流数: {self.max_streams}, "
|
||||
f"超时: {self.stream_timeout}秒, "
|
||||
f"队列大小: {self.stream_queue_size}"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止路由器"""
|
||||
self.is_running = False
|
||||
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有流消费者
|
||||
logger.info(f"正在停止 {len(self.streams)} 个流消费者...")
|
||||
for consumer in self.streams.values():
|
||||
await consumer.stop()
|
||||
|
||||
self.streams.clear()
|
||||
logger.info("StreamRouter 已停止")
|
||||
|
||||
async def route_message(self, message: dict) -> None:
|
||||
"""路由消息到对应的流"""
|
||||
stream_id = self._extract_stream_id(message)
|
||||
|
||||
# 快速路径:流已存在
|
||||
if stream_id in self.streams:
|
||||
await self.streams[stream_id].enqueue(message)
|
||||
return
|
||||
|
||||
# 慢路径:需要创建新流
|
||||
async with self.lock:
|
||||
# 双重检查
|
||||
if stream_id not in self.streams:
|
||||
# 检查流数量限制
|
||||
if len(self.streams) >= self.max_streams:
|
||||
logger.warning(
|
||||
f"达到最大流数量限制 ({self.max_streams}),"
|
||||
f"尝试清理不活跃的流..."
|
||||
)
|
||||
await self._cleanup_inactive_streams()
|
||||
|
||||
# 清理后仍然超限,记录警告但继续创建
|
||||
if len(self.streams) >= self.max_streams:
|
||||
logger.error(
|
||||
f"清理后仍达到最大流数量 ({len(self.streams)}/{self.max_streams})!"
|
||||
)
|
||||
|
||||
# 创建新流
|
||||
consumer = StreamConsumer(stream_id, self.stream_queue_size)
|
||||
self.streams[stream_id] = consumer
|
||||
await consumer.start()
|
||||
logger.info(f"创建新的 Stream Consumer: {stream_id} (总流数: {len(self.streams)})")
|
||||
|
||||
await self.streams[stream_id].enqueue(message)
|
||||
|
||||
def _extract_stream_id(self, message: dict) -> str:
|
||||
"""从消息中提取 stream_id
|
||||
|
||||
返回格式: platform:id:type
|
||||
例如: qq:123456:group 或 qq:789012:private
|
||||
"""
|
||||
post_type = message.get("post_type")
|
||||
|
||||
# 非消息类型,使用默认流(避免创建过多流)
|
||||
if post_type not in ["message", "notice"]:
|
||||
return "system:meta_event"
|
||||
|
||||
# 消息类型
|
||||
if post_type == "message":
|
||||
message_type = message.get("message_type")
|
||||
if message_type == "group":
|
||||
group_id = message.get("group_id")
|
||||
return f"qq:{group_id}:group"
|
||||
elif message_type == "private":
|
||||
user_id = message.get("user_id")
|
||||
return f"qq:{user_id}:private"
|
||||
|
||||
# notice 类型
|
||||
elif post_type == "notice":
|
||||
group_id = message.get("group_id")
|
||||
if group_id:
|
||||
return f"qq:{group_id}:group"
|
||||
user_id = message.get("user_id")
|
||||
if user_id:
|
||||
return f"qq:{user_id}:private"
|
||||
|
||||
# 未知类型,使用通用流
|
||||
return "unknown:unknown"
|
||||
|
||||
async def _cleanup_inactive_streams(self) -> None:
|
||||
"""清理不活跃的流"""
|
||||
current_time = time.time()
|
||||
to_remove = []
|
||||
|
||||
for stream_id, consumer in self.streams.items():
|
||||
if current_time - consumer.last_active_time > self.stream_timeout:
|
||||
to_remove.append(stream_id)
|
||||
|
||||
for stream_id in to_remove:
|
||||
await self.streams[stream_id].stop()
|
||||
del self.streams[stream_id]
|
||||
logger.debug(f"清理不活跃的流: {stream_id}")
|
||||
|
||||
if to_remove:
|
||||
logger.info(
|
||||
f"清理了 {len(to_remove)} 个不活跃的流 "
|
||||
f"(当前活跃流: {len(self.streams)}/{self.max_streams})"
|
||||
)
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""定期清理循环"""
|
||||
logger.info(f"清理循环已启动,间隔: {self.cleanup_interval}秒")
|
||||
try:
|
||||
while self.is_running:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self._cleanup_inactive_streams()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("清理循环已停止")
|
||||
|
||||
def get_all_stats(self) -> list[dict]:
|
||||
"""获取所有流的统计信息"""
|
||||
return [consumer.get_stats() for consumer in self.streams.values()]
|
||||
|
||||
def get_summary(self) -> dict:
|
||||
"""获取路由器摘要"""
|
||||
total_messages = sum(c.stats["total_messages"] for c in self.streams.values())
|
||||
total_queue_size = sum(c.queue.qsize() for c in self.streams.values())
|
||||
total_overflows = sum(c.stats["queue_overflow_count"] for c in self.streams.values())
|
||||
|
||||
# 计算平均队列长度
|
||||
avg_queue_size = total_queue_size / len(self.streams) if self.streams else 0
|
||||
|
||||
# 找出最繁忙的流
|
||||
busiest_stream = None
|
||||
if self.streams:
|
||||
busiest_stream = max(
|
||||
self.streams.values(),
|
||||
key=lambda c: c.stats["total_messages"]
|
||||
).stream_id
|
||||
|
||||
return {
|
||||
"total_streams": len(self.streams),
|
||||
"max_streams": self.max_streams,
|
||||
"total_messages_processed": total_messages,
|
||||
"total_queue_size": total_queue_size,
|
||||
"avg_queue_size": avg_queue_size,
|
||||
"total_queue_overflows": total_overflows,
|
||||
"busiest_stream": busiest_stream,
|
||||
}
|
||||
|
||||
|
||||
# 全局路由器实例
|
||||
stream_router = StreamRouter()
|
||||
@@ -110,7 +110,7 @@ class NapcatAdapter(BaseAdapter):
|
||||
|
||||
logger.info("Napcat 适配器已关闭")
|
||||
|
||||
async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope: # type: ignore[override]
|
||||
async def from_platform_message(self, raw: Dict[str, Any]) -> MessageEnvelope | None: # type: ignore[override]
|
||||
"""
|
||||
将 Napcat/OneBot 原始消息转换为 MessageEnvelope
|
||||
|
||||
@@ -144,7 +144,7 @@ class NapcatAdapter(BaseAdapter):
|
||||
|
||||
# 未知事件类型
|
||||
else:
|
||||
logger.warning(f"未知的事件类型: {post_type}")
|
||||
return
|
||||
|
||||
async def _send_platform_message(self, envelope: MessageEnvelope) -> None: # type: ignore[override]
|
||||
"""
|
||||
@@ -0,0 +1,715 @@
|
||||
"""消息处理器 - 将 Napcat OneBot 消息转换为 MessageEnvelope"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
import uuid
|
||||
|
||||
from mofox_wire import MessageBuilder
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
from mofox_wire import (
|
||||
MessageEnvelope,
|
||||
SegPayload,
|
||||
MessageInfoPayload,
|
||||
UserInfoPayload,
|
||||
GroupInfoPayload,
|
||||
)
|
||||
|
||||
from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType
|
||||
from ..utils import *
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.message_handler")
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
"""处理来自 Napcat 的消息事件"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_raw_message(self, raw: Dict[str, Any]):
|
||||
"""
|
||||
处理原始消息并转换为 MessageEnvelope
|
||||
|
||||
Args:
|
||||
raw: OneBot 原始消息数据
|
||||
|
||||
Returns:
|
||||
MessageEnvelope (dict)
|
||||
"""
|
||||
|
||||
message_type = raw.get("message_type")
|
||||
message_id = str(raw.get("message_id", ""))
|
||||
message_time = time.time()
|
||||
|
||||
msg_builder = MessageBuilder()
|
||||
|
||||
# 构造用户信息
|
||||
sender_info = raw.get("sender", {})
|
||||
|
||||
(
|
||||
msg_builder.direction("incoming")
|
||||
.message_id(message_id)
|
||||
.timestamp_ms(int(message_time * 1000))
|
||||
.from_user(
|
||||
user_id=str(sender_info.get("user_id", "")),
|
||||
platform="qq",
|
||||
nickname=sender_info.get("nickname", ""),
|
||||
cardname=sender_info.get("card", ""),
|
||||
user_avatar=sender_info.get("avatar", ""),
|
||||
)
|
||||
)
|
||||
|
||||
# 构造群组信息(如果是群消息)
|
||||
if message_type == "group":
|
||||
group_id = raw.get("group_id")
|
||||
if group_id:
|
||||
fetched_group_info = await get_group_info(group_id)
|
||||
(
|
||||
msg_builder.from_group(
|
||||
group_id=str(group_id),
|
||||
platform="qq",
|
||||
name=(
|
||||
fetched_group_info.get("group_name", "")
|
||||
if fetched_group_info
|
||||
else raw.get("group_name", "")
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# 解析消息段
|
||||
message_segments = raw.get("message", [])
|
||||
seg_list: List[SegPayload] = []
|
||||
|
||||
for segment in message_segments:
|
||||
seg_message = await self.handle_single_segment(segment, raw)
|
||||
if seg_message:
|
||||
seg_list.append(seg_message)
|
||||
|
||||
msg_builder.format_info(
|
||||
content_format=[seg["type"] for seg in seg_list],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
)
|
||||
|
||||
msg_builder.seg_list(seg_list)
|
||||
|
||||
return msg_builder.build()
|
||||
|
||||
async def handle_single_segment(
|
||||
self, segment: dict, raw_message: dict, in_reply: bool = False
|
||||
) -> SegPayload | None:
|
||||
"""
|
||||
处理单一消息段并转换为 MessageEnvelope
|
||||
|
||||
Args:
|
||||
segment: 单一原始消息段
|
||||
raw_message: 完整的原始消息数据
|
||||
|
||||
Returns:
|
||||
SegPayload | None
|
||||
"""
|
||||
seg_type = segment.get("type")
|
||||
|
||||
match seg_type:
|
||||
case RealMessageType.text:
|
||||
return await self._handle_text_message(segment)
|
||||
case RealMessageType.image:
|
||||
return await self._handle_image_message(segment)
|
||||
case RealMessageType.face:
|
||||
return await self._handle_face_message(segment)
|
||||
case RealMessageType.at:
|
||||
return await self._handle_at_message(segment, raw_message)
|
||||
case RealMessageType.reply:
|
||||
return await self._handle_reply_message(segment, raw_message, in_reply)
|
||||
case RealMessageType.record:
|
||||
return await self._handle_record_message(segment)
|
||||
case RealMessageType.video:
|
||||
return await self._handle_video_message(segment)
|
||||
case RealMessageType.rps:
|
||||
return await self._handle_rps_message(segment)
|
||||
case RealMessageType.dice:
|
||||
return await self._handle_dice_message(segment)
|
||||
case RealMessageType.forward:
|
||||
messages = await get_forward_message(segment, adapter=self.adapter)
|
||||
if not messages:
|
||||
logger.warning("转发消息内容为空或获取失败")
|
||||
return None
|
||||
return await self.handle_forward_message(messages)
|
||||
case RealMessageType.json:
|
||||
return await self._handle_json_message(segment)
|
||||
case RealMessageType.file:
|
||||
return await self._handle_file_message(segment)
|
||||
|
||||
case _:
|
||||
logger.warning(f"Unsupported segment type: {seg_type}")
|
||||
return None
|
||||
|
||||
# Utility methods for handling different message types
|
||||
|
||||
async def _handle_text_message(self, segment: dict) -> SegPayload:
|
||||
"""处理纯文本消息"""
|
||||
message_data = segment.get("data", {})
|
||||
plain_text = message_data.get("text", "")
|
||||
return {"type": "text", "data": plain_text}
|
||||
|
||||
async def _handle_face_message(self, segment: dict) -> SegPayload | None:
|
||||
"""处理表情消息"""
|
||||
message_data = segment.get("data", {})
|
||||
face_raw_id = str(message_data.get("id", ""))
|
||||
if face_raw_id in QQ_FACE:
|
||||
face_content = QQ_FACE.get(face_raw_id, "[未知表情]")
|
||||
return {"type": "text", "data": face_content}
|
||||
else:
|
||||
logger.warning(f"不支持的表情:{face_raw_id}")
|
||||
return None
|
||||
|
||||
async def _handle_image_message(self, segment: dict) -> SegPayload | None:
|
||||
"""处理图片消息与表情包消息"""
|
||||
message_data = segment.get("data", {})
|
||||
image_sub_type = message_data.get("sub_type")
|
||||
try:
|
||||
image_base64 = await get_image_base64(message_data.get("url", ""))
|
||||
except Exception as e:
|
||||
logger.error(f"图片消息处理失败: {str(e)}")
|
||||
return None
|
||||
if image_sub_type == 0:
|
||||
return {"type": "image", "data": image_base64}
|
||||
elif image_sub_type not in [4, 9]:
|
||||
return {"type": "emoji", "data": image_base64}
|
||||
else:
|
||||
logger.warning(f"不支持的图片子类型:{image_sub_type}")
|
||||
return None
|
||||
|
||||
async def _handle_at_message(self, segment: dict, raw_message: dict) -> SegPayload | None:
|
||||
"""处理@消息"""
|
||||
seg_data = segment.get("data", {})
|
||||
if not seg_data:
|
||||
return None
|
||||
|
||||
qq_id = seg_data.get("qq")
|
||||
self_id = raw_message.get("self_id")
|
||||
group_id = raw_message.get("group_id")
|
||||
|
||||
if str(self_id) == str(qq_id):
|
||||
logger.debug("机器人被at")
|
||||
self_info = await get_self_info()
|
||||
if self_info:
|
||||
return {"type": "at", "data": f"{self_info.get('nickname')}:{self_info.get('user_id')}"}
|
||||
return None
|
||||
else:
|
||||
if qq_id and group_id:
|
||||
member_info = await get_member_info(group_id=group_id, user_id=qq_id)
|
||||
if member_info:
|
||||
return {"type": "at", "data": f"{member_info.get('nickname')}:{member_info.get('user_id')}"}
|
||||
return None
|
||||
|
||||
async def _handle_reply_message(self, segment: dict, raw_message: dict, in_reply: bool) -> SegPayload | None:
|
||||
"""处理回复消息"""
|
||||
if in_reply:
|
||||
return None
|
||||
|
||||
seg_data = segment.get("data", {})
|
||||
if not seg_data:
|
||||
return None
|
||||
|
||||
message_id = seg_data.get("id")
|
||||
if not message_id:
|
||||
return None
|
||||
|
||||
message_detail = await get_message_detail(message_id)
|
||||
if not message_detail:
|
||||
logger.warning("获取被引用的消息详情失败")
|
||||
return {"type": "text", "data": "[无法获取被引用的消息]"}
|
||||
|
||||
# 递归处理被引用的消息
|
||||
reply_segments = []
|
||||
for reply_seg in message_detail.get("message", []):
|
||||
if isinstance(reply_seg, dict):
|
||||
reply_result = await self.handle_single_segment(reply_seg, raw_message, in_reply=True)
|
||||
if reply_result:
|
||||
reply_segments.append(reply_result)
|
||||
|
||||
if not reply_segments:
|
||||
reply_text = "[无法获取被引用的消息]"
|
||||
else:
|
||||
# 简化处理,只取第一个segment的data
|
||||
reply_text = reply_segments[0].get("data", "") if reply_segments else ""
|
||||
|
||||
sender_info = message_detail.get("sender", {})
|
||||
sender_nickname = sender_info.get("nickname", "未知用户")
|
||||
sender_id = sender_info.get("user_id")
|
||||
|
||||
if sender_id:
|
||||
return {"type": "text", "data": f"[回复<{sender_nickname}({sender_id})>:{reply_text}],说:"}
|
||||
else:
|
||||
return {"type": "text", "data": f"[回复<{sender_nickname}>:{reply_text}],说:"}
|
||||
|
||||
async def _handle_record_message(self, segment: dict) -> SegPayload | None:
|
||||
"""处理语音消息"""
|
||||
message_data = segment.get("data", {})
|
||||
file = message_data.get("file", "")
|
||||
if not file:
|
||||
logger.warning("语音消息缺少文件信息")
|
||||
return None
|
||||
|
||||
try:
|
||||
record_detail = await get_record_detail(file)
|
||||
if not record_detail:
|
||||
logger.warning("获取语音消息详情失败")
|
||||
return None
|
||||
audio_base64 = record_detail.get("base64", "")
|
||||
except Exception as e:
|
||||
logger.error(f"语音消息处理失败: {str(e)}")
|
||||
return None
|
||||
|
||||
if not audio_base64:
|
||||
logger.error("语音消息处理失败,未获取到音频数据")
|
||||
return None
|
||||
|
||||
return {"type": "voice", "data": audio_base64}
|
||||
|
||||
async def _handle_video_message(self, segment: dict) -> SegPayload | None:
|
||||
"""处理视频消息"""
|
||||
message_data = segment.get("data", {})
|
||||
|
||||
video_url = message_data.get("url")
|
||||
file_path = message_data.get("filePath") or message_data.get("file_path")
|
||||
|
||||
video_source = file_path if file_path else video_url
|
||||
if not video_source:
|
||||
logger.warning("视频消息缺少URL或文件路径信息")
|
||||
return None
|
||||
|
||||
try:
|
||||
if file_path and Path(file_path).exists():
|
||||
# 本地文件处理
|
||||
with open(file_path, "rb") as f:
|
||||
video_data = f.read()
|
||||
video_base64 = base64.b64encode(video_data).decode("utf-8")
|
||||
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"base64": video_base64,
|
||||
"filename": Path(file_path).name,
|
||||
"size_mb": len(video_data) / (1024 * 1024),
|
||||
},
|
||||
}
|
||||
elif video_url:
|
||||
# URL下载处理
|
||||
from ..video_handler import get_video_downloader
|
||||
video_downloader = get_video_downloader()
|
||||
download_result = await video_downloader.download_video(video_url)
|
||||
|
||||
if not download_result["success"]:
|
||||
logger.warning(f"视频下载失败: {download_result.get('error', '未知错误')}")
|
||||
return None
|
||||
|
||||
video_base64 = base64.b64encode(download_result["data"]).decode("utf-8")
|
||||
logger.debug(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
|
||||
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"base64": video_base64,
|
||||
"filename": download_result.get("filename", "video.mp4"),
|
||||
"size_mb": len(download_result["data"]) / (1024 * 1024),
|
||||
"url": video_url,
|
||||
},
|
||||
}
|
||||
else:
|
||||
logger.warning("既没有有效的本地文件路径,也没有有效的视频URL")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"视频消息处理失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _handle_rps_message(self, segment: dict) -> SegPayload:
|
||||
"""处理猜拳消息"""
|
||||
message_data = segment.get("data", {})
|
||||
res = message_data.get("result", "")
|
||||
shape_map = {"1": "布", "2": "剪刀"}
|
||||
shape = shape_map.get(res, "石头")
|
||||
return {"type": "text", "data": f"[发送了一个魔法猜拳表情,结果是:{shape}]"}
|
||||
|
||||
async def _handle_dice_message(self, segment: dict) -> SegPayload:
|
||||
"""处理骰子消息"""
|
||||
message_data = segment.get("data", {})
|
||||
res = message_data.get("result", "")
|
||||
return {"type": "text", "data": f"[扔了一个骰子,点数是{res}]"}
|
||||
|
||||
|
||||
async def handle_forward_message(self, message_list: list) -> SegPayload | None:
|
||||
"""
|
||||
递归处理转发消息,并按照动态方式确定图片处理方式
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表
|
||||
"""
|
||||
handled_message, image_count = await self._handle_forward_message(message_list, 0)
|
||||
if not handled_message:
|
||||
return None
|
||||
|
||||
if 0 < image_count < 5:
|
||||
logger.debug("图片数量小于5,开始解析图片为base64")
|
||||
processed_message = await self._recursive_parse_image_seg(handled_message, True)
|
||||
elif image_count > 0:
|
||||
logger.debug("图片数量大于等于5,开始解析图片为占位符")
|
||||
processed_message = await self._recursive_parse_image_seg(handled_message, False)
|
||||
else:
|
||||
logger.debug("没有图片,直接返回")
|
||||
processed_message = handled_message
|
||||
|
||||
forward_hint = {"type": "text", "data": "这是一条转发消息:\n"}
|
||||
return {"type": "seglist", "data": [forward_hint, processed_message]}
|
||||
|
||||
async def _recursive_parse_image_seg(self, seg_data: SegPayload, to_image: bool) -> SegPayload:
|
||||
# sourcery skip: merge-else-if-into-elif
|
||||
if seg_data.get("type") == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.get("data", []):
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return {"type": "seglist", "data": new_seg_list}
|
||||
|
||||
if to_image:
|
||||
if seg_data.get("type") == "image":
|
||||
image_url = seg_data.get("data")
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return {"type": "text", "data": "[图片]"}
|
||||
return {"type": "image", "data": encoded_image}
|
||||
if seg_data.get("type") == "emoji":
|
||||
image_url = seg_data.get("data")
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return {"type": "text", "data": "[表情包]"}
|
||||
return {"type": "emoji", "data": encoded_image}
|
||||
logger.debug(f"不处理类型: {seg_data.get('type')}")
|
||||
return seg_data
|
||||
|
||||
if seg_data.get("type") == "image":
|
||||
return {"type": "text", "data": "[图片]"}
|
||||
if seg_data.get("type") == "emoji":
|
||||
return {"type": "text", "data": "[动画表情]"}
|
||||
logger.debug(f"不处理类型: {seg_data.get('type')}")
|
||||
return seg_data
|
||||
|
||||
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[SegPayload | None, int]:
|
||||
# sourcery skip: low-code-quality
|
||||
"""
|
||||
递归处理实际转发消息
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段
|
||||
layer: int: 当前层级
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
image_count: int: 图片数量
|
||||
"""
|
||||
seg_list: List[SegPayload] = []
|
||||
image_count = 0
|
||||
if message_list is None:
|
||||
return None, 0
|
||||
for sub_message in message_list:
|
||||
sender_info: dict = sub_message.get("sender", {})
|
||||
user_nickname: str = sender_info.get("nickname", "QQ用户")
|
||||
user_nickname_str = f"【{user_nickname}】:"
|
||||
break_seg: SegPayload = {"type": "text", "data": "\n"}
|
||||
message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
|
||||
if not message_of_sub_message_list:
|
||||
logger.warning("转发消息内容为空")
|
||||
continue
|
||||
message_of_sub_message = message_of_sub_message_list[0]
|
||||
message_type = message_of_sub_message.get("type")
|
||||
if message_type == RealMessageType.forward:
|
||||
if layer >= 3:
|
||||
full_seg_data: SegPayload = {
|
||||
"type": "text",
|
||||
"data": ("--" * layer) + f"【{user_nickname}】:【转发消息】\n",
|
||||
}
|
||||
else:
|
||||
sub_message_data = message_of_sub_message.get("data")
|
||||
if not sub_message_data:
|
||||
continue
|
||||
contents = sub_message_data.get("content")
|
||||
seg_data, count = await self._handle_forward_message(contents, layer + 1)
|
||||
if seg_data is None:
|
||||
continue
|
||||
image_count += count
|
||||
head_tip: SegPayload = {
|
||||
"type": "text",
|
||||
"data": ("--" * layer) + f"【{user_nickname}】: 合并转发消息内容:\n",
|
||||
}
|
||||
full_seg_data = {"type": "seglist", "data": [head_tip, seg_data]}
|
||||
seg_list.append(full_seg_data)
|
||||
elif message_type == RealMessageType.text:
|
||||
sub_message_data = message_of_sub_message.get("data")
|
||||
if not sub_message_data:
|
||||
continue
|
||||
text_message = sub_message_data.get("text")
|
||||
seg_data: SegPayload = {"type": "text", "data": text_message}
|
||||
nickname_prefix = ("--" * layer) + user_nickname_str if layer > 0 else user_nickname_str
|
||||
data_list: List[SegPayload] = [
|
||||
{"type": "text", "data": nickname_prefix},
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
seg_list.append({"type": "seglist", "data": data_list})
|
||||
elif message_type == RealMessageType.image:
|
||||
image_count += 1
|
||||
image_data = message_of_sub_message.get("data", {})
|
||||
image_url = image_data.get("url")
|
||||
if not image_url:
|
||||
logger.warning("转发消息图片缺少URL")
|
||||
continue
|
||||
sub_type = image_data.get("sub_type")
|
||||
if sub_type == 0:
|
||||
seg_data = {"type": "image", "data": image_url}
|
||||
else:
|
||||
seg_data = {"type": "emoji", "data": image_url}
|
||||
nickname_prefix = ("--" * layer) + user_nickname_str if layer > 0 else user_nickname_str
|
||||
data_list = [
|
||||
{"type": "text", "data": nickname_prefix},
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
full_seg_data = {"type": "seglist", "data": data_list}
|
||||
seg_list.append(full_seg_data)
|
||||
return {"type": "seglist", "data": seg_list}, image_count
|
||||
|
||||
async def _handle_file_message(self, segment: dict) -> SegPayload | None:
|
||||
"""处理文件消息"""
|
||||
message_data = segment.get("data", {})
|
||||
if not message_data:
|
||||
logger.warning("文件消息缺少 data 字段")
|
||||
return None
|
||||
|
||||
# 提取文件信息
|
||||
file_name = message_data.get("file")
|
||||
file_size = message_data.get("file_size")
|
||||
file_id = message_data.get("file_id")
|
||||
|
||||
logger.info(f"收到文件消息: name={file_name}, size={file_size}, id={file_id}")
|
||||
|
||||
# 将文件信息打包成字典
|
||||
file_data = {
|
||||
"name": file_name,
|
||||
"size": file_size,
|
||||
"id": file_id,
|
||||
}
|
||||
|
||||
return {"type": "file", "data": file_data}
|
||||
|
||||
async def _handle_json_message(self, segment: dict) -> SegPayload | None:
|
||||
"""
|
||||
处理JSON消息
|
||||
Parameters:
|
||||
segment: dict: 消息段
|
||||
Returns:
|
||||
SegPayload | None: 处理后的消息段
|
||||
"""
|
||||
message_data = segment.get("data", {})
|
||||
json_data = message_data.get("data", "")
|
||||
|
||||
# 检查JSON消息格式
|
||||
if not message_data or "data" not in message_data:
|
||||
logger.warning("JSON消息格式不正确")
|
||||
return {"type": "json", "data": str(message_data)}
|
||||
|
||||
try:
|
||||
# 尝试将json_data解析为Python对象
|
||||
nested_data = orjson.loads(json_data)
|
||||
|
||||
# 检查是否是机器人自己上传文件的回声
|
||||
if self._is_file_upload_echo(nested_data):
|
||||
logger.info("检测到机器人发送文件的回声消息,将作为文件消息处理")
|
||||
# 从回声消息中提取文件信息
|
||||
file_info = self._extract_file_info_from_echo(nested_data)
|
||||
if file_info:
|
||||
return {"type": "file", "data": file_info}
|
||||
|
||||
# 检查是否是QQ小程序分享消息
|
||||
if "app" in nested_data and "com.tencent.miniapp" in str(nested_data.get("app", "")):
|
||||
logger.debug("检测到QQ小程序分享消息,开始提取信息")
|
||||
|
||||
# 提取目标字段
|
||||
extracted_info = {}
|
||||
|
||||
# 提取 meta.detail_1 中的信息
|
||||
meta = nested_data.get("meta", {})
|
||||
detail_1 = meta.get("detail_1", {})
|
||||
|
||||
if detail_1:
|
||||
extracted_info["title"] = detail_1.get("title", "")
|
||||
extracted_info["desc"] = detail_1.get("desc", "")
|
||||
qqdocurl = detail_1.get("qqdocurl", "")
|
||||
|
||||
# 从qqdocurl中提取b23.tv短链接
|
||||
if qqdocurl and "b23.tv" in qqdocurl:
|
||||
# 查找b23.tv链接的起始位置
|
||||
start_pos = qqdocurl.find("https://b23.tv/")
|
||||
if start_pos != -1:
|
||||
# 提取从https://b23.tv/开始的部分
|
||||
b23_part = qqdocurl[start_pos:]
|
||||
# 查找第一个?的位置,截取到?之前
|
||||
question_pos = b23_part.find("?")
|
||||
if question_pos != -1:
|
||||
extracted_info["short_url"] = b23_part[:question_pos]
|
||||
else:
|
||||
extracted_info["short_url"] = b23_part
|
||||
else:
|
||||
extracted_info["short_url"] = qqdocurl
|
||||
else:
|
||||
extracted_info["short_url"] = qqdocurl
|
||||
|
||||
# 如果成功提取到关键信息,返回格式化的文本
|
||||
if extracted_info.get("title") or extracted_info.get("desc") or extracted_info.get("short_url"):
|
||||
content_parts = []
|
||||
|
||||
if extracted_info.get("title"):
|
||||
content_parts.append(f"来源: {extracted_info['title']}")
|
||||
|
||||
if extracted_info.get("desc"):
|
||||
content_parts.append(f"标题: {extracted_info['desc']}")
|
||||
|
||||
if extracted_info.get("short_url"):
|
||||
content_parts.append(f"链接: {extracted_info['short_url']}")
|
||||
|
||||
formatted_content = "\n".join(content_parts)
|
||||
return{
|
||||
"type": "text",
|
||||
"data": f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}",
|
||||
}
|
||||
|
||||
|
||||
|
||||
# 检查是否是音乐分享 (QQ音乐类型)
|
||||
if nested_data.get("view") == "music" and "com.tencent.music" in str(nested_data.get("app", "")):
|
||||
meta = nested_data.get("meta", {})
|
||||
music = meta.get("music", {})
|
||||
if music:
|
||||
tag = music.get("tag", "未知来源")
|
||||
logger.debug(f"检测到【{tag}】音乐分享消息 (music view),开始提取信息")
|
||||
|
||||
title = music.get("title", "未知歌曲")
|
||||
desc = music.get("desc", "未知艺术家")
|
||||
jump_url = music.get("jumpUrl", "")
|
||||
preview_url = music.get("preview", "")
|
||||
|
||||
artist = "未知艺术家"
|
||||
song_title = title
|
||||
|
||||
if "网易云音乐" in tag:
|
||||
artist = desc
|
||||
elif "QQ音乐" in tag:
|
||||
if " - " in title:
|
||||
parts = title.split(" - ", 1)
|
||||
song_title = parts[0]
|
||||
artist = parts[1]
|
||||
else:
|
||||
artist = desc
|
||||
|
||||
formatted_content = (
|
||||
f"这是一张来自【{tag}】的音乐分享卡片:\n"
|
||||
f"歌曲: {song_title}\n"
|
||||
f"艺术家: {artist}\n"
|
||||
f"跳转链接: {jump_url}\n"
|
||||
f"封面图: {preview_url}"
|
||||
)
|
||||
return {"type": "text", "data": formatted_content}
|
||||
|
||||
# 检查是否是新闻/图文分享 (网易云音乐可能伪装成这种)
|
||||
elif nested_data.get("view") == "news" and "com.tencent.tuwen" in str(nested_data.get("app", "")):
|
||||
meta = nested_data.get("meta", {})
|
||||
news = meta.get("news", {})
|
||||
if news and "网易云音乐" in news.get("tag", ""):
|
||||
tag = news.get("tag")
|
||||
logger.debug(f"检测到【{tag}】音乐分享消息 (news view),开始提取信息")
|
||||
|
||||
title = news.get("title", "未知歌曲")
|
||||
desc = news.get("desc", "未知艺术家")
|
||||
jump_url = news.get("jumpUrl", "")
|
||||
preview_url = news.get("preview", "")
|
||||
|
||||
formatted_content = (
|
||||
f"这是一张来自【{tag}】的音乐分享卡片:\n"
|
||||
f"标题: {title}\n"
|
||||
f"描述: {desc}\n"
|
||||
f"跳转链接: {jump_url}\n"
|
||||
f"封面图: {preview_url}"
|
||||
)
|
||||
return {"type": "text", "data": formatted_content}
|
||||
|
||||
# 如果没有提取到关键信息,返回None
|
||||
return None
|
||||
|
||||
except orjson.JSONDecodeError:
|
||||
# 如果解析失败,我们假设它不是我们关心的任何一种结构化JSON,
|
||||
# 而是普通的文本或者无法解析的格式。
|
||||
logger.debug(f"无法将data字段解析为JSON: {json_data}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"处理JSON消息时发生未知错误: {e}")
|
||||
return None
|
||||
|
||||
def _is_file_upload_echo(self, nested_data: Any) -> bool:
|
||||
"""检查一个JSON对象是否是机器人自己上传文件的回声消息"""
|
||||
if not isinstance(nested_data, dict):
|
||||
return False
|
||||
|
||||
# 检查 'app' 和 'meta' 字段是否存在
|
||||
if "app" not in nested_data or "meta" not in nested_data:
|
||||
return False
|
||||
|
||||
# 检查 'app' 字段是否包含 'com.tencent.miniapp'
|
||||
if "com.tencent.miniapp" not in str(nested_data.get("app", "")):
|
||||
return False
|
||||
|
||||
# 检查 'meta' 内部的 'detail_1' 的 'busi_id' 是否为 '1014'
|
||||
meta = nested_data.get("meta", {})
|
||||
detail_1 = meta.get("detail_1", {})
|
||||
if detail_1.get("busi_id") == "1014":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_file_info_from_echo(self, nested_data: dict) -> Optional[dict]:
|
||||
"""从文件上传的回声消息中提取文件信息"""
|
||||
try:
|
||||
meta = nested_data.get("meta", {})
|
||||
detail_1 = meta.get("detail_1", {})
|
||||
|
||||
# 文件名在 'desc' 字段
|
||||
file_name = detail_1.get("desc")
|
||||
|
||||
# 文件大小在 'summary' 字段,格式为 "大小:1.7MB"
|
||||
summary = detail_1.get("summary", "")
|
||||
file_size_str = summary.replace("大小:", "").strip() # 移除前缀和空格
|
||||
|
||||
# QQ API有时返回的大小不标准,这里我们只提取它给的字符串
|
||||
# 实际大小已经由Napcat在发送时记录,这里主要是为了保持格式一致
|
||||
|
||||
if file_name and file_size_str:
|
||||
return {"file": file_name, "file_size": file_size_str, "file_id": None} # file_id在回声中不可用
|
||||
except Exception as e:
|
||||
logger.error(f"从文件回声中提取信息失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""元事件处理器"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ...event_models import MetaEventType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....plugin import NapcatAdapter
|
||||
|
||||
logger = get_logger("napcat_adapter.meta_event_handler")
|
||||
|
||||
|
||||
class MetaEventHandler:
|
||||
"""处理 Napcat 元事件(心跳、生命周期)"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
self._interval_checking = False
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_meta_event(self, raw: Dict[str, Any]):
|
||||
event_type = raw.get("meta_event_type")
|
||||
if event_type == MetaEventType.lifecycle:
|
||||
sub_type = raw.get("sub_type")
|
||||
if sub_type == MetaEventType.Lifecycle.connect:
|
||||
self_id = raw.get("self_id")
|
||||
self.last_heart_beat = time.time()
|
||||
logger.info(f"Bot {self_id} 连接成功")
|
||||
# 不在连接时立即启动心跳检查,等第一个心跳包到达后再启动
|
||||
elif event_type == MetaEventType.heartbeat:
|
||||
if raw["status"].get("online") and raw["status"].get("good"):
|
||||
self_id = raw.get("self_id")
|
||||
if not self._interval_checking and self_id:
|
||||
# 第一次收到心跳包时才启动心跳检查
|
||||
asyncio.create_task(self.check_heartbeat(self_id))
|
||||
self.last_heart_beat = time.time()
|
||||
interval = raw.get("interval")
|
||||
if interval:
|
||||
self.interval = interval / 1000
|
||||
else:
|
||||
self_id = raw.get("self_id")
|
||||
logger.warning(f"Bot {self_id} Napcat 端异常!")
|
||||
|
||||
async def check_heartbeat(self, id: int) -> None:
|
||||
self._interval_checking = True
|
||||
while True:
|
||||
now_time = time.time()
|
||||
if now_time - self.last_heart_beat > self.interval * 2:
|
||||
logger.error(f"Bot {id} 可能发生了连接断开,被下线,或者Napcat卡死!")
|
||||
break
|
||||
await asyncio.sleep(self.interval)
|
||||
@@ -0,0 +1,579 @@
|
||||
"""发送处理器 - 将 MessageEnvelope 转换并发送到 Napcat"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from mofox_wire import MessageEnvelope, SegPayload, GroupInfoPayload, UserInfoPayload, MessageInfoPayload
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
from ...event_models import CommandType
|
||||
from ..utils import convert_image_to_gif, get_image_format
|
||||
|
||||
logger = get_logger("napcat_adapter.send_handler")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....plugin import NapcatAdapter
|
||||
|
||||
|
||||
class SendHandler:
|
||||
"""负责向 Napcat 发送消息"""
|
||||
|
||||
def __init__(self, adapter: "NapcatAdapter"):
|
||||
self.adapter = adapter
|
||||
self.plugin_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = config
|
||||
|
||||
async def handle_message(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
处理来自核心的消息,将其转换为 Napcat 可接受的格式并发送
|
||||
"""
|
||||
logger.info("接收到来自MoFox-Bot的消息,处理中")
|
||||
|
||||
if not envelope:
|
||||
logger.warning("空的消息,跳过处理")
|
||||
return
|
||||
|
||||
message_segment = envelope.get("message_segment")
|
||||
if isinstance(message_segment, list):
|
||||
segment: SegPayload = {"type": "seglist", "data": message_segment}
|
||||
else:
|
||||
segment = message_segment or {}
|
||||
|
||||
if segment:
|
||||
seg_type = segment.get("type")
|
||||
|
||||
if seg_type == "command":
|
||||
logger.info("处理命令")
|
||||
return await self.send_command(envelope)
|
||||
if seg_type == "adapter_command":
|
||||
logger.info("处理适配器命令")
|
||||
return await self.handle_adapter_command(envelope)
|
||||
if seg_type == "adapter_response":
|
||||
logger.info("收到adapter_response消息,此消息应该由Bot端处理,跳过")
|
||||
return None
|
||||
|
||||
logger.info("处理普通消息")
|
||||
return await self.send_normal_message(envelope)
|
||||
|
||||
async def send_normal_message(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
处理普通消息发送
|
||||
"""
|
||||
logger.info("处理普通信息中")
|
||||
message_info: MessageInfoPayload = envelope.get("message_info", {})
|
||||
message_segment: SegPayload = envelope.get("message_segment", {}) # type: ignore[assignment]
|
||||
|
||||
if isinstance(message_segment, list):
|
||||
seg_data: SegPayload = {"type": "seglist", "data": message_segment}
|
||||
else:
|
||||
seg_data = message_segment
|
||||
|
||||
group_info: Optional[GroupInfoPayload] = message_info.get("group_info")
|
||||
user_info: Optional[UserInfoPayload] = message_info.get("user_info")
|
||||
target_id: Optional[int] = None
|
||||
action: Optional[str] = None
|
||||
id_name: Optional[str] = None
|
||||
processed_message: list = []
|
||||
try:
|
||||
processed_message = await self.handle_seg_recursive(seg_data, user_info or {})
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时发生错误: {e}")
|
||||
return None
|
||||
|
||||
if not processed_message:
|
||||
logger.critical("现在暂时不支持解析此回复!")
|
||||
return None
|
||||
|
||||
if group_info and group_info.get("group_id"):
|
||||
logger.debug("发送群聊消息")
|
||||
target_id = int(group_info["group_id"])
|
||||
action = "send_group_msg"
|
||||
id_name = "group_id"
|
||||
elif user_info and user_info.get("user_id"):
|
||||
logger.debug("发送私聊消息")
|
||||
target_id = int(user_info["user_id"])
|
||||
action = "send_private_msg"
|
||||
id_name = "user_id"
|
||||
else:
|
||||
logger.error("无法识别的消息类型")
|
||||
return
|
||||
logger.info("尝试发送到napcat")
|
||||
logger.debug(
|
||||
f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'"
|
||||
)
|
||||
response = await self.send_message_to_napcat(
|
||||
action or "",
|
||||
{
|
||||
id_name or "target_id": target_id,
|
||||
"message": processed_message,
|
||||
},
|
||||
)
|
||||
if response.get("status") == "ok":
|
||||
logger.info("消息发送成功")
|
||||
else:
|
||||
logger.warning(f"消息发送失败,napcat返回:{str(response)}")
|
||||
|
||||
async def send_command(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
处理命令类
|
||||
"""
|
||||
logger.info("处理命令中")
|
||||
message_info: Dict[str, Any] = envelope.get("message_info", {})
|
||||
group_info: Optional[Dict[str, Any]] = message_info.get("group_info")
|
||||
segment: SegPayload = envelope.get("message_segment", {}) # type: ignore[assignment]
|
||||
seg_data: Dict[str, Any] = segment.get("data", {}) if isinstance(segment, dict) else {}
|
||||
command_name: Optional[str] = seg_data.get("name")
|
||||
try:
|
||||
args = seg_data.get("args", {})
|
||||
if not isinstance(args, dict):
|
||||
args = {}
|
||||
|
||||
if command_name == CommandType.GROUP_BAN.name:
|
||||
command, args_dict = self.handle_ban_command(args, group_info)
|
||||
elif command_name == CommandType.GROUP_WHOLE_BAN.name:
|
||||
command, args_dict = self.handle_whole_ban_command(args, group_info)
|
||||
elif command_name == CommandType.GROUP_KICK.name:
|
||||
command, args_dict = self.handle_kick_command(args, group_info)
|
||||
elif command_name == CommandType.SEND_POKE.name:
|
||||
command, args_dict = self.handle_poke_command(args, group_info)
|
||||
elif command_name == CommandType.DELETE_MSG.name:
|
||||
command, args_dict = self.delete_msg_command(args)
|
||||
elif command_name == CommandType.AI_VOICE_SEND.name:
|
||||
command, args_dict = self.handle_ai_voice_send_command(args, group_info)
|
||||
elif command_name == CommandType.SET_EMOJI_LIKE.name:
|
||||
command, args_dict = self.handle_set_emoji_like_command(args)
|
||||
elif command_name == CommandType.SEND_AT_MESSAGE.name:
|
||||
command, args_dict = self.handle_at_message_command(args, group_info)
|
||||
elif command_name == CommandType.SEND_LIKE.name:
|
||||
command, args_dict = self.handle_send_like_command(args)
|
||||
else:
|
||||
logger.error(f"未知命令: {command_name}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时发生错误: {e}")
|
||||
return None
|
||||
|
||||
if not command or not args_dict:
|
||||
logger.error("命令或参数缺失")
|
||||
return None
|
||||
|
||||
logger.info(f"准备向 Napcat 发送命令: command='{command}', args_dict='{args_dict}'")
|
||||
response = await self.send_message_to_napcat(command, args_dict)
|
||||
logger.info(f"收到 Napcat 的命令响应: {response}")
|
||||
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"命令 {command_name} 执行成功")
|
||||
else:
|
||||
logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}")
|
||||
|
||||
async def handle_adapter_command(self, envelope: MessageEnvelope) -> None:
|
||||
"""
|
||||
处理适配器命令类 - 用于直接向Napcat发送命令并返回结果
|
||||
"""
|
||||
logger.info("处理适配器命令中")
|
||||
segment: SegPayload = envelope.get("message_segment", {}) # type: ignore[assignment]
|
||||
seg_data: Dict[str, Any] = segment.get("data", {}) if isinstance(segment, dict) else {}
|
||||
|
||||
try:
|
||||
action = seg_data.get("action")
|
||||
params = seg_data.get("params", {})
|
||||
request_id = seg_data.get("request_id")
|
||||
timeout = float(seg_data.get("timeout", 20.0))
|
||||
|
||||
if not action:
|
||||
logger.error("适配器命令缺少action参数")
|
||||
return
|
||||
|
||||
logger.info(f"执行适配器命令: {action}")
|
||||
|
||||
if action == "get_cookies":
|
||||
response = await self.send_message_to_napcat(action, params, timeout=40.0)
|
||||
else:
|
||||
response = await self.send_message_to_napcat(action, params, timeout=timeout)
|
||||
|
||||
try:
|
||||
from src.plugin_system.apis.send_api import put_adapter_response
|
||||
|
||||
if request_id:
|
||||
put_adapter_response(str(request_id), response)
|
||||
except Exception as e:
|
||||
logger.debug(f"回填 adapter 响应失败: {e}")
|
||||
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"适配器命令 {action} 执行成功")
|
||||
else:
|
||||
logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{str(response)}")
|
||||
logger.debug(f"适配器命令 {action} 的完整响应: {response}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器命令时发生错误: {e}")
|
||||
|
||||
def get_level(self, seg_data: SegPayload) -> int:
|
||||
if seg_data.get("type") == "seglist":
|
||||
return 1 + max(self.get_level(seg) for seg in seg_data.get("data", []) if isinstance(seg, dict))
|
||||
return 1
|
||||
|
||||
async def handle_seg_recursive(self, seg_data: SegPayload, user_info: UserInfoPayload) -> list:
|
||||
payload: list = []
|
||||
if seg_data.get("type") == "seglist":
|
||||
if not seg_data.get("data"):
|
||||
return []
|
||||
for seg in seg_data["data"]:
|
||||
if not isinstance(seg, dict):
|
||||
continue
|
||||
payload = await self.process_message_by_type(seg, payload, user_info)
|
||||
else:
|
||||
payload = await self.process_message_by_type(seg_data, payload, user_info)
|
||||
return payload
|
||||
|
||||
async def process_message_by_type(self, seg: SegPayload, payload: list, user_info: UserInfoPayload) -> list:
|
||||
new_payload = payload
|
||||
seg_type = seg.get("type")
|
||||
if seg_type == "reply":
|
||||
target_id = seg.get("data")
|
||||
target_id = str(target_id)
|
||||
if target_id == "notice":
|
||||
return payload
|
||||
logger.info(target_id if isinstance(target_id, str) else "")
|
||||
new_payload = self.build_payload(payload, await self.handle_reply_message(target_id, user_info), True)
|
||||
elif seg_type == "text":
|
||||
text = seg.get("data")
|
||||
if not text:
|
||||
return payload
|
||||
new_payload = self.build_payload(payload, self.handle_text_message(str(text)), False)
|
||||
elif seg_type == "face":
|
||||
logger.warning("MoFox-Bot 发送了qq原生表情,暂时不支持")
|
||||
elif seg_type == "image":
|
||||
image = seg.get("data")
|
||||
new_payload = self.build_payload(payload, self.handle_image_message(str(image)), False)
|
||||
elif seg_type == "emoji":
|
||||
emoji = seg.get("data")
|
||||
new_payload = self.build_payload(payload, self.handle_emoji_message(str(emoji)), False)
|
||||
elif seg_type == "voice":
|
||||
voice = seg.get("data")
|
||||
new_payload = self.build_payload(payload, self.handle_voice_message(str(voice)), False)
|
||||
elif seg_type == "voiceurl":
|
||||
voice_url = seg.get("data")
|
||||
new_payload = self.build_payload(payload, self.handle_voiceurl_message(str(voice_url)), False)
|
||||
elif seg_type == "music":
|
||||
song_id = seg.get("data")
|
||||
new_payload = self.build_payload(payload, self.handle_music_message(str(song_id)), False)
|
||||
elif seg_type == "videourl":
|
||||
video_url = seg.get("data")
|
||||
new_payload = self.build_payload(payload, self.handle_videourl_message(str(video_url)), False)
|
||||
elif seg_type == "file":
|
||||
file_path = seg.get("data")
|
||||
new_payload = self.build_payload(payload, self.handle_file_message(str(file_path)), False)
|
||||
elif seg_type == "seglist":
|
||||
# 嵌套列表继续递归
|
||||
nested_payload: list = []
|
||||
for sub_seg in seg.get("data", []):
|
||||
if not isinstance(sub_seg, dict):
|
||||
continue
|
||||
nested_payload = await self.process_message_by_type(sub_seg, nested_payload, user_info)
|
||||
new_payload = self.build_payload(payload, nested_payload, False)
|
||||
return new_payload
|
||||
|
||||
def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list:
|
||||
"""构建发送的消息体"""
|
||||
if is_reply:
|
||||
temp_list = []
|
||||
if isinstance(addon, list):
|
||||
temp_list.extend(addon)
|
||||
else:
|
||||
temp_list.append(addon)
|
||||
for i in payload:
|
||||
if isinstance(i, dict) and i.get("type") == "reply":
|
||||
logger.debug("检测到多个回复,使用最新的回复")
|
||||
continue
|
||||
temp_list.append(i)
|
||||
return temp_list
|
||||
|
||||
if isinstance(addon, list):
|
||||
payload.extend(addon)
|
||||
else:
|
||||
payload.append(addon)
|
||||
return payload
|
||||
|
||||
async def handle_reply_message(self, message_id: str, user_info: UserInfoPayload) -> dict | list:
|
||||
"""处理回复消息"""
|
||||
logger.debug(f"开始处理回复消息,消息ID: {message_id}")
|
||||
reply_seg = {"type": "reply", "data": {"id": message_id}}
|
||||
|
||||
# 检查是否启用引用艾特功能
|
||||
if not config_api.get_plugin_config(self.plugin_config, "features.enable_reply_at", False):
|
||||
logger.info("引用艾特功能未启用,仅发送普通回复")
|
||||
return reply_seg
|
||||
|
||||
try:
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": message_id})
|
||||
logger.debug(f"获取消息 {message_id} 的详情响应: {msg_info_response}")
|
||||
|
||||
replied_user_id = None
|
||||
if msg_info_response and msg_info_response.get("status") == "ok":
|
||||
sender_info = msg_info_response.get("data", {}).get("sender")
|
||||
if sender_info:
|
||||
replied_user_id = sender_info.get("user_id")
|
||||
|
||||
if not replied_user_id:
|
||||
logger.warning(f"无法获取消息 {message_id} 的发送者信息,跳过 @")
|
||||
logger.info(f"最终返回的回复段: {reply_seg}")
|
||||
return reply_seg
|
||||
|
||||
if random.random() < config_api.get_plugin_config(self.plugin_config, "features.reply_at_rate", 0.5):
|
||||
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
|
||||
text_seg = {"type": "text", "data": {"text": " "}}
|
||||
result_seg = [reply_seg, at_seg, text_seg]
|
||||
logger.info(f"最终返回的回复段: {result_seg}")
|
||||
return result_seg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理引用回复并尝试@时出错: {e}")
|
||||
logger.info(f"最终返回的回复段: {reply_seg}")
|
||||
return reply_seg
|
||||
|
||||
logger.info(f"最终返回的回复段: {reply_seg}")
|
||||
return reply_seg
|
||||
|
||||
def handle_text_message(self, message: str) -> dict:
|
||||
"""处理文本消息"""
|
||||
return {"type": "text", "data": {"text": message}}
|
||||
|
||||
def handle_image_message(self, encoded_image: str) -> dict:
|
||||
"""处理图片消息"""
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"file": f"base64://{encoded_image}",
|
||||
"subtype": 0,
|
||||
},
|
||||
}
|
||||
|
||||
def handle_emoji_message(self, encoded_emoji: str) -> dict:
|
||||
"""处理表情消息"""
|
||||
encoded_image = encoded_emoji
|
||||
image_format = get_image_format(encoded_emoji)
|
||||
if image_format != "gif":
|
||||
encoded_image = convert_image_to_gif(encoded_emoji)
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"file": f"base64://{encoded_image}",
|
||||
"subtype": 1,
|
||||
"summary": "[动画表情]",
|
||||
},
|
||||
}
|
||||
|
||||
def handle_voice_message(self, encoded_voice: str) -> dict:
|
||||
"""处理语音消息"""
|
||||
use_tts = False
|
||||
if self.plugin_config:
|
||||
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
|
||||
|
||||
if not use_tts:
|
||||
logger.warning("未启用语音消息处理")
|
||||
return {}
|
||||
if not encoded_voice:
|
||||
return {}
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": f"base64://{encoded_voice}"},
|
||||
}
|
||||
|
||||
def handle_voiceurl_message(self, voice_url: str) -> dict:
|
||||
"""处理语音链接消息"""
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": voice_url},
|
||||
}
|
||||
|
||||
def handle_music_message(self, song_id: str) -> dict:
|
||||
"""处理音乐消息"""
|
||||
return {
|
||||
"type": "music",
|
||||
"data": {"type": "163", "id": song_id},
|
||||
}
|
||||
|
||||
def handle_videourl_message(self, video_url: str) -> dict:
|
||||
"""处理视频链接消息"""
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {"file": video_url},
|
||||
}
|
||||
|
||||
def handle_file_message(self, file_path: str) -> dict:
|
||||
"""处理文件消息"""
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {"file": f"file://{file_path}"},
|
||||
}
|
||||
|
||||
def delete_msg_command(self, args: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
|
||||
"""处理删除消息命令"""
|
||||
return "delete_msg", {"message_id": args["message_id"]}
|
||||
|
||||
def handle_ban_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]:
|
||||
"""处理封禁命令"""
|
||||
duration: int = int(args["duration"])
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: int = int(group_info["group_id"]) if group_info and group_info.get("group_id") else 0
|
||||
if duration < 0:
|
||||
raise ValueError("封禁时间必须大于等于0")
|
||||
if not user_id or not group_id:
|
||||
raise ValueError("封禁命令缺少必要参数")
|
||||
if duration > 2592000:
|
||||
raise ValueError("封禁时间不能超过30天")
|
||||
return (
|
||||
CommandType.GROUP_BAN.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"duration": duration,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]:
|
||||
"""处理全体禁言命令"""
|
||||
enable = args["enable"]
|
||||
assert isinstance(enable, bool), "enable参数必须是布尔值"
|
||||
group_id: int = int(group_info["group_id"]) if group_info and group_info.get("group_id") else 0
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
return (
|
||||
CommandType.GROUP_WHOLE_BAN.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"enable": enable,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_kick_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]:
|
||||
"""处理群成员踢出命令"""
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: int = int(group_info["group_id"]) if group_info and group_info.get("group_id") else 0
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
if user_id <= 0:
|
||||
raise ValueError("用户ID无效")
|
||||
return (
|
||||
CommandType.GROUP_KICK.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"reject_add_request": False,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_poke_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]:
|
||||
"""处理戳一戳命令"""
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: Optional[int] = None
|
||||
if group_info and group_info.get("group_id"):
|
||||
group_id = int(group_info["group_id"])
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
if user_id <= 0:
|
||||
raise ValueError("用户ID无效")
|
||||
return (
|
||||
CommandType.SEND_POKE.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
|
||||
"""处理设置表情回应命令"""
|
||||
logger.info(f"开始处理表情回应命令, 接收到参数: {args}")
|
||||
try:
|
||||
message_id = int(args["message_id"])
|
||||
emoji_id = int(args["emoji_id"])
|
||||
set_like = bool(args["set"])
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.error(f"处理表情回应命令时发生错误: {e}, 原始参数: {args}")
|
||||
raise ValueError(f"缺少必需参数或参数类型错误: {e}")
|
||||
|
||||
return (
|
||||
CommandType.SET_EMOJI_LIKE.value,
|
||||
{"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
|
||||
)
|
||||
|
||||
def handle_send_like_command(self, args: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
|
||||
"""处理发送点赞命令的逻辑。"""
|
||||
try:
|
||||
user_id: int = int(args["qq_id"])
|
||||
times: int = int(args["times"])
|
||||
except (KeyError, ValueError):
|
||||
raise ValueError("缺少必需参数: qq_id 或 times")
|
||||
|
||||
return (
|
||||
CommandType.SEND_LIKE.value,
|
||||
{"user_id": user_id, "times": times},
|
||||
)
|
||||
|
||||
def handle_at_message_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]:
|
||||
"""处理艾特并发送消息命令"""
|
||||
at_user_id = args.get("qq_id")
|
||||
text = args.get("text")
|
||||
|
||||
if not at_user_id or not text:
|
||||
raise ValueError("艾特消息命令缺少 qq_id 或 text 参数")
|
||||
|
||||
if not group_info or not group_info.get("group_id"):
|
||||
raise ValueError("艾特消息命令必须在群聊上下文中使用")
|
||||
|
||||
message_payload = [
|
||||
{"type": "at", "data": {"qq": str(at_user_id)}},
|
||||
{"type": "text", "data": {"text": " " + str(text)}},
|
||||
]
|
||||
|
||||
return (
|
||||
"send_group_msg",
|
||||
{
|
||||
"group_id": group_info["group_id"],
|
||||
"message": message_payload,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: Optional[Dict[str, Any]]) -> tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理AI语音发送命令的逻辑。
|
||||
并返回 NapCat 兼容的 (action, params) 元组。
|
||||
"""
|
||||
if not group_info or not group_info.get("group_id"):
|
||||
raise ValueError("AI语音发送命令必须在群聊上下文中使用")
|
||||
if not args:
|
||||
raise ValueError("AI语音发送命令缺少参数")
|
||||
|
||||
group_id: int = int(group_info["group_id"])
|
||||
character_id = args.get("character")
|
||||
text_content = args.get("text")
|
||||
|
||||
if not character_id or not text_content:
|
||||
raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'")
|
||||
|
||||
return (
|
||||
CommandType.AI_VOICE_SEND.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"text": text_content,
|
||||
"character": character_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict:
|
||||
"""通过 adapter API 发送到 napcat"""
|
||||
try:
|
||||
response = await self.adapter.send_napcat_api(action, params, timeout=timeout)
|
||||
return response or {"status": "error", "message": "no response"}
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import io
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
@@ -97,7 +98,9 @@ def _get_adapter(adapter: "NapcatAdapter | None" = None) -> "NapcatAdapter":
|
||||
if target is None and _adapter_ref:
|
||||
target = _adapter_ref()
|
||||
if target is None:
|
||||
raise RuntimeError("NapcatAdapter 未注册,请确保已调用 utils.register_adapter 注册")
|
||||
raise RuntimeError(
|
||||
"NapcatAdapter 未注册,请确保已调用 utils.register_adapter 注册"
|
||||
)
|
||||
return target
|
||||
|
||||
|
||||
@@ -136,6 +139,14 @@ class SSLAdapter(urllib3.PoolManager):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
async def get_respose(
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
adapter: "NapcatAdapter | None" = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
return await _call_adapter_api(action, params, adapter=adapter, timeout=timeout)
|
||||
|
||||
async def get_group_info(
|
||||
group_id: int,
|
||||
*,
|
||||
@@ -317,7 +328,9 @@ async def get_stranger_info(
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
response = await _call_adapter_api("get_stranger_info", {"user_id": user_id}, adapter=adapter)
|
||||
response = await _call_adapter_api(
|
||||
"get_stranger_info", {"user_id": user_id}, adapter=adapter
|
||||
)
|
||||
data = response.get("data") if response else None
|
||||
if data is not None and use_cache:
|
||||
await _set_cached("stranger_info", cache_key, data)
|
||||
@@ -359,3 +372,40 @@ async def get_record_detail(
|
||||
timeout=30,
|
||||
)
|
||||
return response.get("data") if response else None
|
||||
|
||||
|
||||
async def get_forward_message(
|
||||
raw_message: dict, *, adapter: "NapcatAdapter | None" = None
|
||||
) -> dict[str, Any] | None:
|
||||
forward_message_data: dict = raw_message.get("data", {})
|
||||
if not forward_message_data:
|
||||
logger.warning("转发消息内容为空")
|
||||
return None
|
||||
forward_message_id = forward_message_data.get("id")
|
||||
|
||||
try:
|
||||
response = await _call_adapter_api(
|
||||
"get_forward_msg",
|
||||
{"message_id": forward_message_id},
|
||||
timeout=10.0,
|
||||
adapter=adapter,
|
||||
)
|
||||
if response is None:
|
||||
logger.error("获取转发消息失败,返回值为空")
|
||||
return None
|
||||
except TimeoutError:
|
||||
logger.error("获取转发消息超时")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取转发消息失败: {str(e)}")
|
||||
return None
|
||||
logger.debug(
|
||||
f"转发消息原始格式:{orjson.dumps(response).decode('utf-8')[:80]}..."
|
||||
if len(orjson.dumps(response).decode("utf-8")) > 80
|
||||
else orjson.dumps(response).decode("utf-8")
|
||||
)
|
||||
response_data: Dict = response.get("data")
|
||||
if not response_data:
|
||||
logger.warning("转发消息内容为空或获取失败")
|
||||
return None
|
||||
return response_data.get("messages")
|
||||
@@ -0,0 +1,179 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
视频下载和处理模块
|
||||
用于从QQ消息中下载视频并转发给Bot进行分析
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("video_handler")
|
||||
|
||||
|
||||
class VideoDownloader:
|
||||
def __init__(self, max_size_mb: int = 100, download_timeout: int = 60):
|
||||
self.max_size_mb = max_size_mb
|
||||
self.download_timeout = download_timeout
|
||||
self.supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
|
||||
|
||||
def is_video_url(self, url: str) -> bool:
|
||||
"""检查URL是否为视频文件"""
|
||||
try:
|
||||
# QQ视频URL可能没有扩展名,所以先检查Content-Type
|
||||
# 对于QQ视频,我们先假设是视频,稍后通过Content-Type验证
|
||||
|
||||
# 检查URL中是否包含视频相关的关键字
|
||||
video_keywords = ["video", "mp4", "avi", "mov", "mkv", "flv", "wmv", "webm", "m4v"]
|
||||
url_lower = url.lower()
|
||||
|
||||
# 如果URL包含视频关键字,认为是视频
|
||||
if any(keyword in url_lower for keyword in video_keywords):
|
||||
return True
|
||||
|
||||
# 检查文件扩展名(传统方法)
|
||||
path = Path(url.split("?")[0]) # 移除查询参数
|
||||
if path.suffix.lower() in self.supported_formats:
|
||||
return True
|
||||
|
||||
# 对于QQ等特殊平台,URL可能没有扩展名
|
||||
# 我们允许这些URL通过,稍后通过HTTP头Content-Type验证
|
||||
qq_domains = ["qpic.cn", "gtimg.cn", "qq.com", "tencent.com"]
|
||||
if any(domain in url_lower for domain in qq_domains):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
# 如果解析失败,默认允许尝试下载(稍后验证)
|
||||
return True
|
||||
|
||||
def check_file_size(self, content_length: Optional[str]) -> bool:
|
||||
"""检查文件大小是否在允许范围内"""
|
||||
if content_length is None:
|
||||
return True # 无法获取大小时允许下载
|
||||
|
||||
try:
|
||||
size_bytes = int(content_length)
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
return size_mb <= self.max_size_mb
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
async def download_video(self, url: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
下载视频文件
|
||||
|
||||
Args:
|
||||
url: 视频URL
|
||||
filename: 可选的文件名
|
||||
|
||||
Returns:
|
||||
dict: 下载结果,包含success、data、filename、error等字段
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始下载视频: {url}")
|
||||
|
||||
# 检查URL格式
|
||||
if not self.is_video_url(url):
|
||||
logger.warning(f"URL格式检查失败: {url}")
|
||||
return {"success": False, "error": "不支持的视频格式", "url": url}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 先发送HEAD请求检查文件大小
|
||||
try:
|
||||
async with session.head(url, timeout=aiohttp.ClientTimeout(total=10)) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"HEAD请求失败,状态码: {response.status}")
|
||||
else:
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if not self.check_file_size(content_length):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"视频文件过大,超过{self.max_size_mb}MB限制",
|
||||
"url": url,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"HEAD请求失败: {e},继续尝试下载")
|
||||
|
||||
# 下载文件
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=self.download_timeout)) as response:
|
||||
if response.status != 200:
|
||||
return {"success": False, "error": f"下载失败,HTTP状态码: {response.status}", "url": url}
|
||||
|
||||
# 检查Content-Type是否为视频
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
if content_type:
|
||||
# 检查是否为视频类型
|
||||
video_mime_types = [
|
||||
"video/",
|
||||
"application/octet-stream",
|
||||
"application/x-msvideo",
|
||||
"video/x-msvideo",
|
||||
]
|
||||
is_video_content = any(mime in content_type for mime in video_mime_types)
|
||||
|
||||
if not is_video_content:
|
||||
logger.warning(f"Content-Type不是视频格式: {content_type}")
|
||||
# 如果不是明确的视频类型,但可能是QQ的特殊格式,继续尝试
|
||||
if "text/" in content_type or "application/json" in content_type:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"URL返回的不是视频内容,Content-Type: {content_type}",
|
||||
"url": url,
|
||||
}
|
||||
|
||||
# 再次检查Content-Length
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if not self.check_file_size(content_length):
|
||||
return {"success": False, "error": f"视频文件过大,超过{self.max_size_mb}MB限制", "url": url}
|
||||
|
||||
# 读取文件内容
|
||||
video_data = await response.read()
|
||||
|
||||
# 检查实际文件大小
|
||||
actual_size_mb = len(video_data) / (1024 * 1024)
|
||||
if actual_size_mb > self.max_size_mb:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"视频文件过大,实际大小: {actual_size_mb:.2f}MB",
|
||||
"url": url,
|
||||
}
|
||||
|
||||
# 确定文件名
|
||||
if filename is None:
|
||||
filename = Path(url.split("?")[0]).name
|
||||
if not filename or "." not in filename:
|
||||
filename = "video.mp4"
|
||||
|
||||
logger.info(f"视频下载成功: {filename}, 大小: {actual_size_mb:.2f}MB")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": video_data,
|
||||
"filename": filename,
|
||||
"size_mb": actual_size_mb,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {"success": False, "error": "下载超时", "url": url}
|
||||
except Exception as e:
|
||||
logger.error(f"下载视频时出错: {e}")
|
||||
return {"success": False, "error": str(e), "url": url}
|
||||
|
||||
|
||||
# 全局实例
|
||||
_video_downloader = None
|
||||
|
||||
|
||||
def get_video_downloader(max_size_mb: int = 100, download_timeout: int = 60) -> VideoDownloader:
|
||||
"""获取视频下载器实例"""
|
||||
global _video_downloader
|
||||
if _video_downloader is None:
|
||||
_video_downloader = VideoDownloader(max_size_mb, download_timeout)
|
||||
return _video_downloader
|
||||
Reference in New Issue
Block a user