refactor(db): 将数据库操作异步化

- 将所有 session.add() 改为 await session.add()
- 将所有 session.commit() 改为 await session.commit()
- 将 session.refresh() 改为 await session.refresh()
This commit is contained in:
雅诺狐
2025-09-20 14:35:31 +08:00
committed by Windpicker-owo
parent 883bf3a7ea
commit 0cffc0aa95
20 changed files with 44 additions and 47 deletions

View File

@@ -265,7 +265,7 @@ class AntiPromptInjector:
# 删除对应的消息记录
stmt = delete(Messages).where(Messages.message_id == message_id)
result = session.execute(stmt)
session.commit()
await session.commit()
if result.rowcount > 0:
logger.debug(f"成功删除违禁消息记录: {message_id}")
@@ -295,7 +295,7 @@ class AntiPromptInjector:
.values(processed_plain_text=new_content, display_message=new_content)
)
result = session.execute(stmt)
session.commit()
await session.commit()
if result.rowcount > 0:
logger.debug(f"成功更新消息内容为加盾版本: {message_id}")

View File

@@ -32,9 +32,9 @@ class AntiInjectionStatistics:
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
if not stats:
stats = AntiInjectionStats()
session.add(stats)
session.commit()
session.refresh(stats)
await session.add(stats)
await session.commit()
await session.refresh(stats)
return stats
except Exception as e:
logger.error(f"获取统计记录失败: {e}")
@@ -48,7 +48,7 @@ class AntiInjectionStatistics:
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
if not stats:
stats = AntiInjectionStats()
session.add(stats)
await session.add(stats)
# 更新统计字段
for key, value in kwargs.items():
@@ -80,7 +80,7 @@ class AntiInjectionStatistics:
# 直接设置的字段
setattr(stats, key, value)
session.commit()
await session.commit()
except Exception as e:
logger.error(f"更新统计数据失败: {e}")
@@ -141,7 +141,7 @@ class AntiInjectionStatistics:
with get_db_session() as session:
# 删除现有统计记录
session.query(AntiInjectionStats).delete()
session.commit()
await session.commit()
logger.info("统计信息已重置")
except Exception as e:
logger.error(f"重置统计信息失败: {e}")

View File

@@ -52,7 +52,7 @@ class UserBanManager:
# 封禁已过期,重置违规次数
ban_record.violation_num = 0
ban_record.created_at = datetime.datetime.now()
session.commit()
await session.commit()
logger.info(f"用户 {platform}:{user_id} 封禁已过期,违规次数已重置")
return None
@@ -85,9 +85,9 @@ class UserBanManager:
reason=f"提示词注入攻击 (置信度: {detection_result.confidence:.2f})",
created_at=datetime.datetime.now(),
)
session.add(ban_record)
await session.add(ban_record)
session.commit()
await session.commit()
# 检查是否需要自动封禁
if ban_record.violation_num >= self.config.auto_ban_violation_threshold:
@@ -95,7 +95,7 @@ class UserBanManager:
# 只有在首次达到阈值时才更新封禁开始时间
if ban_record.violation_num == self.config.auto_ban_violation_threshold:
ban_record.created_at = datetime.datetime.now()
session.commit()
await session.commit()
else:
logger.info(f"用户 {platform}:{user_id} 违规记录已更新,当前违规次数: {ban_record.violation_num}")

View File

@@ -166,7 +166,7 @@ class MaiEmoji:
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
session.add(emoji)
await session.add(emoji)
await session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")

View File

@@ -117,8 +117,8 @@ class InstantMemory:
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time,
)
session.add(memory)
session.commit()
await session.add(memory)
await session.commit()
async def get_memory(self, target: str):
from json_repair import repair_json

View File

@@ -147,7 +147,7 @@ class ChatManager:
# db.connect(reuse_if_open=True)
# # 确保 ChatStreams 表存在
# session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
# session.commit()
# await session.commit()
# except Exception as e:
# logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")

View File

@@ -123,7 +123,8 @@ class MessageStorage:
is_picid=is_picid,
)
async with get_db_session() as session:
session.add(new_message)
await session.add(new_message)
await session.commit()
except Exception:
logger.exception("存储消息失败")
@@ -162,9 +163,6 @@ class MessageStorage:
logger.debug(f"消息段数据: {message.message_segment.data}")
return
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session:
matched_message = (
await session.execute(

View File

@@ -128,7 +128,7 @@ class ImageManager:
description=description,
timestamp=current_timestamp,
)
session.add(new_desc)
await session.add(new_desc)
await session.commit()
# 会在上下文管理器中自动调用
except Exception as e:
@@ -278,7 +278,7 @@ class ImageManager:
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
)
session.add(new_img)
await session.add(new_img)
await session.commit()
except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}")
@@ -370,7 +370,7 @@ class ImageManager:
vlm_processed=True,
count=1,
)
session.add(new_img)
await session.add(new_img)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
await session.commit()
@@ -590,7 +590,7 @@ class ImageManager:
vlm_processed=True,
count=1,
)
session.add(new_img)
await session.add(new_img)
await session.commit()
return image_id, f"[picid:{image_id}]"

View File

@@ -242,7 +242,7 @@ class VideoAnalyzer:
existing_video.fps = metadata.get("fps")
existing_video.resolution = metadata.get("resolution")
existing_video.file_size = metadata.get("file_size")
session.commit()
await session.commit()
session.refresh(existing_video)
logger.info(f"✅ 更新已存在的视频记录hash: {video_hash[:16]}..., count: {existing_video.count}")
return existing_video
@@ -257,8 +257,8 @@ class VideoAnalyzer:
video_record.resolution = metadata.get("resolution")
video_record.file_size = metadata.get("file_size")
session.add(video_record)
session.commit()
await session.add(video_record)
await session.commit()
session.refresh(video_record)
logger.info(f"✅ 新视频分析结果已保存到数据库hash: {video_hash[:16]}...")
return video_record