refactor(database): 将同步数据库操作迁移为异步操作

将整个项目的数据库操作从同步模式迁移为异步模式,主要涉及以下修改:

- 将 `with get_db_session()` 改为 `async with get_db_session()`
- 将同步的 SQLAlchemy 查询方法改为异步执行
- 更新相关的方法签名,添加 async/await 关键字
- 修复由于异步化导致的并发问题和性能问题

这些修改提高了数据库操作的并发性能,避免了阻塞主线程,提升了系统的整体响应能力。涉及修改的模块包括表情包管理、反提示注入统计、用户封禁管理、记忆系统、消息存储等多个核心组件。

BREAKING CHANGE: 所有涉及数据库操作的方法现在都需要使用异步调用,同步调用将不再工作
This commit is contained in:
Windpicker-owo
2025-09-28 15:42:30 +08:00
parent ff24bd8148
commit 08ef960947
35 changed files with 1180 additions and 1053 deletions

View File

@@ -65,7 +65,7 @@ class AffinityChatter(BaseChatter):
"""
try:
# 触发表达学习
learner = expression_learner_manager.get_expression_learner(self.stream_id)
learner = await expression_learner_manager.get_expression_learner(self.stream_id)
asyncio.create_task(learner.trigger_learning_for_chat())
unread_messages = context.get_unread_messages()

View File

@@ -69,7 +69,7 @@ class ChatterInterestScoringSystem:
keywords = self._extract_keywords_from_database(message)
interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords)
relationship_score = self._calculate_relationship_score(message.user_info.user_id)
relationship_score = await self._calculate_relationship_score(message.user_info.user_id)
mentioned_score = self._calculate_mentioned_score(message, bot_nickname)
total_score = (
@@ -189,7 +189,7 @@ class ChatterInterestScoringSystem:
unique_keywords = list(set(keywords))
return unique_keywords[:10] # 返回前10个唯一关键词
def _calculate_relationship_score(self, user_id: str) -> float:
async def _calculate_relationship_score(self, user_id: str) -> float:
"""计算关系分 - 从数据库获取关系分"""
# 优先使用内存中的关系分
if user_id in self.user_relationships:
@@ -212,7 +212,7 @@ class ChatterInterestScoringSystem:
global_tracker = ChatterRelationshipTracker()
if global_tracker:
relationship_score = global_tracker.get_user_relationship_score(user_id)
relationship_score = await global_tracker.get_user_relationship_score(user_id)
# 同时更新内存缓存
self.user_relationships[user_id] = relationship_score
return relationship_score

View File

@@ -287,7 +287,7 @@ class ChatterRelationshipTracker:
# ===== 数据库支持方法 =====
def get_user_relationship_score(self, user_id: str) -> float:
async def get_user_relationship_score(self, user_id: str) -> float:
"""获取用户关系分"""
# 先检查缓存
if user_id in self.user_relationship_cache:
@@ -298,7 +298,7 @@ class ChatterRelationshipTracker:
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
# 缓存过期或不存在,从数据库获取
relationship_data = self._get_user_relationship_from_db(user_id)
relationship_data = await self._get_user_relationship_from_db(user_id)
if relationship_data:
# 更新缓存
self.user_relationship_cache[user_id] = {
@@ -313,37 +313,38 @@ class ChatterRelationshipTracker:
# 数据库中也没有,返回默认值
return global_config.affinity_flow.base_relationship_score
def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
"""从数据库获取用户关系数据"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查询用户关系表
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = session.execute(stmt).scalar_one_or_none()
result = await session.execute(stmt)
relationship = result.scalar_one_or_none()
if result:
if relationship:
return {
"relationship_text": result.relationship_text or "",
"relationship_score": float(result.relationship_score)
if result.relationship_score is not None
"relationship_text": relationship.relationship_text or "",
"relationship_score": float(relationship.relationship_score)
if relationship.relationship_score is not None
else 0.3,
"last_updated": result.last_updated,
"last_updated": relationship.last_updated,
}
except Exception as e:
logger.error(f"从数据库获取用户关系失败: {e}")
return None
def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
"""更新数据库中的用户关系"""
try:
current_time = time.time()
with get_db_session() as session:
async with get_db_session() as session:
# 检查是否已存在关系记录
existing = session.execute(
select(UserRelationships).where(UserRelationships.user_id == user_id)
).scalar_one_or_none()
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新现有记录
@@ -362,7 +363,7 @@ class ChatterRelationshipTracker:
)
session.add(new_relationship)
session.commit()
await session.commit()
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
except Exception as e:
@@ -399,7 +400,7 @@ class ChatterRelationshipTracker:
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
# 获取当前关系数据
current_relationship = self._get_user_relationship_from_db(user_id)
current_relationship = await self._get_user_relationship_from_db(user_id)
current_score = (
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
if current_relationship
@@ -417,14 +418,14 @@ class ChatterRelationshipTracker:
logger.error(f"回复后关系追踪失败: {e}")
logger.debug("错误详情:", exc_info=True)
def _get_last_tracked_time(self, user_id: str) -> float:
async def _get_last_tracked_time(self, user_id: str) -> float:
"""获取上次追踪时间"""
# 先检查缓存
if user_id in self.user_relationship_cache:
return self.user_relationship_cache[user_id].get("last_tracked", 0)
# 从数据库获取
relationship_data = self._get_user_relationship_from_db(user_id)
relationship_data = await self._get_user_relationship_from_db(user_id)
if relationship_data:
return relationship_data.get("last_updated", 0)
@@ -433,7 +434,7 @@ class ChatterRelationshipTracker:
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
"""获取上次bot回复该用户的消息"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查询bot回复给该用户的最新消息
stmt = (
select(Messages)
@@ -443,10 +444,11 @@ class ChatterRelationshipTracker:
.limit(1)
)
result = session.execute(stmt).scalar_one_or_none()
if result:
result = await session.execute(stmt)
message = result.scalar_one_or_none()
if message:
# 将SQLAlchemy模型转换为DatabaseMessages对象
return self._sqlalchemy_to_database_messages(result)
return self._sqlalchemy_to_database_messages(message)
except Exception as e:
logger.error(f"获取上次回复消息失败: {e}")
@@ -456,7 +458,7 @@ class ChatterRelationshipTracker:
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
"""获取用户在bot回复后的反应消息"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查询用户在回复时间之后的5分钟内的消息
end_time = reply_time + 5 * 60 # 5分钟
@@ -468,9 +470,10 @@ class ChatterRelationshipTracker:
.order_by(Messages.time)
)
results = session.execute(stmt).scalars().all()
if results:
return [self._sqlalchemy_to_database_messages(result) for result in results]
result = await session.execute(stmt)
messages = result.scalars().all()
if messages:
return [self._sqlalchemy_to_database_messages(message) for message in messages]
except Exception as e:
logger.error(f"获取用户反应消息失败: {e}")
@@ -593,7 +596,7 @@ class ChatterRelationshipTracker:
quality = response_data.get("interaction_quality", "medium")
# 更新数据库
self._update_user_relationship_in_db(user_id, new_text, new_score)
await self._update_user_relationship_in_db(user_id, new_text, new_score)
# 更新缓存
self.user_relationship_cache[user_id] = {
@@ -696,7 +699,7 @@ class ChatterRelationshipTracker:
)
# 更新数据库和缓存
self._update_user_relationship_in_db(user_id, new_text, new_score)
await self._update_user_relationship_in_db(user_id, new_text, new_score)
self.user_relationship_cache[user_id] = {
"relationship_text": new_text,
"relationship_score": new_score,

View File

@@ -13,6 +13,7 @@ from typing import Callable
from src.common.logger import get_logger
from src.schedule.schedule_manager import schedule_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
from .qzone_service import QZoneService
@@ -138,15 +139,13 @@ class SchedulerService:
:return: 如果已处理过,返回 True否则返回 False。
"""
try:
with get_db_session() as session:
record = (
session.query(MaiZoneScheduleStatus)
.filter(
MaiZoneScheduleStatus.datetime_hour == hour_str,
MaiZoneScheduleStatus.is_processed == True, # noqa: E712
)
.first()
async with get_db_session() as session:
stmt = select(MaiZoneScheduleStatus).where(
MaiZoneScheduleStatus.datetime_hour == hour_str,
MaiZoneScheduleStatus.is_processed == True, # noqa: E712
)
result = await session.execute(stmt)
record = result.scalar_one_or_none()
return record is not None
except Exception as e:
logger.error(f"检查日程处理状态时发生数据库错误: {e}")
@@ -162,11 +161,11 @@ class SchedulerService:
:param content: 最终发送的说说内容或错误信息。
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 查找是否已存在该记录
record = (
session.query(MaiZoneScheduleStatus).filter(MaiZoneScheduleStatus.datetime_hour == hour_str).first()
)
stmt = select(MaiZoneScheduleStatus).where(MaiZoneScheduleStatus.datetime_hour == hour_str)
result = await session.execute(stmt)
record = result.scalar_one_or_none()
if record:
# 如果存在,则更新状态
@@ -185,7 +184,7 @@ class SchedulerService:
send_success=success,
)
session.add(new_record)
session.commit()
await session.commit()
logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}")
except Exception as e:
logger.error(f"更新日程处理状态时发生数据库错误: {e}")

View File

@@ -64,15 +64,9 @@ async def message_recv(server_connection: Server.ServerConnection):
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
post_type = decoded_raw_message.get("post_type")
# 兼容没有 post_type 的普通消息
if not post_type and "message_type" in decoded_raw_message:
decoded_raw_message["post_type"] = "message"
post_type = "message"
if post_type in ["meta_event", "message", "notice"]:
await message_queue.put(decoded_raw_message)
else:
elif post_type is None:
await put_response(decoded_raw_message)
except json.JSONDecodeError as e:
@@ -428,8 +422,9 @@ class NapcatAdapterPlugin(BasePlugin):
def get_plugin_components(self):
self.register_events()
components = [(LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler),
(StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)]
components = []
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
for handler in get_classes_in_module(event_handlers):
if issubclass(handler, BaseEventHandler):
components.append((handler.get_handler_info(), handler))

View File

@@ -1,156 +1,162 @@
"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API)
本模块替换原先的 sqlmodel + 同步Session 实现:
1. 复用主项目的异步数据库连接与迁移体系
2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record)
3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records
数据语义:
user_id == 0 表示群全体禁言
注意: 所有方法均为异步, 需要在 async 上下文中调用。
"""
from __future__ import annotations
import os
from typing import Optional, List
from dataclasses import dataclass
from typing import Optional, List, Sequence
from sqlmodel import Field, Session, SQLModel, create_engine, select
from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.sqlalchemy_models import Base, get_db_session
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
"""
表记录的方式:
| group_id | user_id | lift_time |
|----------|---------|-----------|
class NapcatBanRecord(Base):
__tablename__ = "napcat_ban_records"
id = Column(Integer, primary_key=True, autoincrement=True)
group_id = Column(BigInteger, nullable=False, index=True)
user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言
lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久
__table_args__ = (
UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"),
Index("idx_napcat_ban_group", "group_id"),
Index("idx_napcat_ban_user", "user_id"),
)
其中使用 user_id == 0 表示群全体禁言
"""
@dataclass
class BanUser:
"""
程序处理使用的实例
"""
user_id: int
group_id: int
lift_time: Optional[int] = -1
def identity(self) -> tuple[int, int]:
return self.group_id, self.user_id
lift_time: Optional[int] = Field(default=-1)
class NapcatDatabase:
async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]:
result = await session.execute(select(NapcatBanRecord))
return result.scalars().all()
class DB_BanUser(SQLModel, table=True):
"""
表示数据库中的用户禁言记录。
使用双重主键
"""
async def get_ban_records(self) -> List[BanUser]:
async with get_db_session() as session:
rows = await self._fetch_all(session)
return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows]
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
lift_time: Optional[int] # 禁言解除的时间(时间戳)
async def update_ban_records(self, ban_list: List[BanUser]) -> None:
target_map = {b.identity(): b for b in ban_list}
async with get_db_session() as session:
rows = await self._fetch_all(session)
existing_map = {(r.group_id, r.user_id): r for r in rows}
changed = 0
for ident, ban in target_map.items():
if ident in existing_map:
row = existing_map[ident]
if row.lift_time != ban.lift_time:
row.lift_time = ban.lift_time
changed += 1
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
"""
检查两个 BanUser 对象是否相同。
"""
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
class DatabaseManager:
"""
数据库管理类,负责与数据库交互。
"""
def __init__(self):
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
self._ensure_database() # 确保数据库和表已创建
def _ensure_database(self) -> None:
"""
确保数据库和表已创建。
"""
logger.info("确保数据库文件和表已创建...")
SQLModel.metadata.create_all(self.engine)
logger.info("数据库和表已创建或已存在")
def update_ban_record(self, ban_list: List[BanUser]) -> None:
# sourcery skip: class-extract-method
"""
更新禁言列表到数据库。
支持在不存在时创建新记录,对于多余的项目自动删除。
"""
with Session(self.engine) as session:
all_records = session.exec(select(DB_BanUser)).all()
for ban_user in ban_list:
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
)
if existing_record := session.exec(statement).first():
if existing_record.lift_time == ban_user.lift_time:
logger.debug(f"禁言记录未变更: {existing_record}")
continue
# 更新现有记录的 lift_time
existing_record.lift_time = ban_user.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {existing_record}")
else:
session.add(
NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time)
# 创建新记录
db_record = DB_BanUser(
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
)
changed += 1
removed = 0
for ident, row in existing_map.items():
if ident not in target_map:
await session.delete(row)
removed += 1
logger.debug(
f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}"
)
async def create_or_update(self, ban_record: BanUser) -> None:
async with get_db_session() as session:
stmt = select(NapcatBanRecord).where(
NapcatBanRecord.group_id == ban_record.group_id,
NapcatBanRecord.user_id == ban_record.user_id,
)
result = await session.execute(stmt)
row = result.scalars().first()
if row:
if row.lift_time != ban_record.lift_time:
row.lift_time = ban_record.lift_time
logger.debug(
f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_user}")
# 删除不在 ban_list 中的记录
for db_record in all_records:
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
if not any(is_identical(record, ban_user) for ban_user in ban_list):
statement = select(DB_BanUser).where(
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: {ban_record}")
logger.info("禁言记录已更新")
def get_ban_records(self) -> List[BanUser]:
"""
读取所有禁言记录。
"""
with Session(self.engine) as session:
statement = select(DB_BanUser)
records = session.exec(statement).all()
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
def create_ban_record(self, ban_record: BanUser) -> None:
"""
为特定群组中的用户创建禁言记录。
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
其同时还是简化版的更新方式。
"""
with Session(self.engine) as session:
# 检查记录是否已存在
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
)
existing_record = session.exec(statement).first()
if existing_record:
# 如果记录已存在,更新 lift_time
existing_record.lift_time = ban_record.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {ban_record}")
else:
session.add(
NapcatBanRecord(
group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time
)
)
logger.debug(
f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}"
# 如果记录不存在,创建新记录
db_record = DB_BanUser(
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_record}")
async def delete_record(self, ban_record: BanUser) -> None:
async with get_db_session() as session:
stmt = select(NapcatBanRecord).where(
NapcatBanRecord.group_id == ban_record.group_id,
NapcatBanRecord.user_id == ban_record.user_id,
)
result = await session.execute(stmt)
row = result.scalars().first()
if row:
await session.delete(row)
logger.debug(
f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}"
)
def delete_ban_record(self, ban_record: BanUser):
"""
删除特定用户在特定群组中的禁言记录。
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
"""
user_id = ban_record.user_id
group_id = ban_record.group_id
with Session(self.engine) as session:
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(
f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}"
)
# 兼容旧命名
async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name
await self.update_ban_records(ban_list)
async def create_ban_record(self, ban_record: BanUser) -> None: # old name
await self.create_or_update(ban_record)
async def delete_ban_record(self, ban_record: BanUser) -> None: # old name
await self.delete_record(ban_record)
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
napcat_db = NapcatDatabase()
def is_identical(a: BanUser, b: BanUser) -> bool:
return a.group_id == b.group_id and a.user_id == b.user_id
__all__ = [
"BanUser",
"NapcatBanRecord",
"napcat_db",
"is_identical",
]
db_manager = DatabaseManager()

View File

@@ -112,8 +112,7 @@ class MessageChunker:
else:
return [{"_original_message": message}]
@staticmethod
def is_chunk_message(message: Union[str, Dict[str, Any]]) -> bool:
def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
"""判断是否是切片消息"""
try:
if isinstance(message, str):

View File

@@ -14,7 +14,6 @@ class MetaEventHandler:
"""
def __init__(self):
self.last_heart_beat = time.time()
self.interval = 5.0 # 默认值稍后通过set_plugin_config设置
self._interval_checking = False
self.plugin_config = None
@@ -40,6 +39,7 @@ class MetaEventHandler:
if message["status"].get("online") and message["status"].get("good"):
if not self._interval_checking:
asyncio.create_task(self.check_heartbeat())
self.last_heart_beat = time.time()
self.interval = message.get("interval") / 1000
else:
self_id = message.get("self_id")

View File

@@ -76,7 +76,7 @@ class SendHandler:
processed_message = await self.handle_seg_recursive(message_segment, user_info)
except Exception as e:
logger.error(f"处理消息时发生错误: {e}")
return None
return
if not processed_message:
logger.critical("现在暂时不支持解析此回复!")
@@ -94,7 +94,7 @@ class SendHandler:
id_name = "user_id"
else:
logger.error("无法识别的消息类型")
return None
return
logger.info("尝试发送到napcat")
logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'")
response = await self.send_message_to_napcat(
@@ -108,10 +108,8 @@ class SendHandler:
logger.info("消息发送成功")
qq_message_id = response.get("data", {}).get("message_id")
await self.message_sent_back(raw_message_base, qq_message_id)
return None
else:
logger.warning(f"消息发送失败napcat返回{str(response)}")
return None
async def send_command(self, raw_message_base: MessageBase) -> None:
"""
@@ -149,7 +147,7 @@ class SendHandler:
command, args_dict = self.handle_send_like_command(args)
case _:
logger.error(f"未知命令: {command_name}")
return None
return
except Exception as e:
logger.error(f"处理命令时发生错误: {e}")
return None
@@ -161,10 +159,8 @@ class SendHandler:
response = await self.send_message_to_napcat(command, args_dict)
if response.get("status") == "ok":
logger.info(f"命令 {command_name} 执行成功")
return None
else:
logger.warning(f"命令 {command_name} 执行失败napcat返回{str(response)}")
return None
async def handle_adapter_command(self, raw_message_base: MessageBase) -> None:
"""
@@ -272,8 +268,7 @@ class SendHandler:
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
return new_payload
@staticmethod
def build_payload(payload: list, addon: dict | list, is_reply: bool = False) -> list:
def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list:
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
"""构建发送的消息体"""
if is_reply:
@@ -339,13 +334,11 @@ class SendHandler:
logger.info(f"最终返回的回复段: {reply_seg}")
return reply_seg
@staticmethod
def handle_text_message(message: str) -> dict:
def handle_text_message(self, message: str) -> dict:
"""处理文本消息"""
return {"type": "text", "data": {"text": message}}
@staticmethod
def handle_image_message(encoded_image: str) -> dict:
def handle_image_message(self, encoded_image: str) -> dict:
"""处理图片消息"""
return {
"type": "image",
@@ -355,8 +348,7 @@ class SendHandler:
},
} # base64 编码的图片
@staticmethod
def handle_emoji_message(encoded_emoji: str) -> dict:
def handle_emoji_message(self, encoded_emoji: str) -> dict:
"""处理表情消息"""
encoded_image = encoded_emoji
image_format = get_image_format(encoded_emoji)
@@ -387,45 +379,39 @@ class SendHandler:
"data": {"file": f"base64://{encoded_voice}"},
}
@staticmethod
def handle_voiceurl_message(voice_url: str) -> dict:
def handle_voiceurl_message(self, voice_url: str) -> dict:
"""处理语音链接消息"""
return {
"type": "record",
"data": {"file": voice_url},
}
@staticmethod
def handle_music_message(song_id: str) -> dict:
def handle_music_message(self, song_id: str) -> dict:
"""处理音乐消息"""
return {
"type": "music",
"data": {"type": "163", "id": song_id},
}
@staticmethod
def handle_videourl_message(video_url: str) -> dict:
def handle_videourl_message(self, video_url: str) -> dict:
"""处理视频链接消息"""
return {
"type": "video",
"data": {"file": video_url},
}
@staticmethod
def handle_file_message(file_path: str) -> dict:
def handle_file_message(self, file_path: str) -> dict:
"""处理文件消息"""
return {
"type": "file",
"data": {"file": f"file://{file_path}"},
}
@staticmethod
def delete_msg_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理删除消息命令"""
return "delete_msg", {"message_id": args["message_id"]}
@staticmethod
def handle_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理封禁命令
Args:
@@ -453,8 +439,7 @@ class SendHandler:
},
)
@staticmethod
def handle_whole_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理全体禁言命令
Args:
@@ -477,8 +462,7 @@ class SendHandler:
},
)
@staticmethod
def handle_kick_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理群成员踢出命令
Args:
@@ -503,8 +487,7 @@ class SendHandler:
},
)
@staticmethod
def handle_poke_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理戳一戳命令
Args:
@@ -531,8 +514,7 @@ class SendHandler:
},
)
@staticmethod
def handle_set_emoji_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理设置表情回应命令
Args:
@@ -554,8 +536,7 @@ class SendHandler:
{"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
)
@staticmethod
def handle_send_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""
处理发送点赞命令的逻辑。
@@ -576,8 +557,7 @@ class SendHandler:
{"user_id": user_id, "times": times},
)
@staticmethod
def handle_ai_voice_send_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""
处理AI语音发送命令的逻辑。
并返回 NapCat 兼容的 (action, params) 元组。
@@ -624,8 +604,7 @@ class SendHandler:
return {"status": "error", "message": str(e)}
return response
@staticmethod
async def message_sent_back(message_base: MessageBase, qq_message_id: str) -> None:
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
# 修改 additional_config添加 echo 字段
if message_base.message_info.additional_config is None:
message_base.message_info.additional_config = {}
@@ -643,9 +622,8 @@ class SendHandler:
logger.debug("已回送消息ID")
return
@staticmethod
async def send_adapter_command_response(
original_message: MessageBase, response_data: dict, request_id: str
self, original_message: MessageBase, response_data: dict, request_id: str
) -> None:
"""
发送适配器命令响应回MaiBot
@@ -674,8 +652,7 @@ class SendHandler:
except Exception as e:
logger.error(f"发送适配器命令响应时出错: {e}")
@staticmethod
def handle_at_message_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理艾特并发送消息命令
Args:

View File

@@ -6,7 +6,7 @@ import urllib3
import ssl
import io
from .database import BanUser, napcat_db
from .database import BanUser, db_manager
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
@@ -270,11 +270,10 @@ async def read_ban_list(
]
"""
try:
ban_list = await napcat_db.get_ban_records()
ban_list = db_manager.get_ban_records()
lifted_list: List[BanUser] = []
logger.info("已经读取禁言列表")
# 复制列表以避免迭代中修改原列表问题
for ban_record in list(ban_list):
for ban_record in ban_list:
if ban_record.user_id == 0:
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
if fetched_group_info is None:
@@ -302,12 +301,12 @@ async def read_ban_list(
ban_list.remove(ban_record)
else:
ban_record.lift_time = lift_ban_time
await napcat_db.update_ban_record(ban_list)
db_manager.update_ban_record(ban_list)
return ban_list, lifted_list
except Exception as e:
logger.error(f"读取禁言列表失败: {e}")
return [], []
async def save_ban_record(list: List[BanUser]):
return await napcat_db.update_ban_record(list)
def save_ban_record(list: List[BanUser]):
return db_manager.update_ban_record(list)