Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -55,6 +55,8 @@ class ConnectionInfo:
|
|||||||
try:
|
try:
|
||||||
await self.session.close()
|
await self.session.close()
|
||||||
logger.debug("连接已关闭")
|
logger.debug("连接已关闭")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.warning("关闭连接时任务被取消")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"关闭连接时出错: {e}")
|
logger.warning(f"关闭连接时出错: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,162 +1,156 @@
|
|||||||
import os
|
"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API)
|
||||||
from typing import Optional, List
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
|
||||||
|
|
||||||
|
本模块替换原先的 sqlmodel + 同步Session 实现:
|
||||||
|
1. 复用主项目的异步数据库连接与迁移体系
|
||||||
|
2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record)
|
||||||
|
3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records
|
||||||
|
|
||||||
|
数据语义:
|
||||||
|
user_id == 0 表示群全体禁言
|
||||||
|
|
||||||
|
注意: 所有方法均为异步, 需要在 async 上下文中调用。
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List, Sequence
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.common.database.sqlalchemy_models import Base, get_db_session
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("napcat_adapter")
|
logger = get_logger("napcat_adapter")
|
||||||
|
|
||||||
"""
|
|
||||||
表记录的方式:
|
|
||||||
| group_id | user_id | lift_time |
|
|
||||||
|----------|---------|-----------|
|
|
||||||
|
|
||||||
其中使用 user_id == 0 表示群全体禁言
|
class NapcatBanRecord(Base):
|
||||||
"""
|
__tablename__ = "napcat_ban_records"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
group_id = Column(BigInteger, nullable=False, index=True)
|
||||||
|
user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言
|
||||||
|
lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"),
|
||||||
|
Index("idx_napcat_ban_group", "group_id"),
|
||||||
|
Index("idx_napcat_ban_user", "user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BanUser:
|
class BanUser:
|
||||||
"""
|
|
||||||
程序处理使用的实例
|
|
||||||
"""
|
|
||||||
|
|
||||||
user_id: int
|
user_id: int
|
||||||
group_id: int
|
group_id: int
|
||||||
lift_time: Optional[int] = Field(default=-1)
|
lift_time: Optional[int] = -1
|
||||||
|
|
||||||
|
def identity(self) -> tuple[int, int]:
|
||||||
|
return self.group_id, self.user_id
|
||||||
|
|
||||||
|
|
||||||
class DB_BanUser(SQLModel, table=True):
|
class NapcatDatabase:
|
||||||
"""
|
async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]:
|
||||||
表示数据库中的用户禁言记录。
|
result = await session.execute(select(NapcatBanRecord))
|
||||||
使用双重主键
|
return result.scalars().all()
|
||||||
"""
|
|
||||||
|
|
||||||
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
|
async def get_ban_records(self) -> List[BanUser]:
|
||||||
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
|
async with get_db_session() as session:
|
||||||
lift_time: Optional[int] # 禁言解除的时间(时间戳)
|
rows = await self._fetch_all(session)
|
||||||
|
return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows]
|
||||||
|
|
||||||
|
async def update_ban_records(self, ban_list: List[BanUser]) -> None:
|
||||||
|
target_map = {b.identity(): b for b in ban_list}
|
||||||
|
async with get_db_session() as session:
|
||||||
|
rows = await self._fetch_all(session)
|
||||||
|
existing_map = {(r.group_id, r.user_id): r for r in rows}
|
||||||
|
|
||||||
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
|
changed = 0
|
||||||
"""
|
for ident, ban in target_map.items():
|
||||||
检查两个 BanUser 对象是否相同。
|
if ident in existing_map:
|
||||||
"""
|
row = existing_map[ident]
|
||||||
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
|
if row.lift_time != ban.lift_time:
|
||||||
|
row.lift_time = ban.lift_time
|
||||||
|
changed += 1
|
||||||
class DatabaseManager:
|
|
||||||
"""
|
|
||||||
数据库管理类,负责与数据库交互。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
|
|
||||||
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
|
|
||||||
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
|
|
||||||
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
|
|
||||||
self._ensure_database() # 确保数据库和表已创建
|
|
||||||
|
|
||||||
def _ensure_database(self) -> None:
|
|
||||||
"""
|
|
||||||
确保数据库和表已创建。
|
|
||||||
"""
|
|
||||||
logger.info("确保数据库文件和表已创建...")
|
|
||||||
SQLModel.metadata.create_all(self.engine)
|
|
||||||
logger.info("数据库和表已创建或已存在")
|
|
||||||
|
|
||||||
def update_ban_record(self, ban_list: List[BanUser]) -> None:
|
|
||||||
# sourcery skip: class-extract-method
|
|
||||||
"""
|
|
||||||
更新禁言列表到数据库。
|
|
||||||
支持在不存在时创建新记录,对于多余的项目自动删除。
|
|
||||||
"""
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
all_records = session.exec(select(DB_BanUser)).all()
|
|
||||||
for ban_user in ban_list:
|
|
||||||
statement = select(DB_BanUser).where(
|
|
||||||
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
|
|
||||||
)
|
|
||||||
if existing_record := session.exec(statement).first():
|
|
||||||
if existing_record.lift_time == ban_user.lift_time:
|
|
||||||
logger.debug(f"禁言记录未变更: {existing_record}")
|
|
||||||
continue
|
|
||||||
# 更新现有记录的 lift_time
|
|
||||||
existing_record.lift_time = ban_user.lift_time
|
|
||||||
session.add(existing_record)
|
|
||||||
logger.debug(f"更新禁言记录: {existing_record}")
|
|
||||||
else:
|
else:
|
||||||
# 创建新记录
|
session.add(
|
||||||
db_record = DB_BanUser(
|
NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time)
|
||||||
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
|
|
||||||
)
|
)
|
||||||
session.add(db_record)
|
changed += 1
|
||||||
logger.debug(f"创建新禁言记录: {ban_user}")
|
|
||||||
# 删除不在 ban_list 中的记录
|
|
||||||
for db_record in all_records:
|
|
||||||
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
|
|
||||||
if not any(is_identical(record, ban_user) for ban_user in ban_list):
|
|
||||||
statement = select(DB_BanUser).where(
|
|
||||||
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
|
|
||||||
)
|
|
||||||
if ban_record := session.exec(statement).first():
|
|
||||||
session.delete(ban_record)
|
|
||||||
|
|
||||||
logger.debug(f"删除禁言记录: {ban_record}")
|
removed = 0
|
||||||
else:
|
for ident, row in existing_map.items():
|
||||||
logger.info(f"未找到禁言记录: {ban_record}")
|
if ident not in target_map:
|
||||||
|
await session.delete(row)
|
||||||
|
removed += 1
|
||||||
|
|
||||||
logger.info("禁言记录已更新")
|
logger.debug(
|
||||||
|
f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}"
|
||||||
def get_ban_records(self) -> List[BanUser]:
|
|
||||||
"""
|
|
||||||
读取所有禁言记录。
|
|
||||||
"""
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
statement = select(DB_BanUser)
|
|
||||||
records = session.exec(statement).all()
|
|
||||||
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
|
|
||||||
|
|
||||||
def create_ban_record(self, ban_record: BanUser) -> None:
|
|
||||||
"""
|
|
||||||
为特定群组中的用户创建禁言记录。
|
|
||||||
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
|
|
||||||
其同时还是简化版的更新方式。
|
|
||||||
"""
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
# 检查记录是否已存在
|
|
||||||
statement = select(DB_BanUser).where(
|
|
||||||
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
|
|
||||||
)
|
)
|
||||||
existing_record = session.exec(statement).first()
|
|
||||||
if existing_record:
|
async def create_or_update(self, ban_record: BanUser) -> None:
|
||||||
# 如果记录已存在,更新 lift_time
|
async with get_db_session() as session:
|
||||||
existing_record.lift_time = ban_record.lift_time
|
stmt = select(NapcatBanRecord).where(
|
||||||
session.add(existing_record)
|
NapcatBanRecord.group_id == ban_record.group_id,
|
||||||
logger.debug(f"更新禁言记录: {ban_record}")
|
NapcatBanRecord.user_id == ban_record.user_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
row = result.scalars().first()
|
||||||
|
if row:
|
||||||
|
if row.lift_time != ban_record.lift_time:
|
||||||
|
row.lift_time = ban_record.lift_time
|
||||||
|
logger.debug(
|
||||||
|
f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 如果记录不存在,创建新记录
|
session.add(
|
||||||
db_record = DB_BanUser(
|
NapcatBanRecord(
|
||||||
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
|
group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
|
||||||
)
|
)
|
||||||
session.add(db_record)
|
|
||||||
logger.debug(f"创建新禁言记录: {ban_record}")
|
|
||||||
|
|
||||||
def delete_ban_record(self, ban_record: BanUser):
|
async def delete_record(self, ban_record: BanUser) -> None:
|
||||||
"""
|
async with get_db_session() as session:
|
||||||
删除特定用户在特定群组中的禁言记录。
|
stmt = select(NapcatBanRecord).where(
|
||||||
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
|
NapcatBanRecord.group_id == ban_record.group_id,
|
||||||
"""
|
NapcatBanRecord.user_id == ban_record.user_id,
|
||||||
user_id = ban_record.user_id
|
)
|
||||||
group_id = ban_record.group_id
|
result = await session.execute(stmt)
|
||||||
with Session(self.engine) as session:
|
row = result.scalars().first()
|
||||||
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
|
if row:
|
||||||
if ban_record := session.exec(statement).first():
|
await session.delete(row)
|
||||||
session.delete(ban_record)
|
logger.debug(
|
||||||
|
f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}"
|
||||||
logger.debug(f"删除禁言记录: {ban_record}")
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
|
logger.info(
|
||||||
|
f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 兼容旧命名
|
||||||
|
async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name
|
||||||
|
await self.update_ban_records(ban_list)
|
||||||
|
|
||||||
|
async def create_ban_record(self, ban_record: BanUser) -> None: # old name
|
||||||
|
await self.create_or_update(ban_record)
|
||||||
|
|
||||||
|
async def delete_ban_record(self, ban_record: BanUser) -> None: # old name
|
||||||
|
await self.delete_record(ban_record)
|
||||||
|
|
||||||
|
|
||||||
db_manager = DatabaseManager()
|
napcat_db = NapcatDatabase()
|
||||||
|
|
||||||
|
|
||||||
|
def is_identical(a: BanUser, b: BanUser) -> bool:
|
||||||
|
return a.group_id == b.group_id and a.user_id == b.user_id
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BanUser",
|
||||||
|
"NapcatBanRecord",
|
||||||
|
"napcat_db",
|
||||||
|
"is_identical",
|
||||||
|
]
|
||||||
@@ -9,7 +9,7 @@ from src.common.logger import get_logger
|
|||||||
logger = get_logger("napcat_adapter")
|
logger = get_logger("napcat_adapter")
|
||||||
|
|
||||||
from src.plugin_system.apis import config_api
|
from src.plugin_system.apis import config_api
|
||||||
from ..database import BanUser, db_manager, is_identical
|
from ..database import BanUser, napcat_db, is_identical
|
||||||
from . import NoticeType, ACCEPT_FORMAT
|
from . import NoticeType, ACCEPT_FORMAT
|
||||||
from .message_sending import message_send_instance
|
from .message_sending import message_send_instance
|
||||||
from .message_handler import message_handler
|
from .message_handler import message_handler
|
||||||
@@ -62,7 +62,7 @@ class NoticeHandler:
|
|||||||
return self.server_connection
|
return self.server_connection
|
||||||
return websocket_manager.get_connection()
|
return websocket_manager.get_connection()
|
||||||
|
|
||||||
def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
|
async def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
|
||||||
"""
|
"""
|
||||||
将用户禁言记录添加到self.banned_list中
|
将用户禁言记录添加到self.banned_list中
|
||||||
如果是全体禁言,则user_id为0
|
如果是全体禁言,则user_id为0
|
||||||
@@ -71,16 +71,16 @@ class NoticeHandler:
|
|||||||
user_id = 0 # 使用0表示全体禁言
|
user_id = 0 # 使用0表示全体禁言
|
||||||
lift_time = -1
|
lift_time = -1
|
||||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
|
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
|
||||||
for record in self.banned_list:
|
for record in list(self.banned_list):
|
||||||
if is_identical(record, ban_record):
|
if is_identical(record, ban_record):
|
||||||
self.banned_list.remove(record)
|
self.banned_list.remove(record)
|
||||||
self.banned_list.append(ban_record)
|
self.banned_list.append(ban_record)
|
||||||
db_manager.create_ban_record(ban_record) # 作为更新
|
await napcat_db.create_ban_record(ban_record) # 更新
|
||||||
return
|
return
|
||||||
self.banned_list.append(ban_record)
|
self.banned_list.append(ban_record)
|
||||||
db_manager.create_ban_record(ban_record) # 添加到数据库
|
await napcat_db.create_ban_record(ban_record) # 新建
|
||||||
|
|
||||||
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
|
async def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
|
||||||
"""
|
"""
|
||||||
从self.lifted_group_list中移除已经解除全体禁言的群
|
从self.lifted_group_list中移除已经解除全体禁言的群
|
||||||
"""
|
"""
|
||||||
@@ -88,7 +88,12 @@ class NoticeHandler:
|
|||||||
user_id = 0 # 使用0表示全体禁言
|
user_id = 0 # 使用0表示全体禁言
|
||||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
|
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
|
||||||
self.lifted_list.append(ban_record)
|
self.lifted_list.append(ban_record)
|
||||||
db_manager.delete_ban_record(ban_record) # 删除数据库中的记录
|
# 从被禁言列表里移除对应记录
|
||||||
|
for record in list(self.banned_list):
|
||||||
|
if is_identical(record, ban_record):
|
||||||
|
self.banned_list.remove(record)
|
||||||
|
break
|
||||||
|
await napcat_db.delete_ban_record(ban_record)
|
||||||
|
|
||||||
async def handle_notice(self, raw_message: dict) -> None:
|
async def handle_notice(self, raw_message: dict) -> None:
|
||||||
notice_type = raw_message.get("notice_type")
|
notice_type = raw_message.get("notice_type")
|
||||||
@@ -116,9 +121,9 @@ class NoticeHandler:
|
|||||||
sub_type = raw_message.get("sub_type")
|
sub_type = raw_message.get("sub_type")
|
||||||
match sub_type:
|
match sub_type:
|
||||||
case NoticeType.Notify.poke:
|
case NoticeType.Notify.poke:
|
||||||
if config_api.get_plugin_config(
|
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat(
|
||||||
self.plugin_config, "features.enable_poke", True
|
user_id, group_id, False, False
|
||||||
) and await message_handler.check_allow_to_chat(user_id, group_id, False, False):
|
):
|
||||||
logger.debug("处理戳一戳消息")
|
logger.debug("处理戳一戳消息")
|
||||||
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
||||||
else:
|
else:
|
||||||
@@ -127,18 +132,14 @@ class NoticeHandler:
|
|||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
from ...event_types import NapcatEvent
|
from ...event_types import NapcatEvent
|
||||||
|
|
||||||
await event_manager.trigger_event(
|
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
|
||||||
NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME
|
|
||||||
)
|
|
||||||
case _:
|
case _:
|
||||||
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
||||||
case NoticeType.group_msg_emoji_like:
|
case NoticeType.group_msg_emoji_like:
|
||||||
# 该事件转移到 handle_group_emoji_like_notify函数内触发
|
# 该事件转移到 handle_group_emoji_like_notify函数内触发
|
||||||
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
|
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
|
||||||
logger.debug("处理群聊表情回复")
|
logger.debug("处理群聊表情回复")
|
||||||
handled_message, user_info = await self.handle_group_emoji_like_notify(
|
handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id)
|
||||||
raw_message, group_id, user_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
|
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
|
||||||
case NoticeType.group_ban:
|
case NoticeType.group_ban:
|
||||||
@@ -201,9 +202,11 @@ class NoticeHandler:
|
|||||||
|
|
||||||
if system_notice:
|
if system_notice:
|
||||||
await self.put_notice(message_base)
|
await self.put_notice(message_base)
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
logger.debug("发送到Maibot处理通知信息")
|
logger.debug("发送到Maibot处理通知信息")
|
||||||
await message_send_instance.message_send(message_base)
|
await message_send_instance.message_send(message_base)
|
||||||
|
return None
|
||||||
|
|
||||||
async def handle_poke_notify(
|
async def handle_poke_notify(
|
||||||
self, raw_message: dict, group_id: int, user_id: int
|
self, raw_message: dict, group_id: int, user_id: int
|
||||||
@@ -298,7 +301,7 @@ class NoticeHandler:
|
|||||||
async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int):
|
async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int):
|
||||||
if not group_id:
|
if not group_id:
|
||||||
logger.error("群ID不能为空,无法处理群聊表情回复通知")
|
logger.error("群ID不能为空,无法处理群聊表情回复通知")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||||
if user_qq_info:
|
if user_qq_info:
|
||||||
@@ -308,42 +311,37 @@ class NoticeHandler:
|
|||||||
user_name = "QQ用户"
|
user_name = "QQ用户"
|
||||||
user_cardname = "QQ用户"
|
user_cardname = "QQ用户"
|
||||||
logger.debug("无法获取表情回复对方的用户昵称")
|
logger.debug("无法获取表情回复对方的用户昵称")
|
||||||
|
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
from ...event_types import NapcatEvent
|
from ...event_types import NapcatEvent
|
||||||
|
|
||||||
target_message = await event_manager.trigger_event(
|
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id",""))
|
||||||
NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "")
|
target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","")
|
||||||
)
|
|
||||||
target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "")
|
|
||||||
if not target_message:
|
if not target_message:
|
||||||
logger.error("未找到对应消息")
|
logger.error("未找到对应消息")
|
||||||
return None, None
|
return None, None
|
||||||
if len(target_message_text) > 15:
|
if len(target_message_text) > 15:
|
||||||
target_message_text = target_message_text[:15] + "..."
|
target_message_text = target_message_text[:15] + "..."
|
||||||
|
|
||||||
user_info: UserInfo = UserInfo(
|
user_info: UserInfo = UserInfo(
|
||||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_nickname=user_name,
|
user_nickname=user_name,
|
||||||
user_cardname=user_cardname,
|
user_cardname=user_cardname,
|
||||||
)
|
)
|
||||||
|
|
||||||
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
||||||
await event_manager.trigger_event(
|
await event_manager.trigger_event(
|
||||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||||
permission_group=PLUGIN_NAME,
|
permission_group=PLUGIN_NAME,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
message_id=raw_message.get("message_id", ""),
|
message_id=raw_message.get("message_id",""),
|
||||||
emoji_id=like_emoji_id,
|
emoji_id=like_emoji_id
|
||||||
)
|
)
|
||||||
seg_data = Seg(
|
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]")
|
||||||
type="text",
|
|
||||||
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
|
|
||||||
)
|
|
||||||
return seg_data, user_info
|
return seg_data, user_info
|
||||||
|
|
||||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||||
if not group_id:
|
if not group_id:
|
||||||
logger.error("群ID不能为空,无法处理禁言通知")
|
logger.error("群ID不能为空,无法处理禁言通知")
|
||||||
@@ -383,7 +381,7 @@ class NoticeHandler:
|
|||||||
|
|
||||||
if user_id == 0: # 为全体禁言
|
if user_id == 0: # 为全体禁言
|
||||||
sub_type: str = "whole_ban"
|
sub_type: str = "whole_ban"
|
||||||
self._ban_operation(group_id)
|
await self._ban_operation(group_id)
|
||||||
else: # 为单人禁言
|
else: # 为单人禁言
|
||||||
# 获取被禁言人的信息
|
# 获取被禁言人的信息
|
||||||
sub_type: str = "ban"
|
sub_type: str = "ban"
|
||||||
@@ -397,7 +395,7 @@ class NoticeHandler:
|
|||||||
user_nickname=user_nickname,
|
user_nickname=user_nickname,
|
||||||
user_cardname=user_cardname,
|
user_cardname=user_cardname,
|
||||||
)
|
)
|
||||||
self._ban_operation(group_id, user_id, int(time.time() + duration))
|
await self._ban_operation(group_id, user_id, int(time.time() + duration))
|
||||||
|
|
||||||
seg_data: Seg = Seg(
|
seg_data: Seg = Seg(
|
||||||
type="notify",
|
type="notify",
|
||||||
@@ -446,7 +444,7 @@ class NoticeHandler:
|
|||||||
user_id = raw_message.get("user_id")
|
user_id = raw_message.get("user_id")
|
||||||
if user_id == 0: # 全体禁言解除
|
if user_id == 0: # 全体禁言解除
|
||||||
sub_type = "whole_lift_ban"
|
sub_type = "whole_lift_ban"
|
||||||
self._lift_operation(group_id)
|
await self._lift_operation(group_id)
|
||||||
else: # 单人禁言解除
|
else: # 单人禁言解除
|
||||||
sub_type = "lift_ban"
|
sub_type = "lift_ban"
|
||||||
# 获取被解除禁言人的信息
|
# 获取被解除禁言人的信息
|
||||||
@@ -462,7 +460,7 @@ class NoticeHandler:
|
|||||||
user_nickname=user_nickname,
|
user_nickname=user_nickname,
|
||||||
user_cardname=user_cardname,
|
user_cardname=user_cardname,
|
||||||
)
|
)
|
||||||
self._lift_operation(group_id, user_id)
|
await self._lift_operation(group_id, user_id)
|
||||||
|
|
||||||
seg_data: Seg = Seg(
|
seg_data: Seg = Seg(
|
||||||
type="notify",
|
type="notify",
|
||||||
@@ -473,7 +471,8 @@ class NoticeHandler:
|
|||||||
)
|
)
|
||||||
return seg_data, operator_info
|
return seg_data, operator_info
|
||||||
|
|
||||||
async def put_notice(self, message_base: MessageBase) -> None:
|
@staticmethod
|
||||||
|
async def put_notice(message_base: MessageBase) -> None:
|
||||||
"""
|
"""
|
||||||
将处理后的通知消息放入通知队列
|
将处理后的通知消息放入通知队列
|
||||||
"""
|
"""
|
||||||
@@ -489,7 +488,7 @@ class NoticeHandler:
|
|||||||
group_id = lift_record.group_id
|
group_id = lift_record.group_id
|
||||||
user_id = lift_record.user_id
|
user_id = lift_record.user_id
|
||||||
|
|
||||||
db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录
|
asyncio.create_task(napcat_db.delete_ban_record(lift_record)) # 从数据库中删除禁言记录
|
||||||
|
|
||||||
seg_message: Seg = await self.natural_lift(group_id, user_id)
|
seg_message: Seg = await self.natural_lift(group_id, user_id)
|
||||||
|
|
||||||
@@ -586,7 +585,8 @@ class NoticeHandler:
|
|||||||
self.banned_list.remove(ban_record)
|
self.banned_list.remove(ban_record)
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
async def send_notice(self) -> None:
|
@staticmethod
|
||||||
|
async def send_notice() -> None:
|
||||||
"""
|
"""
|
||||||
发送通知消息到Napcat
|
发送通知消息到Napcat
|
||||||
"""
|
"""
|
||||||
@@ -617,4 +617,4 @@ class NoticeHandler:
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
notice_handler = NoticeHandler()
|
notice_handler = NoticeHandler()
|
||||||
@@ -6,33 +6,7 @@ import urllib3
|
|||||||
import ssl
|
import ssl
|
||||||
import io
|
import io
|
||||||
|
|
||||||
import time
|
from .database import BanUser, napcat_db
|
||||||
from asyncio import Lock
|
|
||||||
|
|
||||||
_internal_cache = {}
|
|
||||||
_cache_lock = Lock()
|
|
||||||
CACHE_TIMEOUT = 300 # 缓存5分钟
|
|
||||||
|
|
||||||
|
|
||||||
async def get_from_cache(key: str):
|
|
||||||
async with _cache_lock:
|
|
||||||
data = _internal_cache.get(key)
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result, timestamp = data
|
|
||||||
if time.time() - timestamp < CACHE_TIMEOUT:
|
|
||||||
logger.debug(f"从缓存命中: {key}")
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def set_to_cache(key: str, value: any):
|
|
||||||
async with _cache_lock:
|
|
||||||
_internal_cache[key] = (value, time.time())
|
|
||||||
|
|
||||||
|
|
||||||
from .database import BanUser, db_manager
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("napcat_adapter")
|
logger = get_logger("napcat_adapter")
|
||||||
@@ -53,16 +27,11 @@ class SSLAdapter(urllib3.PoolManager):
|
|||||||
|
|
||||||
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||||
"""
|
"""
|
||||||
获取群相关信息 (带缓存)
|
获取群相关信息
|
||||||
|
|
||||||
返回值需要处理可能为空的情况
|
返回值需要处理可能为空的情况
|
||||||
"""
|
"""
|
||||||
cache_key = f"group_info:{group_id}"
|
logger.debug("获取群聊信息中")
|
||||||
cached_data = await get_from_cache(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
return cached_data
|
|
||||||
|
|
||||||
logger.debug(f"获取群聊信息中 (无缓存): {group_id}")
|
|
||||||
request_uuid = str(uuid.uuid4())
|
request_uuid = str(uuid.uuid4())
|
||||||
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||||
try:
|
try:
|
||||||
@@ -74,11 +43,8 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取群信息失败: {e}")
|
logger.error(f"获取群信息失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
logger.debug(socket_response)
|
||||||
data = socket_response.get("data")
|
return socket_response.get("data")
|
||||||
if data:
|
|
||||||
await set_to_cache(cache_key, data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||||
@@ -105,16 +71,11 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in
|
|||||||
|
|
||||||
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
|
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
|
||||||
"""
|
"""
|
||||||
获取群成员信息 (带缓存)
|
获取群成员信息
|
||||||
|
|
||||||
返回值需要处理可能为空的情况
|
返回值需要处理可能为空的情况
|
||||||
"""
|
"""
|
||||||
cache_key = f"member_info:{group_id}:{user_id}"
|
logger.debug("获取群成员信息中")
|
||||||
cached_data = await get_from_cache(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
return cached_data
|
|
||||||
|
|
||||||
logger.debug(f"获取群成员信息中 (无缓存): group={group_id}, user={user_id}")
|
|
||||||
request_uuid = str(uuid.uuid4())
|
request_uuid = str(uuid.uuid4())
|
||||||
payload = json.dumps(
|
payload = json.dumps(
|
||||||
{
|
{
|
||||||
@@ -132,11 +93,8 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取成员信息失败: {e}")
|
logger.error(f"获取成员信息失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
logger.debug(socket_response)
|
||||||
data = socket_response.get("data")
|
return socket_response.get("data")
|
||||||
if data:
|
|
||||||
await set_to_cache(cache_key, data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def get_image_base64(url: str) -> str:
|
async def get_image_base64(url: str) -> str:
|
||||||
@@ -179,18 +137,13 @@ def convert_image_to_gif(image_base64: str) -> str:
|
|||||||
|
|
||||||
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||||
"""
|
"""
|
||||||
获取自身信息 (带缓存)
|
获取自身信息
|
||||||
Parameters:
|
Parameters:
|
||||||
websocket: WebSocket连接对象
|
websocket: WebSocket连接对象
|
||||||
Returns:
|
Returns:
|
||||||
data: dict: 返回的自身信息
|
data: dict: 返回的自身信息
|
||||||
"""
|
"""
|
||||||
cache_key = "self_info"
|
logger.debug("获取自身信息中")
|
||||||
cached_data = await get_from_cache(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
return cached_data
|
|
||||||
|
|
||||||
logger.debug("获取自身信息中 (无缓存)")
|
|
||||||
request_uuid = str(uuid.uuid4())
|
request_uuid = str(uuid.uuid4())
|
||||||
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
|
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
|
||||||
try:
|
try:
|
||||||
@@ -202,11 +155,8 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取自身信息失败: {e}")
|
logger.error(f"获取自身信息失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
logger.debug(response)
|
||||||
data = response.get("data")
|
return response.get("data")
|
||||||
if data:
|
|
||||||
await set_to_cache(cache_key, data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_format(raw_data: str) -> str:
|
def get_image_format(raw_data: str) -> str:
|
||||||
@@ -320,10 +270,11 @@ async def read_ban_list(
|
|||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
ban_list = db_manager.get_ban_records()
|
ban_list = await napcat_db.get_ban_records()
|
||||||
lifted_list: List[BanUser] = []
|
lifted_list: List[BanUser] = []
|
||||||
logger.info("已经读取禁言列表")
|
logger.info("已经读取禁言列表")
|
||||||
for ban_record in ban_list:
|
# 复制列表以避免迭代中修改原列表问题
|
||||||
|
for ban_record in list(ban_list):
|
||||||
if ban_record.user_id == 0:
|
if ban_record.user_id == 0:
|
||||||
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
|
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
|
||||||
if fetched_group_info is None:
|
if fetched_group_info is None:
|
||||||
@@ -351,12 +302,12 @@ async def read_ban_list(
|
|||||||
ban_list.remove(ban_record)
|
ban_list.remove(ban_record)
|
||||||
else:
|
else:
|
||||||
ban_record.lift_time = lift_ban_time
|
ban_record.lift_time = lift_ban_time
|
||||||
db_manager.update_ban_record(ban_list)
|
await napcat_db.update_ban_record(ban_list)
|
||||||
return ban_list, lifted_list
|
return ban_list, lifted_list
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取禁言列表失败: {e}")
|
logger.error(f"读取禁言列表失败: {e}")
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
def save_ban_record(list: List[BanUser]):
|
async def save_ban_record(list: List[BanUser]):
|
||||||
return db_manager.update_ban_record(list)
|
return await napcat_db.update_ban_record(list)
|
||||||
Reference in New Issue
Block a user