至少让插件跑起来了
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -326,6 +326,7 @@ run_pet.bat
|
||||
!/plugins/permission_example
|
||||
!/plugins/hello_world_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
!/plugins/napcat_adapter_plugin
|
||||
|
||||
config.toml
|
||||
|
||||
|
||||
279
plugins/napcat_adapter_plugin/.gitignore
vendored
Normal file
279
plugins/napcat_adapter_plugin/.gitignore
vendored
Normal file
@@ -0,0 +1,279 @@
|
||||
|
||||
log/
|
||||
logs/
|
||||
out/
|
||||
|
||||
.env
|
||||
.env.*
|
||||
.cursor
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
uv.lock
|
||||
llm_statistics.txt
|
||||
mongodb
|
||||
napcat
|
||||
run_dev.bat
|
||||
elua.confirmed
|
||||
# C extensions
|
||||
*.so
|
||||
/results
|
||||
config_backup/
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# jieba
|
||||
jieba.cache
|
||||
|
||||
# .vscode
|
||||
!.vscode/settings.json
|
||||
|
||||
# direnv
|
||||
/.direnv
|
||||
|
||||
# JetBrains
|
||||
.idea
|
||||
*.iml
|
||||
*.ipr
|
||||
|
||||
# PyEnv
|
||||
# If using PyEnv and configured to use a specific Python version locally
|
||||
# a .local-version file will be created in the root of the project to specify the version.
|
||||
.python-version
|
||||
|
||||
OtherRes.txt
|
||||
|
||||
/eula.confirmed
|
||||
/privacy.confirmed
|
||||
|
||||
logs
|
||||
|
||||
.ruff_cache
|
||||
|
||||
.vscode
|
||||
|
||||
/config/*
|
||||
config/old/bot_config_20250405_212257.toml
|
||||
temp/
|
||||
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
config.toml
|
||||
feature.toml
|
||||
config.toml.back
|
||||
test
|
||||
data/NapcatAdapter.db
|
||||
data/NapcatAdapter.db-shm
|
||||
data/NapcatAdapter.db-wal
|
||||
42
plugins/napcat_adapter_plugin/_manifest.json
Normal file
42
plugins/napcat_adapter_plugin/_manifest.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "napcat_plugin",
|
||||
"version": "1.0.0",
|
||||
"description": "基于OneBot 11协议的NapCat QQ协议插件,提供完整的QQ机器人API接口,使用现有adapter连接",
|
||||
"author": {
|
||||
"name": "Windpicker_owo",
|
||||
"url": "https://github.com/Windpicker-owo"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.0",
|
||||
"max_version": "0.10.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
|
||||
"repository_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
|
||||
"keywords": ["qq", "bot", "napcat", "onebot", "api", "websocket"],
|
||||
"categories": ["protocol"],
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"components": [
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "napcat_tool",
|
||||
"description": "NapCat QQ协议综合工具,提供消息发送、群管理、好友管理、文件操作等完整功能"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
"消息发送与接收",
|
||||
"群管理功能",
|
||||
"好友管理功能",
|
||||
"文件上传下载",
|
||||
"AI语音功能",
|
||||
"群签到与戳一戳",
|
||||
"现有adapter连接"
|
||||
]
|
||||
}
|
||||
}
|
||||
127
plugins/napcat_adapter_plugin/plugin.py
Normal file
127
plugins/napcat_adapter_plugin/plugin.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
import websockets as Server
|
||||
from typing import List, Tuple
|
||||
|
||||
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField, BaseAction, ActionActivationType
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
# 添加当前目录到Python路径,这样可以识别src包
|
||||
current_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(current_dir))
|
||||
|
||||
from .src.recv_handler.message_handler import message_handler
|
||||
from .src.recv_handler.meta_event_handler import meta_event_handler
|
||||
from .src.recv_handler.notice_handler import notice_handler
|
||||
from .src.recv_handler.message_sending import message_send_instance
|
||||
from .src.send_handler import send_handler
|
||||
from .src.config import global_config
|
||||
from .src.config.features_config import features_manager
|
||||
from .src.config.migrate_features import auto_migrate_features
|
||||
from .src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
|
||||
from .src.response_pool import put_response, check_timeout_response
|
||||
from .src.websocket_manager import websocket_manager
|
||||
|
||||
message_queue = asyncio.Queue()
|
||||
|
||||
class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
"""自动启动Adapter"""
|
||||
|
||||
handler_name: str = "launch_napcat_adapter_handler"
|
||||
handler_description: str = "自动启动napcat adapter"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [EventType.ON_START]
|
||||
|
||||
async def message_recv(self, server_connection: Server.ServerConnection):
|
||||
await message_handler.set_server_connection(server_connection)
|
||||
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
||||
await send_handler.set_server_connection(server_connection)
|
||||
async for raw_message in server_connection:
|
||||
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
if post_type in ["meta_event", "message", "notice"]:
|
||||
await message_queue.put(decoded_raw_message)
|
||||
elif post_type is None:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
|
||||
async def message_process(self):
|
||||
while True:
|
||||
message = await message_queue.get()
|
||||
post_type = message.get("post_type")
|
||||
if post_type == "message":
|
||||
await message_handler.handle_raw_message(message)
|
||||
elif post_type == "meta_event":
|
||||
await meta_event_handler.handle_meta_event(message)
|
||||
elif post_type == "notice":
|
||||
await notice_handler.handle_notice(message)
|
||||
else:
|
||||
logger.warning(f"未知的post_type: {post_type}")
|
||||
message_queue.task_done()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def napcat_server(self):
|
||||
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
|
||||
mode = global_config.napcat_server.mode
|
||||
logger.info(f"正在启动 adapter,连接模式: {mode}")
|
||||
|
||||
try:
|
||||
await websocket_manager.start_connection(self.message_recv)
|
||||
except Exception as e:
|
||||
logger.error(f"启动 WebSocket 连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def execute(self, kwargs):
|
||||
# 执行功能配置迁移(如果需要)
|
||||
logger.info("检查功能配置迁移...")
|
||||
auto_migrate_features()
|
||||
|
||||
# 初始化功能管理器
|
||||
logger.info("正在初始化功能管理器...")
|
||||
features_manager.load_config()
|
||||
await features_manager.start_file_watcher(check_interval=2.0)
|
||||
logger.info("功能管理器初始化完成")
|
||||
logger.info("开始启动Napcat Adapter")
|
||||
message_send_instance.maibot_router = router
|
||||
# 创建单独的异步任务,防止阻塞主线程
|
||||
asyncio.create_task(self.napcat_server())
|
||||
asyncio.create_task(mmc_start_com())
|
||||
asyncio.create_task(self.message_process())
|
||||
asyncio.create_task(check_timeout_response())
|
||||
|
||||
@register_plugin
|
||||
class NapcatAdapterPlugin(BasePlugin):
|
||||
plugin_name = "napcat_adapter"
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_plugin_components(self):
|
||||
return [
|
||||
(LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler)
|
||||
]
|
||||
47
plugins/napcat_adapter_plugin/pyproject.toml
Normal file
47
plugins/napcat_adapter_plugin/pyproject.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[project]
|
||||
name = "MaiBotNapcatAdapter"
|
||||
version = "0.4.8"
|
||||
description = "A MaiBot adapter for Napcat"
|
||||
dependencies = [
|
||||
"ruff>=0.12.9",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
include = ["*.py"]
|
||||
|
||||
# 行长度设置
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
fixable = ["ALL"]
|
||||
unfixable = []
|
||||
|
||||
# 启用的规则
|
||||
select = [
|
||||
"E", # pycodestyle 错误
|
||||
"F", # pyflakes
|
||||
"B", # flake8-bugbear
|
||||
]
|
||||
|
||||
ignore = ["E711","E501"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
indent-style = "space"
|
||||
|
||||
|
||||
# 使用双引号表示字符串
|
||||
quote-style = "double"
|
||||
|
||||
# 尊重魔法尾随逗号
|
||||
# 例如:
|
||||
# items = [
|
||||
# "apple",
|
||||
# "banana",
|
||||
# "cherry",
|
||||
# ]
|
||||
skip-magic-trailing-comma = false
|
||||
|
||||
# 自动检测合适的换行符
|
||||
line-ending = "auto"
|
||||
31
plugins/napcat_adapter_plugin/src/__init__.py
Normal file
31
plugins/napcat_adapter_plugin/src/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from enum import Enum
|
||||
import tomlkit
|
||||
import os
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
SEND_POKE = "send_poke" # 戳一戳
|
||||
DELETE_MSG = "delete_msg" # 撤回消息
|
||||
AI_VOICE_SEND = "send_group_ai_record" # 发送群AI语音
|
||||
SET_EMOJI_LIKE = "set_emoji_like" # 设置表情回应
|
||||
SEND_AT_MESSAGE = "send_at_message" # 艾特用户并发送消息
|
||||
SEND_LIKE = "send_like" # 点赞
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
pyproject_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "pyproject.toml"
|
||||
)
|
||||
toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read())
|
||||
project_data = toml_data.get("project", {})
|
||||
version = project_data.get("version", "unknown")
|
||||
logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n喜欢的话点个star喵~\n")
|
||||
5
plugins/napcat_adapter_plugin/src/config/__init__.py
Normal file
5
plugins/napcat_adapter_plugin/src/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import global_config
|
||||
|
||||
__all__ = [
|
||||
"global_config",
|
||||
]
|
||||
148
plugins/napcat_adapter_plugin/src/config/config.py
Normal file
148
plugins/napcat_adapter_plugin/src/config/config.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
import tomlkit
|
||||
import shutil
|
||||
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from rich.traceback import install
|
||||
|
||||
from .config_base import ConfigBase
|
||||
from .official_configs import (
|
||||
DebugConfig,
|
||||
MaiBotServerConfig,
|
||||
NapcatServerConfig,
|
||||
NicknameConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
TEMPLATE_DIR = "plugins/napcat_adapter_plugin/template"
|
||||
CONFIG_DIR = "plugins/napcat_adapter_plugin/config"
|
||||
OLD_CONFIG_DIR = "plugins/napcat_adapter_plugin/config/old"
|
||||
|
||||
|
||||
def ensure_config_directories():
|
||||
"""确保配置目录存在"""
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
os.makedirs(OLD_CONFIG_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def update_config():
|
||||
"""更新配置文件,统一使用 config/old 目录进行备份"""
|
||||
# 确保目录存在
|
||||
ensure_config_directories()
|
||||
|
||||
# 定义文件路径
|
||||
template_path = f"{TEMPLATE_DIR}/template_config.toml"
|
||||
config_path = f"{CONFIG_DIR}/config.toml"
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
logger.info("主配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(template_path, config_path)
|
||||
logger.info(f"已创建新配置文件: {config_path}")
|
||||
logger.info("程序将退出,请检查配置文件后重启")
|
||||
|
||||
# 读取配置文件和模板文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建备份文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = os.path.join(OLD_CONFIG_DIR, f"config.toml.bak.{timestamp}")
|
||||
|
||||
# 备份旧配置文件
|
||||
shutil.copy2(config_path, backup_path)
|
||||
logger.info(f"已备份旧配置文件到: {backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path, config_path)
|
||||
logger.info(f"已创建新配置文件: {config_path}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
|
||||
"""将source字典的值更新到target字典中(如果target中存在相同的键)"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||
update_dict(target[key], value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(ConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
nickname: NicknameConfig
|
||||
napcat_server: NapcatServerConfig
|
||||
maibot_server: MaiBotServerConfig
|
||||
voice: VoiceConfig
|
||||
debug: DebugConfig
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""
|
||||
加载配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象
|
||||
try:
|
||||
return Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
# 更新配置
|
||||
update_config()
|
||||
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path=f"{CONFIG_DIR}/config.toml")
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
136
plugins/napcat_adapter_plugin/src/config/config_base.py
Normal file
136
plugins/napcat_adapter_plugin/src/config/config_base.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
TOML_DICT_TYPE = {
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
bool,
|
||||
list,
|
||||
dict,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigBase:
|
||||
"""配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||||
"""从字典加载配置字段"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: Dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
field_type = f.type
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type)
|
||||
except TypeError as e:
|
||||
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
|
||||
|
||||
return cls(**init_args)
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
"""
|
||||
转换字段值为指定类型
|
||||
|
||||
1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
|
||||
2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
|
||||
3. 对于基础类型(int, str, float, bool),直接转换
|
||||
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
|
||||
"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
|
||||
return field_type.from_dict(value)
|
||||
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_args_type = get_args(field_type)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
# 检查提供的value是否为list
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
return [cls._convert_field(item, field_args_type[0]) for item in value]
|
||||
if field_origin_type is set:
|
||||
return {cls._convert_field(item, field_args_type[0]) for item in value}
|
||||
if field_origin_type is tuple:
|
||||
# 检查提供的value长度是否与类型参数一致
|
||||
if len(value) != len(field_args_type):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
|
||||
|
||||
if field_origin_type is dict:
|
||||
# 检查提供的value是否为dict
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 检查字典的键值类型
|
||||
if len(field_args_type) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_args_type
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理Optional类型
|
||||
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
|
||||
if value is None:
|
||||
return None
|
||||
# 如果有数据,检查实际类型
|
||||
if type(value) not in field_args_type:
|
||||
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
|
||||
return cls._convert_field(value, field_args_type[0])
|
||||
|
||||
# 处理int, str, float, bool等基础类型
|
||||
if field_origin_type is None:
|
||||
if isinstance(value, field_type):
|
||||
return field_type(value)
|
||||
else:
|
||||
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal:
|
||||
# 获取Literal的允许值
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
# 处理其他类型
|
||||
if field_type is Any:
|
||||
return value
|
||||
|
||||
# 其他类型直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
|
||||
|
||||
def __str__(self):
|
||||
"""返回配置类的字符串表示"""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|
||||
146
plugins/napcat_adapter_plugin/src/config/config_utils.py
Normal file
146
plugins/napcat_adapter_plugin/src/config/config_utils.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
配置文件工具模块
|
||||
提供统一的配置文件生成和管理功能
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
def ensure_config_directories():
|
||||
"""确保配置目录存在"""
|
||||
os.makedirs("config", exist_ok=True)
|
||||
os.makedirs("config/old", exist_ok=True)
|
||||
|
||||
|
||||
def create_config_from_template(
|
||||
config_path: str,
|
||||
template_path: str,
|
||||
config_name: str = "配置文件",
|
||||
should_exit: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
从模板创建配置文件的统一函数
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
template_path: 模板文件路径
|
||||
config_name: 配置文件名称(用于日志显示)
|
||||
should_exit: 创建后是否退出程序
|
||||
|
||||
Returns:
|
||||
bool: 是否成功创建配置文件
|
||||
"""
|
||||
try:
|
||||
# 确保配置目录存在
|
||||
ensure_config_directories()
|
||||
|
||||
config_path_obj = Path(config_path)
|
||||
template_path_obj = Path(template_path)
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if config_path_obj.exists():
|
||||
return False # 配置文件已存在,无需创建
|
||||
|
||||
logger.info(f"{config_name}不存在,从模板创建新配置")
|
||||
|
||||
# 检查模板文件是否存在
|
||||
if not template_path_obj.exists():
|
||||
logger.error(f"模板文件不存在: {template_path}")
|
||||
if should_exit:
|
||||
logger.critical("无法创建配置文件,程序退出")
|
||||
quit(1)
|
||||
return False
|
||||
|
||||
# 确保配置文件目录存在
|
||||
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path_obj, config_path_obj)
|
||||
logger.info(f"已创建新{config_name}: {config_path}")
|
||||
|
||||
if should_exit:
|
||||
logger.info("程序将退出,请检查配置文件后重启")
|
||||
quit(0)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建{config_name}失败: {e}")
|
||||
if should_exit:
|
||||
logger.critical("无法创建配置文件,程序退出")
|
||||
quit(1)
|
||||
return False
|
||||
|
||||
|
||||
def create_default_config_dict(default_values: dict, config_path: str, config_name: str = "配置文件") -> bool:
|
||||
"""
|
||||
创建默认配置文件(使用字典数据)
|
||||
|
||||
Args:
|
||||
default_values: 默认配置值字典
|
||||
config_path: 配置文件路径
|
||||
config_name: 配置文件名称(用于日志显示)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功创建配置文件
|
||||
"""
|
||||
try:
|
||||
import tomlkit
|
||||
|
||||
config_path_obj = Path(config_path)
|
||||
|
||||
# 确保配置文件目录存在
|
||||
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入默认配置
|
||||
with open(config_path_obj, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(default_values, f)
|
||||
|
||||
logger.info(f"已创建默认{config_name}: {config_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建默认{config_name}失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def backup_config_file(config_path: str, backup_dir: str = "config/old") -> Optional[str]:
|
||||
"""
|
||||
备份配置文件
|
||||
|
||||
Args:
|
||||
config_path: 要备份的配置文件路径
|
||||
backup_dir: 备份目录
|
||||
|
||||
Returns:
|
||||
Optional[str]: 备份文件路径,失败时返回None
|
||||
"""
|
||||
try:
|
||||
config_path_obj = Path(config_path)
|
||||
if not config_path_obj.exists():
|
||||
return None
|
||||
|
||||
# 确保备份目录存在
|
||||
backup_dir_obj = Path(backup_dir)
|
||||
backup_dir_obj.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建备份文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_filename = f"{config_path_obj.stem}.toml.bak.{timestamp}"
|
||||
backup_path = backup_dir_obj / backup_filename
|
||||
|
||||
# 备份文件
|
||||
shutil.copy2(config_path_obj, backup_path)
|
||||
logger.info(f"已备份配置文件到: {backup_path}")
|
||||
|
||||
return str(backup_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份配置文件失败: {e}")
|
||||
return None
|
||||
359
plugins/napcat_adapter_plugin/src/config/features_config.py
Normal file
359
plugins/napcat_adapter_plugin/src/config/features_config.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from pathlib import Path
|
||||
import tomlkit
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .config_base import ConfigBase
|
||||
from .config_utils import create_config_from_template, create_default_config_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeaturesConfig(ConfigBase):
|
||||
"""功能配置类"""
|
||||
|
||||
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""群聊列表类型 白名单/黑名单"""
|
||||
|
||||
group_list: list[int] = field(default_factory=list)
|
||||
"""群聊列表"""
|
||||
|
||||
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""私聊列表类型 白名单/黑名单"""
|
||||
|
||||
private_list: list[int] = field(default_factory=list)
|
||||
"""私聊列表"""
|
||||
|
||||
ban_user_id: list[int] = field(default_factory=list)
|
||||
"""被封禁的用户ID列表,封禁后将无法与其进行交互"""
|
||||
|
||||
ban_qq_bot: bool = False
|
||||
"""是否屏蔽QQ官方机器人,若为True,则所有QQ官方机器人将无法与MaiMCore进行交互"""
|
||||
|
||||
enable_poke: bool = True
|
||||
"""是否启用戳一戳功能"""
|
||||
|
||||
ignore_non_self_poke: bool = False
|
||||
"""是否无视不是针对自己的戳一戳"""
|
||||
|
||||
poke_debounce_seconds: int = 3
|
||||
"""戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"""
|
||||
|
||||
enable_reply_at: bool = True
|
||||
"""是否启用引用回复时艾特用户的功能"""
|
||||
|
||||
reply_at_rate: float = 0.5
|
||||
"""引用回复时艾特用户的几率 (0.0 ~ 1.0)"""
|
||||
|
||||
enable_video_analysis: bool = True
|
||||
"""是否启用视频识别功能"""
|
||||
|
||||
max_video_size_mb: int = 100
|
||||
"""视频文件最大大小限制(MB)"""
|
||||
|
||||
download_timeout: int = 60
|
||||
"""视频下载超时时间(秒)"""
|
||||
|
||||
supported_formats: list[str] = field(default_factory=lambda: ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
|
||||
"""支持的视频格式"""
|
||||
|
||||
# 消息缓冲配置
|
||||
enable_message_buffer: bool = True
|
||||
"""是否启用消息缓冲合并功能"""
|
||||
|
||||
message_buffer_enable_group: bool = True
|
||||
"""是否启用群消息缓冲合并"""
|
||||
|
||||
message_buffer_enable_private: bool = True
|
||||
"""是否启用私聊消息缓冲合并"""
|
||||
|
||||
message_buffer_interval: float = 3.0
|
||||
"""消息合并间隔时间(秒),在此时间内的连续消息将被合并"""
|
||||
|
||||
message_buffer_initial_delay: float = 0.5
|
||||
"""消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"""
|
||||
|
||||
message_buffer_max_components: int = 50
|
||||
"""单个会话最大缓冲消息组件数量,超过此数量将强制合并"""
|
||||
|
||||
message_buffer_block_prefixes: list[str] = field(default_factory=lambda: ["/", "!", "!", ".", "。", "#", "%"])
|
||||
"""消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"""
|
||||
|
||||
|
||||
class FeaturesManager:
|
||||
"""功能管理器,支持热重载"""
|
||||
|
||||
def __init__(self, config_path: str = "plugins/napcat_adapter_plugin/config/features.toml"):
|
||||
self.config_path = Path(config_path)
|
||||
self.config: Optional[FeaturesConfig] = None
|
||||
self._file_watcher_task: Optional[asyncio.Task] = None
|
||||
self._last_modified: Optional[float] = None
|
||||
self._callbacks: list = []
|
||||
|
||||
def add_reload_callback(self, callback):
|
||||
"""添加配置重载回调函数"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def remove_reload_callback(self, callback):
|
||||
"""移除配置重载回调函数"""
|
||||
if callback in self._callbacks:
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
async def _notify_callbacks(self):
|
||||
"""通知所有回调函数配置已重载"""
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(self.config)
|
||||
else:
|
||||
callback(self.config)
|
||||
except Exception as e:
|
||||
logger.error(f"配置重载回调执行失败: {e}")
|
||||
|
||||
def load_config(self) -> FeaturesConfig:
|
||||
"""加载功能配置文件"""
|
||||
try:
|
||||
# 检查配置文件是否存在,如果不存在则创建并退出程序
|
||||
if not self.config_path.exists():
|
||||
logger.info(f"功能配置文件不存在: {self.config_path}")
|
||||
self._create_default_config()
|
||||
# 配置文件创建后程序应该退出,让用户检查配置
|
||||
logger.info("程序将退出,请检查功能配置文件后重启")
|
||||
quit(0)
|
||||
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
self.config = FeaturesConfig.from_dict(config_data)
|
||||
self._last_modified = self.config_path.stat().st_mtime
|
||||
logger.info(f"功能配置加载成功: {self.config_path}")
|
||||
return self.config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置加载失败: {e}")
|
||||
logger.critical("无法加载功能配置文件,程序退出")
|
||||
quit(1)
|
||||
|
||||
def _create_default_config(self):
|
||||
"""创建默认功能配置文件"""
|
||||
template_path = "template/features_template.toml"
|
||||
|
||||
# 尝试从模板创建配置文件
|
||||
if create_config_from_template(
|
||||
str(self.config_path),
|
||||
template_path,
|
||||
"功能配置文件",
|
||||
should_exit=False # 不在这里退出,由调用方决定
|
||||
):
|
||||
return
|
||||
|
||||
# 如果模板文件不存在,创建基本配置
|
||||
logger.info("模板文件不存在,创建基本功能配置")
|
||||
default_config = {
|
||||
"group_list_type": "whitelist",
|
||||
"group_list": [],
|
||||
"private_list_type": "whitelist",
|
||||
"private_list": [],
|
||||
"ban_user_id": [],
|
||||
"ban_qq_bot": False,
|
||||
"enable_poke": True,
|
||||
"ignore_non_self_poke": False,
|
||||
"poke_debounce_seconds": 3,
|
||||
"enable_reply_at": True,
|
||||
"reply_at_rate": 0.5,
|
||||
"enable_video_analysis": True,
|
||||
"max_video_size_mb": 100,
|
||||
"download_timeout": 60,
|
||||
"supported_formats": ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"],
|
||||
# 消息缓冲配置
|
||||
"enable_message_buffer": True,
|
||||
"message_buffer_enable_group": True,
|
||||
"message_buffer_enable_private": True,
|
||||
"message_buffer_interval": 3.0,
|
||||
"message_buffer_initial_delay": 0.5,
|
||||
"message_buffer_max_components": 50,
|
||||
"message_buffer_block_prefixes": ["/", "!", "!", ".", "。", "#", "%"]
|
||||
}
|
||||
|
||||
if not create_default_config_dict(default_config, str(self.config_path), "功能配置文件"):
|
||||
logger.critical("无法创建功能配置文件")
|
||||
quit(1)
|
||||
|
||||
async def reload_config(self) -> bool:
|
||||
"""重新加载配置文件"""
|
||||
try:
|
||||
if not self.config_path.exists():
|
||||
logger.warning(f"功能配置文件不存在,无法重载: {self.config_path}")
|
||||
return False
|
||||
|
||||
current_modified = self.config_path.stat().st_mtime
|
||||
if self._last_modified and current_modified <= self._last_modified:
|
||||
return False # 文件未修改
|
||||
|
||||
old_config = self.config
|
||||
new_config = self.load_config()
|
||||
|
||||
# 检查配置是否真的发生了变化
|
||||
if old_config and self._configs_equal(old_config, new_config):
|
||||
return False
|
||||
|
||||
logger.info("功能配置已重载")
|
||||
await self._notify_callbacks()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置重载失败: {e}")
|
||||
return False
|
||||
|
||||
def _configs_equal(self, config1: FeaturesConfig, config2: FeaturesConfig) -> bool:
|
||||
"""比较两个配置是否相等"""
|
||||
return (
|
||||
config1.group_list_type == config2.group_list_type and
|
||||
set(config1.group_list) == set(config2.group_list) and
|
||||
config1.private_list_type == config2.private_list_type and
|
||||
set(config1.private_list) == set(config2.private_list) and
|
||||
set(config1.ban_user_id) == set(config2.ban_user_id) and
|
||||
config1.ban_qq_bot == config2.ban_qq_bot and
|
||||
config1.enable_poke == config2.enable_poke and
|
||||
config1.ignore_non_self_poke == config2.ignore_non_self_poke and
|
||||
config1.poke_debounce_seconds == config2.poke_debounce_seconds and
|
||||
config1.enable_reply_at == config2.enable_reply_at and
|
||||
config1.reply_at_rate == config2.reply_at_rate and
|
||||
config1.enable_video_analysis == config2.enable_video_analysis and
|
||||
config1.max_video_size_mb == config2.max_video_size_mb and
|
||||
config1.download_timeout == config2.download_timeout and
|
||||
set(config1.supported_formats) == set(config2.supported_formats) and
|
||||
# 消息缓冲配置比较
|
||||
config1.enable_message_buffer == config2.enable_message_buffer and
|
||||
config1.message_buffer_enable_group == config2.message_buffer_enable_group and
|
||||
config1.message_buffer_enable_private == config2.message_buffer_enable_private and
|
||||
config1.message_buffer_interval == config2.message_buffer_interval and
|
||||
config1.message_buffer_initial_delay == config2.message_buffer_initial_delay and
|
||||
config1.message_buffer_max_components == config2.message_buffer_max_components and
|
||||
set(config1.message_buffer_block_prefixes) == set(config2.message_buffer_block_prefixes)
|
||||
)
|
||||
|
||||
async def start_file_watcher(self, check_interval: float = 1.0):
|
||||
"""启动文件监控,定期检查配置文件变化"""
|
||||
if self._file_watcher_task and not self._file_watcher_task.done():
|
||||
logger.warning("文件监控已在运行")
|
||||
return
|
||||
|
||||
self._file_watcher_task = asyncio.create_task(
|
||||
self._file_watcher_loop(check_interval)
|
||||
)
|
||||
logger.info(f"功能配置文件监控已启动,检查间隔: {check_interval}秒")
|
||||
|
||||
async def stop_file_watcher(self):
|
||||
"""停止文件监控"""
|
||||
if self._file_watcher_task and not self._file_watcher_task.done():
|
||||
self._file_watcher_task.cancel()
|
||||
try:
|
||||
await self._file_watcher_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("功能配置文件监控已停止")
|
||||
|
||||
async def _file_watcher_loop(self, check_interval: float):
|
||||
"""文件监控循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(check_interval)
|
||||
await self.reload_config()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"文件监控循环出错: {e}")
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
def get_config(self) -> FeaturesConfig:
|
||||
"""获取当前功能配置"""
|
||||
if self.config is None:
|
||||
return self.load_config()
|
||||
return self.config
|
||||
|
||||
def is_group_allowed(self, group_id: int) -> bool:
|
||||
"""检查群聊是否被允许"""
|
||||
config = self.get_config()
|
||||
if config.group_list_type == "whitelist":
|
||||
return group_id in config.group_list
|
||||
else: # blacklist
|
||||
return group_id not in config.group_list
|
||||
|
||||
def is_private_allowed(self, user_id: int) -> bool:
|
||||
"""检查私聊是否被允许"""
|
||||
config = self.get_config()
|
||||
if config.private_list_type == "whitelist":
|
||||
return user_id in config.private_list
|
||||
else: # blacklist
|
||||
return user_id not in config.private_list
|
||||
|
||||
def is_user_banned(self, user_id: int) -> bool:
|
||||
"""检查用户是否被全局禁止"""
|
||||
config = self.get_config()
|
||||
return user_id in config.ban_user_id
|
||||
|
||||
def is_qq_bot_banned(self) -> bool:
|
||||
"""检查是否禁止QQ官方机器人"""
|
||||
config = self.get_config()
|
||||
return config.ban_qq_bot
|
||||
|
||||
def is_poke_enabled(self) -> bool:
|
||||
"""检查戳一戳功能是否启用"""
|
||||
config = self.get_config()
|
||||
return config.enable_poke
|
||||
|
||||
def is_non_self_poke_ignored(self) -> bool:
|
||||
"""检查是否忽略非自己戳一戳"""
|
||||
config = self.get_config()
|
||||
return config.ignore_non_self_poke
|
||||
|
||||
def is_message_buffer_enabled(self) -> bool:
|
||||
"""检查消息缓冲功能是否启用"""
|
||||
config = self.get_config()
|
||||
return config.enable_message_buffer
|
||||
|
||||
def is_message_buffer_group_enabled(self) -> bool:
|
||||
"""检查群消息缓冲是否启用"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_group
|
||||
|
||||
def is_message_buffer_private_enabled(self) -> bool:
|
||||
"""检查私聊消息缓冲是否启用"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_private
|
||||
|
||||
def get_message_buffer_interval(self) -> float:
|
||||
"""获取消息缓冲间隔时间"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_interval
|
||||
|
||||
def get_message_buffer_initial_delay(self) -> float:
|
||||
"""获取消息缓冲初始延迟"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_initial_delay
|
||||
|
||||
def get_message_buffer_max_components(self) -> int:
|
||||
"""获取消息缓冲最大组件数量"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_max_components
|
||||
|
||||
def is_message_buffer_group_enabled(self) -> bool:
|
||||
"""检查是否启用群聊消息缓冲"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_group
|
||||
|
||||
def is_message_buffer_private_enabled(self) -> bool:
|
||||
"""检查是否启用私聊消息缓冲"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_private
|
||||
|
||||
def get_message_buffer_block_prefixes(self) -> list[str]:
|
||||
"""获取消息缓冲屏蔽前缀列表"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_block_prefixes
|
||||
|
||||
|
||||
# 全局功能管理器实例
|
||||
features_manager = FeaturesManager()
|
||||
194
plugins/napcat_adapter_plugin/src/config/migrate_features.py
Normal file
194
plugins/napcat_adapter_plugin/src/config/migrate_features.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
功能配置迁移脚本
|
||||
用于将旧的配置文件中的聊天、权限、视频处理等设置迁移到新的独立功能配置文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import tomlkit
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
def migrate_features_from_config(old_config_path: str = "plugins/napcat_adapter_plugin/config/config.toml",
|
||||
new_features_path: str = "plugins/napcat_adapter_plugin/config/features.toml",
|
||||
template_path: str = "plugins/napcat_adapter_plugin/template/features_template.toml"):
|
||||
"""
|
||||
从旧配置文件迁移功能设置到新的功能配置文件
|
||||
|
||||
Args:
|
||||
old_config_path: 旧配置文件路径
|
||||
new_features_path: 新功能配置文件路径
|
||||
template_path: 功能配置模板路径
|
||||
"""
|
||||
try:
|
||||
# 检查旧配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
logger.warning(f"旧配置文件不存在: {old_config_path}")
|
||||
return False
|
||||
|
||||
# 读取旧配置文件
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
|
||||
# 检查是否有chat配置段和video配置段
|
||||
chat_config = old_config.get("chat", {})
|
||||
video_config = old_config.get("video", {})
|
||||
|
||||
# 检查是否有权限相关配置
|
||||
permission_keys = ["group_list_type", "group_list", "private_list_type",
|
||||
"private_list", "ban_user_id", "ban_qq_bot",
|
||||
"enable_poke", "ignore_non_self_poke", "poke_debounce_seconds"]
|
||||
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
|
||||
|
||||
has_permission_config = any(key in chat_config for key in permission_keys)
|
||||
has_video_config = any(key in video_config for key in video_keys)
|
||||
|
||||
if not has_permission_config and not has_video_config:
|
||||
logger.info("旧配置文件中没有找到功能相关配置,无需迁移")
|
||||
return False
|
||||
|
||||
# 确保新功能配置目录存在
|
||||
new_features_dir = Path(new_features_path).parent
|
||||
new_features_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 如果新功能配置文件已存在,先备份
|
||||
if os.path.exists(new_features_path):
|
||||
backup_path = f"{new_features_path}.backup"
|
||||
shutil.copy2(new_features_path, backup_path)
|
||||
logger.info(f"已备份现有功能配置文件到: {backup_path}")
|
||||
|
||||
# 创建新的功能配置
|
||||
new_features_config = {
|
||||
"group_list_type": chat_config.get("group_list_type", "whitelist"),
|
||||
"group_list": chat_config.get("group_list", []),
|
||||
"private_list_type": chat_config.get("private_list_type", "whitelist"),
|
||||
"private_list": chat_config.get("private_list", []),
|
||||
"ban_user_id": chat_config.get("ban_user_id", []),
|
||||
"ban_qq_bot": chat_config.get("ban_qq_bot", False),
|
||||
"enable_poke": chat_config.get("enable_poke", True),
|
||||
"ignore_non_self_poke": chat_config.get("ignore_non_self_poke", False),
|
||||
"poke_debounce_seconds": chat_config.get("poke_debounce_seconds", 3),
|
||||
"enable_video_analysis": video_config.get("enable_video_analysis", True),
|
||||
"max_video_size_mb": video_config.get("max_video_size_mb", 100),
|
||||
"download_timeout": video_config.get("download_timeout", 60),
|
||||
"supported_formats": video_config.get("supported_formats", ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
|
||||
}
|
||||
|
||||
# 写入新的功能配置文件
|
||||
with open(new_features_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(new_features_config, f)
|
||||
|
||||
logger.info(f"功能配置已成功迁移到: {new_features_path}")
|
||||
|
||||
# 显示迁移的配置内容
|
||||
logger.info("迁移的配置内容:")
|
||||
for key, value in new_features_config.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置迁移失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def remove_features_from_old_config(config_path: str = "plugins/napcat_adapter_plugin/config/config.toml"):
|
||||
"""
|
||||
从旧配置文件中移除功能相关配置,并将旧配置移动到 config/old/ 目录
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(config_path):
|
||||
logger.warning(f"配置文件不存在: {config_path}")
|
||||
return False
|
||||
|
||||
# 确保 config/old 目录存在
|
||||
old_config_dir = "plugins/napcat_adapter_plugin/config/old"
|
||||
os.makedirs(old_config_dir, exist_ok=True)
|
||||
|
||||
# 备份原配置文件到 config/old 目录
|
||||
old_config_path = os.path.join(old_config_dir, "config_with_features.toml")
|
||||
shutil.copy2(config_path, old_config_path)
|
||||
logger.info(f"已备份包含功能配置的原文件到: {old_config_path}")
|
||||
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 移除chat段中的功能相关配置
|
||||
removed_keys = []
|
||||
if "chat" in config:
|
||||
chat_config = config["chat"]
|
||||
permission_keys = ["group_list_type", "group_list", "private_list_type",
|
||||
"private_list", "ban_user_id", "ban_qq_bot",
|
||||
"enable_poke", "ignore_non_self_poke", "poke_debounce_seconds"]
|
||||
|
||||
for key in permission_keys:
|
||||
if key in chat_config:
|
||||
del chat_config[key]
|
||||
removed_keys.append(key)
|
||||
|
||||
if removed_keys:
|
||||
logger.info(f"已从chat配置段中移除功能相关配置: {removed_keys}")
|
||||
|
||||
# 移除video段中的配置
|
||||
if "video" in config:
|
||||
video_config = config["video"]
|
||||
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
|
||||
|
||||
video_removed_keys = []
|
||||
for key in video_keys:
|
||||
if key in video_config:
|
||||
del video_config[key]
|
||||
video_removed_keys.append(key)
|
||||
|
||||
if video_removed_keys:
|
||||
logger.info(f"已从video配置段中移除配置: {video_removed_keys}")
|
||||
removed_keys.extend(video_removed_keys)
|
||||
|
||||
# 如果video段为空,则删除整个段
|
||||
if not video_config:
|
||||
del config["video"]
|
||||
logger.info("已删除空的video配置段")
|
||||
|
||||
if removed_keys:
|
||||
logger.info(f"总共移除的配置项: {removed_keys}")
|
||||
|
||||
# 写回配置文件
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(config))
|
||||
|
||||
logger.info(f"已更新配置文件: {config_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"移除功能配置失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def auto_migrate_features():
|
||||
"""
|
||||
自动执行功能配置迁移
|
||||
"""
|
||||
logger.info("开始自动功能配置迁移...")
|
||||
|
||||
# 执行迁移
|
||||
if migrate_features_from_config():
|
||||
logger.info("功能配置迁移成功")
|
||||
|
||||
# 询问是否要从旧配置文件中移除功能配置
|
||||
logger.info("功能配置已迁移到独立文件,建议从主配置文件中移除相关配置")
|
||||
# 在实际使用中,这里可以添加用户确认逻辑
|
||||
# 为了自动化,这里直接执行移除
|
||||
remove_features_from_old_config()
|
||||
|
||||
else:
|
||||
logger.info("功能配置迁移跳过或失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_migrate_features()
|
||||
67
plugins/napcat_adapter_plugin/src/config/official_configs.py
Normal file
67
plugins/napcat_adapter_plugin/src/config/official_configs.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from .config_base import ConfigBase
|
||||
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
2. 所有新增的class都需要继承自ConfigBase
|
||||
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
|
||||
ADAPTER_PLATFORM = "qq"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NicknameConfig(ConfigBase):
|
||||
nickname: str
|
||||
"""机器人昵称"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NapcatServerConfig(ConfigBase):
|
||||
mode: Literal["reverse", "forward"] = "reverse"
|
||||
"""连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""主机地址"""
|
||||
|
||||
port: int = 8095
|
||||
"""端口号"""
|
||||
|
||||
url: str = ""
|
||||
"""正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws"""
|
||||
|
||||
access_token: str = ""
|
||||
"""WebSocket 连接的访问令牌,用于身份验证"""
|
||||
|
||||
heartbeat_interval: int = 30
|
||||
"""心跳间隔时间,单位为秒"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiBotServerConfig(ConfigBase):
|
||||
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
|
||||
"""平台名称,“qq”"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""MaiMCore的主机地址"""
|
||||
|
||||
port: int = 8000
|
||||
"""MaiMCore的端口号"""
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
use_tts: bool = False
|
||||
"""是否启用TTS功能"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugConfig(ConfigBase):
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
"""日志级别,默认为INFO"""
|
||||
163
plugins/napcat_adapter_plugin/src/database.py
Normal file
163
plugins/napcat_adapter_plugin/src/database.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import os
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
"""
|
||||
表记录的方式:
|
||||
| group_id | user_id | lift_time |
|
||||
|----------|---------|-----------|
|
||||
|
||||
其中使用 user_id == 0 表示群全体禁言
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BanUser:
|
||||
"""
|
||||
程序处理使用的实例
|
||||
"""
|
||||
|
||||
user_id: int
|
||||
group_id: int
|
||||
lift_time: Optional[int] = Field(default=-1)
|
||||
|
||||
|
||||
class DB_BanUser(SQLModel, table=True):
|
||||
"""
|
||||
表示数据库中的用户禁言记录。
|
||||
使用双重主键
|
||||
"""
|
||||
|
||||
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
|
||||
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
|
||||
lift_time: Optional[int] # 禁言解除的时间(时间戳)
|
||||
|
||||
|
||||
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
|
||||
"""
|
||||
检查两个 BanUser 对象是否相同。
|
||||
"""
|
||||
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
数据库管理类,负责与数据库交互。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
|
||||
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
|
||||
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
|
||||
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
|
||||
self._ensure_database() # 确保数据库和表已创建
|
||||
|
||||
def _ensure_database(self) -> None:
|
||||
"""
|
||||
确保数据库和表已创建。
|
||||
"""
|
||||
logger.info("确保数据库文件和表已创建...")
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
logger.info("数据库和表已创建或已存在")
|
||||
|
||||
def update_ban_record(self, ban_list: List[BanUser]) -> None:
|
||||
# sourcery skip: class-extract-method
|
||||
"""
|
||||
更新禁言列表到数据库。
|
||||
支持在不存在时创建新记录,对于多余的项目自动删除。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
all_records = session.exec(select(DB_BanUser)).all()
|
||||
for ban_user in ban_list:
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
|
||||
)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.lift_time == ban_user.lift_time:
|
||||
logger.debug(f"禁言记录未变更: {existing_record}")
|
||||
continue
|
||||
# 更新现有记录的 lift_time
|
||||
existing_record.lift_time = ban_user.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {existing_record}")
|
||||
else:
|
||||
# 创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_user}")
|
||||
# 删除不在 ban_list 中的记录
|
||||
for db_record in all_records:
|
||||
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
|
||||
if not any(is_identical(record, ban_user) for ban_user in ban_list):
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
|
||||
)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: {ban_record}")
|
||||
|
||||
|
||||
logger.info("禁言记录已更新")
|
||||
|
||||
def get_ban_records(self) -> List[BanUser]:
|
||||
"""
|
||||
读取所有禁言记录。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser)
|
||||
records = session.exec(statement).all()
|
||||
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
|
||||
|
||||
def create_ban_record(self, ban_record: BanUser) -> None:
|
||||
"""
|
||||
为特定群组中的用户创建禁言记录。
|
||||
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
|
||||
其同时还是简化版的更新方式。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
# 检查记录是否已存在
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
|
||||
)
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record:
|
||||
# 如果记录已存在,更新 lift_time
|
||||
existing_record.lift_time = ban_record.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {ban_record}")
|
||||
else:
|
||||
# 如果记录不存在,创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_record}")
|
||||
|
||||
|
||||
def delete_ban_record(self, ban_record: BanUser):
|
||||
"""
|
||||
删除特定用户在特定群组中的禁言记录。
|
||||
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
|
||||
"""
|
||||
user_id = ban_record.user_id
|
||||
group_id = ban_record.group_id
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
|
||||
|
||||
|
||||
db_manager = DatabaseManager()
|
||||
320
plugins/napcat_adapter_plugin/src/message_buffer.py
Normal file
320
plugins/napcat_adapter_plugin/src/message_buffer.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from .config.features_config import features_manager
|
||||
from .recv_handler import RealMessageType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextMessage:
|
||||
"""文本消息"""
|
||||
text: str
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferedSession:
|
||||
"""缓冲会话数据"""
|
||||
session_id: str
|
||||
messages: List[TextMessage] = field(default_factory=list)
|
||||
timer_task: Optional[asyncio.Task] = None
|
||||
delay_task: Optional[asyncio.Task] = None
|
||||
original_event: Any = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class SimpleMessageBuffer:
|
||||
|
||||
def __init__(self, merge_callback=None):
|
||||
"""
|
||||
初始化消息缓冲器
|
||||
|
||||
Args:
|
||||
merge_callback: 消息合并后的回调函数,接收(session_id, merged_text, original_event)参数
|
||||
"""
|
||||
self.buffer_pool: Dict[str, BufferedSession] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
self.merge_callback = merge_callback
|
||||
self._shutdown = False
|
||||
|
||||
def get_session_id(self, event_data: Dict[str, Any]) -> str:
|
||||
"""根据事件数据生成会话ID"""
|
||||
message_type = event_data.get("message_type", "unknown")
|
||||
user_id = event_data.get("user_id", "unknown")
|
||||
|
||||
if message_type == "private":
|
||||
return f"private_{user_id}"
|
||||
elif message_type == "group":
|
||||
group_id = event_data.get("group_id", "unknown")
|
||||
return f"group_{group_id}_{user_id}"
|
||||
else:
|
||||
return f"{message_type}_{user_id}"
|
||||
|
||||
def extract_text_from_message(self, message: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""从OneBot消息中提取纯文本,如果包含非文本内容则返回None"""
|
||||
text_parts = []
|
||||
has_non_text = False
|
||||
|
||||
logger.debug(f"正在提取消息文本,消息段数量: {len(message)}")
|
||||
|
||||
for msg_seg in message:
|
||||
msg_type = msg_seg.get("type", "")
|
||||
logger.debug(f"处理消息段类型: {msg_type}")
|
||||
|
||||
if msg_type == RealMessageType.text:
|
||||
text = msg_seg.get("data", {}).get("text", "").strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取到文本: {text[:50]}...")
|
||||
else:
|
||||
# 发现非文本消息段,标记为包含非文本内容
|
||||
has_non_text = True
|
||||
logger.debug(f"发现非文本消息段: {msg_type},跳过缓冲")
|
||||
|
||||
# 如果包含非文本内容,则不进行缓冲
|
||||
if has_non_text:
|
||||
logger.debug("消息包含非文本内容,不进行缓冲")
|
||||
return None
|
||||
|
||||
if text_parts:
|
||||
combined_text = " ".join(text_parts).strip()
|
||||
logger.debug(f"成功提取纯文本: {combined_text[:50]}...")
|
||||
return combined_text
|
||||
|
||||
logger.debug("没有找到有效的文本内容")
|
||||
return None
|
||||
|
||||
def should_skip_message(self, text: str) -> bool:
|
||||
"""判断消息是否应该跳过缓冲"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
|
||||
# 检查屏蔽前缀
|
||||
config = features_manager.get_config()
|
||||
block_prefixes = tuple(config.message_buffer_block_prefixes)
|
||||
|
||||
text = text.strip()
|
||||
if text.startswith(block_prefixes):
|
||||
logger.debug(f"消息以屏蔽前缀开头,跳过缓冲: {text[:20]}...")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def add_text_message(self, event_data: Dict[str, Any], message: List[Dict[str, Any]],
|
||||
original_event: Any = None) -> bool:
|
||||
"""
|
||||
添加文本消息到缓冲区
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
message: OneBot消息数组
|
||||
original_event: 原始事件对象
|
||||
|
||||
Returns:
|
||||
是否成功添加到缓冲区
|
||||
"""
|
||||
if self._shutdown:
|
||||
return False
|
||||
|
||||
config = features_manager.get_config()
|
||||
if not config.enable_message_buffer:
|
||||
return False
|
||||
|
||||
# 检查是否启用对应类型的缓冲
|
||||
message_type = event_data.get("message_type", "")
|
||||
if message_type == "group" and not config.message_buffer_enable_group:
|
||||
return False
|
||||
elif message_type == "private" and not config.message_buffer_enable_private:
|
||||
return False
|
||||
|
||||
# 提取文本
|
||||
text = self.extract_text_from_message(message)
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 检查是否应该跳过
|
||||
if self.should_skip_message(text):
|
||||
return False
|
||||
|
||||
session_id = self.get_session_id(event_data)
|
||||
|
||||
async with self.lock:
|
||||
# 获取或创建会话
|
||||
if session_id not in self.buffer_pool:
|
||||
self.buffer_pool[session_id] = BufferedSession(
|
||||
session_id=session_id,
|
||||
original_event=original_event
|
||||
)
|
||||
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 检查是否超过最大组件数量
|
||||
if len(session.messages) >= config.message_buffer_max_components:
|
||||
logger.info(f"会话 {session_id} 消息数量达到上限,强制合并")
|
||||
asyncio.create_task(self._force_merge_session(session_id))
|
||||
self.buffer_pool[session_id] = BufferedSession(
|
||||
session_id=session_id,
|
||||
original_event=original_event
|
||||
)
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 添加文本消息
|
||||
session.messages.append(TextMessage(text=text))
|
||||
session.original_event = original_event # 更新事件
|
||||
|
||||
# 取消之前的定时器
|
||||
await self._cancel_session_timers(session)
|
||||
|
||||
# 设置新的延迟任务
|
||||
session.delay_task = asyncio.create_task(
|
||||
self._wait_and_start_merge(session_id)
|
||||
)
|
||||
|
||||
logger.debug(f"文本消息已添加到缓冲器 {session_id}: {text[:50]}...")
|
||||
return True
|
||||
|
||||
async def _cancel_session_timers(self, session: BufferedSession):
|
||||
"""取消会话的所有定时器"""
|
||||
for task_name in ['timer_task', 'delay_task']:
|
||||
task = getattr(session, task_name)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
setattr(session, task_name, None)
|
||||
|
||||
async def _wait_and_start_merge(self, session_id: str):
|
||||
"""等待初始延迟后开始合并定时器"""
|
||||
config = features_manager.get_config()
|
||||
await asyncio.sleep(config.message_buffer_initial_delay)
|
||||
|
||||
async with self.lock:
|
||||
session = self.buffer_pool.get(session_id)
|
||||
if session and session.messages:
|
||||
# 取消旧的定时器
|
||||
if session.timer_task and not session.timer_task.done():
|
||||
session.timer_task.cancel()
|
||||
try:
|
||||
await session.timer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 设置合并定时器
|
||||
session.timer_task = asyncio.create_task(
|
||||
self._wait_and_merge(session_id)
|
||||
)
|
||||
|
||||
async def _wait_and_merge(self, session_id: str):
|
||||
"""等待合并间隔后执行合并"""
|
||||
config = features_manager.get_config()
|
||||
await asyncio.sleep(config.message_buffer_interval)
|
||||
await self._merge_session(session_id)
|
||||
|
||||
async def _force_merge_session(self, session_id: str):
|
||||
"""强制合并会话(不等待定时器)"""
|
||||
await self._merge_session(session_id, force=True)
|
||||
|
||||
async def _merge_session(self, session_id: str, force: bool = False):
|
||||
"""合并会话中的消息"""
|
||||
async with self.lock:
|
||||
session = self.buffer_pool.get(session_id)
|
||||
if not session or not session.messages:
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
return
|
||||
|
||||
try:
|
||||
# 合并文本消息
|
||||
text_parts = []
|
||||
for msg in session.messages:
|
||||
if msg.text.strip():
|
||||
text_parts.append(msg.text.strip())
|
||||
|
||||
if not text_parts:
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
return
|
||||
|
||||
merged_text = ",".join(text_parts) # 使用中文逗号连接
|
||||
message_count = len(session.messages)
|
||||
|
||||
logger.info(f"合并会话 {session_id} 的 {message_count} 条文本消息: {merged_text[:100]}...")
|
||||
|
||||
# 调用回调函数
|
||||
if self.merge_callback:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(self.merge_callback):
|
||||
await self.merge_callback(session_id, merged_text, session.original_event)
|
||||
else:
|
||||
self.merge_callback(session_id, merged_text, session.original_event)
|
||||
except Exception as e:
|
||||
logger.error(f"消息合并回调执行失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并会话 {session_id} 时出错: {e}")
|
||||
finally:
|
||||
# 清理会话
|
||||
await self._cancel_session_timers(session)
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
|
||||
async def flush_session(self, session_id: str):
|
||||
"""强制刷新指定会话的缓冲区"""
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def flush_all(self):
|
||||
"""强制刷新所有会话的缓冲区"""
|
||||
session_ids = list(self.buffer_pool.keys())
|
||||
for session_id in session_ids:
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def get_buffer_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓冲区统计信息"""
|
||||
async with self.lock:
|
||||
stats = {
|
||||
"total_sessions": len(self.buffer_pool),
|
||||
"sessions": {}
|
||||
}
|
||||
|
||||
for session_id, session in self.buffer_pool.items():
|
||||
stats["sessions"][session_id] = {
|
||||
"message_count": len(session.messages),
|
||||
"created_at": session.created_at,
|
||||
"age": time.time() - session.created_at
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
async def clear_expired_sessions(self, max_age: float = 300.0):
|
||||
"""清理过期的会话"""
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
async with self.lock:
|
||||
for session_id, session in self.buffer_pool.items():
|
||||
if current_time - session.created_at > max_age:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
logger.info(f"清理过期会话: {session_id}")
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭消息缓冲器"""
|
||||
self._shutdown = True
|
||||
logger.info("正在关闭简化消息缓冲器...")
|
||||
|
||||
# 刷新所有缓冲区
|
||||
await self.flush_all()
|
||||
|
||||
# 确保所有任务都被取消
|
||||
async with self.lock:
|
||||
for session in list(self.buffer_pool.values()):
|
||||
await self._cancel_session_timers(session)
|
||||
self.buffer_pool.clear()
|
||||
|
||||
logger.info("简化消息缓冲器已关闭")
|
||||
26
plugins/napcat_adapter_plugin/src/mmc_com_layer.py
Normal file
26
plugins/napcat_adapter_plugin/src/mmc_com_layer.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from maim_message import Router, RouteConfig, TargetConfig
|
||||
from .config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from .send_handler import send_handler
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
global_config.maibot_server.platform_name: TargetConfig(
|
||||
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
|
||||
token=None,
|
||||
)
|
||||
}
|
||||
)
|
||||
router = Router(route_config)
|
||||
|
||||
|
||||
async def mmc_start_com():
|
||||
logger.info("正在连接MaiBot")
|
||||
router.register_class_handler(send_handler.handle_message)
|
||||
await router.run()
|
||||
|
||||
|
||||
async def mmc_stop_com():
|
||||
await router.stop()
|
||||
87
plugins/napcat_adapter_plugin/src/recv_handler/__init__.py
Normal file
87
plugins/napcat_adapter_plugin/src/recv_handler/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MetaEventType:
|
||||
lifecycle = "lifecycle" # 生命周期
|
||||
|
||||
class Lifecycle:
|
||||
connect = "connect" # 生命周期 - WebSocket 连接成功
|
||||
|
||||
heartbeat = "heartbeat" # 心跳
|
||||
|
||||
|
||||
class MessageType: # 接受消息大类
|
||||
private = "private" # 私聊消息
|
||||
|
||||
class Private:
|
||||
friend = "friend" # 私聊消息 - 好友
|
||||
group = "group" # 私聊消息 - 群临时
|
||||
group_self = "group_self" # 私聊消息 - 群中自身发送
|
||||
other = "other" # 私聊消息 - 其他
|
||||
|
||||
group = "group" # 群聊消息
|
||||
|
||||
class Group:
|
||||
normal = "normal" # 群聊消息 - 普通
|
||||
anonymous = "anonymous" # 群聊消息 - 匿名消息
|
||||
notice = "notice" # 群聊消息 - 系统提示
|
||||
|
||||
|
||||
class NoticeType: # 通知事件
|
||||
friend_recall = "friend_recall" # 私聊消息撤回
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
group_ban = "group_ban" # 群禁言
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
|
||||
class GroupBan:
|
||||
ban = "ban" # 禁言
|
||||
lift_ban = "lift_ban" # 解除禁言
|
||||
|
||||
|
||||
class RealMessageType: # 实际消息分类
|
||||
text = "text" # 纯文本
|
||||
face = "face" # qq表情
|
||||
image = "image" # 图片
|
||||
record = "record" # 语音
|
||||
video = "video" # 视频
|
||||
at = "at" # @某人
|
||||
rps = "rps" # 猜拳魔法表情
|
||||
dice = "dice" # 骰子
|
||||
shake = "shake" # 私聊窗口抖动(只收)
|
||||
poke = "poke" # 群聊戳一戳
|
||||
share = "share" # 链接分享(json形式)
|
||||
reply = "reply" # 回复消息
|
||||
forward = "forward" # 转发消息
|
||||
node = "node" # 转发消息节点
|
||||
|
||||
|
||||
class MessageSentType:
|
||||
private = "private"
|
||||
|
||||
class Private:
|
||||
friend = "friend"
|
||||
group = "group"
|
||||
|
||||
group = "group"
|
||||
|
||||
class Group:
|
||||
normal = "normal"
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
SEND_POKE = "send_poke" # 戳一戳
|
||||
DELETE_MSG = "delete_msg" # 撤回消息
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command", "voiceurl", "music", "videourl", "file"]
|
||||
@@ -0,0 +1,884 @@
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from ..config import global_config
|
||||
from ..config.features_config import features_manager
|
||||
from ..message_buffer import SimpleMessageBuffer
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
get_image_base64,
|
||||
get_record_detail,
|
||||
get_self_info,
|
||||
get_message_detail,
|
||||
)
|
||||
from .qq_emoji_list import qq_face
|
||||
from .message_sending import message_send_instance
|
||||
from . import RealMessageType, MessageType, ACCEPT_FORMAT
|
||||
from ..video_handler import get_video_downloader
|
||||
from ..websocket_manager import websocket_manager
|
||||
|
||||
import time
|
||||
import json
|
||||
import websockets as Server
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import uuid
|
||||
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
TemplateInfo,
|
||||
FormatInfo,
|
||||
)
|
||||
|
||||
|
||||
from ..response_pool import get_response
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
self.bot_id_list: Dict[int, bool] = {}
|
||||
# 初始化简化消息缓冲器,传入回调函数
|
||||
self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭消息处理器,清理资源"""
|
||||
if self.message_buffer:
|
||||
await self.message_buffer.shutdown()
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
def get_server_connection(self) -> Server.ServerConnection:
|
||||
"""获取当前的服务器连接"""
|
||||
# 优先使用直接设置的连接,否则从 websocket_manager 获取
|
||||
if self.server_connection:
|
||||
return self.server_connection
|
||||
return websocket_manager.get_connection()
|
||||
|
||||
async def check_allow_to_chat(
|
||||
self,
|
||||
user_id: int,
|
||||
group_id: Optional[int] = None,
|
||||
ignore_bot: Optional[bool] = False,
|
||||
ignore_global_list: Optional[bool] = False,
|
||||
) -> bool:
|
||||
# sourcery skip: hoist-statement-from-if, merge-else-if-into-elif
|
||||
"""
|
||||
检查是否允许聊天
|
||||
Parameters:
|
||||
user_id: int: 用户ID
|
||||
group_id: int: 群ID
|
||||
ignore_bot: bool: 是否忽略机器人检查
|
||||
ignore_global_list: bool: 是否忽略全局黑名单检查
|
||||
Returns:
|
||||
bool: 是否允许聊天
|
||||
"""
|
||||
logger.debug(f"群聊id: {group_id}, 用户id: {user_id}")
|
||||
logger.debug("开始检查聊天白名单/黑名单")
|
||||
|
||||
# 使用新的权限管理器检查权限
|
||||
if group_id:
|
||||
if not features_manager.is_group_allowed(group_id):
|
||||
logger.warning("群聊不在聊天权限范围内,消息被丢弃")
|
||||
return False
|
||||
else:
|
||||
if not features_manager.is_private_allowed(user_id):
|
||||
logger.warning("私聊不在聊天权限范围内,消息被丢弃")
|
||||
return False
|
||||
|
||||
# 检查全局禁止名单
|
||||
if not ignore_global_list and features_manager.is_user_banned(user_id):
|
||||
logger.warning("用户在全局黑名单中,消息被丢弃")
|
||||
return False
|
||||
|
||||
# 检查QQ官方机器人
|
||||
if features_manager.is_qq_bot_banned() and group_id and not ignore_bot:
|
||||
logger.debug("开始判断是否为机器人")
|
||||
member_info = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if member_info:
|
||||
is_bot = member_info.get("is_robot")
|
||||
if is_bot is None:
|
||||
logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新")
|
||||
else:
|
||||
if is_bot:
|
||||
logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单")
|
||||
self.bot_id_list[user_id] = True
|
||||
return False
|
||||
else:
|
||||
self.bot_id_list[user_id] = False
|
||||
|
||||
return True
|
||||
|
||||
async def handle_raw_message(self, raw_message: dict) -> None:
|
||||
# sourcery skip: low-code-quality, remove-unreachable-code
|
||||
"""
|
||||
从Napcat接受的原始消息处理
|
||||
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
"""
|
||||
message_type: str = raw_message.get("message_type")
|
||||
message_id: int = raw_message.get("message_id")
|
||||
# message_time: int = raw_message.get("time")
|
||||
message_time: float = time.time() # 应可乐要求,现在是float了
|
||||
|
||||
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
|
||||
format_info: FormatInfo = FormatInfo(
|
||||
content_format=["text", "image", "emoji", "voice"],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
) # 格式化信息
|
||||
if message_type == MessageType.private:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
if sub_type == MessageType.Private.friend:
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
if not await self.check_allow_to_chat(sender_info.get("user_id"), None):
|
||||
return None
|
||||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
)
|
||||
|
||||
# 不存在群信息
|
||||
group_info: GroupInfo = None
|
||||
elif sub_type == MessageType.Private.group:
|
||||
"""
|
||||
本部分暂时不做支持,先放着
|
||||
"""
|
||||
logger.warning("群临时消息类型不支持")
|
||||
return None
|
||||
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
# 由于临时会话中,Napcat默认不发送成员昵称,所以需要单独获取
|
||||
fetched_member_info: dict = await get_member_info(
|
||||
self.get_server_connection(),
|
||||
raw_message.get("group_id"),
|
||||
sender_info.get("user_id"),
|
||||
)
|
||||
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=nickname,
|
||||
user_cardname=None,
|
||||
)
|
||||
|
||||
# -------------------这里需要群信息吗?-------------------
|
||||
|
||||
# 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有
|
||||
fetched_group_info: dict = await get_group_info(self.get_server_connection(), raw_message.get("group_id"))
|
||||
group_name = ""
|
||||
if fetched_group_info.get("group_name"):
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"私聊消息类型 {sub_type} 不支持")
|
||||
return None
|
||||
elif message_type == MessageType.group:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
if sub_type == MessageType.Group.normal:
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
if not await self.check_allow_to_chat(sender_info.get("user_id"), raw_message.get("group_id")):
|
||||
return None
|
||||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
)
|
||||
|
||||
# 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有
|
||||
fetched_group_info = await get_group_info(self.get_server_connection(), raw_message.get("group_id"))
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"群聊消息类型 {sub_type} 不支持")
|
||||
return None
|
||||
|
||||
additional_config: dict = {}
|
||||
if global_config.voice.use_tts:
|
||||
additional_config["allow_tts"] = True
|
||||
|
||||
# 消息信息
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id=message_id,
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
template_info=template_info,
|
||||
format_info=format_info,
|
||||
additional_config=additional_config,
|
||||
)
|
||||
|
||||
# 处理实际信息
|
||||
if not raw_message.get("message"):
|
||||
logger.warning("原始消息内容为空")
|
||||
return None
|
||||
|
||||
# 获取Seg列表
|
||||
seg_message: List[Seg] = await self.handle_real_message(raw_message)
|
||||
if not seg_message:
|
||||
logger.warning("处理后消息内容为空")
|
||||
return None
|
||||
|
||||
# 检查是否需要使用消息缓冲
|
||||
if features_manager.is_message_buffer_enabled():
|
||||
# 检查消息类型是否启用缓冲
|
||||
message_type = raw_message.get("message_type")
|
||||
should_use_buffer = False
|
||||
|
||||
if message_type == "group" and features_manager.is_message_buffer_group_enabled():
|
||||
should_use_buffer = True
|
||||
elif message_type == "private" and features_manager.is_message_buffer_private_enabled():
|
||||
should_use_buffer = True
|
||||
|
||||
if should_use_buffer:
|
||||
logger.debug(f"尝试缓冲消息,消息类型: {message_type}, 用户: {user_info.user_id}")
|
||||
logger.debug(f"原始消息段: {raw_message.get('message', [])}")
|
||||
|
||||
# 尝试添加到缓冲器
|
||||
buffered = await self.message_buffer.add_text_message(
|
||||
event_data={
|
||||
"message_type": message_type,
|
||||
"user_id": user_info.user_id,
|
||||
"group_id": group_info.group_id if group_info else None,
|
||||
},
|
||||
message=raw_message.get("message", []),
|
||||
original_event={
|
||||
"message_info": message_info,
|
||||
"raw_message": raw_message
|
||||
}
|
||||
)
|
||||
|
||||
if buffered:
|
||||
logger.info(f"✅ 文本消息已成功缓冲: {user_info.user_id}")
|
||||
return None # 缓冲成功,不立即发送
|
||||
# 如果缓冲失败(消息包含非文本元素),走正常处理流程
|
||||
logger.info(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}")
|
||||
# 缓冲失败时继续执行后面的正常处理流程,不要直接返回
|
||||
|
||||
logger.debug(f"准备发送消息到MaiBot,消息段数量: {len(seg_message)}")
|
||||
for i, seg in enumerate(seg_message):
|
||||
logger.debug(f"消息段 {i}: type={seg.type}, data={str(seg.data)[:100]}...")
|
||||
|
||||
submit_seg: Seg = Seg(
|
||||
type="seglist",
|
||||
data=seg_message,
|
||||
)
|
||||
# MessageBase创建
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=submit_seg,
|
||||
raw_message=raw_message.get("raw_message"),
|
||||
)
|
||||
|
||||
logger.info("发送到Maibot处理信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None:
|
||||
# sourcery skip: low-code-quality
|
||||
"""
|
||||
处理实际消息
|
||||
Parameters:
|
||||
real_message: dict: 实际消息
|
||||
Returns:
|
||||
seg_message: list[Seg]: 处理后的消息段列表
|
||||
"""
|
||||
real_message: list = raw_message.get("message")
|
||||
if not real_message:
|
||||
return None
|
||||
seg_message: List[Seg] = []
|
||||
for sub_message in real_message:
|
||||
sub_message: dict
|
||||
sub_message_type = sub_message.get("type")
|
||||
match sub_message_type:
|
||||
case RealMessageType.text:
|
||||
ret_seg = await self.handle_text_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("text处理失败")
|
||||
case RealMessageType.face:
|
||||
ret_seg = await self.handle_face_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("face处理失败或不支持")
|
||||
case RealMessageType.reply:
|
||||
if not in_reply:
|
||||
ret_seg = await self.handle_reply_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message += ret_seg
|
||||
else:
|
||||
logger.warning("reply处理失败")
|
||||
case RealMessageType.image:
|
||||
logger.debug(f"开始处理图片消息段")
|
||||
ret_seg = await self.handle_image_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
logger.debug(f"图片处理成功,添加到消息段")
|
||||
else:
|
||||
logger.warning("image处理失败")
|
||||
logger.debug(f"图片消息段处理完成")
|
||||
case RealMessageType.record:
|
||||
ret_seg = await self.handle_record_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.clear()
|
||||
seg_message.append(ret_seg)
|
||||
break # 使得消息只有record消息
|
||||
else:
|
||||
logger.warning("record处理失败或不支持")
|
||||
case RealMessageType.video:
|
||||
ret_seg = await self.handle_video_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("video处理失败")
|
||||
case RealMessageType.at:
|
||||
ret_seg = await self.handle_at_message(
|
||||
sub_message,
|
||||
raw_message.get("self_id"),
|
||||
raw_message.get("group_id"),
|
||||
)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("at处理失败")
|
||||
case RealMessageType.rps:
|
||||
logger.warning("暂时不支持猜拳魔法表情解析")
|
||||
case RealMessageType.dice:
|
||||
logger.warning("暂时不支持骰子表情解析")
|
||||
case RealMessageType.shake:
|
||||
# 预计等价于戳一戳
|
||||
logger.warning("暂时不支持窗口抖动解析")
|
||||
case RealMessageType.share:
|
||||
logger.warning("暂时不支持链接解析")
|
||||
case RealMessageType.forward:
|
||||
messages = await self._get_forward_message(sub_message)
|
||||
if not messages:
|
||||
logger.warning("转发消息内容为空或获取失败")
|
||||
return None
|
||||
ret_seg = await self.handle_forward_message(messages)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("转发消息处理失败")
|
||||
case RealMessageType.node:
|
||||
logger.warning("不支持转发消息节点解析")
|
||||
case _:
|
||||
logger.warning(f"未知消息类型: {sub_message_type}")
|
||||
|
||||
logger.debug(f"handle_real_message完成,处理了{len(real_message)}个消息段,生成了{len(seg_message)}个seg")
|
||||
return seg_message
|
||||
|
||||
async def handle_text_message(self, raw_message: dict) -> Seg:
|
||||
"""
|
||||
处理纯文本信息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
plain_text: str = message_data.get("text")
|
||||
return Seg(type="text", data=plain_text)
|
||||
|
||||
async def handle_face_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理表情消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
face_raw_id: str = str(message_data.get("id"))
|
||||
if face_raw_id in qq_face:
|
||||
face_content: str = qq_face.get(face_raw_id)
|
||||
return Seg(type="text", data=face_content)
|
||||
else:
|
||||
logger.warning(f"不支持的表情:{face_raw_id}")
|
||||
return None
|
||||
|
||||
async def handle_image_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理图片消息与表情包消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
image_sub_type = message_data.get("sub_type")
|
||||
try:
|
||||
logger.debug(f"开始下载图片: {message_data.get('url')}")
|
||||
image_base64 = await get_image_base64(message_data.get("url"))
|
||||
logger.debug(f"图片下载成功,大小: {len(image_base64)} 字符")
|
||||
except Exception as e:
|
||||
logger.error(f"图片消息处理失败: {str(e)}")
|
||||
return None
|
||||
if image_sub_type == 0:
|
||||
"""这部分认为是图片"""
|
||||
return Seg(type="image", data=image_base64)
|
||||
elif image_sub_type not in [4, 9]:
|
||||
"""这部分认为是表情包"""
|
||||
return Seg(type="emoji", data=image_base64)
|
||||
else:
|
||||
logger.warning(f"不支持的图片子类型:{image_sub_type}")
|
||||
return None
|
||||
|
||||
async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg | None:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
处理at消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
self_id: int: 机器人QQ号
|
||||
group_id: int: 群号
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
if message_data:
|
||||
qq_id = message_data.get("qq")
|
||||
if str(self_id) == str(qq_id):
|
||||
logger.debug("机器人被at")
|
||||
self_info: dict = await get_self_info(self.get_server_connection())
|
||||
if self_info:
|
||||
return Seg(type="text", data=f"@<{self_info.get('nickname')}:{self_info.get('user_id')}>")
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
member_info: dict = await get_member_info(self.get_server_connection(), group_id=group_id, user_id=qq_id)
|
||||
if member_info:
|
||||
return Seg(type="text", data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>")
|
||||
else:
|
||||
return None
|
||||
|
||||
async def handle_record_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理语音消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
file: str = message_data.get("file")
|
||||
if not file:
|
||||
logger.warning("语音消息缺少文件信息")
|
||||
return None
|
||||
try:
|
||||
record_detail = await get_record_detail(self.get_server_connection(), file)
|
||||
if not record_detail:
|
||||
logger.warning("获取语音消息详情失败")
|
||||
return None
|
||||
audio_base64: str = record_detail.get("base64")
|
||||
except Exception as e:
|
||||
logger.error(f"语音消息处理失败: {str(e)}")
|
||||
return None
|
||||
if not audio_base64:
|
||||
logger.error("语音消息处理失败,未获取到音频数据")
|
||||
return None
|
||||
return Seg(type="voice", data=audio_base64)
|
||||
|
||||
async def handle_video_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理视频消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
|
||||
# 添加详细的调试信息
|
||||
logger.debug(f"视频消息原始数据: {raw_message}")
|
||||
logger.debug(f"视频消息数据: {message_data}")
|
||||
|
||||
# QQ视频消息可能包含url或filePath字段
|
||||
video_url = message_data.get("url")
|
||||
file_path = message_data.get("filePath") or message_data.get("file_path")
|
||||
|
||||
logger.info(f"视频URL: {video_url}")
|
||||
logger.info(f"视频文件路径: {file_path}")
|
||||
|
||||
# 优先使用本地文件路径,其次使用URL
|
||||
video_source = file_path if file_path else video_url
|
||||
|
||||
if not video_source:
|
||||
logger.warning("视频消息缺少URL或文件路径信息")
|
||||
logger.warning(f"完整消息数据: {message_data}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 检查是否为本地文件路径
|
||||
if file_path and Path(file_path).exists():
|
||||
logger.info(f"使用本地视频文件: {file_path}")
|
||||
# 直接读取本地文件
|
||||
with open(file_path, "rb") as f:
|
||||
video_data = f.read()
|
||||
|
||||
# 将视频数据编码为base64用于传输
|
||||
video_base64 = base64.b64encode(video_data).decode('utf-8')
|
||||
logger.info(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||
|
||||
# 返回包含详细信息的字典格式
|
||||
return Seg(type="video", data={
|
||||
"base64": video_base64,
|
||||
"filename": Path(file_path).name,
|
||||
"size_mb": len(video_data) / (1024 * 1024)
|
||||
})
|
||||
|
||||
elif video_url:
|
||||
logger.info(f"使用视频URL下载: {video_url}")
|
||||
# 使用video_handler下载视频
|
||||
video_downloader = get_video_downloader()
|
||||
download_result = await video_downloader.download_video(video_url)
|
||||
|
||||
if not download_result["success"]:
|
||||
logger.warning(f"视频下载失败: {download_result.get('error', '未知错误')}")
|
||||
logger.warning(f"失败的URL: {video_url}")
|
||||
return None
|
||||
|
||||
# 将视频数据编码为base64用于传输
|
||||
video_base64 = base64.b64encode(download_result["data"]).decode('utf-8')
|
||||
logger.info(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
|
||||
|
||||
# 返回包含详细信息的字典格式
|
||||
return Seg(type="video", data={
|
||||
"base64": video_base64,
|
||||
"filename": download_result.get("filename", "video.mp4"),
|
||||
"size_mb": len(download_result["data"]) / (1024 * 1024),
|
||||
"url": video_url
|
||||
})
|
||||
|
||||
else:
|
||||
logger.warning("既没有有效的本地文件路径,也没有有效的视频URL")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"视频消息处理失败: {str(e)}")
|
||||
logger.error(f"视频源: {video_source}")
|
||||
return None
|
||||
|
||||
async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None:
|
||||
# sourcery skip: move-assign-in-block, use-named-expression
|
||||
"""
|
||||
处理回复消息
|
||||
|
||||
"""
|
||||
raw_message_data: dict = raw_message.get("data")
|
||||
message_id: int = None
|
||||
if raw_message_data:
|
||||
message_id = raw_message_data.get("id")
|
||||
else:
|
||||
return None
|
||||
message_detail: dict = await get_message_detail(self.get_server_connection(), message_id)
|
||||
if not message_detail:
|
||||
logger.warning("获取被引用的消息详情失败")
|
||||
return None
|
||||
reply_message = await self.handle_real_message(message_detail, in_reply=True)
|
||||
if reply_message is None:
|
||||
reply_message = [Seg(type="text", data="(获取发言内容失败)")]
|
||||
sender_info: dict = message_detail.get("sender")
|
||||
sender_nickname: str = sender_info.get("nickname")
|
||||
sender_id: str = sender_info.get("user_id")
|
||||
seg_message: List[Seg] = []
|
||||
if not sender_nickname:
|
||||
logger.warning("无法获取被引用的人的昵称,返回默认值")
|
||||
seg_message.append(Seg(type="text", data="[回复 未知用户:"))
|
||||
else:
|
||||
seg_message.append(Seg(type="text", data=f"[回复<{sender_nickname}:{sender_id}>:"))
|
||||
seg_message += reply_message
|
||||
seg_message.append(Seg(type="text", data="],说:"))
|
||||
return seg_message
|
||||
|
||||
async def handle_forward_message(self, message_list: list) -> Seg | None:
|
||||
"""
|
||||
递归处理转发消息,并按照动态方式确定图片处理方式
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表
|
||||
"""
|
||||
handled_message, image_count = await self._handle_forward_message(
|
||||
message_list, 0
|
||||
)
|
||||
handled_message: Seg
|
||||
image_count: int
|
||||
if not handled_message:
|
||||
return None
|
||||
|
||||
processed_message: Seg
|
||||
if image_count < 5 and image_count > 0:
|
||||
# 处理图片数量小于5的情况,此时解析图片为base64
|
||||
logger.trace("图片数量小于5,开始解析图片为base64")
|
||||
processed_message = await self._recursive_parse_image_seg(
|
||||
handled_message, True
|
||||
)
|
||||
elif image_count > 0:
|
||||
logger.trace("图片数量大于等于5,开始解析图片为占位符")
|
||||
# 处理图片数量大于等于5的情况,此时解析图片为占位符
|
||||
processed_message = await self._recursive_parse_image_seg(
|
||||
handled_message, False
|
||||
)
|
||||
else:
|
||||
# 处理没有图片的情况,此时直接返回
|
||||
logger.trace("没有图片,直接返回")
|
||||
processed_message = handled_message
|
||||
|
||||
# 添加转发消息提示
|
||||
forward_hint = Seg(type="text", data="这是一条转发消息:\n")
|
||||
return Seg(type="seglist", data=[forward_hint, processed_message])
|
||||
|
||||
async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg:
|
||||
# sourcery skip: merge-else-if-into-elif
|
||||
if to_image:
|
||||
if seg_data.type == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.data:
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return Seg(type="seglist", data=new_seg_list)
|
||||
elif seg_data.type == "image":
|
||||
image_url = seg_data.data
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return Seg(type="text", data="[图片]")
|
||||
return Seg(type="image", data=encoded_image)
|
||||
elif seg_data.type == "emoji":
|
||||
image_url = seg_data.data
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return Seg(type="text", data="[表情包]")
|
||||
return Seg(type="emoji", data=encoded_image)
|
||||
else:
|
||||
logger.trace(f"不处理类型: {seg_data.type}")
|
||||
return seg_data
|
||||
else:
|
||||
if seg_data.type == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.data:
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return Seg(type="seglist", data=new_seg_list)
|
||||
elif seg_data.type == "image":
|
||||
return Seg(type="text", data="[图片]")
|
||||
elif seg_data.type == "emoji":
|
||||
return Seg(type="text", data="[动画表情]")
|
||||
else:
|
||||
logger.trace(f"不处理类型: {seg_data.type}")
|
||||
return seg_data
|
||||
|
||||
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]:
|
||||
# sourcery skip: low-code-quality
|
||||
"""
|
||||
递归处理实际转发消息
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段
|
||||
layer: int: 当前层级
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
image_count: int: 图片数量
|
||||
"""
|
||||
seg_list: List[Seg] = []
|
||||
image_count = 0
|
||||
if message_list is None:
|
||||
return None, 0
|
||||
for sub_message in message_list:
|
||||
sub_message: dict
|
||||
sender_info: dict = sub_message.get("sender")
|
||||
user_nickname: str = sender_info.get("nickname", "QQ用户")
|
||||
user_nickname_str = f"【{user_nickname}】:"
|
||||
break_seg = Seg(type="text", data="\n")
|
||||
message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
|
||||
if not message_of_sub_message_list:
|
||||
logger.warning("转发消息内容为空")
|
||||
continue
|
||||
message_of_sub_message = message_of_sub_message_list[0]
|
||||
if message_of_sub_message.get("type") == RealMessageType.forward:
|
||||
if layer >= 3:
|
||||
full_seg_data = Seg(
|
||||
type="text",
|
||||
data=("--" * layer) + f"【{user_nickname}】:【转发消息】\n",
|
||||
)
|
||||
else:
|
||||
sub_message_data = message_of_sub_message.get("data")
|
||||
if not sub_message_data:
|
||||
continue
|
||||
contents = sub_message_data.get("content")
|
||||
seg_data, count = await self._handle_forward_message(contents, layer + 1)
|
||||
image_count += count
|
||||
head_tip = Seg(
|
||||
type="text",
|
||||
data=("--" * layer) + f"【{user_nickname}】: 合并转发消息内容:\n",
|
||||
)
|
||||
full_seg_data = Seg(type="seglist", data=[head_tip, seg_data])
|
||||
seg_list.append(full_seg_data)
|
||||
elif message_of_sub_message.get("type") == RealMessageType.text:
|
||||
sub_message_data = message_of_sub_message.get("data")
|
||||
if not sub_message_data:
|
||||
continue
|
||||
text_message = sub_message_data.get("text")
|
||||
seg_data = Seg(type="text", data=text_message)
|
||||
data_list: List[Any] = []
|
||||
if layer > 0:
|
||||
data_list = [
|
||||
Seg(type="text", data=("--" * layer) + user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
else:
|
||||
data_list = [
|
||||
Seg(type="text", data=user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
seg_list.append(Seg(type="seglist", data=data_list))
|
||||
elif message_of_sub_message.get("type") == RealMessageType.image:
|
||||
image_count += 1
|
||||
image_data = message_of_sub_message.get("data")
|
||||
sub_type = image_data.get("sub_type")
|
||||
image_url = image_data.get("url")
|
||||
data_list: List[Any] = []
|
||||
if sub_type == 0:
|
||||
seg_data = Seg(type="image", data=image_url)
|
||||
else:
|
||||
seg_data = Seg(type="emoji", data=image_url)
|
||||
if layer > 0:
|
||||
data_list = [
|
||||
Seg(type="text", data=("--" * layer) + user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
else:
|
||||
data_list = [
|
||||
Seg(type="text", data=user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
full_seg_data = Seg(type="seglist", data=data_list)
|
||||
seg_list.append(full_seg_data)
|
||||
return Seg(type="seglist", data=seg_list), image_count
|
||||
|
||||
async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None:
|
||||
forward_message_data: Dict = raw_message.get("data")
|
||||
if not forward_message_data:
|
||||
logger.warning("转发消息内容为空")
|
||||
return None
|
||||
forward_message_id = forward_message_data.get("id")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_forward_msg",
|
||||
"params": {"message_id": forward_message_id},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
connection = self.get_server_connection()
|
||||
if not connection:
|
||||
logger.error("没有可用的 WebSocket 连接")
|
||||
return None
|
||||
await connection.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("获取转发消息超时")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取转发消息失败: {str(e)}")
|
||||
return None
|
||||
logger.debug(
|
||||
f"转发消息原始格式:{json.dumps(response)[:80]}..."
|
||||
if len(json.dumps(response)) > 80
|
||||
else json.dumps(response)
|
||||
)
|
||||
response_data: Dict = response.get("data")
|
||||
if not response_data:
|
||||
logger.warning("转发消息内容为空或获取失败")
|
||||
return None
|
||||
return response_data.get("messages")
|
||||
|
||||
async def _send_buffered_message(self, session_id: str, merged_text: str, original_event: Dict[str, Any]):
|
||||
"""发送缓冲的合并消息"""
|
||||
try:
|
||||
# 从原始事件数据中提取信息
|
||||
message_info = original_event.get("message_info")
|
||||
raw_message = original_event.get("raw_message")
|
||||
|
||||
if not message_info or not raw_message:
|
||||
logger.error("缓冲消息缺少必要信息")
|
||||
return
|
||||
|
||||
# 创建合并后的消息段 - 将合并的文本转换为Seg格式
|
||||
from maim_message import Seg
|
||||
merged_seg = Seg(type="text", data=merged_text)
|
||||
submit_seg = Seg(type="seglist", data=[merged_seg])
|
||||
|
||||
# 创建新的消息ID
|
||||
import time
|
||||
new_message_id = f"buffered-{message_info.message_id}-{int(time.time() * 1000)}"
|
||||
|
||||
# 更新消息信息
|
||||
from maim_message import BaseMessageInfo, MessageBase
|
||||
buffered_message_info = BaseMessageInfo(
|
||||
platform=message_info.platform,
|
||||
message_id=new_message_id,
|
||||
time=time.time(),
|
||||
user_info=message_info.user_info,
|
||||
group_info=message_info.group_info,
|
||||
template_info=message_info.template_info,
|
||||
format_info=message_info.format_info,
|
||||
additional_config=message_info.additional_config,
|
||||
)
|
||||
|
||||
# 创建MessageBase
|
||||
message_base = MessageBase(
|
||||
message_info=buffered_message_info,
|
||||
message_segment=submit_seg,
|
||||
raw_message=raw_message.get("raw_message", ""),
|
||||
)
|
||||
|
||||
logger.info(f"发送缓冲合并消息到Maibot处理: {session_id}")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送缓冲消息失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
message_handler = MessageHandler()
|
||||
@@ -0,0 +1,32 @@
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from maim_message import MessageBase, Router
|
||||
|
||||
|
||||
class MessageSending:
|
||||
"""
|
||||
负责把消息发送到麦麦
|
||||
"""
|
||||
|
||||
maibot_router: Router = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def message_send(self, message_base: MessageBase) -> bool:
|
||||
"""
|
||||
发送消息
|
||||
Parameters:
|
||||
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
|
||||
"""
|
||||
try:
|
||||
send_status = await self.maibot_router.send_message(message_base)
|
||||
if not send_status:
|
||||
raise RuntimeError("可能是路由未正确配置或连接异常")
|
||||
return send_status
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {str(e)}")
|
||||
logger.error("请检查与MaiBot之间的连接")
|
||||
|
||||
|
||||
message_send_instance = MessageSending()
|
||||
@@ -0,0 +1,50 @@
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from ..config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from . import MetaEventType
|
||||
|
||||
|
||||
class MetaEventHandler:
|
||||
"""
|
||||
处理Meta事件
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.interval = global_config.napcat_server.heartbeat_interval
|
||||
self._interval_checking = False
|
||||
|
||||
async def handle_meta_event(self, message: dict) -> None:
|
||||
event_type = message.get("meta_event_type")
|
||||
if event_type == MetaEventType.lifecycle:
|
||||
sub_type = message.get("sub_type")
|
||||
if sub_type == MetaEventType.Lifecycle.connect:
|
||||
self_id = message.get("self_id")
|
||||
self.last_heart_beat = time.time()
|
||||
logger.info(f"Bot {self_id} 连接成功")
|
||||
asyncio.create_task(self.check_heartbeat(self_id))
|
||||
elif event_type == MetaEventType.heartbeat:
|
||||
if message["status"].get("online") and message["status"].get("good"):
|
||||
if not self._interval_checking:
|
||||
asyncio.create_task(self.check_heartbeat())
|
||||
self.last_heart_beat = time.time()
|
||||
self.interval = message.get("interval") / 1000
|
||||
else:
|
||||
self_id = message.get("self_id")
|
||||
logger.warning(f"Bot {self_id} Napcat 端异常!")
|
||||
|
||||
async def check_heartbeat(self, id: int) -> None:
|
||||
self._interval_checking = True
|
||||
while True:
|
||||
now_time = time.time()
|
||||
if now_time - self.last_heart_beat > self.interval * 2:
|
||||
logger.error(f"Bot {id} 可能发生了连接断开,被下线,或者Napcat卡死!")
|
||||
break
|
||||
else:
|
||||
logger.debug("心跳正常")
|
||||
await asyncio.sleep(self.interval)
|
||||
|
||||
|
||||
meta_event_handler = MetaEventHandler()
|
||||
549
plugins/napcat_adapter_plugin/src/recv_handler/notice_handler.py
Normal file
549
plugins/napcat_adapter_plugin/src/recv_handler/notice_handler.py
Normal file
@@ -0,0 +1,549 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import websockets as Server
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from ..config import global_config
|
||||
from ..config.features_config import features_manager
|
||||
from ..database import BanUser, db_manager, is_identical
|
||||
from . import NoticeType, ACCEPT_FORMAT
|
||||
from .message_sending import message_send_instance
|
||||
from .message_handler import message_handler
|
||||
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
|
||||
from ..websocket_manager import websocket_manager
|
||||
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
get_self_info,
|
||||
get_stranger_info,
|
||||
read_ban_list,
|
||||
)
|
||||
|
||||
notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100)
|
||||
unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3)
|
||||
|
||||
|
||||
class NoticeHandler:
|
||||
banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表
|
||||
lifted_list: list[BanUser] = [] # 已经自然解除禁言
|
||||
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection | None = None
|
||||
self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
while self.server_connection.state != Server.State.OPEN:
|
||||
await asyncio.sleep(0.5)
|
||||
self.banned_list, self.lifted_list = await read_ban_list(self.server_connection)
|
||||
|
||||
asyncio.create_task(self.auto_lift_detect())
|
||||
asyncio.create_task(self.send_notice())
|
||||
asyncio.create_task(self.handle_natural_lift())
|
||||
|
||||
def get_server_connection(self) -> Server.ServerConnection:
|
||||
"""获取当前的服务器连接"""
|
||||
# 优先使用直接设置的连接,否则从 websocket_manager 获取
|
||||
if self.server_connection:
|
||||
return self.server_connection
|
||||
return websocket_manager.get_connection()
|
||||
|
||||
def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
|
||||
"""
|
||||
将用户禁言记录添加到self.banned_list中
|
||||
如果是全体禁言,则user_id为0
|
||||
"""
|
||||
if user_id is None:
|
||||
user_id = 0 # 使用0表示全体禁言
|
||||
lift_time = -1
|
||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
|
||||
for record in self.banned_list:
|
||||
if is_identical(record, ban_record):
|
||||
self.banned_list.remove(record)
|
||||
self.banned_list.append(ban_record)
|
||||
db_manager.create_ban_record(ban_record) # 作为更新
|
||||
return
|
||||
self.banned_list.append(ban_record)
|
||||
db_manager.create_ban_record(ban_record) # 添加到数据库
|
||||
|
||||
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
从self.lifted_group_list中移除已经解除全体禁言的群
|
||||
"""
|
||||
if user_id is None:
|
||||
user_id = 0 # 使用0表示全体禁言
|
||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
|
||||
self.lifted_list.append(ban_record)
|
||||
db_manager.delete_ban_record(ban_record) # 删除数据库中的记录
|
||||
|
||||
async def handle_notice(self, raw_message: dict) -> None:
|
||||
notice_type = raw_message.get("notice_type")
|
||||
# message_time: int = raw_message.get("time")
|
||||
message_time: float = time.time() # 应可乐要求,现在是float了
|
||||
|
||||
group_id = raw_message.get("group_id")
|
||||
user_id = raw_message.get("user_id")
|
||||
target_id = raw_message.get("target_id")
|
||||
|
||||
handled_message: Seg = None
|
||||
user_info: UserInfo = None
|
||||
system_notice: bool = False
|
||||
|
||||
match notice_type:
|
||||
case NoticeType.friend_recall:
|
||||
logger.info("好友撤回一条消息")
|
||||
logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
|
||||
logger.warning("暂时不支持撤回消息处理")
|
||||
case NoticeType.group_recall:
|
||||
logger.info("群内用户撤回一条消息")
|
||||
logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
|
||||
logger.warning("暂时不支持撤回消息处理")
|
||||
case NoticeType.notify:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.Notify.poke:
|
||||
if features_manager.is_poke_enabled() and await message_handler.check_allow_to_chat(
|
||||
user_id, group_id, False, False
|
||||
):
|
||||
logger.info("处理戳一戳消息")
|
||||
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
||||
else:
|
||||
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
|
||||
case _:
|
||||
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
||||
case NoticeType.group_ban:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.GroupBan.ban:
|
||||
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
return None
|
||||
logger.info("处理群禁言")
|
||||
handled_message, user_info = await self.handle_ban_notify(raw_message, group_id)
|
||||
system_notice = True
|
||||
case NoticeType.GroupBan.lift_ban:
|
||||
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
return None
|
||||
logger.info("处理解除群禁言")
|
||||
handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id)
|
||||
system_notice = True
|
||||
case _:
|
||||
logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}")
|
||||
case _:
|
||||
logger.warning(f"不支持的notice类型: {notice_type}")
|
||||
return None
|
||||
if not handled_message or not user_info:
|
||||
logger.warning("notice处理失败或不支持")
|
||||
return None
|
||||
|
||||
group_info: GroupInfo = None
|
||||
if group_id:
|
||||
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
else:
|
||||
logger.warning("无法获取notice消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id="notice",
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
template_info=None,
|
||||
format_info=FormatInfo(
|
||||
content_format=["text", "notify"],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
),
|
||||
additional_config={"target_id": target_id}, # 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁
|
||||
)
|
||||
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=handled_message,
|
||||
raw_message=json.dumps(raw_message),
|
||||
)
|
||||
|
||||
if system_notice:
|
||||
await self.put_notice(message_base)
|
||||
else:
|
||||
logger.info("发送到Maibot处理通知信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_poke_notify(
|
||||
self, raw_message: dict, group_id: int, user_id: int
|
||||
) -> Tuple[Seg | None, UserInfo | None]:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches
|
||||
self_info: dict = await get_self_info(self.get_server_connection())
|
||||
|
||||
if not self_info:
|
||||
logger.error("自身信息获取失败")
|
||||
return None, None
|
||||
|
||||
self_id = raw_message.get("self_id")
|
||||
target_id = raw_message.get("target_id")
|
||||
|
||||
# 防抖检查:如果是针对机器人的戳一戳,检查防抖时间
|
||||
if self_id == target_id:
|
||||
current_time = time.time()
|
||||
debounce_seconds = features_manager.get_config().poke_debounce_seconds
|
||||
|
||||
if self.last_poke_time > 0:
|
||||
time_diff = current_time - self.last_poke_time
|
||||
if time_diff < debounce_seconds:
|
||||
logger.info(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)")
|
||||
return None, None
|
||||
|
||||
# 记录这次戳一戳的时间
|
||||
self.last_poke_time = current_time
|
||||
|
||||
target_name: str = None
|
||||
raw_info: list = raw_message.get("raw_info")
|
||||
|
||||
if group_id:
|
||||
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
else:
|
||||
user_qq_info: dict = await get_stranger_info(self.get_server_connection(), user_id)
|
||||
if user_qq_info:
|
||||
user_name = user_qq_info.get("nickname")
|
||||
user_cardname = user_qq_info.get("card")
|
||||
else:
|
||||
user_name = "QQ用户"
|
||||
user_cardname = "QQ用户"
|
||||
logger.info("无法获取戳一戳对方的用户昵称")
|
||||
|
||||
# 计算Seg
|
||||
if self_id == target_id:
|
||||
display_name = ""
|
||||
target_name = self_info.get("nickname")
|
||||
|
||||
elif self_id == user_id:
|
||||
# 让ada不发送麦麦戳别人的消息
|
||||
return None, None
|
||||
|
||||
else:
|
||||
# 如果配置为忽略不是针对自己的戳一戳,则直接返回None
|
||||
if features_manager.is_non_self_poke_ignored():
|
||||
logger.info("忽略不是针对自己的戳一戳消息")
|
||||
return None, None
|
||||
|
||||
# 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境
|
||||
if group_id:
|
||||
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, target_id)
|
||||
if fetched_member_info:
|
||||
target_name = fetched_member_info.get("nickname")
|
||||
else:
|
||||
target_name = "QQ用户"
|
||||
logger.info("无法获取被戳一戳方的用户昵称")
|
||||
display_name = user_name
|
||||
else:
|
||||
return None, None
|
||||
|
||||
first_txt: str = "戳了戳"
|
||||
second_txt: str = ""
|
||||
try:
|
||||
first_txt = raw_info[2].get("txt", "戳了戳")
|
||||
second_txt = raw_info[4].get("txt", "")
|
||||
except Exception as e:
|
||||
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
|
||||
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_name,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="text",
|
||||
data=f"{display_name}{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)",
|
||||
)
|
||||
return seg_data, user_info
|
||||
|
||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理禁言通知")
|
||||
return None, None
|
||||
|
||||
# 计算user_info
|
||||
operator_id = raw_message.get("operator_id")
|
||||
operator_nickname: str = None
|
||||
operator_cardname: str = None
|
||||
|
||||
member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id)
|
||||
if member_info:
|
||||
operator_nickname = member_info.get("nickname")
|
||||
operator_cardname = member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取禁言执行者的昵称,消息可能会无效")
|
||||
operator_nickname = "QQ用户"
|
||||
|
||||
operator_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=operator_id,
|
||||
user_nickname=operator_nickname,
|
||||
user_cardname=operator_cardname,
|
||||
)
|
||||
|
||||
# 计算Seg
|
||||
user_id = raw_message.get("user_id")
|
||||
banned_user_info: UserInfo = None
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
sub_type: str = None
|
||||
|
||||
duration = raw_message.get("duration")
|
||||
if duration is None:
|
||||
logger.error("禁言时长不能为空,无法处理禁言通知")
|
||||
return None, None
|
||||
|
||||
if user_id == 0: # 为全体禁言
|
||||
sub_type: str = "whole_ban"
|
||||
self._ban_operation(group_id)
|
||||
else: # 为单人禁言
|
||||
# 获取被禁言人的信息
|
||||
sub_type: str = "ban"
|
||||
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
banned_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
self._ban_operation(group_id, user_id, int(time.time() + duration))
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": sub_type,
|
||||
"duration": duration,
|
||||
"banned_user_info": banned_user_info.to_dict() if banned_user_info else None,
|
||||
},
|
||||
)
|
||||
|
||||
return seg_data, operator_info
|
||||
|
||||
async def handle_lift_ban_notify(
|
||||
self, raw_message: dict, group_id: int
|
||||
) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理解除禁言通知")
|
||||
return None, None
|
||||
|
||||
# 计算user_info
|
||||
operator_id = raw_message.get("operator_id")
|
||||
operator_nickname: str = None
|
||||
operator_cardname: str = None
|
||||
|
||||
member_info: dict = await get_member_info(self.get_server_connection(), group_id, operator_id)
|
||||
if member_info:
|
||||
operator_nickname = member_info.get("nickname")
|
||||
operator_cardname = member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效")
|
||||
operator_nickname = "QQ用户"
|
||||
|
||||
operator_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=operator_id,
|
||||
user_nickname=operator_nickname,
|
||||
user_cardname=operator_cardname,
|
||||
)
|
||||
|
||||
# 计算Seg
|
||||
sub_type: str = None
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
lifted_user_info: UserInfo = None
|
||||
|
||||
user_id = raw_message.get("user_id")
|
||||
if user_id == 0: # 全体禁言解除
|
||||
sub_type = "whole_lift_ban"
|
||||
self._lift_operation(group_id)
|
||||
else: # 单人禁言解除
|
||||
sub_type = "lift_ban"
|
||||
# 获取被解除禁言人的信息
|
||||
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
|
||||
lifted_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
self._lift_operation(group_id, user_id)
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": sub_type,
|
||||
"lifted_user_info": lifted_user_info.to_dict() if lifted_user_info else None,
|
||||
},
|
||||
)
|
||||
return seg_data, operator_info
|
||||
|
||||
async def put_notice(self, message_base: MessageBase) -> None:
|
||||
"""
|
||||
将处理后的通知消息放入通知队列
|
||||
"""
|
||||
if notice_queue.full() or unsuccessful_notice_queue.full():
|
||||
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
|
||||
else:
|
||||
await notice_queue.put(message_base)
|
||||
|
||||
async def handle_natural_lift(self) -> None:
|
||||
while True:
|
||||
if len(self.lifted_list) != 0:
|
||||
lift_record = self.lifted_list.pop()
|
||||
group_id = lift_record.group_id
|
||||
user_id = lift_record.user_id
|
||||
|
||||
db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录
|
||||
|
||||
seg_message: Seg = await self.natural_lift(group_id, user_id)
|
||||
|
||||
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
else:
|
||||
logger.warning("无法获取notice消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id="notice",
|
||||
time=time.time(),
|
||||
user_info=None, # 自然解除禁言没有操作者
|
||||
group_info=group_info,
|
||||
template_info=None,
|
||||
format_info=None,
|
||||
)
|
||||
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=seg_message,
|
||||
raw_message=json.dumps(
|
||||
{
|
||||
"post_type": "notice",
|
||||
"notice_type": "group_ban",
|
||||
"sub_type": "lift_ban",
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"operator_id": None, # 自然解除禁言没有操作者
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
await self.put_notice(message_base)
|
||||
await asyncio.sleep(0.5) # 确保队列处理间隔
|
||||
else:
|
||||
await asyncio.sleep(5) # 每5秒检查一次
|
||||
|
||||
async def natural_lift(self, group_id: int, user_id: int) -> Seg | None:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理解除禁言通知")
|
||||
return None
|
||||
|
||||
if user_id == 0: # 理论上永远不会触发
|
||||
return Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": "whole_lift_ban",
|
||||
"lifted_user_info": None,
|
||||
},
|
||||
)
|
||||
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
fetched_member_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
|
||||
lifted_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
return Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": "lift_ban",
|
||||
"lifted_user_info": lifted_user_info.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
async def auto_lift_detect(self) -> None:
|
||||
while True:
|
||||
if len(self.banned_list) == 0:
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
for ban_record in self.banned_list:
|
||||
if ban_record.user_id == 0 or ban_record.lift_time == -1:
|
||||
continue
|
||||
if ban_record.lift_time <= int(time.time()):
|
||||
# 触发自然解除禁言
|
||||
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
|
||||
self.lifted_list.append(ban_record)
|
||||
self.banned_list.remove(ban_record)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def send_notice(self) -> None:
|
||||
"""
|
||||
发送通知消息到Napcat
|
||||
"""
|
||||
while True:
|
||||
if not unsuccessful_notice_queue.empty():
|
||||
to_be_send: MessageBase = await unsuccessful_notice_queue.get()
|
||||
try:
|
||||
send_status = await message_send_instance.message_send(to_be_send)
|
||||
if send_status:
|
||||
unsuccessful_notice_queue.task_done()
|
||||
else:
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
except Exception as e:
|
||||
logger.error(f"发送通知消息失败: {str(e)}")
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
to_be_send: MessageBase = await notice_queue.get()
|
||||
try:
|
||||
send_status = await message_send_instance.message_send(to_be_send)
|
||||
if send_status:
|
||||
notice_queue.task_done()
|
||||
else:
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
except Exception as e:
|
||||
logger.error(f"发送通知消息失败: {str(e)}")
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
|
||||
|
||||
notice_handler = NoticeHandler()
|
||||
250
plugins/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py
Normal file
250
plugins/napcat_adapter_plugin/src/recv_handler/qq_emoji_list.py
Normal file
@@ -0,0 +1,250 @@
|
||||
qq_face: dict = {
|
||||
"0": "[表情:惊讶]",
|
||||
"1": "[表情:撇嘴]",
|
||||
"2": "[表情:色]",
|
||||
"3": "[表情:发呆]",
|
||||
"4": "[表情:得意]",
|
||||
"5": "[表情:流泪]",
|
||||
"6": "[表情:害羞]",
|
||||
"7": "[表情:闭嘴]",
|
||||
"8": "[表情:睡]",
|
||||
"9": "[表情:大哭]",
|
||||
"10": "[表情:尴尬]",
|
||||
"11": "[表情:发怒]",
|
||||
"12": "[表情:调皮]",
|
||||
"13": "[表情:呲牙]",
|
||||
"14": "[表情:微笑]",
|
||||
"15": "[表情:难过]",
|
||||
"16": "[表情:酷]",
|
||||
"18": "[表情:抓狂]",
|
||||
"19": "[表情:吐]",
|
||||
"20": "[表情:偷笑]",
|
||||
"21": "[表情:可爱]",
|
||||
"22": "[表情:白眼]",
|
||||
"23": "[表情:傲慢]",
|
||||
"24": "[表情:饥饿]",
|
||||
"25": "[表情:困]",
|
||||
"26": "[表情:惊恐]",
|
||||
"27": "[表情:流汗]",
|
||||
"28": "[表情:憨笑]",
|
||||
"29": "[表情:悠闲]",
|
||||
"30": "[表情:奋斗]",
|
||||
"31": "[表情:咒骂]",
|
||||
"32": "[表情:疑问]",
|
||||
"33": "[表情: 嘘]",
|
||||
"34": "[表情:晕]",
|
||||
"35": "[表情:折磨]",
|
||||
"36": "[表情:衰]",
|
||||
"37": "[表情:骷髅]",
|
||||
"38": "[表情:敲打]",
|
||||
"39": "[表情:再见]",
|
||||
"41": "[表情:发抖]",
|
||||
"42": "[表情:爱情]",
|
||||
"43": "[表情:跳跳]",
|
||||
"46": "[表情:猪头]",
|
||||
"49": "[表情:拥抱]",
|
||||
"53": "[表情:蛋糕]",
|
||||
"56": "[表情:刀]",
|
||||
"59": "[表情:便便]",
|
||||
"60": "[表情:咖啡]",
|
||||
"63": "[表情:玫瑰]",
|
||||
"64": "[表情:凋谢]",
|
||||
"66": "[表情:爱心]",
|
||||
"67": "[表情:心碎]",
|
||||
"74": "[表情:太阳]",
|
||||
"75": "[表情:月亮]",
|
||||
"76": "[表情:赞]",
|
||||
"77": "[表情:踩]",
|
||||
"78": "[表情:握手]",
|
||||
"79": "[表情:胜利]",
|
||||
"85": "[表情:飞吻]",
|
||||
"86": "[表情:怄火]",
|
||||
"89": "[表情:西瓜]",
|
||||
"96": "[表情:冷汗]",
|
||||
"97": "[表情:擦汗]",
|
||||
"98": "[表情:抠鼻]",
|
||||
"99": "[表情:鼓掌]",
|
||||
"100": "[表情:糗大了]",
|
||||
"101": "[表情:坏笑]",
|
||||
"102": "[表情:左哼哼]",
|
||||
"103": "[表情:右哼哼]",
|
||||
"104": "[表情:哈欠]",
|
||||
"105": "[表情:鄙视]",
|
||||
"106": "[表情:委屈]",
|
||||
"107": "[表情:快哭了]",
|
||||
"108": "[表情:阴险]",
|
||||
"109": "[表情:左亲亲]",
|
||||
"110": "[表情:吓]",
|
||||
"111": "[表情:可怜]",
|
||||
"112": "[表情:菜刀]",
|
||||
"114": "[表情:篮球]",
|
||||
"116": "[表情:示爱]",
|
||||
"118": "[表情:抱拳]",
|
||||
"119": "[表情:勾引]",
|
||||
"120": "[表情:拳头]",
|
||||
"121": "[表情:差劲]",
|
||||
"123": "[表情:NO]",
|
||||
"124": "[表情:OK]",
|
||||
"125": "[表情:转圈]",
|
||||
"129": "[表情:挥手]",
|
||||
"137": "[表情:鞭炮]",
|
||||
"144": "[表情:喝彩]",
|
||||
"146": "[表情:爆筋]",
|
||||
"147": "[表情:棒棒糖]",
|
||||
"169": "[表情:手枪]",
|
||||
"171": "[表情:茶]",
|
||||
"172": "[表情:眨眼睛]",
|
||||
"173": "[表情:泪奔]",
|
||||
"174": "[表情:无奈]",
|
||||
"175": "[表情:卖萌]",
|
||||
"176": "[表情:小纠结]",
|
||||
"177": "[表情:喷血]",
|
||||
"178": "[表情:斜眼笑]",
|
||||
"179": "[表情:doge]",
|
||||
"181": "[表情:戳一戳]",
|
||||
"182": "[表情:笑哭]",
|
||||
"183": "[表情:我最美]",
|
||||
"185": "[表情:羊驼]",
|
||||
"187": "[表情:幽灵]",
|
||||
"201": "[表情:点赞]",
|
||||
"212": "[表情:托腮]",
|
||||
"262": "[表情:脑阔疼]",
|
||||
"263": "[表情:沧桑]",
|
||||
"264": "[表情:捂脸]",
|
||||
"265": "[表情:辣眼睛]",
|
||||
"266": "[表情:哦哟]",
|
||||
"267": "[表情:头秃]",
|
||||
"268": "[表情:问号脸]",
|
||||
"269": "[表情:暗中观察]",
|
||||
"270": "[表情:emm]",
|
||||
"271": "[表情:吃 瓜]",
|
||||
"272": "[表情:呵呵哒]",
|
||||
"273": "[表情:我酸了]",
|
||||
"277": "[表情:汪汪]",
|
||||
"281": "[表情:无眼笑]",
|
||||
"282": "[表情:敬礼]",
|
||||
"283": "[表情:狂笑]",
|
||||
"284": "[表情:面无表情]",
|
||||
"285": "[表情:摸鱼]",
|
||||
"286": "[表情:魔鬼笑]",
|
||||
"287": "[表情:哦]",
|
||||
"289": "[表情:睁眼]",
|
||||
"293": "[表情:摸锦鲤]",
|
||||
"294": "[表情:期待]",
|
||||
"295": "[表情:拿到红包]",
|
||||
"297": "[表情:拜谢]",
|
||||
"298": "[表情:元宝]",
|
||||
"299": "[表情:牛啊]",
|
||||
"300": "[表情:胖三斤]",
|
||||
"302": "[表情:左拜年]",
|
||||
"303": "[表情:右拜年]",
|
||||
"305": "[表情:右亲亲]",
|
||||
"306": "[表情:牛气冲天]",
|
||||
"307": "[表情:喵喵]",
|
||||
"311": "[表情:打call]",
|
||||
"312": "[表情:变形]",
|
||||
"314": "[表情:仔细分析]",
|
||||
"317": "[表情:菜汪]",
|
||||
"318": "[表情:崇拜]",
|
||||
"319": "[表情: 比心]",
|
||||
"320": "[表情:庆祝]",
|
||||
"323": "[表情:嫌弃]",
|
||||
"324": "[表情:吃糖]",
|
||||
"325": "[表情:惊吓]",
|
||||
"326": "[表情:生气]",
|
||||
"332": "[表情:举牌牌]",
|
||||
"333": "[表情:烟花]",
|
||||
"334": "[表情:虎虎生威]",
|
||||
"336": "[表情:豹富]",
|
||||
"337": "[表情:花朵脸]",
|
||||
"338": "[表情:我想开了]",
|
||||
"339": "[表情:舔屏]",
|
||||
"341": "[表情:打招呼]",
|
||||
"342": "[表情:酸Q]",
|
||||
"343": "[表情:我方了]",
|
||||
"344": "[表情:大怨种]",
|
||||
"345": "[表情:红包多多]",
|
||||
"346": "[表情:你真棒棒]",
|
||||
"347": "[表情:大展宏兔]",
|
||||
"349": "[表情:坚强]",
|
||||
"350": "[表情:贴贴]",
|
||||
"351": "[表情:敲敲]",
|
||||
"352": "[表情:咦]",
|
||||
"353": "[表情:拜托]",
|
||||
"354": "[表情:尊嘟假嘟]",
|
||||
"355": "[表情:耶]",
|
||||
"356": "[表情:666]",
|
||||
"357": "[表情:裂开]",
|
||||
"392": "[表情:龙年 快乐]",
|
||||
"393": "[表情:新年中龙]",
|
||||
"394": "[表情:新年大龙]",
|
||||
"395": "[表情:略略略]",
|
||||
"😊": "[表情:嘿嘿]",
|
||||
"😌": "[表情:羞涩]",
|
||||
"😚": "[ 表情:亲亲]",
|
||||
"😓": "[表情:汗]",
|
||||
"😰": "[表情:紧张]",
|
||||
"😝": "[表情:吐舌]",
|
||||
"😁": "[表情:呲牙]",
|
||||
"😜": "[表情:淘气]",
|
||||
"☺": "[表情:可爱]",
|
||||
"😍": "[表情:花痴]",
|
||||
"😔": "[表情:失落]",
|
||||
"😄": "[表情:高兴]",
|
||||
"😏": "[表情:哼哼]",
|
||||
"😒": "[表情:不屑]",
|
||||
"😳": "[表情:瞪眼]",
|
||||
"😘": "[表情:飞吻]",
|
||||
"😭": "[表情:大哭]",
|
||||
"😱": "[表情:害怕]",
|
||||
"😂": "[表情:激动]",
|
||||
"💪": "[表情:肌肉]",
|
||||
"👊": "[表情:拳头]",
|
||||
"👍": "[表情 :厉害]",
|
||||
"👏": "[表情:鼓掌]",
|
||||
"👎": "[表情:鄙视]",
|
||||
"🙏": "[表情:合十]",
|
||||
"👌": "[表情:好的]",
|
||||
"👆": "[表情:向上]",
|
||||
"👀": "[表情:眼睛]",
|
||||
"🍜": "[表情:拉面]",
|
||||
"🍧": "[表情:刨冰]",
|
||||
"🍞": "[表情:面包]",
|
||||
"🍺": "[表情:啤酒]",
|
||||
"🍻": "[表情:干杯]",
|
||||
"☕": "[表情:咖啡]",
|
||||
"🍎": "[表情:苹果]",
|
||||
"🍓": "[表情:草莓]",
|
||||
"🍉": "[表情:西瓜]",
|
||||
"🚬": "[表情:吸烟]",
|
||||
"🌹": "[表情:玫瑰]",
|
||||
"🎉": "[表情:庆祝]",
|
||||
"💝": "[表情:礼物]",
|
||||
"💣": "[表情:炸弹]",
|
||||
"✨": "[表情:闪光]",
|
||||
"💨": "[表情:吹气]",
|
||||
"💦": "[表情:水]",
|
||||
"🔥": "[表情:火]",
|
||||
"💤": "[表情:睡觉]",
|
||||
"💩": "[表情:便便]",
|
||||
"💉": "[表情:打针]",
|
||||
"📫": "[表情:邮箱]",
|
||||
"🐎": "[表情:骑马]",
|
||||
"👧": "[表情:女孩]",
|
||||
"👦": "[表情:男孩]",
|
||||
"🐵": "[表情:猴]",
|
||||
"🐷": "[表情:猪]",
|
||||
"🐮": "[表情:牛]",
|
||||
"🐔": "[表情:公鸡]",
|
||||
"🐸": "[表情:青蛙]",
|
||||
"👻": "[表情:幽灵]",
|
||||
"🐛": "[表情:虫]",
|
||||
"🐶": "[表情:狗]",
|
||||
"🐳": "[表情:鲸鱼]",
|
||||
"👢": "[表情:靴子]",
|
||||
"☀": "[表情:晴天]",
|
||||
"❔": "[表情:问号]",
|
||||
"🔫": "[表情:手枪]",
|
||||
"💓": "[表情:爱 心]",
|
||||
"🏪": "[表情:便利店]",
|
||||
}
|
||||
45
plugins/napcat_adapter_plugin/src/response_pool.py
Normal file
45
plugins/napcat_adapter_plugin/src/response_pool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict
|
||||
from .config import global_config
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
response_dict: Dict = {}
|
||||
response_time_dict: Dict = {}
|
||||
|
||||
|
||||
async def get_response(request_id: str, timeout: int = 10) -> dict:
|
||||
response = await asyncio.wait_for(_get_response(request_id), timeout)
|
||||
_ = response_time_dict.pop(request_id)
|
||||
logger.trace(f"响应信息id: {request_id} 已从响应字典中取出")
|
||||
return response
|
||||
|
||||
async def _get_response(request_id: str) -> dict:
|
||||
"""
|
||||
内部使用的获取响应函数,主要用于在需要时获取响应
|
||||
"""
|
||||
while request_id not in response_dict:
|
||||
await asyncio.sleep(0.2)
|
||||
return response_dict.pop(request_id)
|
||||
|
||||
async def put_response(response: dict):
|
||||
echo_id = response.get("echo")
|
||||
now_time = time.time()
|
||||
response_dict[echo_id] = response
|
||||
response_time_dict[echo_id] = now_time
|
||||
logger.trace(f"响应信息id: {echo_id} 已存入响应字典")
|
||||
|
||||
|
||||
async def check_timeout_response() -> None:
|
||||
while True:
|
||||
cleaned_message_count: int = 0
|
||||
now_time = time.time()
|
||||
for echo_id, response_time in list(response_time_dict.items()):
|
||||
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
|
||||
cleaned_message_count += 1
|
||||
response_dict.pop(echo_id)
|
||||
response_time_dict.pop(echo_id)
|
||||
logger.warning(f"响应消息 {echo_id} 超时,已删除")
|
||||
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
|
||||
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)
|
||||
711
plugins/napcat_adapter_plugin/src/send_handler.py
Normal file
711
plugins/napcat_adapter_plugin/src/send_handler.py
Normal file
@@ -0,0 +1,711 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import websockets as Server
|
||||
import uuid
|
||||
import asyncio
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
)
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
|
||||
from . import CommandType
|
||||
from .config import global_config
|
||||
from .response_pool import get_response
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .utils import get_image_format, convert_image_to_gif
|
||||
from .recv_handler.message_sending import message_send_instance
|
||||
from .websocket_manager import websocket_manager
|
||||
from .config.features_config import features_manager
|
||||
|
||||
|
||||
class SendHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Optional[Server.ServerConnection] = None
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
def get_server_connection(self) -> Optional[Server.ServerConnection]:
|
||||
"""获取当前的服务器连接"""
|
||||
# 优先使用直接设置的连接,否则从 websocket_manager 获取
|
||||
if self.server_connection:
|
||||
return self.server_connection
|
||||
return websocket_manager.get_connection()
|
||||
|
||||
async def handle_message(self, raw_message_base_dict: dict) -> None:
|
||||
raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict)
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
logger.info("接收到来自MaiBot的消息,处理中")
|
||||
if message_segment.type == "command":
|
||||
logger.info("处理命令")
|
||||
return await self.send_command(raw_message_base)
|
||||
elif message_segment.type == "adapter_command":
|
||||
logger.info("处理适配器命令")
|
||||
return await self.handle_adapter_command(raw_message_base)
|
||||
else:
|
||||
logger.info("处理普通消息")
|
||||
return await self.send_normal_message(raw_message_base)
|
||||
|
||||
async def send_normal_message(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
处理普通消息发送
|
||||
"""
|
||||
logger.info("处理普通信息中")
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
group_info: Optional[GroupInfo] = message_info.group_info
|
||||
user_info: Optional[UserInfo] = message_info.user_info
|
||||
target_id: Optional[int] = None
|
||||
action: Optional[str] = None
|
||||
id_name: Optional[str] = None
|
||||
processed_message: list = []
|
||||
try:
|
||||
if user_info:
|
||||
processed_message = await self.handle_seg_recursive(
|
||||
message_segment, user_info
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时发生错误: {e}")
|
||||
return
|
||||
|
||||
if not processed_message:
|
||||
logger.critical("现在暂时不支持解析此回复!")
|
||||
return None
|
||||
|
||||
if group_info and user_info:
|
||||
logger.debug("发送群聊消息")
|
||||
target_id = int(group_info.group_id) if group_info.group_id else None
|
||||
action = "send_group_msg"
|
||||
id_name = "group_id"
|
||||
elif user_info:
|
||||
logger.debug("发送私聊消息")
|
||||
target_id = int(user_info.user_id) if user_info.user_id else None
|
||||
action = "send_private_msg"
|
||||
id_name = "user_id"
|
||||
else:
|
||||
logger.error("无法识别的消息类型")
|
||||
return
|
||||
logger.info("尝试发送到napcat")
|
||||
response = await self.send_message_to_napcat(
|
||||
action,
|
||||
{
|
||||
id_name: target_id,
|
||||
"message": processed_message,
|
||||
},
|
||||
)
|
||||
if response.get("status") == "ok":
|
||||
logger.info("消息发送成功")
|
||||
qq_message_id = response.get("data", {}).get("message_id")
|
||||
await self.message_sent_back(raw_message_base, qq_message_id)
|
||||
else:
|
||||
logger.warning(f"消息发送失败,napcat返回:{str(response)}")
|
||||
|
||||
async def send_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
处理命令类
|
||||
"""
|
||||
logger.info("处理命令中")
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
group_info: Optional[GroupInfo] = message_info.group_info
|
||||
seg_data: Dict[str, Any] = (
|
||||
message_segment.data
|
||||
if isinstance(message_segment.data, dict)
|
||||
else {}
|
||||
)
|
||||
command_name: Optional[str] = seg_data.get("name")
|
||||
try:
|
||||
args = seg_data.get("args", {})
|
||||
if not isinstance(args, dict):
|
||||
args = {}
|
||||
|
||||
match command_name:
|
||||
case CommandType.GROUP_BAN.name:
|
||||
command, args_dict = self.handle_ban_command(args, group_info)
|
||||
case CommandType.GROUP_WHOLE_BAN.name:
|
||||
command, args_dict = self.handle_whole_ban_command(
|
||||
args, group_info
|
||||
)
|
||||
case CommandType.GROUP_KICK.name:
|
||||
command, args_dict = self.handle_kick_command(args, group_info)
|
||||
case CommandType.SEND_POKE.name:
|
||||
command, args_dict = self.handle_poke_command(args, group_info)
|
||||
case CommandType.DELETE_MSG.name:
|
||||
command, args_dict = self.delete_msg_command(args)
|
||||
case CommandType.AI_VOICE_SEND.name:
|
||||
command, args_dict = self.handle_ai_voice_send_command(
|
||||
args, group_info
|
||||
)
|
||||
case CommandType.SET_EMOJI_LIKE.name:
|
||||
command, args_dict = self.handle_set_emoji_like_command(args)
|
||||
case CommandType.SEND_AT_MESSAGE.name:
|
||||
command, args_dict = self.handle_at_message_command(
|
||||
args, group_info
|
||||
)
|
||||
case CommandType.SEND_LIKE.name:
|
||||
command, args_dict = self.handle_send_like_command(args)
|
||||
case _:
|
||||
logger.error(f"未知命令: {command_name}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时发生错误: {e}")
|
||||
return None
|
||||
|
||||
if not command or not args_dict:
|
||||
logger.error("命令或参数缺失")
|
||||
return None
|
||||
|
||||
response = await self.send_message_to_napcat(command, args_dict)
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"命令 {command_name} 执行成功")
|
||||
else:
|
||||
logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}")
|
||||
|
||||
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
处理适配器命令类 - 用于直接向Napcat发送命令并返回结果
|
||||
"""
|
||||
logger.info("处理适配器命令中")
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
seg_data: Dict[str, Any] = (
|
||||
message_segment.data
|
||||
if isinstance(message_segment.data, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
try:
|
||||
action = seg_data.get("action")
|
||||
params = seg_data.get("params", {})
|
||||
request_id = seg_data.get("request_id")
|
||||
|
||||
if not action:
|
||||
logger.error("适配器命令缺少action参数")
|
||||
await self.send_adapter_command_response(
|
||||
raw_message_base,
|
||||
{"status": "error", "message": "缺少action参数"},
|
||||
request_id
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"执行适配器命令: {action}")
|
||||
|
||||
# 直接向Napcat发送命令并获取响应
|
||||
response_task = asyncio.create_task(self.send_message_to_napcat(action, params))
|
||||
response = await response_task
|
||||
|
||||
# 发送响应回MaiBot
|
||||
await self.send_adapter_command_response(raw_message_base, response, request_id)
|
||||
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"适配器命令 {action} 执行成功")
|
||||
else:
|
||||
logger.warning(f"适配器命令 {action} 执行失败,napcat返回:{str(response)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器命令时发生错误: {e}")
|
||||
error_response = {"status": "error", "message": str(e)}
|
||||
await self.send_adapter_command_response(
|
||||
raw_message_base,
|
||||
error_response,
|
||||
seg_data.get("request_id")
|
||||
)
|
||||
|
||||
def get_level(self, seg_data: Seg) -> int:
|
||||
if seg_data.type == "seglist":
|
||||
return 1 + max(self.get_level(seg) for seg in seg_data.data)
|
||||
else:
|
||||
return 1
|
||||
|
||||
async def handle_seg_recursive(self, seg_data: Seg, user_info: UserInfo) -> list:
|
||||
payload: list = []
|
||||
if seg_data.type == "seglist":
|
||||
# level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用
|
||||
if not seg_data.data:
|
||||
return []
|
||||
for seg in seg_data.data:
|
||||
payload = await self.process_message_by_type(seg, payload, user_info)
|
||||
else:
|
||||
payload = await self.process_message_by_type(seg_data, payload, user_info)
|
||||
return payload
|
||||
|
||||
async def process_message_by_type(
|
||||
self, seg: Seg, payload: list, user_info: UserInfo
|
||||
) -> list:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
new_payload = payload
|
||||
if seg.type == "reply":
|
||||
target_id = seg.data
|
||||
if target_id == "notice":
|
||||
return payload
|
||||
new_payload = self.build_payload(
|
||||
payload,
|
||||
await self.handle_reply_message(
|
||||
target_id if isinstance(target_id, str) else "", user_info
|
||||
),
|
||||
True,
|
||||
)
|
||||
elif seg.type == "text":
|
||||
text = seg.data
|
||||
if not text:
|
||||
return payload
|
||||
new_payload = self.build_payload(
|
||||
payload,
|
||||
self.handle_text_message(text if isinstance(text, str) else ""),
|
||||
False,
|
||||
)
|
||||
elif seg.type == "face":
|
||||
logger.warning("MaiBot 发送了qq原生表情,暂时不支持")
|
||||
elif seg.type == "image":
|
||||
image = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_image_message(image), False)
|
||||
elif seg.type == "emoji":
|
||||
emoji = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False)
|
||||
elif seg.type == "voice":
|
||||
voice = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_voice_message(voice), False)
|
||||
elif seg.type == "voiceurl":
|
||||
voice_url = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_voiceurl_message(voice_url), False)
|
||||
elif seg.type == "music":
|
||||
song_id = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_music_message(song_id), False)
|
||||
elif seg.type == "videourl":
|
||||
video_url = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_videourl_message(video_url), False)
|
||||
elif seg.type == "file":
|
||||
file_path = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
|
||||
return new_payload
|
||||
|
||||
def build_payload(
|
||||
self, payload: list, addon: dict | list, is_reply: bool = False
|
||||
) -> list:
|
||||
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
|
||||
"""构建发送的消息体"""
|
||||
if is_reply:
|
||||
temp_list = []
|
||||
if isinstance(addon, list):
|
||||
temp_list.extend(addon)
|
||||
else:
|
||||
temp_list.append(addon)
|
||||
for i in payload:
|
||||
if i.get("type") == "reply":
|
||||
logger.debug("检测到多个回复,使用最新的回复")
|
||||
continue
|
||||
temp_list.append(i)
|
||||
return temp_list
|
||||
else:
|
||||
if isinstance(addon, list):
|
||||
payload.extend(addon)
|
||||
else:
|
||||
payload.append(addon)
|
||||
return payload
|
||||
|
||||
async def handle_reply_message(self, id: str, user_info: UserInfo) -> dict | list:
|
||||
"""处理回复消息"""
|
||||
reply_seg = {"type": "reply", "data": {"id": id}}
|
||||
|
||||
# 获取功能配置
|
||||
ft_config = features_manager.get_config()
|
||||
|
||||
# 检查是否启用引用艾特功能
|
||||
if not ft_config.enable_reply_at:
|
||||
return reply_seg
|
||||
|
||||
try:
|
||||
# 尝试通过 message_id 获取消息详情
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
|
||||
|
||||
replied_user_id = None
|
||||
if msg_info_response and msg_info_response.get("status") == "ok":
|
||||
sender_info = msg_info_response.get("data", {}).get("sender")
|
||||
if sender_info:
|
||||
replied_user_id = sender_info.get("user_id")
|
||||
|
||||
# 如果没有获取到被回复者的ID,则直接返回,不进行@
|
||||
if not replied_user_id:
|
||||
logger.warning(f"无法获取消息 {id} 的发送者信息,跳过 @")
|
||||
return reply_seg
|
||||
|
||||
# 根据概率决定是否艾特用户
|
||||
if random.random() < ft_config.reply_at_rate:
|
||||
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
|
||||
# 在艾特后面添加一个空格
|
||||
text_seg = {"type": "text", "data": {"text": " "}}
|
||||
return [reply_seg, at_seg, text_seg]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理引用回复并尝试@时出错: {e}")
|
||||
# 出现异常时,只发送普通的回复,避免程序崩溃
|
||||
return reply_seg
|
||||
|
||||
return reply_seg
|
||||
|
||||
def handle_text_message(self, message: str) -> dict:
|
||||
"""处理文本消息"""
|
||||
return {"type": "text", "data": {"text": message}}
|
||||
|
||||
def handle_image_message(self, encoded_image: str) -> dict:
|
||||
"""处理图片消息"""
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"file": f"base64://{encoded_image}",
|
||||
"subtype": 0,
|
||||
},
|
||||
} # base64 编码的图片
|
||||
|
||||
def handle_emoji_message(self, encoded_emoji: str) -> dict:
|
||||
"""处理表情消息"""
|
||||
encoded_image = encoded_emoji
|
||||
image_format = get_image_format(encoded_emoji)
|
||||
if image_format != "gif":
|
||||
encoded_image = convert_image_to_gif(encoded_emoji)
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"file": f"base64://{encoded_image}",
|
||||
"subtype": 1,
|
||||
"summary": "[动画表情]",
|
||||
},
|
||||
}
|
||||
|
||||
def handle_voice_message(self, encoded_voice: str) -> dict:
|
||||
"""处理语音消息"""
|
||||
if not global_config.voice.use_tts:
|
||||
logger.warning("未启用语音消息处理")
|
||||
return {}
|
||||
if not encoded_voice:
|
||||
return {}
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": f"base64://{encoded_voice}"},
|
||||
}
|
||||
|
||||
def handle_voiceurl_message(self, voice_url: str) -> dict:
|
||||
"""处理语音链接消息"""
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": voice_url},
|
||||
}
|
||||
|
||||
def handle_music_message(self, song_id: str) -> dict:
|
||||
"""处理音乐消息"""
|
||||
return {
|
||||
"type": "music",
|
||||
"data": {"type": "163", "id": song_id},
|
||||
}
|
||||
def handle_videourl_message(self, video_url: str) -> dict:
|
||||
"""处理视频链接消息"""
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {"file": video_url},
|
||||
}
|
||||
|
||||
def handle_file_message(self, file_path: str) -> dict:
|
||||
"""处理文件消息"""
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {"file": f"file://{file_path}"},
|
||||
}
|
||||
|
||||
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理删除消息命令"""
|
||||
return "delete_msg", {"message_id": args["message_id"]}
|
||||
|
||||
def handle_ban_command(
|
||||
self, args: Dict[str, Any], group_info: GroupInfo
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理封禁命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
duration: int = int(args["duration"])
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: int = int(group_info.group_id)
|
||||
if duration < 0:
|
||||
raise ValueError("封禁时间必须大于等于0")
|
||||
if not user_id or not group_id:
|
||||
raise ValueError("封禁命令缺少必要参数")
|
||||
if duration > 2592000:
|
||||
raise ValueError("封禁时间不能超过30天")
|
||||
return (
|
||||
CommandType.GROUP_BAN.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"duration": duration,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理全体禁言命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
enable = args["enable"]
|
||||
assert isinstance(enable, bool), "enable参数必须是布尔值"
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
return (
|
||||
CommandType.GROUP_WHOLE_BAN.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"enable": enable,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理群成员踢出命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
if user_id <= 0:
|
||||
raise ValueError("用户ID无效")
|
||||
return (
|
||||
CommandType.GROUP_KICK.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"reject_add_request": False, # 不拒绝加群请求
|
||||
},
|
||||
)
|
||||
|
||||
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理戳一戳命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
user_id: int = int(args["qq_id"])
|
||||
if group_info is None:
|
||||
group_id = None
|
||||
else:
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
if user_id <= 0:
|
||||
raise ValueError("用户ID无效")
|
||||
return (
|
||||
CommandType.SEND_POKE.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理设置表情回应命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
try:
|
||||
message_id = int(args["message_id"])
|
||||
emoji_id = int(args["emoji_id"])
|
||||
set_like = str(args["set"])
|
||||
except:
|
||||
raise ValueError("缺少必需参数: message_id 或 emoji_id")
|
||||
|
||||
return (
|
||||
CommandType.SET_EMOJI_LIKE.value,
|
||||
{
|
||||
"message_id": message_id,
|
||||
"emoji_id": emoji_id,
|
||||
"set": set_like
|
||||
},
|
||||
)
|
||||
|
||||
def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理发送点赞命令的逻辑。
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
try:
|
||||
user_id: int = int(args["qq_id"])
|
||||
times: int = int(args["times"])
|
||||
except (KeyError, ValueError):
|
||||
raise ValueError("缺少必需参数: qq_id 或 times")
|
||||
|
||||
return (
|
||||
CommandType.SEND_LIKE.value,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"times": times
|
||||
},
|
||||
)
|
||||
|
||||
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理AI语音发送命令的逻辑。
|
||||
并返回 NapCat 兼容的 (action, params) 元组。
|
||||
"""
|
||||
if not group_info or not group_info.group_id:
|
||||
raise ValueError("AI语音发送命令必须在群聊上下文中使用")
|
||||
if not args:
|
||||
raise ValueError("AI语音发送命令缺少参数")
|
||||
|
||||
group_id: int = int(group_info.group_id)
|
||||
character_id = args.get("character")
|
||||
text_content = args.get("text")
|
||||
|
||||
if not character_id or not text_content:
|
||||
raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'")
|
||||
|
||||
return (
|
||||
CommandType.AI_VOICE_SEND.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"text": text_content,
|
||||
"character": character_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
|
||||
|
||||
# 获取当前连接
|
||||
connection = self.get_server_connection()
|
||||
if not connection:
|
||||
logger.error("没有可用的 Napcat 连接")
|
||||
return {"status": "error", "message": "no connection"}
|
||||
|
||||
try:
|
||||
await connection.send(payload)
|
||||
response = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("发送消息超时,未收到响应")
|
||||
return {"status": "error", "message": "timeout"}
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
return response
|
||||
|
||||
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
|
||||
# 修改 additional_config,添加 echo 字段
|
||||
if message_base.message_info.additional_config is None:
|
||||
message_base.message_info.additional_config = {}
|
||||
|
||||
message_base.message_info.additional_config["echo"] = True
|
||||
|
||||
# 获取原始的 mmc_message_id
|
||||
mmc_message_id = message_base.message_info.message_id
|
||||
|
||||
# 修改 message_segment 为 notify 类型
|
||||
message_base.message_segment = Seg(
|
||||
type="notify", data={"sub_type": "echo", "echo": mmc_message_id, "actual_id": qq_message_id}
|
||||
)
|
||||
await message_send_instance.message_send(message_base)
|
||||
logger.debug("已回送消息ID")
|
||||
return
|
||||
|
||||
async def send_adapter_command_response(
|
||||
self, original_message: MessageBase, response_data: dict, request_id: str
|
||||
) -> None:
|
||||
"""
|
||||
发送适配器命令响应回MaiBot
|
||||
|
||||
Args:
|
||||
original_message: 原始消息
|
||||
response_data: 响应数据
|
||||
request_id: 请求ID
|
||||
"""
|
||||
try:
|
||||
# 修改 additional_config,添加 echo 字段
|
||||
if original_message.message_info.additional_config is None:
|
||||
original_message.message_info.additional_config = {}
|
||||
|
||||
original_message.message_info.additional_config["echo"] = True
|
||||
|
||||
# 修改 message_segment 为 adapter_response 类型
|
||||
original_message.message_segment = Seg(
|
||||
type="adapter_response",
|
||||
data={
|
||||
"request_id": request_id,
|
||||
"response": response_data,
|
||||
"timestamp": int(time.time() * 1000)
|
||||
}
|
||||
)
|
||||
|
||||
await message_send_instance.message_send(original_message)
|
||||
logger.debug(f"已发送适配器命令响应,request_id: {request_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送适配器命令响应时出错: {e}")
|
||||
|
||||
def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理艾特并发送消息命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典, 包含 qq_id 和 text
|
||||
group_info (GroupInfo): 群聊信息
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict[str, Any]]: (action, params)
|
||||
"""
|
||||
at_user_id = args.get("qq_id")
|
||||
text = args.get("text")
|
||||
|
||||
if not at_user_id or not text:
|
||||
raise ValueError("艾特消息命令缺少 qq_id 或 text 参数")
|
||||
|
||||
if not group_info:
|
||||
raise ValueError("艾特消息命令必须在群聊上下文中使用")
|
||||
|
||||
message_payload = [
|
||||
{"type": "at", "data": {"qq": str(at_user_id)}},
|
||||
{"type": "text", "data": {"text": " " + str(text)}},
|
||||
]
|
||||
|
||||
return (
|
||||
"send_group_msg",
|
||||
{
|
||||
"group_id": group_info.group_id,
|
||||
"message": message_payload,
|
||||
},
|
||||
)
|
||||
|
||||
send_handler = SendHandler()
|
||||
311
plugins/napcat_adapter_plugin/src/utils.py
Normal file
311
plugins/napcat_adapter_plugin/src/utils.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import websockets as Server
|
||||
import json
|
||||
import base64
|
||||
import uuid
|
||||
import urllib3
|
||||
import ssl
|
||||
import io
|
||||
|
||||
from .database import BanUser, db_manager
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .response_pool import get_response
|
||||
|
||||
from PIL import Image
|
||||
from typing import Union, List, Tuple, Optional
|
||||
|
||||
|
||||
class SSLAdapter(urllib3.PoolManager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
context = ssl.create_default_context()
|
||||
context.set_ciphers("DEFAULT@SECLEVEL=1")
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
kwargs["ssl_context"] = context
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||
"""
|
||||
获取群相关信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群聊信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取群信息超时,群号: {group_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||
"""
|
||||
获取群详细信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群详细信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取群详细信息超时,群号: {group_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群详细信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
|
||||
"""
|
||||
获取群成员信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群成员信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_group_member_info",
|
||||
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取成员信息超时,群号: {group_id}, 用户ID: {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取成员信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_image_base64(url: str) -> str:
|
||||
# sourcery skip: raise-specific-error
|
||||
"""获取图片/表情包的Base64"""
|
||||
logger.debug(f"下载图片: {url}")
|
||||
http = SSLAdapter()
|
||||
try:
|
||||
response = http.request("GET", url, timeout=10)
|
||||
if response.status != 200:
|
||||
raise Exception(f"HTTP Error: {response.status}")
|
||||
image_bytes = response.data
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片下载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_image_to_gif(image_base64: str) -> str:
|
||||
# sourcery skip: extract-method
|
||||
"""
|
||||
将Base64编码的图片转换为GIF格式
|
||||
Parameters:
|
||||
image_base64: str: Base64编码的图片数据
|
||||
Returns:
|
||||
str: Base64编码的GIF图片数据
|
||||
"""
|
||||
logger.debug("转换图片为GIF格式")
|
||||
try:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
output_buffer = io.BytesIO()
|
||||
image.save(output_buffer, format="GIF")
|
||||
output_buffer.seek(0)
|
||||
return base64.b64encode(output_buffer.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为GIF失败: {str(e)}")
|
||||
return image_base64
|
||||
|
||||
|
||||
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||
"""
|
||||
获取自身信息
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
Returns:
|
||||
data: dict: 返回的自身信息
|
||||
"""
|
||||
logger.debug("获取自身信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("获取自身信息超时")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取自身信息失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
def get_image_format(raw_data: str) -> str:
|
||||
"""
|
||||
从Base64编码的数据中确定图片的格式。
|
||||
Parameters:
|
||||
raw_data: str: Base64编码的图片数据。
|
||||
Returns:
|
||||
format: str: 图片的格式(例如 'jpeg', 'png', 'gif')。
|
||||
"""
|
||||
image_bytes = base64.b64decode(raw_data)
|
||||
return Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
|
||||
async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None:
|
||||
"""
|
||||
获取陌生人信息
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
user_id: 用户ID
|
||||
Returns:
|
||||
dict: 返回的陌生人信息
|
||||
"""
|
||||
logger.debug("获取陌生人信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取陌生人信息超时,用户ID: {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取陌生人信息失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None:
|
||||
"""
|
||||
获取消息详情,可能为空
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
message_id: 消息ID
|
||||
Returns:
|
||||
dict: 返回的消息详情
|
||||
"""
|
||||
logger.debug("获取消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
except TimeoutError:
|
||||
logger.error(f"获取消息详情超时,消息ID: {message_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息详情失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def get_record_detail(
|
||||
websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取语音消息内容
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
file: 文件名
|
||||
file_id: 文件ID
|
||||
Returns:
|
||||
dict: 返回的语音消息详情
|
||||
"""
|
||||
logger.debug("获取语音消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_record",
|
||||
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
except TimeoutError:
|
||||
logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取语音消息详情失败: {e}")
|
||||
return None
|
||||
logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def read_ban_list(
|
||||
websocket: Server.ServerConnection,
|
||||
) -> Tuple[List[BanUser], List[BanUser]]:
|
||||
"""
|
||||
从根目录下的data文件夹中的文件读取禁言列表。
|
||||
同时自动更新已经失效禁言
|
||||
Returns:
|
||||
Tuple[
|
||||
一个仍在禁言中的用户的BanUser列表,
|
||||
一个已经自然解除禁言的用户的BanUser列表,
|
||||
一个仍在全体禁言中的群的BanUser列表,
|
||||
一个已经自然解除全体禁言的群的BanUser列表,
|
||||
]
|
||||
"""
|
||||
try:
|
||||
ban_list = db_manager.get_ban_records()
|
||||
lifted_list: List[BanUser] = []
|
||||
logger.info("已经读取禁言列表")
|
||||
for ban_record in ban_list:
|
||||
if ban_record.user_id == 0:
|
||||
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
|
||||
if fetched_group_info is None:
|
||||
logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除")
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
group_all_shut: int = fetched_group_info.get("group_all_shut")
|
||||
if group_all_shut == 0:
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
else:
|
||||
fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id)
|
||||
if fetched_member_info is None:
|
||||
logger.warning(
|
||||
f"无法获取群成员信息,用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除"
|
||||
)
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
lift_ban_time: int = fetched_member_info.get("shut_up_timestamp")
|
||||
if lift_ban_time == 0:
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
else:
|
||||
ban_record.lift_time = lift_ban_time
|
||||
db_manager.update_ban_record(ban_list)
|
||||
return ban_list, lifted_list
|
||||
except Exception as e:
|
||||
logger.error(f"读取禁言列表失败: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
def save_ban_record(list: List[BanUser]):
|
||||
return db_manager.update_ban_record(list)
|
||||
192
plugins/napcat_adapter_plugin/src/video_handler.py
Normal file
192
plugins/napcat_adapter_plugin/src/video_handler.py
Normal file
@@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
视频下载和处理模块
|
||||
用于从QQ消息中下载视频并转发给Bot进行分析
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("video_handler")
|
||||
|
||||
|
||||
class VideoDownloader:
|
||||
def __init__(self, max_size_mb: int = 100, download_timeout: int = 60):
|
||||
self.max_size_mb = max_size_mb
|
||||
self.download_timeout = download_timeout
|
||||
self.supported_formats = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
|
||||
def is_video_url(self, url: str) -> bool:
|
||||
"""检查URL是否为视频文件"""
|
||||
try:
|
||||
# QQ视频URL可能没有扩展名,所以先检查Content-Type
|
||||
# 对于QQ视频,我们先假设是视频,稍后通过Content-Type验证
|
||||
|
||||
# 检查URL中是否包含视频相关的关键字
|
||||
video_keywords = ['video', 'mp4', 'avi', 'mov', 'mkv', 'flv', 'wmv', 'webm', 'm4v']
|
||||
url_lower = url.lower()
|
||||
|
||||
# 如果URL包含视频关键字,认为是视频
|
||||
if any(keyword in url_lower for keyword in video_keywords):
|
||||
return True
|
||||
|
||||
# 检查文件扩展名(传统方法)
|
||||
path = Path(url.split('?')[0]) # 移除查询参数
|
||||
if path.suffix.lower() in self.supported_formats:
|
||||
return True
|
||||
|
||||
# 对于QQ等特殊平台,URL可能没有扩展名
|
||||
# 我们允许这些URL通过,稍后通过HTTP头Content-Type验证
|
||||
qq_domains = ['qpic.cn', 'gtimg.cn', 'qq.com', 'tencent.com']
|
||||
if any(domain in url_lower for domain in qq_domains):
|
||||
return True
|
||||
|
||||
return False
|
||||
except:
|
||||
# 如果解析失败,默认允许尝试下载(稍后验证)
|
||||
return True
|
||||
|
||||
def check_file_size(self, content_length: Optional[str]) -> bool:
|
||||
"""检查文件大小是否在允许范围内"""
|
||||
if content_length is None:
|
||||
return True # 无法获取大小时允许下载
|
||||
|
||||
try:
|
||||
size_bytes = int(content_length)
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
return size_mb <= self.max_size_mb
|
||||
except:
|
||||
return True
|
||||
|
||||
async def download_video(self, url: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
下载视频文件
|
||||
|
||||
Args:
|
||||
url: 视频URL
|
||||
filename: 可选的文件名
|
||||
|
||||
Returns:
|
||||
dict: 下载结果,包含success、data、filename、error等字段
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始下载视频: {url}")
|
||||
|
||||
# 检查URL格式
|
||||
if not self.is_video_url(url):
|
||||
logger.warning(f"URL格式检查失败: {url}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "不支持的视频格式",
|
||||
"url": url
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 先发送HEAD请求检查文件大小
|
||||
try:
|
||||
async with session.head(url, timeout=aiohttp.ClientTimeout(total=10)) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"HEAD请求失败,状态码: {response.status}")
|
||||
else:
|
||||
content_length = response.headers.get('Content-Length')
|
||||
if not self.check_file_size(content_length):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"视频文件过大,超过{self.max_size_mb}MB限制",
|
||||
"url": url
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"HEAD请求失败: {e},继续尝试下载")
|
||||
|
||||
# 下载文件
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=self.download_timeout)) as response:
|
||||
if response.status != 200:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"下载失败,HTTP状态码: {response.status}",
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 检查Content-Type是否为视频
|
||||
content_type = response.headers.get('Content-Type', '').lower()
|
||||
if content_type:
|
||||
# 检查是否为视频类型
|
||||
video_mime_types = [
|
||||
'video/', 'application/octet-stream',
|
||||
'application/x-msvideo', 'video/x-msvideo'
|
||||
]
|
||||
is_video_content = any(mime in content_type for mime in video_mime_types)
|
||||
|
||||
if not is_video_content:
|
||||
logger.warning(f"Content-Type不是视频格式: {content_type}")
|
||||
# 如果不是明确的视频类型,但可能是QQ的特殊格式,继续尝试
|
||||
if 'text/' in content_type or 'application/json' in content_type:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"URL返回的不是视频内容,Content-Type: {content_type}",
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 再次检查Content-Length
|
||||
content_length = response.headers.get('Content-Length')
|
||||
if not self.check_file_size(content_length):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"视频文件过大,超过{self.max_size_mb}MB限制",
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 读取文件内容
|
||||
video_data = await response.read()
|
||||
|
||||
# 检查实际文件大小
|
||||
actual_size_mb = len(video_data) / (1024 * 1024)
|
||||
if actual_size_mb > self.max_size_mb:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"视频文件过大,实际大小: {actual_size_mb:.2f}MB",
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 确定文件名
|
||||
if filename is None:
|
||||
filename = Path(url.split('?')[0]).name
|
||||
if not filename or '.' not in filename:
|
||||
filename = "video.mp4"
|
||||
|
||||
logger.info(f"视频下载成功: {filename}, 大小: {actual_size_mb:.2f}MB")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": video_data,
|
||||
"filename": filename,
|
||||
"size_mb": actual_size_mb,
|
||||
"url": url
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "下载超时",
|
||||
"url": url
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载视频时出错: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"url": url
|
||||
}
|
||||
|
||||
# 全局实例
|
||||
_video_downloader = None
|
||||
|
||||
def get_video_downloader(max_size_mb: int = 100, download_timeout: int = 60) -> VideoDownloader:
|
||||
"""获取视频下载器实例"""
|
||||
global _video_downloader
|
||||
if _video_downloader is None:
|
||||
_video_downloader = VideoDownloader(max_size_mb, download_timeout)
|
||||
return _video_downloader
|
||||
158
plugins/napcat_adapter_plugin/src/websocket_manager.py
Normal file
158
plugins/napcat_adapter_plugin/src/websocket_manager.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import websockets as Server
|
||||
from typing import Optional, Callable, Any
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .config import global_config
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""WebSocket 连接管理器,支持正向和反向连接"""
|
||||
|
||||
def __init__(self):
|
||||
self.connection: Optional[Server.ServerConnection] = None
|
||||
self.server: Optional[Server.WebSocketServer] = None
|
||||
self.is_running = False
|
||||
self.reconnect_interval = 5 # 重连间隔(秒)
|
||||
self.max_reconnect_attempts = 10 # 最大重连次数
|
||||
|
||||
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
|
||||
"""根据配置启动 WebSocket 连接"""
|
||||
mode = global_config.napcat_server.mode
|
||||
|
||||
if mode == "reverse":
|
||||
await self._start_reverse_connection(message_handler)
|
||||
elif mode == "forward":
|
||||
await self._start_forward_connection(message_handler)
|
||||
else:
|
||||
raise ValueError(f"不支持的连接模式: {mode}")
|
||||
|
||||
async def _start_reverse_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
|
||||
"""启动反向连接(作为服务器)"""
|
||||
host = global_config.napcat_server.host
|
||||
port = global_config.napcat_server.port
|
||||
|
||||
logger.info(f"正在启动反向连接模式,监听地址: ws://{host}:{port}")
|
||||
|
||||
async def handle_client(websocket, path=None):
|
||||
self.connection = websocket
|
||||
logger.info(f"Napcat 客户端已连接: {websocket.remote_address}")
|
||||
try:
|
||||
await message_handler(websocket)
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端连接时出错: {e}")
|
||||
finally:
|
||||
self.connection = None
|
||||
logger.info("Napcat 客户端已断开连接")
|
||||
|
||||
self.server = await Server.serve(
|
||||
handle_client,
|
||||
host,
|
||||
port,
|
||||
max_size=2**26
|
||||
)
|
||||
self.is_running = True
|
||||
logger.info(f"反向连接服务器已启动,监听地址: ws://{host}:{port}")
|
||||
|
||||
# 保持服务器运行
|
||||
await self.server.serve_forever()
|
||||
|
||||
async def _start_forward_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
|
||||
"""启动正向连接(作为客户端)"""
|
||||
url = self._get_forward_url()
|
||||
logger.info(f"正在启动正向连接模式,目标地址: {url}")
|
||||
|
||||
reconnect_count = 0
|
||||
|
||||
while reconnect_count < self.max_reconnect_attempts:
|
||||
try:
|
||||
logger.info(f"尝试连接到 Napcat 服务器: {url}")
|
||||
|
||||
# 准备连接参数
|
||||
connect_kwargs = {"max_size": 2**26}
|
||||
|
||||
# 如果配置了访问令牌,添加到请求头
|
||||
if global_config.napcat_server.access_token:
|
||||
connect_kwargs["additional_headers"] = {
|
||||
"Authorization": f"Bearer {global_config.napcat_server.access_token}"
|
||||
}
|
||||
logger.info("已添加访问令牌到连接请求头")
|
||||
|
||||
async with Server.connect(url, **connect_kwargs) as websocket:
|
||||
self.connection = websocket
|
||||
self.is_running = True
|
||||
reconnect_count = 0 # 重置重连计数
|
||||
|
||||
logger.info(f"成功连接到 Napcat 服务器: {url}")
|
||||
|
||||
try:
|
||||
await message_handler(websocket)
|
||||
except Server.exceptions.ConnectionClosed:
|
||||
logger.warning("与 Napcat 服务器的连接已断开")
|
||||
except Exception as e:
|
||||
logger.error(f"处理正向连接时出错: {e}")
|
||||
finally:
|
||||
self.connection = None
|
||||
self.is_running = False
|
||||
|
||||
except (Server.exceptions.ConnectionClosed, Server.exceptions.InvalidMessage, OSError, ConnectionRefusedError) as e:
|
||||
reconnect_count += 1
|
||||
logger.warning(f"连接失败 ({reconnect_count}/{self.max_reconnect_attempts}): {e}")
|
||||
|
||||
if reconnect_count < self.max_reconnect_attempts:
|
||||
logger.info(f"将在 {self.reconnect_interval} 秒后重试连接...")
|
||||
await asyncio.sleep(self.reconnect_interval)
|
||||
else:
|
||||
logger.error("已达到最大重连次数,停止重连")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"正向连接时发生未知错误: {e}")
|
||||
raise
|
||||
|
||||
def _get_forward_url(self) -> str:
|
||||
"""获取正向连接的 URL"""
|
||||
config = global_config.napcat_server
|
||||
|
||||
# 如果配置了完整的 URL,直接使用
|
||||
if config.url:
|
||||
return config.url
|
||||
|
||||
# 否则根据 host 和 port 构建 URL
|
||||
host = config.host
|
||||
port = config.port
|
||||
return f"ws://{host}:{port}"
|
||||
|
||||
async def stop_connection(self) -> None:
|
||||
"""停止 WebSocket 连接"""
|
||||
self.is_running = False
|
||||
|
||||
if self.connection:
|
||||
try:
|
||||
await self.connection.close()
|
||||
logger.info("WebSocket 连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭 WebSocket 连接时出错: {e}")
|
||||
finally:
|
||||
self.connection = None
|
||||
|
||||
if self.server:
|
||||
try:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
logger.info("WebSocket 服务器已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭 WebSocket 服务器时出错: {e}")
|
||||
finally:
|
||||
self.server = None
|
||||
|
||||
def get_connection(self) -> Optional[Server.ServerConnection]:
|
||||
"""获取当前的 WebSocket 连接"""
|
||||
return self.connection
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self.connection is not None and self.is_running
|
||||
|
||||
|
||||
# 全局 WebSocket 管理器实例
|
||||
websocket_manager = WebSocketManager()
|
||||
@@ -0,0 +1,43 @@
|
||||
# 权限配置文件
|
||||
# 此文件用于管理群聊和私聊的黑白名单设置,以及聊天相关功能
|
||||
# 支持热重载,修改后会自动生效
|
||||
|
||||
# 群聊权限设置
|
||||
group_list_type = "whitelist" # 群聊列表类型:whitelist(白名单)或 blacklist(黑名单)
|
||||
group_list = [] # 群聊ID列表
|
||||
# 当 group_list_type 为 whitelist 时,只有列表中的群聊可以使用机器人
|
||||
# 当 group_list_type 为 blacklist 时,列表中的群聊无法使用机器人
|
||||
# 示例:group_list = [123456789, 987654321]
|
||||
|
||||
# 私聊权限设置
|
||||
private_list_type = "whitelist" # 私聊列表类型:whitelist(白名单)或 blacklist(黑名单)
|
||||
private_list = [] # 用户ID列表
|
||||
# 当 private_list_type 为 whitelist 时,只有列表中的用户可以私聊机器人
|
||||
# 当 private_list_type 为 blacklist 时,列表中的用户无法私聊机器人
|
||||
# 示例:private_list = [123456789, 987654321]
|
||||
|
||||
# 全局禁止设置
|
||||
ban_user_id = [] # 全局禁止用户ID列表,这些用户无法在任何地方使用机器人
|
||||
ban_qq_bot = false # 是否屏蔽QQ官方机器人消息
|
||||
|
||||
# 聊天功能设置
|
||||
enable_poke = true # 是否启用戳一戳功能
|
||||
ignore_non_self_poke = false # 是否无视不是针对自己的戳一戳
|
||||
poke_debounce_seconds = 3 # 戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略
|
||||
enable_reply_at = true # 是否启用引用回复时艾特用户的功能
|
||||
reply_at_rate = 0.5 # 引用回复时艾特用户的几率 (0.0 ~ 1.0)
|
||||
|
||||
# 视频处理设置
|
||||
enable_video_analysis = true # 是否启用视频识别功能
|
||||
max_video_size_mb = 100 # 视频文件最大大小限制(MB)
|
||||
download_timeout = 60 # 视频下载超时时间(秒)
|
||||
supported_formats = ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"] # 支持的视频格式
|
||||
|
||||
# 消息缓冲设置
|
||||
enable_message_buffer = true # 是否启用消息缓冲合并功能
|
||||
message_buffer_enable_group = true # 是否启用群聊消息缓冲合并
|
||||
message_buffer_enable_private = true # 是否启用私聊消息缓冲合并
|
||||
message_buffer_interval = 3.0 # 消息合并间隔时间(秒),在此时间内的连续消息将被合并
|
||||
message_buffer_initial_delay = 0.5 # 消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并
|
||||
message_buffer_max_components = 50 # 单个会话最大缓冲消息组件数量,超过此数量将强制合并
|
||||
message_buffer_block_prefixes = ["/", "!", "!", ".", "。", "#", "%"] # 消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲
|
||||
25
plugins/napcat_adapter_plugin/template/template_config.toml
Normal file
25
plugins/napcat_adapter_plugin/template/template_config.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[inner]
|
||||
version = "0.2.0" # 版本号
|
||||
# 请勿修改版本号,除非你知道自己在做什么
|
||||
|
||||
[nickname] # 现在没用
|
||||
nickname = ""
|
||||
|
||||
[napcat_server] # Napcat连接的ws服务设置
|
||||
mode = "reverse" # 连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)
|
||||
host = "localhost" # 主机地址
|
||||
port = 8095 # 端口号
|
||||
url = "" # 正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)
|
||||
access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选)
|
||||
heartbeat_interval = 30 # 心跳间隔时间(按秒计)
|
||||
|
||||
[maibot_server] # 连接麦麦的ws服务设置
|
||||
host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
|
||||
port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
|
||||
|
||||
[voice] # 发送语音设置
|
||||
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
|
||||
|
||||
[debug]
|
||||
level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
|
||||
7
plugins/napcat_adapter_plugin/todo.md
Normal file
7
plugins/napcat_adapter_plugin/todo.md
Normal file
@@ -0,0 +1,7 @@
|
||||
[x] logger使用主程序的
|
||||
[ ] 使用插件系统的config系统
|
||||
[ ] 接收从napcat传递的所有信息
|
||||
[ ] 优化架构,各模块解耦,暴露关键方法用于提供接口
|
||||
[ ] 单独一个模块负责与主程序通信
|
||||
[ ] 使用event系统完善接口api
|
||||
|
||||
Reference in New Issue
Block a user