refactor(db): 将数据库操作异步化
- 将所有 session.add() 改为 await session.add() - 将所有 session.commit() 改为 await session.commit() - 将 session.refresh() 改为 await session.refresh()
This commit is contained in:
@@ -265,7 +265,7 @@ class AntiPromptInjector:
|
||||
# 删除对应的消息记录
|
||||
stmt = delete(Messages).where(Messages.message_id == message_id)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.debug(f"成功删除违禁消息记录: {message_id}")
|
||||
@@ -295,7 +295,7 @@ class AntiPromptInjector:
|
||||
.values(processed_plain_text=new_content, display_message=new_content)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.debug(f"成功更新消息内容为加盾版本: {message_id}")
|
||||
|
||||
@@ -32,9 +32,9 @@ class AntiInjectionStatistics:
|
||||
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
|
||||
if not stats:
|
||||
stats = AntiInjectionStats()
|
||||
session.add(stats)
|
||||
session.commit()
|
||||
session.refresh(stats)
|
||||
await session.add(stats)
|
||||
await session.commit()
|
||||
await session.refresh(stats)
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计记录失败: {e}")
|
||||
@@ -48,7 +48,7 @@ class AntiInjectionStatistics:
|
||||
stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first()
|
||||
if not stats:
|
||||
stats = AntiInjectionStats()
|
||||
session.add(stats)
|
||||
await session.add(stats)
|
||||
|
||||
# 更新统计字段
|
||||
for key, value in kwargs.items():
|
||||
@@ -80,7 +80,7 @@ class AntiInjectionStatistics:
|
||||
# 直接设置的字段
|
||||
setattr(stats, key, value)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"更新统计数据失败: {e}")
|
||||
|
||||
@@ -141,7 +141,7 @@ class AntiInjectionStatistics:
|
||||
with get_db_session() as session:
|
||||
# 删除现有统计记录
|
||||
session.query(AntiInjectionStats).delete()
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info("统计信息已重置")
|
||||
except Exception as e:
|
||||
logger.error(f"重置统计信息失败: {e}")
|
||||
|
||||
@@ -52,7 +52,7 @@ class UserBanManager:
|
||||
# 封禁已过期,重置违规次数
|
||||
ban_record.violation_num = 0
|
||||
ban_record.created_at = datetime.datetime.now()
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"用户 {platform}:{user_id} 封禁已过期,违规次数已重置")
|
||||
|
||||
return None
|
||||
@@ -85,9 +85,9 @@ class UserBanManager:
|
||||
reason=f"提示词注入攻击 (置信度: {detection_result.confidence:.2f})",
|
||||
created_at=datetime.datetime.now(),
|
||||
)
|
||||
session.add(ban_record)
|
||||
await session.add(ban_record)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# 检查是否需要自动封禁
|
||||
if ban_record.violation_num >= self.config.auto_ban_violation_threshold:
|
||||
@@ -95,7 +95,7 @@ class UserBanManager:
|
||||
# 只有在首次达到阈值时才更新封禁开始时间
|
||||
if ban_record.violation_num == self.config.auto_ban_violation_threshold:
|
||||
ban_record.created_at = datetime.datetime.now()
|
||||
session.commit()
|
||||
await session.commit()
|
||||
else:
|
||||
logger.info(f"用户 {platform}:{user_id} 违规记录已更新,当前违规次数: {ban_record.violation_num}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user