refactor(database): 将同步数据库操作迁移为异步操作
将整个项目的数据库操作从同步模式迁移为异步模式,主要涉及以下修改: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 将同步的 SQLAlchemy 查询方法改为异步执行 - 更新相关的方法签名,添加 async/await 关键字 - 修复由于异步化导致的并发问题和性能问题 这些修改提高了数据库操作的并发性能,避免了阻塞主线程,提升了系统的整体响应能力。涉及修改的模块包括表情包管理、反提示注入统计、用户封禁管理、记忆系统、消息存储等多个核心组件。 BREAKING CHANGE: 所有涉及数据库操作的方法现在都需要使用异步调用,同步调用将不再工作
This commit is contained in:
@@ -65,7 +65,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
try:
|
||||
# 触发表达学习
|
||||
learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
learner = await expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
asyncio.create_task(learner.trigger_learning_for_chat())
|
||||
|
||||
unread_messages = context.get_unread_messages()
|
||||
|
||||
@@ -69,7 +69,7 @@ class ChatterInterestScoringSystem:
|
||||
|
||||
keywords = self._extract_keywords_from_database(message)
|
||||
interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords)
|
||||
relationship_score = self._calculate_relationship_score(message.user_info.user_id)
|
||||
relationship_score = await self._calculate_relationship_score(message.user_info.user_id)
|
||||
mentioned_score = self._calculate_mentioned_score(message, bot_nickname)
|
||||
|
||||
total_score = (
|
||||
@@ -189,7 +189,7 @@ class ChatterInterestScoringSystem:
|
||||
unique_keywords = list(set(keywords))
|
||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||
|
||||
def _calculate_relationship_score(self, user_id: str) -> float:
|
||||
async def _calculate_relationship_score(self, user_id: str) -> float:
|
||||
"""计算关系分 - 从数据库获取关系分"""
|
||||
# 优先使用内存中的关系分
|
||||
if user_id in self.user_relationships:
|
||||
@@ -212,7 +212,7 @@ class ChatterInterestScoringSystem:
|
||||
|
||||
global_tracker = ChatterRelationshipTracker()
|
||||
if global_tracker:
|
||||
relationship_score = global_tracker.get_user_relationship_score(user_id)
|
||||
relationship_score = await global_tracker.get_user_relationship_score(user_id)
|
||||
# 同时更新内存缓存
|
||||
self.user_relationships[user_id] = relationship_score
|
||||
return relationship_score
|
||||
|
||||
@@ -287,7 +287,7 @@ class ChatterRelationshipTracker:
|
||||
|
||||
# ===== 数据库支持方法 =====
|
||||
|
||||
def get_user_relationship_score(self, user_id: str) -> float:
|
||||
async def get_user_relationship_score(self, user_id: str) -> float:
|
||||
"""获取用户关系分"""
|
||||
# 先检查缓存
|
||||
if user_id in self.user_relationship_cache:
|
||||
@@ -298,7 +298,7 @@ class ChatterRelationshipTracker:
|
||||
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
|
||||
# 缓存过期或不存在,从数据库获取
|
||||
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||
relationship_data = await self._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
@@ -313,37 +313,38 @@ class ChatterRelationshipTracker:
|
||||
# 数据库中也没有,返回默认值
|
||||
return global_config.affinity_flow.base_relationship_score
|
||||
|
||||
def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||
async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||
"""从数据库获取用户关系数据"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查询用户关系表
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = session.execute(stmt).scalar_one_or_none()
|
||||
result = await session.execute(stmt)
|
||||
relationship = result.scalar_one_or_none()
|
||||
|
||||
if result:
|
||||
if relationship:
|
||||
return {
|
||||
"relationship_text": result.relationship_text or "",
|
||||
"relationship_score": float(result.relationship_score)
|
||||
if result.relationship_score is not None
|
||||
"relationship_text": relationship.relationship_text or "",
|
||||
"relationship_score": float(relationship.relationship_score)
|
||||
if relationship.relationship_score is not None
|
||||
else 0.3,
|
||||
"last_updated": result.last_updated,
|
||||
"last_updated": relationship.last_updated,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取用户关系失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
|
||||
async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
|
||||
"""更新数据库中的用户关系"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 检查是否已存在关系记录
|
||||
existing = session.execute(
|
||||
select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
).scalar_one_or_none()
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
@@ -362,7 +363,7 @@ class ChatterRelationshipTracker:
|
||||
)
|
||||
session.add(new_relationship)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -399,7 +400,7 @@ class ChatterRelationshipTracker:
|
||||
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
|
||||
|
||||
# 获取当前关系数据
|
||||
current_relationship = self._get_user_relationship_from_db(user_id)
|
||||
current_relationship = await self._get_user_relationship_from_db(user_id)
|
||||
current_score = (
|
||||
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
if current_relationship
|
||||
@@ -417,14 +418,14 @@ class ChatterRelationshipTracker:
|
||||
logger.error(f"回复后关系追踪失败: {e}")
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
def _get_last_tracked_time(self, user_id: str) -> float:
|
||||
async def _get_last_tracked_time(self, user_id: str) -> float:
|
||||
"""获取上次追踪时间"""
|
||||
# 先检查缓存
|
||||
if user_id in self.user_relationship_cache:
|
||||
return self.user_relationship_cache[user_id].get("last_tracked", 0)
|
||||
|
||||
# 从数据库获取
|
||||
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||
relationship_data = await self._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
return relationship_data.get("last_updated", 0)
|
||||
|
||||
@@ -433,7 +434,7 @@ class ChatterRelationshipTracker:
|
||||
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
|
||||
"""获取上次bot回复该用户的消息"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查询bot回复给该用户的最新消息
|
||||
stmt = (
|
||||
select(Messages)
|
||||
@@ -443,10 +444,11 @@ class ChatterRelationshipTracker:
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = session.execute(stmt).scalar_one_or_none()
|
||||
if result:
|
||||
result = await session.execute(stmt)
|
||||
message = result.scalar_one_or_none()
|
||||
if message:
|
||||
# 将SQLAlchemy模型转换为DatabaseMessages对象
|
||||
return self._sqlalchemy_to_database_messages(result)
|
||||
return self._sqlalchemy_to_database_messages(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取上次回复消息失败: {e}")
|
||||
@@ -456,7 +458,7 @@ class ChatterRelationshipTracker:
|
||||
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
|
||||
"""获取用户在bot回复后的反应消息"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查询用户在回复时间之后的5分钟内的消息
|
||||
end_time = reply_time + 5 * 60 # 5分钟
|
||||
|
||||
@@ -468,9 +470,10 @@ class ChatterRelationshipTracker:
|
||||
.order_by(Messages.time)
|
||||
)
|
||||
|
||||
results = session.execute(stmt).scalars().all()
|
||||
if results:
|
||||
return [self._sqlalchemy_to_database_messages(result) for result in results]
|
||||
result = await session.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
if messages:
|
||||
return [self._sqlalchemy_to_database_messages(message) for message in messages]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户反应消息失败: {e}")
|
||||
@@ -593,7 +596,7 @@ class ChatterRelationshipTracker:
|
||||
quality = response_data.get("interaction_quality", "medium")
|
||||
|
||||
# 更新数据库
|
||||
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
await self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
@@ -696,7 +699,7 @@ class ChatterRelationshipTracker:
|
||||
)
|
||||
|
||||
# 更新数据库和缓存
|
||||
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
await self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": new_text,
|
||||
"relationship_score": new_score,
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Callable
|
||||
from src.common.logger import get_logger
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
|
||||
|
||||
from .qzone_service import QZoneService
|
||||
@@ -138,15 +139,13 @@ class SchedulerService:
|
||||
:return: 如果已处理过,返回 True,否则返回 False。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = (
|
||||
session.query(MaiZoneScheduleStatus)
|
||||
.filter(
|
||||
MaiZoneScheduleStatus.datetime_hour == hour_str,
|
||||
MaiZoneScheduleStatus.is_processed == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
async with get_db_session() as session:
|
||||
stmt = select(MaiZoneScheduleStatus).where(
|
||||
MaiZoneScheduleStatus.datetime_hour == hour_str,
|
||||
MaiZoneScheduleStatus.is_processed == True, # noqa: E712
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
return record is not None
|
||||
except Exception as e:
|
||||
logger.error(f"检查日程处理状态时发生数据库错误: {e}")
|
||||
@@ -162,11 +161,11 @@ class SchedulerService:
|
||||
:param content: 最终发送的说说内容或错误信息。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查找是否已存在该记录
|
||||
record = (
|
||||
session.query(MaiZoneScheduleStatus).filter(MaiZoneScheduleStatus.datetime_hour == hour_str).first()
|
||||
)
|
||||
stmt = select(MaiZoneScheduleStatus).where(MaiZoneScheduleStatus.datetime_hour == hour_str)
|
||||
result = await session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
if record:
|
||||
# 如果存在,则更新状态
|
||||
@@ -185,7 +184,7 @@ class SchedulerService:
|
||||
send_success=success,
|
||||
)
|
||||
session.add(new_record)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}")
|
||||
except Exception as e:
|
||||
logger.error(f"更新日程处理状态时发生数据库错误: {e}")
|
||||
|
||||
@@ -64,15 +64,9 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
|
||||
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
|
||||
# 兼容没有 post_type 的普通消息
|
||||
if not post_type and "message_type" in decoded_raw_message:
|
||||
decoded_raw_message["post_type"] = "message"
|
||||
post_type = "message"
|
||||
|
||||
if post_type in ["meta_event", "message", "notice"]:
|
||||
await message_queue.put(decoded_raw_message)
|
||||
else:
|
||||
elif post_type is None:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -428,8 +422,9 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
def get_plugin_components(self):
|
||||
self.register_events()
|
||||
|
||||
components = [(LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler),
|
||||
(StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)]
|
||||
components = []
|
||||
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
|
||||
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
|
||||
for handler in get_classes_in_module(event_handlers):
|
||||
if issubclass(handler, BaseEventHandler):
|
||||
components.append((handler.get_handler_info(), handler))
|
||||
|
||||
@@ -1,156 +1,162 @@
|
||||
"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API)
|
||||
|
||||
本模块替换原先的 sqlmodel + 同步Session 实现:
|
||||
1. 复用主项目的异步数据库连接与迁移体系
|
||||
2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record)
|
||||
3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records
|
||||
|
||||
数据语义:
|
||||
user_id == 0 表示群全体禁言
|
||||
|
||||
注意: 所有方法均为异步, 需要在 async 上下文中调用。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Sequence
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.common.database.sqlalchemy_models import Base, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
"""
|
||||
表记录的方式:
|
||||
| group_id | user_id | lift_time |
|
||||
|----------|---------|-----------|
|
||||
|
||||
class NapcatBanRecord(Base):
|
||||
__tablename__ = "napcat_ban_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
group_id = Column(BigInteger, nullable=False, index=True)
|
||||
user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言
|
||||
lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"),
|
||||
Index("idx_napcat_ban_group", "group_id"),
|
||||
Index("idx_napcat_ban_user", "user_id"),
|
||||
)
|
||||
其中使用 user_id == 0 表示群全体禁言
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BanUser:
|
||||
"""
|
||||
程序处理使用的实例
|
||||
"""
|
||||
|
||||
user_id: int
|
||||
group_id: int
|
||||
lift_time: Optional[int] = -1
|
||||
|
||||
def identity(self) -> tuple[int, int]:
|
||||
return self.group_id, self.user_id
|
||||
lift_time: Optional[int] = Field(default=-1)
|
||||
|
||||
|
||||
class NapcatDatabase:
|
||||
async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]:
|
||||
result = await session.execute(select(NapcatBanRecord))
|
||||
return result.scalars().all()
|
||||
class DB_BanUser(SQLModel, table=True):
|
||||
"""
|
||||
表示数据库中的用户禁言记录。
|
||||
使用双重主键
|
||||
"""
|
||||
|
||||
async def get_ban_records(self) -> List[BanUser]:
|
||||
async with get_db_session() as session:
|
||||
rows = await self._fetch_all(session)
|
||||
return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows]
|
||||
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
|
||||
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
|
||||
lift_time: Optional[int] # 禁言解除的时间(时间戳)
|
||||
|
||||
async def update_ban_records(self, ban_list: List[BanUser]) -> None:
|
||||
target_map = {b.identity(): b for b in ban_list}
|
||||
async with get_db_session() as session:
|
||||
rows = await self._fetch_all(session)
|
||||
existing_map = {(r.group_id, r.user_id): r for r in rows}
|
||||
|
||||
changed = 0
|
||||
for ident, ban in target_map.items():
|
||||
if ident in existing_map:
|
||||
row = existing_map[ident]
|
||||
if row.lift_time != ban.lift_time:
|
||||
row.lift_time = ban.lift_time
|
||||
changed += 1
|
||||
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
|
||||
"""
|
||||
检查两个 BanUser 对象是否相同。
|
||||
"""
|
||||
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
数据库管理类,负责与数据库交互。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
|
||||
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
|
||||
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
|
||||
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
|
||||
self._ensure_database() # 确保数据库和表已创建
|
||||
|
||||
def _ensure_database(self) -> None:
|
||||
"""
|
||||
确保数据库和表已创建。
|
||||
"""
|
||||
logger.info("确保数据库文件和表已创建...")
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
logger.info("数据库和表已创建或已存在")
|
||||
|
||||
def update_ban_record(self, ban_list: List[BanUser]) -> None:
|
||||
# sourcery skip: class-extract-method
|
||||
"""
|
||||
更新禁言列表到数据库。
|
||||
支持在不存在时创建新记录,对于多余的项目自动删除。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
all_records = session.exec(select(DB_BanUser)).all()
|
||||
for ban_user in ban_list:
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
|
||||
)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.lift_time == ban_user.lift_time:
|
||||
logger.debug(f"禁言记录未变更: {existing_record}")
|
||||
continue
|
||||
# 更新现有记录的 lift_time
|
||||
existing_record.lift_time = ban_user.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {existing_record}")
|
||||
else:
|
||||
session.add(
|
||||
NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time)
|
||||
# 创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
|
||||
)
|
||||
changed += 1
|
||||
|
||||
removed = 0
|
||||
for ident, row in existing_map.items():
|
||||
if ident not in target_map:
|
||||
await session.delete(row)
|
||||
removed += 1
|
||||
|
||||
logger.debug(
|
||||
f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}"
|
||||
)
|
||||
|
||||
async def create_or_update(self, ban_record: BanUser) -> None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(NapcatBanRecord).where(
|
||||
NapcatBanRecord.group_id == ban_record.group_id,
|
||||
NapcatBanRecord.user_id == ban_record.user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalars().first()
|
||||
if row:
|
||||
if row.lift_time != ban_record.lift_time:
|
||||
row.lift_time = ban_record.lift_time
|
||||
logger.debug(
|
||||
f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_user}")
|
||||
# 删除不在 ban_list 中的记录
|
||||
for db_record in all_records:
|
||||
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
|
||||
if not any(is_identical(record, ban_user) for ban_user in ban_list):
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
|
||||
)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: {ban_record}")
|
||||
|
||||
logger.info("禁言记录已更新")
|
||||
|
||||
def get_ban_records(self) -> List[BanUser]:
|
||||
"""
|
||||
读取所有禁言记录。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser)
|
||||
records = session.exec(statement).all()
|
||||
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
|
||||
|
||||
def create_ban_record(self, ban_record: BanUser) -> None:
|
||||
"""
|
||||
为特定群组中的用户创建禁言记录。
|
||||
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
|
||||
其同时还是简化版的更新方式。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
# 检查记录是否已存在
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
|
||||
)
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record:
|
||||
# 如果记录已存在,更新 lift_time
|
||||
existing_record.lift_time = ban_record.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {ban_record}")
|
||||
else:
|
||||
session.add(
|
||||
NapcatBanRecord(
|
||||
group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
|
||||
# 如果记录不存在,创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_record}")
|
||||
|
||||
async def delete_record(self, ban_record: BanUser) -> None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(NapcatBanRecord).where(
|
||||
NapcatBanRecord.group_id == ban_record.group_id,
|
||||
NapcatBanRecord.user_id == ban_record.user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalars().first()
|
||||
if row:
|
||||
await session.delete(row)
|
||||
logger.debug(
|
||||
f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}"
|
||||
)
|
||||
def delete_ban_record(self, ban_record: BanUser):
|
||||
"""
|
||||
删除特定用户在特定群组中的禁言记录。
|
||||
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
|
||||
"""
|
||||
user_id = ban_record.user_id
|
||||
group_id = ban_record.group_id
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(
|
||||
f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}"
|
||||
)
|
||||
|
||||
# 兼容旧命名
|
||||
async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name
|
||||
await self.update_ban_records(ban_list)
|
||||
|
||||
async def create_ban_record(self, ban_record: BanUser) -> None: # old name
|
||||
await self.create_or_update(ban_record)
|
||||
|
||||
async def delete_ban_record(self, ban_record: BanUser) -> None: # old name
|
||||
await self.delete_record(ban_record)
|
||||
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
|
||||
|
||||
|
||||
napcat_db = NapcatDatabase()
|
||||
|
||||
|
||||
def is_identical(a: BanUser, b: BanUser) -> bool:
|
||||
return a.group_id == b.group_id and a.user_id == b.user_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BanUser",
|
||||
"NapcatBanRecord",
|
||||
"napcat_db",
|
||||
"is_identical",
|
||||
]
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
@@ -112,8 +112,7 @@ class MessageChunker:
|
||||
else:
|
||||
return [{"_original_message": message}]
|
||||
|
||||
@staticmethod
|
||||
def is_chunk_message(message: Union[str, Dict[str, Any]]) -> bool:
|
||||
def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
|
||||
"""判断是否是切片消息"""
|
||||
try:
|
||||
if isinstance(message, str):
|
||||
|
||||
@@ -14,7 +14,6 @@ class MetaEventHandler:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_heart_beat = time.time()
|
||||
self.interval = 5.0 # 默认值,稍后通过set_plugin_config设置
|
||||
self._interval_checking = False
|
||||
self.plugin_config = None
|
||||
@@ -40,6 +39,7 @@ class MetaEventHandler:
|
||||
if message["status"].get("online") and message["status"].get("good"):
|
||||
if not self._interval_checking:
|
||||
asyncio.create_task(self.check_heartbeat())
|
||||
self.last_heart_beat = time.time()
|
||||
self.interval = message.get("interval") / 1000
|
||||
else:
|
||||
self_id = message.get("self_id")
|
||||
|
||||
@@ -76,7 +76,7 @@ class SendHandler:
|
||||
processed_message = await self.handle_seg_recursive(message_segment, user_info)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时发生错误: {e}")
|
||||
return None
|
||||
return
|
||||
|
||||
if not processed_message:
|
||||
logger.critical("现在暂时不支持解析此回复!")
|
||||
@@ -94,7 +94,7 @@ class SendHandler:
|
||||
id_name = "user_id"
|
||||
else:
|
||||
logger.error("无法识别的消息类型")
|
||||
return None
|
||||
return
|
||||
logger.info("尝试发送到napcat")
|
||||
logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'")
|
||||
response = await self.send_message_to_napcat(
|
||||
@@ -108,10 +108,8 @@ class SendHandler:
|
||||
logger.info("消息发送成功")
|
||||
qq_message_id = response.get("data", {}).get("message_id")
|
||||
await self.message_sent_back(raw_message_base, qq_message_id)
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"消息发送失败,napcat返回:{str(response)}")
|
||||
return None
|
||||
|
||||
async def send_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
@@ -149,7 +147,7 @@ class SendHandler:
|
||||
command, args_dict = self.handle_send_like_command(args)
|
||||
case _:
|
||||
logger.error(f"未知命令: {command_name}")
|
||||
return None
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时发生错误: {e}")
|
||||
return None
|
||||
@@ -161,10 +159,8 @@ class SendHandler:
|
||||
response = await self.send_message_to_napcat(command, args_dict)
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"命令 {command_name} 执行成功")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}")
|
||||
return None
|
||||
|
||||
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
@@ -272,8 +268,7 @@ class SendHandler:
|
||||
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
|
||||
return new_payload
|
||||
|
||||
@staticmethod
|
||||
def build_payload(payload: list, addon: dict | list, is_reply: bool = False) -> list:
|
||||
def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list:
|
||||
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
|
||||
"""构建发送的消息体"""
|
||||
if is_reply:
|
||||
@@ -339,13 +334,11 @@ class SendHandler:
|
||||
logger.info(f"最终返回的回复段: {reply_seg}")
|
||||
return reply_seg
|
||||
|
||||
@staticmethod
|
||||
def handle_text_message(message: str) -> dict:
|
||||
def handle_text_message(self, message: str) -> dict:
|
||||
"""处理文本消息"""
|
||||
return {"type": "text", "data": {"text": message}}
|
||||
|
||||
@staticmethod
|
||||
def handle_image_message(encoded_image: str) -> dict:
|
||||
def handle_image_message(self, encoded_image: str) -> dict:
|
||||
"""处理图片消息"""
|
||||
return {
|
||||
"type": "image",
|
||||
@@ -355,8 +348,7 @@ class SendHandler:
|
||||
},
|
||||
} # base64 编码的图片
|
||||
|
||||
@staticmethod
|
||||
def handle_emoji_message(encoded_emoji: str) -> dict:
|
||||
def handle_emoji_message(self, encoded_emoji: str) -> dict:
|
||||
"""处理表情消息"""
|
||||
encoded_image = encoded_emoji
|
||||
image_format = get_image_format(encoded_emoji)
|
||||
@@ -387,45 +379,39 @@ class SendHandler:
|
||||
"data": {"file": f"base64://{encoded_voice}"},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_voiceurl_message(voice_url: str) -> dict:
|
||||
def handle_voiceurl_message(self, voice_url: str) -> dict:
|
||||
"""处理语音链接消息"""
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": voice_url},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_music_message(song_id: str) -> dict:
|
||||
def handle_music_message(self, song_id: str) -> dict:
|
||||
"""处理音乐消息"""
|
||||
return {
|
||||
"type": "music",
|
||||
"data": {"type": "163", "id": song_id},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_videourl_message(video_url: str) -> dict:
|
||||
def handle_videourl_message(self, video_url: str) -> dict:
|
||||
"""处理视频链接消息"""
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {"file": video_url},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def handle_file_message(file_path: str) -> dict:
|
||||
def handle_file_message(self, file_path: str) -> dict:
|
||||
"""处理文件消息"""
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {"file": f"file://{file_path}"},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def delete_msg_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理删除消息命令"""
|
||||
return "delete_msg", {"message_id": args["message_id"]}
|
||||
|
||||
@staticmethod
|
||||
def handle_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理封禁命令
|
||||
|
||||
Args:
|
||||
@@ -453,8 +439,7 @@ class SendHandler:
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_whole_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理全体禁言命令
|
||||
|
||||
Args:
|
||||
@@ -477,8 +462,7 @@ class SendHandler:
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_kick_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理群成员踢出命令
|
||||
|
||||
Args:
|
||||
@@ -503,8 +487,7 @@ class SendHandler:
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_poke_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理戳一戳命令
|
||||
|
||||
Args:
|
||||
@@ -531,8 +514,7 @@ class SendHandler:
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_set_emoji_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理设置表情回应命令
|
||||
|
||||
Args:
|
||||
@@ -554,8 +536,7 @@ class SendHandler:
|
||||
{"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_send_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理发送点赞命令的逻辑。
|
||||
|
||||
@@ -576,8 +557,7 @@ class SendHandler:
|
||||
{"user_id": user_id, "times": times},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_ai_voice_send_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理AI语音发送命令的逻辑。
|
||||
并返回 NapCat 兼容的 (action, params) 元组。
|
||||
@@ -624,8 +604,7 @@ class SendHandler:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
async def message_sent_back(message_base: MessageBase, qq_message_id: str) -> None:
|
||||
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
|
||||
# 修改 additional_config,添加 echo 字段
|
||||
if message_base.message_info.additional_config is None:
|
||||
message_base.message_info.additional_config = {}
|
||||
@@ -643,9 +622,8 @@ class SendHandler:
|
||||
logger.debug("已回送消息ID")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def send_adapter_command_response(
|
||||
original_message: MessageBase, response_data: dict, request_id: str
|
||||
self, original_message: MessageBase, response_data: dict, request_id: str
|
||||
) -> None:
|
||||
"""
|
||||
发送适配器命令响应回MaiBot
|
||||
@@ -674,8 +652,7 @@ class SendHandler:
|
||||
except Exception as e:
|
||||
logger.error(f"发送适配器命令响应时出错: {e}")
|
||||
|
||||
@staticmethod
|
||||
def handle_at_message_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理艾特并发送消息命令
|
||||
|
||||
Args:
|
||||
|
||||
@@ -6,7 +6,7 @@ import urllib3
|
||||
import ssl
|
||||
import io
|
||||
|
||||
from .database import BanUser, napcat_db
|
||||
from .database import BanUser, db_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
@@ -270,11 +270,10 @@ async def read_ban_list(
|
||||
]
|
||||
"""
|
||||
try:
|
||||
ban_list = await napcat_db.get_ban_records()
|
||||
ban_list = db_manager.get_ban_records()
|
||||
lifted_list: List[BanUser] = []
|
||||
logger.info("已经读取禁言列表")
|
||||
# 复制列表以避免迭代中修改原列表问题
|
||||
for ban_record in list(ban_list):
|
||||
for ban_record in ban_list:
|
||||
if ban_record.user_id == 0:
|
||||
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
|
||||
if fetched_group_info is None:
|
||||
@@ -302,12 +301,12 @@ async def read_ban_list(
|
||||
ban_list.remove(ban_record)
|
||||
else:
|
||||
ban_record.lift_time = lift_ban_time
|
||||
await napcat_db.update_ban_record(ban_list)
|
||||
db_manager.update_ban_record(ban_list)
|
||||
return ban_list, lifted_list
|
||||
except Exception as e:
|
||||
logger.error(f"读取禁言列表失败: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
async def save_ban_record(list: List[BanUser]):
|
||||
return await napcat_db.update_ban_record(list)
|
||||
def save_ban_record(list: List[BanUser]):
|
||||
return db_manager.update_ban_record(list)
|
||||
|
||||
Reference in New Issue
Block a user