163 lines
6.4 KiB
Python
163 lines
6.4 KiB
Python
import os
|
|
from typing import Optional, List
|
|
from dataclasses import dataclass
|
|
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
|
|
|
from src.common.logger import get_logger
|
|
|
|
logger = get_logger("napcat_adapter")
|
|
|
|
"""
|
|
表记录的方式:
|
|
| group_id | user_id | lift_time |
|
|
|----------|---------|-----------|
|
|
|
|
其中使用 user_id == 0 表示群全体禁言
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class BanUser:
|
|
"""
|
|
程序处理使用的实例
|
|
"""
|
|
|
|
user_id: int
|
|
group_id: int
|
|
lift_time: Optional[int] = Field(default=-1)
|
|
|
|
|
|
class DB_BanUser(SQLModel, table=True):
|
|
"""
|
|
表示数据库中的用户禁言记录。
|
|
使用双重主键
|
|
"""
|
|
|
|
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
|
|
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
|
|
lift_time: Optional[int] # 禁言解除的时间(时间戳)
|
|
|
|
|
|
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
|
|
"""
|
|
检查两个 BanUser 对象是否相同。
|
|
"""
|
|
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
|
|
|
|
|
|
class DatabaseManager:
|
|
"""
|
|
数据库管理类,负责与数据库交互。
|
|
"""
|
|
|
|
def __init__(self):
|
|
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
|
|
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
|
|
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
|
|
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
|
|
self._ensure_database() # 确保数据库和表已创建
|
|
|
|
def _ensure_database(self) -> None:
|
|
"""
|
|
确保数据库和表已创建。
|
|
"""
|
|
logger.info("确保数据库文件和表已创建...")
|
|
SQLModel.metadata.create_all(self.engine)
|
|
logger.info("数据库和表已创建或已存在")
|
|
|
|
def update_ban_record(self, ban_list: List[BanUser]) -> None:
|
|
# sourcery skip: class-extract-method
|
|
"""
|
|
更新禁言列表到数据库。
|
|
支持在不存在时创建新记录,对于多余的项目自动删除。
|
|
"""
|
|
with Session(self.engine) as session:
|
|
all_records = session.exec(select(DB_BanUser)).all()
|
|
for ban_user in ban_list:
|
|
statement = select(DB_BanUser).where(
|
|
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
|
|
)
|
|
if existing_record := session.exec(statement).first():
|
|
if existing_record.lift_time == ban_user.lift_time:
|
|
logger.debug(f"禁言记录未变更: {existing_record}")
|
|
continue
|
|
# 更新现有记录的 lift_time
|
|
existing_record.lift_time = ban_user.lift_time
|
|
session.add(existing_record)
|
|
logger.debug(f"更新禁言记录: {existing_record}")
|
|
else:
|
|
# 创建新记录
|
|
db_record = DB_BanUser(
|
|
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
|
|
)
|
|
session.add(db_record)
|
|
logger.debug(f"创建新禁言记录: {ban_user}")
|
|
# 删除不在 ban_list 中的记录
|
|
for db_record in all_records:
|
|
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
|
|
if not any(is_identical(record, ban_user) for ban_user in ban_list):
|
|
statement = select(DB_BanUser).where(
|
|
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
|
|
)
|
|
if ban_record := session.exec(statement).first():
|
|
session.delete(ban_record)
|
|
|
|
logger.debug(f"删除禁言记录: {ban_record}")
|
|
else:
|
|
logger.info(f"未找到禁言记录: {ban_record}")
|
|
|
|
logger.info("禁言记录已更新")
|
|
|
|
def get_ban_records(self) -> List[BanUser]:
|
|
"""
|
|
读取所有禁言记录。
|
|
"""
|
|
with Session(self.engine) as session:
|
|
statement = select(DB_BanUser)
|
|
records = session.exec(statement).all()
|
|
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
|
|
|
|
def create_ban_record(self, ban_record: BanUser) -> None:
|
|
"""
|
|
为特定群组中的用户创建禁言记录。
|
|
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
|
|
其同时还是简化版的更新方式。
|
|
"""
|
|
with Session(self.engine) as session:
|
|
# 检查记录是否已存在
|
|
statement = select(DB_BanUser).where(
|
|
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
|
|
)
|
|
existing_record = session.exec(statement).first()
|
|
if existing_record:
|
|
# 如果记录已存在,更新 lift_time
|
|
existing_record.lift_time = ban_record.lift_time
|
|
session.add(existing_record)
|
|
logger.debug(f"更新禁言记录: {ban_record}")
|
|
else:
|
|
# 如果记录不存在,创建新记录
|
|
db_record = DB_BanUser(
|
|
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
|
|
)
|
|
session.add(db_record)
|
|
logger.debug(f"创建新禁言记录: {ban_record}")
|
|
|
|
def delete_ban_record(self, ban_record: BanUser):
|
|
"""
|
|
删除特定用户在特定群组中的禁言记录。
|
|
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
|
|
"""
|
|
user_id = ban_record.user_id
|
|
group_id = ban_record.group_id
|
|
with Session(self.engine) as session:
|
|
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
|
|
if ban_record := session.exec(statement).first():
|
|
session.delete(ban_record)
|
|
|
|
logger.debug(f"删除禁言记录: {ban_record}")
|
|
else:
|
|
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
|
|
|
|
|
|
db_manager = DatabaseManager()
|