refactor(db): 修正SQLAlchemy异步操作调用方式

移除session.add()方法的不必要await调用,修正异步数据库操作模式。主要变更包括:

- 将 `await session.add()` 统一改为 `session.add()`
- 修正部分函数调用为异步版本(如消息查询函数)
- 重构SQLAlchemyTransaction为完全异步实现
- 重写napcat_adapter_plugin数据库层以符合异步规范
- 添加aiomysql和aiosqlite依赖支持
This commit is contained in:
雅诺狐
2025-09-20 17:26:28 +08:00
parent 55717669dd
commit 832743249d
23 changed files with 246 additions and 244 deletions

View File

@@ -32,7 +32,7 @@ class AntiInjectionStatistics:
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
if not stats:
stats = AntiInjectionStats()
await session.add(stats)
session.add(stats)
await session.commit()
await session.refresh(stats)
return stats
@@ -48,7 +48,7 @@ class AntiInjectionStatistics:
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
if not stats:
stats = AntiInjectionStats()
await session.add(stats)
session.add(stats)
# 更新统计字段
for key, value in kwargs.items():

View File

@@ -85,7 +85,7 @@ class UserBanManager:
reason=f"提示词注入攻击 (置信度: {detection_result.confidence:.2f})",
created_at=datetime.datetime.now(),
)
await session.add(ban_record)
session.add(ban_record)
await session.commit()

View File

@@ -166,7 +166,7 @@ class MaiEmoji:
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
await session.add(emoji)
session.add(emoji)
await session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")

View File

@@ -381,7 +381,7 @@ class ExpressionLearner:
type=type,
create_date=current_time, # 手动设置创建日期
)
await session.add(new_expression)
session.add(new_expression)
# 限制最大数量
exprs_result = await session.execute(
@@ -608,7 +608,7 @@ class ExpressionLearnerManager:
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
await session.add(new_expression)
session.add(new_expression)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")

View File

@@ -117,7 +117,7 @@ class InstantMemory:
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time,
)
await session.add(memory)
session.add(memory)
await session.commit()
async def get_memory(self, target: str):

View File

@@ -122,7 +122,7 @@ class MessageStorage:
is_picid=is_picid,
)
async with get_db_session() as session:
await session.add(new_message)
session.add(new_message)
await session.commit()
except Exception:

View File

@@ -128,7 +128,7 @@ class ImageManager:
description=description,
timestamp=current_timestamp,
)
await session.add(new_desc)
session.add(new_desc)
await session.commit()
# 会在上下文管理器中自动调用
except Exception as e:
@@ -278,7 +278,7 @@ class ImageManager:
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
)
await session.add(new_img)
session.add(new_img)
await session.commit()
except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}")
@@ -370,7 +370,7 @@ class ImageManager:
vlm_processed=True,
count=1,
)
await session.add(new_img)
session.add(new_img)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
await session.commit()
@@ -590,7 +590,7 @@ class ImageManager:
vlm_processed=True,
count=1,
)
await session.add(new_img)
session.add(new_img)
await session.commit()
return image_id, f"[picid:{image_id}]"

View File

@@ -22,6 +22,7 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import get_db_session, Videos
from sqlalchemy import select
logger = get_logger("utils_video")
@@ -205,34 +206,29 @@ class VideoAnalyzer:
return hash_obj.hexdigest()
@staticmethod
def _check_video_exists(video_hash: str) -> Optional[Videos]:
"""检查视频是否已经分析过"""
async def _check_video_exists(video_hash: str) -> Optional[Videos]:
"""检查视频是否已经分析过 (异步)"""
try:
with get_db_session() as session:
# 明确刷新会话以确保看到其他事务的最新提交
session.expire_all()
return session.query(Videos).filter(Videos.video_hash == video_hash).first()
async with get_db_session() as session:
result = await session.execute(select(Videos).where(Videos.video_hash == video_hash))
return result.scalar_one_or_none()
except Exception as e:
logger.warning(f"检查视频是否存在时出错: {e}")
return None
@staticmethod
def _store_video_result(
video_hash: str, description: str, metadata: Optional[Dict] = None
async def _store_video_result(
video_hash: str, description: str, metadata: Optional[Dict] = None
) -> Optional[Videos]:
"""存储视频分析结果到数据库"""
# 检查描述是否为错误信息,如果是则不保存
"""存储视频分析结果到数据库 (异步)"""
if description.startswith(""):
logger.warning(f"⚠️ 检测到错误信息,不保存到数据库: {description[:50]}...")
return None
try:
with get_db_session() as session:
# 只根据video_hash查找
existing_video = session.query(Videos).filter(Videos.video_hash == video_hash).first()
async with get_db_session() as session:
result = await session.execute(select(Videos).where(Videos.video_hash == video_hash))
existing_video = result.scalar_one_or_none()
if existing_video:
# 如果已存在,更新描述和计数
existing_video.description = description
existing_video.count += 1
existing_video.timestamp = time.time()
@@ -243,12 +239,17 @@ class VideoAnalyzer:
existing_video.resolution = metadata.get("resolution")
existing_video.file_size = metadata.get("file_size")
await session.commit()
session.refresh(existing_video)
logger.info(f"✅ 更新已存在的视频记录hash: {video_hash[:16]}..., count: {existing_video.count}")
await session.refresh(existing_video)
logger.info(
f"✅ 更新已存在的视频记录hash: {video_hash[:16]}..., count: {existing_video.count}"
)
return existing_video
else:
video_record = Videos(
video_hash=video_hash, description=description, timestamp=time.time(), count=1
video_hash=video_hash,
description=description,
timestamp=time.time(),
count=1,
)
if metadata:
video_record.duration = metadata.get("duration")
@@ -256,11 +257,12 @@ class VideoAnalyzer:
video_record.fps = metadata.get("fps")
video_record.resolution = metadata.get("resolution")
video_record.file_size = metadata.get("file_size")
await session.add(video_record)
session.add(video_record)
await session.commit()
session.refresh(video_record)
logger.info(f"✅ 新视频分析结果已保存到数据库hash: {video_hash[:16]}...")
await session.refresh(video_record)
logger.info(
f"✅ 新视频分析结果已保存到数据库hash: {video_hash[:16]}..."
)
return video_record
except Exception as e:
logger.error(f"❌ 存储视频分析结果时出错: {e}")
@@ -708,7 +710,7 @@ class VideoAnalyzer:
logger.info("✅ 等待结束,检查是否有处理结果")
# 检查是否有结果了
existing_video = self._check_video_exists(video_hash)
existing_video = await self._check_video_exists(video_hash)
if existing_video:
logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})")
return {"summary": existing_video.description}
@@ -722,7 +724,7 @@ class VideoAnalyzer:
logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)")
# 再次检查数据库(可能在等待期间已经有结果了)
existing_video = self._check_video_exists(video_hash)
existing_video = await self._check_video_exists(video_hash)
if existing_video:
logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})")
video_event.set() # 通知其他等待者
@@ -753,7 +755,7 @@ class VideoAnalyzer:
# 保存分析结果到数据库(仅保存成功的结果)
if success:
metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()}
self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
logger.info("✅ 分析结果已保存到数据库")
else:
logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试")