diff --git a/src/common/database/connection_pool_manager.py b/src/common/database/connection_pool_manager.py index 4ca789b6f..d0a68e8d4 100644 --- a/src/common/database/connection_pool_manager.py +++ b/src/common/database/connection_pool_manager.py @@ -55,6 +55,8 @@ class ConnectionInfo: try: await self.session.close() logger.debug("连接已关闭") + except asyncio.CancelledError: + logger.warning("关闭连接时任务被取消") except Exception as e: logger.warning(f"关闭连接时出错: {e}") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/database.py b/src/plugins/built_in/napcat_adapter_plugin/src/database.py index 74842eed5..c0eb471ee 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/database.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/database.py @@ -1,162 +1,156 @@ -import os -from typing import Optional, List -from dataclasses import dataclass -from sqlmodel import Field, Session, SQLModel, create_engine, select +"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API) +本模块替换原先的 sqlmodel + 同步Session 实现: +1. 复用主项目的异步数据库连接与迁移体系 +2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record) +3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records + +数据语义: + user_id == 0 表示群全体禁言 + +注意: 所有方法均为异步, 需要在 async 上下文中调用。 +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, List, Sequence + +from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index +from sqlalchemy.ext.asyncio import AsyncSession + +from src.common.database.sqlalchemy_models import Base, get_db_session from src.common.logger import get_logger logger = get_logger("napcat_adapter") -""" -表记录的方式: -| group_id | user_id | lift_time | -|----------|---------|-----------| -其中使用 user_id == 0 表示群全体禁言 -""" +class NapcatBanRecord(Base): + __tablename__ = "napcat_ban_records" + + id = Column(Integer, primary_key=True, autoincrement=True) + group_id = Column(BigInteger, nullable=False, index=True) + user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言 + lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久 + + __table_args__ = ( + UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"), + Index("idx_napcat_ban_group", "group_id"), + Index("idx_napcat_ban_user", "user_id"), + ) @dataclass class BanUser: - """ - 程序处理使用的实例 - """ - user_id: int group_id: int - lift_time: Optional[int] = Field(default=-1) + lift_time: Optional[int] = -1 + + def identity(self) -> tuple[int, int]: + return self.group_id, self.user_id -class DB_BanUser(SQLModel, table=True): - """ - 表示数据库中的用户禁言记录。 - 使用双重主键 - """ +class NapcatDatabase: + async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]: + result = await session.execute(select(NapcatBanRecord)) + return result.scalars().all() - user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID - group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID - lift_time: Optional[int] # 禁言解除的时间(时间戳) + async def get_ban_records(self) -> List[BanUser]: + async with get_db_session() as session: + rows = await self._fetch_all(session) + return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows] + async def update_ban_records(self, ban_list: List[BanUser]) -> None: + target_map = {b.identity(): b for b in ban_list} + async with get_db_session() as session: + rows = await self._fetch_all(session) + existing_map = {(r.group_id, r.user_id): r for r in rows} -def is_identical(obj1: BanUser, obj2: BanUser) -> bool: - """ - 检查两个 BanUser 对象是否相同。 - """ - return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id - - -class DatabaseManager: - """ - 数据库管理类,负责与数据库交互。 - """ - - def __init__(self): - os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在 - DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db") - self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL - self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎 - self._ensure_database() # 确保数据库和表已创建 - - def _ensure_database(self) -> None: - """ - 确保数据库和表已创建。 - """ - logger.info("确保数据库文件和表已创建...") - SQLModel.metadata.create_all(self.engine) - logger.info("数据库和表已创建或已存在") - - def update_ban_record(self, ban_list: List[BanUser]) -> None: - # sourcery skip: class-extract-method - """ - 更新禁言列表到数据库。 - 支持在不存在时创建新记录,对于多余的项目自动删除。 - """ - with Session(self.engine) as session: - all_records = session.exec(select(DB_BanUser)).all() - for ban_user in ban_list: - statement = select(DB_BanUser).where( - DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id - ) - if existing_record := session.exec(statement).first(): - if existing_record.lift_time == ban_user.lift_time: - logger.debug(f"禁言记录未变更: {existing_record}") - continue - # 更新现有记录的 lift_time - existing_record.lift_time = ban_user.lift_time - session.add(existing_record) - logger.debug(f"更新禁言记录: {existing_record}") + changed = 0 + for ident, ban in target_map.items(): + if ident in existing_map: + row = existing_map[ident] + if row.lift_time != ban.lift_time: + row.lift_time = ban.lift_time + changed += 1 else: - # 创建新记录 - db_record = DB_BanUser( - user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time + session.add( + NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time) ) - session.add(db_record) - logger.debug(f"创建新禁言记录: {ban_user}") - # 删除不在 ban_list 中的记录 - for db_record in all_records: - record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) - if not any(is_identical(record, ban_user) for ban_user in ban_list): - statement = select(DB_BanUser).where( - DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id - ) - if ban_record := session.exec(statement).first(): - session.delete(ban_record) + changed += 1 - logger.debug(f"删除禁言记录: {ban_record}") - else: - logger.info(f"未找到禁言记录: {ban_record}") + removed = 0 + for ident, row in existing_map.items(): + if ident not in target_map: + await session.delete(row) + removed += 1 - logger.info("禁言记录已更新") - - def get_ban_records(self) -> List[BanUser]: - """ - 读取所有禁言记录。 - """ - with Session(self.engine) as session: - statement = select(DB_BanUser) - records = session.exec(statement).all() - return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] - - def create_ban_record(self, ban_record: BanUser) -> None: - """ - 为特定群组中的用户创建禁言记录。 - 一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。 - 其同时还是简化版的更新方式。 - """ - with Session(self.engine) as session: - # 检查记录是否已存在 - statement = select(DB_BanUser).where( - DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id + logger.debug( + f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}" ) - existing_record = session.exec(statement).first() - if existing_record: - # 如果记录已存在,更新 lift_time - existing_record.lift_time = ban_record.lift_time - session.add(existing_record) - logger.debug(f"更新禁言记录: {ban_record}") + + async def create_or_update(self, ban_record: BanUser) -> None: + async with get_db_session() as session: + stmt = select(NapcatBanRecord).where( + NapcatBanRecord.group_id == ban_record.group_id, + NapcatBanRecord.user_id == ban_record.user_id, + ) + result = await session.execute(stmt) + row = result.scalars().first() + if row: + if row.lift_time != ban_record.lift_time: + row.lift_time = ban_record.lift_time + logger.debug( + f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}" + ) else: - # 如果记录不存在,创建新记录 - db_record = DB_BanUser( - user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + session.add( + NapcatBanRecord( + group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time + ) + ) + logger.debug( + f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}" ) - session.add(db_record) - logger.debug(f"创建新禁言记录: {ban_record}") - def delete_ban_record(self, ban_record: BanUser): - """ - 删除特定用户在特定群组中的禁言记录。 - 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 - """ - user_id = ban_record.user_id - group_id = ban_record.group_id - with Session(self.engine) as session: - statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) - if ban_record := session.exec(statement).first(): - session.delete(ban_record) - - logger.debug(f"删除禁言记录: {ban_record}") + async def delete_record(self, ban_record: BanUser) -> None: + async with get_db_session() as session: + stmt = select(NapcatBanRecord).where( + NapcatBanRecord.group_id == ban_record.group_id, + NapcatBanRecord.user_id == ban_record.user_id, + ) + result = await session.execute(stmt) + row = result.scalars().first() + if row: + await session.delete(row) + logger.debug( + f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}" + ) else: - logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}") + logger.info( + f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}" + ) + + # 兼容旧命名 + async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name + await self.update_ban_records(ban_list) + + async def create_ban_record(self, ban_record: BanUser) -> None: # old name + await self.create_or_update(ban_record) + + async def delete_ban_record(self, ban_record: BanUser) -> None: # old name + await self.delete_record(ban_record) -db_manager = DatabaseManager() +napcat_db = NapcatDatabase() + + +def is_identical(a: BanUser, b: BanUser) -> bool: + return a.group_id == b.group_id and a.user_id == b.user_id + + +__all__ = [ + "BanUser", + "NapcatBanRecord", + "napcat_db", + "is_identical", +] \ No newline at end of file diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 5ea018f4d..a15c46d8f 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -9,7 +9,7 @@ from src.common.logger import get_logger logger = get_logger("napcat_adapter") from src.plugin_system.apis import config_api -from ..database import BanUser, db_manager, is_identical +from ..database import BanUser, napcat_db, is_identical from . import NoticeType, ACCEPT_FORMAT from .message_sending import message_send_instance from .message_handler import message_handler @@ -62,7 +62,7 @@ class NoticeHandler: return self.server_connection return websocket_manager.get_connection() - def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: + async def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: """ 将用户禁言记录添加到self.banned_list中 如果是全体禁言,则user_id为0 @@ -71,16 +71,16 @@ class NoticeHandler: user_id = 0 # 使用0表示全体禁言 lift_time = -1 ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time) - for record in self.banned_list: + for record in list(self.banned_list): if is_identical(record, ban_record): self.banned_list.remove(record) self.banned_list.append(ban_record) - db_manager.create_ban_record(ban_record) # 作为更新 + await napcat_db.create_ban_record(ban_record) # 更新 return self.banned_list.append(ban_record) - db_manager.create_ban_record(ban_record) # 添加到数据库 + await napcat_db.create_ban_record(ban_record) # 新建 - def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: + async def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: """ 从self.lifted_group_list中移除已经解除全体禁言的群 """ @@ -88,7 +88,12 @@ class NoticeHandler: user_id = 0 # 使用0表示全体禁言 ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1) self.lifted_list.append(ban_record) - db_manager.delete_ban_record(ban_record) # 删除数据库中的记录 + # 从被禁言列表里移除对应记录 + for record in list(self.banned_list): + if is_identical(record, ban_record): + self.banned_list.remove(record) + break + await napcat_db.delete_ban_record(ban_record) async def handle_notice(self, raw_message: dict) -> None: notice_type = raw_message.get("notice_type") @@ -116,9 +121,9 @@ class NoticeHandler: sub_type = raw_message.get("sub_type") match sub_type: case NoticeType.Notify.poke: - if config_api.get_plugin_config( - self.plugin_config, "features.enable_poke", True - ) and await message_handler.check_allow_to_chat(user_id, group_id, False, False): + if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat( + user_id, group_id, False, False + ): logger.debug("处理戳一戳消息") handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) else: @@ -127,18 +132,14 @@ class NoticeHandler: from src.plugin_system.core.event_manager import event_manager from ...event_types import NapcatEvent - await event_manager.trigger_event( - NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME - ) + await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME) case _: logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") - case NoticeType.group_msg_emoji_like: + case NoticeType.group_msg_emoji_like: # 该事件转移到 handle_group_emoji_like_notify函数内触发 if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True): logger.debug("处理群聊表情回复") - handled_message, user_info = await self.handle_group_emoji_like_notify( - raw_message, group_id, user_id - ) + handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id) else: logger.warning("群聊表情回复被禁用,取消群聊表情回复处理") case NoticeType.group_ban: @@ -201,9 +202,11 @@ class NoticeHandler: if system_notice: await self.put_notice(message_base) + return None else: logger.debug("发送到Maibot处理通知信息") await message_send_instance.message_send(message_base) + return None async def handle_poke_notify( self, raw_message: dict, group_id: int, user_id: int @@ -298,7 +301,7 @@ class NoticeHandler: async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int): if not group_id: logger.error("群ID不能为空,无法处理群聊表情回复通知") - return None, None + return None, None user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id) if user_qq_info: @@ -308,42 +311,37 @@ class NoticeHandler: user_name = "QQ用户" user_cardname = "QQ用户" logger.debug("无法获取表情回复对方的用户昵称") - + from src.plugin_system.core.event_manager import event_manager from ...event_types import NapcatEvent - target_message = await event_manager.trigger_event( - NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "") - ) - target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "") + target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id","")) + target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","") if not target_message: logger.error("未找到对应消息") return None, None if len(target_message_text) > 15: target_message_text = target_message_text[:15] + "..." - + user_info: UserInfo = UserInfo( platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=user_id, user_nickname=user_name, user_cardname=user_cardname, ) - + like_emoji_id = raw_message.get("likes")[0].get("emoji_id") await event_manager.trigger_event( - NapcatEvent.ON_RECEIVED.EMOJI_LIEK, - permission_group=PLUGIN_NAME, - group_id=group_id, - user_id=user_id, - message_id=raw_message.get("message_id", ""), - emoji_id=like_emoji_id, - ) - seg_data = Seg( - type="text", - data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]", - ) + NapcatEvent.ON_RECEIVED.EMOJI_LIEK, + permission_group=PLUGIN_NAME, + group_id=group_id, + user_id=user_id, + message_id=raw_message.get("message_id",""), + emoji_id=like_emoji_id + ) + seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]") return seg_data, user_info - + async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: if not group_id: logger.error("群ID不能为空,无法处理禁言通知") @@ -383,7 +381,7 @@ class NoticeHandler: if user_id == 0: # 为全体禁言 sub_type: str = "whole_ban" - self._ban_operation(group_id) + await self._ban_operation(group_id) else: # 为单人禁言 # 获取被禁言人的信息 sub_type: str = "ban" @@ -397,7 +395,7 @@ class NoticeHandler: user_nickname=user_nickname, user_cardname=user_cardname, ) - self._ban_operation(group_id, user_id, int(time.time() + duration)) + await self._ban_operation(group_id, user_id, int(time.time() + duration)) seg_data: Seg = Seg( type="notify", @@ -446,7 +444,7 @@ class NoticeHandler: user_id = raw_message.get("user_id") if user_id == 0: # 全体禁言解除 sub_type = "whole_lift_ban" - self._lift_operation(group_id) + await self._lift_operation(group_id) else: # 单人禁言解除 sub_type = "lift_ban" # 获取被解除禁言人的信息 @@ -462,7 +460,7 @@ class NoticeHandler: user_nickname=user_nickname, user_cardname=user_cardname, ) - self._lift_operation(group_id, user_id) + await self._lift_operation(group_id, user_id) seg_data: Seg = Seg( type="notify", @@ -473,7 +471,8 @@ class NoticeHandler: ) return seg_data, operator_info - async def put_notice(self, message_base: MessageBase) -> None: + @staticmethod + async def put_notice(message_base: MessageBase) -> None: """ 将处理后的通知消息放入通知队列 """ @@ -489,7 +488,7 @@ class NoticeHandler: group_id = lift_record.group_id user_id = lift_record.user_id - db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录 + asyncio.create_task(napcat_db.delete_ban_record(lift_record)) # 从数据库中删除禁言记录 seg_message: Seg = await self.natural_lift(group_id, user_id) @@ -586,7 +585,8 @@ class NoticeHandler: self.banned_list.remove(ban_record) await asyncio.sleep(5) - async def send_notice(self) -> None: + @staticmethod + async def send_notice() -> None: """ 发送通知消息到Napcat """ @@ -617,4 +617,4 @@ class NoticeHandler: await asyncio.sleep(1) -notice_handler = NoticeHandler() +notice_handler = NoticeHandler() \ No newline at end of file diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py index 3ec4ca181..a2e1d548b 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py @@ -6,33 +6,7 @@ import urllib3 import ssl import io -import time -from asyncio import Lock - -_internal_cache = {} -_cache_lock = Lock() -CACHE_TIMEOUT = 300 # 缓存5分钟 - - -async def get_from_cache(key: str): - async with _cache_lock: - data = _internal_cache.get(key) - if not data: - return None - - result, timestamp = data - if time.time() - timestamp < CACHE_TIMEOUT: - logger.debug(f"从缓存命中: {key}") - return result - return None - - -async def set_to_cache(key: str, value: any): - async with _cache_lock: - _internal_cache[key] = (value, time.time()) - - -from .database import BanUser, db_manager +from .database import BanUser, napcat_db from src.common.logger import get_logger logger = get_logger("napcat_adapter") @@ -53,16 +27,11 @@ class SSLAdapter(urllib3.PoolManager): async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: """ - 获取群相关信息 (带缓存) + 获取群相关信息 返回值需要处理可能为空的情况 """ - cache_key = f"group_info:{group_id}" - cached_data = await get_from_cache(cache_key) - if cached_data: - return cached_data - - logger.debug(f"获取群聊信息中 (无缓存): {group_id}") + logger.debug("获取群聊信息中") request_uuid = str(uuid.uuid4()) payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}) try: @@ -74,11 +43,8 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d except Exception as e: logger.error(f"获取群信息失败: {e}") return None - - data = socket_response.get("data") - if data: - await set_to_cache(cache_key, data) - return data + logger.debug(socket_response) + return socket_response.get("data") async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None: @@ -105,16 +71,11 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None: """ - 获取群成员信息 (带缓存) + 获取群成员信息 返回值需要处理可能为空的情况 """ - cache_key = f"member_info:{group_id}:{user_id}" - cached_data = await get_from_cache(cache_key) - if cached_data: - return cached_data - - logger.debug(f"获取群成员信息中 (无缓存): group={group_id}, user={user_id}") + logger.debug("获取群成员信息中") request_uuid = str(uuid.uuid4()) payload = json.dumps( { @@ -132,11 +93,8 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use except Exception as e: logger.error(f"获取成员信息失败: {e}") return None - - data = socket_response.get("data") - if data: - await set_to_cache(cache_key, data) - return data + logger.debug(socket_response) + return socket_response.get("data") async def get_image_base64(url: str) -> str: @@ -179,18 +137,13 @@ def convert_image_to_gif(image_base64: str) -> str: async def get_self_info(websocket: Server.ServerConnection) -> dict | None: """ - 获取自身信息 (带缓存) + 获取自身信息 Parameters: websocket: WebSocket连接对象 Returns: data: dict: 返回的自身信息 """ - cache_key = "self_info" - cached_data = await get_from_cache(cache_key) - if cached_data: - return cached_data - - logger.debug("获取自身信息中 (无缓存)") + logger.debug("获取自身信息中") request_uuid = str(uuid.uuid4()) payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}) try: @@ -202,11 +155,8 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None: except Exception as e: logger.error(f"获取自身信息失败: {e}") return None - - data = response.get("data") - if data: - await set_to_cache(cache_key, data) - return data + logger.debug(response) + return response.get("data") def get_image_format(raw_data: str) -> str: @@ -320,10 +270,11 @@ async def read_ban_list( ] """ try: - ban_list = db_manager.get_ban_records() + ban_list = await napcat_db.get_ban_records() lifted_list: List[BanUser] = [] logger.info("已经读取禁言列表") - for ban_record in ban_list: + # 复制列表以避免迭代中修改原列表问题 + for ban_record in list(ban_list): if ban_record.user_id == 0: fetched_group_info = await get_group_info(websocket, ban_record.group_id) if fetched_group_info is None: @@ -351,12 +302,12 @@ async def read_ban_list( ban_list.remove(ban_record) else: ban_record.lift_time = lift_ban_time - db_manager.update_ban_record(ban_list) + await napcat_db.update_ban_record(ban_list) return ban_list, lifted_list except Exception as e: logger.error(f"读取禁言列表失败: {e}") return [], [] -def save_ban_record(list: List[BanUser]): - return db_manager.update_ban_record(list) +async def save_ban_record(list: List[BanUser]): + return await napcat_db.update_ban_record(list) \ No newline at end of file