From 0cffc0aa95cd90efeed06e5799214179ff38fef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 14:35:31 +0800 Subject: [PATCH] =?UTF-8?q?refactor(db):=20=E5=B0=86=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E6=93=8D=E4=BD=9C=E5=BC=82=E6=AD=A5=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将所有 session.add() 改为 await session.add() - 将所有 session.commit() 改为 await session.commit() - 将 session.refresh() 改为 await session.refresh() --- src/chat/antipromptinjector/anti_injector.py | 4 ++-- src/chat/antipromptinjector/management/statistics.py | 12 ++++++------ src/chat/antipromptinjector/management/user_ban.py | 8 ++++---- src/chat/emoji_system/emoji_manager.py | 2 +- src/chat/memory_system/instant_memory.py | 4 ++-- src/chat/message_receive/chat_stream.py | 2 +- src/chat/message_receive/storage.py | 6 ++---- src/chat/utils/utils_image.py | 8 ++++---- src/chat/utils/utils_video.py | 6 +++--- src/common/database/database.py | 2 +- src/common/database/sqlalchemy_database_api.py | 4 ++-- src/common/database/sqlalchemy_models.py | 1 - src/common/message_repository.py | 2 +- src/llm_models/utils.py | 2 +- src/person_info/person_info.py | 6 +++--- src/plugin_system/core/permission_manager.py | 6 +++--- .../maizone_refactored/services/scheduler_service.py | 4 ++-- .../built_in/napcat_adapter_plugin/src/database.py | 8 ++++---- src/schedule/database.py | 2 +- src/schedule/schedule_manager.py | 2 +- 20 files changed, 44 insertions(+), 47 deletions(-) diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index 751a7d87e..f35070135 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -265,7 +265,7 @@ class AntiPromptInjector: # 删除对应的消息记录 stmt = delete(Messages).where(Messages.message_id == message_id) result = session.execute(stmt) - session.commit() + await session.commit() if result.rowcount > 0: logger.debug(f"成功删除违禁消息记录: {message_id}") @@ -295,7 +295,7 @@ class AntiPromptInjector: .values(processed_plain_text=new_content, display_message=new_content) ) result = session.execute(stmt) - session.commit() + await session.commit() if result.rowcount > 0: logger.debug(f"成功更新消息内容为加盾版本: {message_id}") diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 12606d4ba..e9b4be66b 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -32,9 +32,9 @@ class AntiInjectionStatistics: stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() if not stats: stats = AntiInjectionStats() - session.add(stats) - session.commit() - session.refresh(stats) + await session.add(stats) + await session.commit() + await session.refresh(stats) return stats except Exception as e: logger.error(f"获取统计记录失败: {e}") @@ -48,7 +48,7 @@ class AntiInjectionStatistics: stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() if not stats: stats = AntiInjectionStats() - session.add(stats) + await session.add(stats) # 更新统计字段 for key, value in kwargs.items(): @@ -80,7 +80,7 @@ class AntiInjectionStatistics: # 直接设置的字段 setattr(stats, key, value) - session.commit() + await session.commit() except Exception as e: logger.error(f"更新统计数据失败: {e}") @@ -141,7 +141,7 @@ class AntiInjectionStatistics: with get_db_session() as session: # 删除现有统计记录 session.query(AntiInjectionStats).delete() - session.commit() + await session.commit() logger.info("统计信息已重置") except Exception as e: logger.error(f"重置统计信息失败: {e}") diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index 5a2239162..865ddddb9 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -52,7 +52,7 @@ class UserBanManager: # 封禁已过期,重置违规次数 ban_record.violation_num = 0 ban_record.created_at = datetime.datetime.now() - session.commit() + await session.commit() logger.info(f"用户 {platform}:{user_id} 封禁已过期,违规次数已重置") return None @@ -85,9 +85,9 @@ class UserBanManager: reason=f"提示词注入攻击 (置信度: {detection_result.confidence:.2f})", created_at=datetime.datetime.now(), ) - session.add(ban_record) + await session.add(ban_record) - session.commit() + await session.commit() # 检查是否需要自动封禁 if ban_record.violation_num >= self.config.auto_ban_violation_threshold: @@ -95,7 +95,7 @@ class UserBanManager: # 只有在首次达到阈值时才更新封禁开始时间 if ban_record.violation_num == self.config.auto_ban_violation_threshold: ban_record.created_at = datetime.datetime.now() - session.commit() + await session.commit() else: logger.info(f"用户 {platform}:{user_id} 违规记录已更新,当前违规次数: {ban_record.violation_num}") diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index e2a6eb7f1..6b2c8df5a 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -166,7 +166,7 @@ class MaiEmoji: usage_count=self.usage_count, last_used_time=self.last_used_time, ) - session.add(emoji) + await session.add(emoji) await session.commit() logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index 0b4b0b2e3..5b78f4d3d 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -117,8 +117,8 @@ class InstantMemory: create_time=memory_item.create_time, last_view_time=memory_item.last_view_time, ) - session.add(memory) - session.commit() + await session.add(memory) + await session.commit() async def get_memory(self, target: str): from json_repair import repair_json diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 9e90eed25..caf87383a 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -147,7 +147,7 @@ class ChatManager: # db.connect(reuse_if_open=True) # # 确保 ChatStreams 表存在 # session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)")) - # session.commit() + # await session.commit() # except Exception as e: # logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 4666fae67..bcb7f1034 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -123,7 +123,8 @@ class MessageStorage: is_picid=is_picid, ) async with get_db_session() as session: - session.add(new_message) + await session.add(new_message) + await session.commit() except Exception: logger.exception("存储消息失败") @@ -162,9 +163,6 @@ class MessageStorage: logger.debug(f"消息段数据: {message.message_segment.data}") return - # 使用上下文管理器确保session正确管理 - from src.common.database.sqlalchemy_models import get_db_session - async with get_db_session() as session: matched_message = ( await session.execute( diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 93ec14957..bcfc6e7fd 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -128,7 +128,7 @@ class ImageManager: description=description, timestamp=current_timestamp, ) - session.add(new_desc) + await session.add(new_desc) await session.commit() # 会在上下文管理器中自动调用 except Exception as e: @@ -278,7 +278,7 @@ class ImageManager: description=detailed_description, # 保存详细描述 timestamp=current_timestamp, ) - session.add(new_img) + await session.add(new_img) await session.commit() except Exception as e: logger.error(f"保存到Images表失败: {str(e)}") @@ -370,7 +370,7 @@ class ImageManager: vlm_processed=True, count=1, ) - session.add(new_img) + await session.add(new_img) logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") await session.commit() @@ -590,7 +590,7 @@ class ImageManager: vlm_processed=True, count=1, ) - session.add(new_img) + await session.add(new_img) await session.commit() return image_id, f"[picid:{image_id}]" diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 6ea5a111f..e249bc133 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -242,7 +242,7 @@ class VideoAnalyzer: existing_video.fps = metadata.get("fps") existing_video.resolution = metadata.get("resolution") existing_video.file_size = metadata.get("file_size") - session.commit() + await session.commit() session.refresh(existing_video) logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}") return existing_video @@ -257,8 +257,8 @@ class VideoAnalyzer: video_record.resolution = metadata.get("resolution") video_record.file_size = metadata.get("file_size") - session.add(video_record) - session.commit() + await session.add(video_record) + await session.commit() session.refresh(video_record) logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...") return video_record diff --git a/src/common/database/database.py b/src/common/database/database.py index 3279a67ed..293f0cd1f 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -43,7 +43,7 @@ class SQLAlchemyTransaction: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: - self.session.commit() + self.await session.commit() else: self.session.rollback() self.session.close() diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 13ef39c1a..63de1e43b 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -168,7 +168,7 @@ async def db_query( # 创建新记录 new_record = model_class(**data) - session.add(new_record) + await session.add(new_record) await session.flush() # 获取自动生成的ID # 转换为字典格式返回 @@ -295,7 +295,7 @@ async def db_save( # 创建新记录 new_record = model_class(**data) - session.add(new_record) + await session.add(new_record) await session.flush() # 转换为字典格式返回 diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 0c193e358..2b276213d 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -676,7 +676,6 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: raise RuntimeError("Database session not initialized") session = SessionLocal() yield session - # await session.commit() except Exception: if session: await session.rollback() diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 63d4c000d..7c620d2c7 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -201,5 +201,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int: # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 -# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。 +# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。 diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index bf23f144a..659fc5399 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -178,7 +178,7 @@ class LLMUsageRecorder: timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段 ) - session.add(usage_record) + await session.add(usage_record) await session.commit() logger.debug( diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index d311ae491..d6dd741f2 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -510,7 +510,7 @@ class PersonInfoManager: async with get_db_session() as session: try: new_person = PersonInfo(**p_data) - session.add(new_person) + await session.add(new_person) await session.commit() return True except Exception as e: @@ -575,7 +575,7 @@ class PersonInfoManager: # 尝试创建 new_person = PersonInfo(**p_data) - session.add(new_person) + await session.add(new_person) await session.commit() return True except Exception as e: @@ -941,7 +941,7 @@ class PersonInfoManager: # 记录不存在,尝试创建 try: new_person = PersonInfo(**init_data) - session.add(new_person) + await session.add(new_person) await session.commit() await session.refresh(new_person) return new_person, True # 创建成功 diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index eb6083fc9..db7ef9b1a 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -149,7 +149,7 @@ class PermissionManager(IPermissionManager): default_granted=node.default_granted, created_at=datetime.utcnow(), ) - session.add(new_node) + await session.add(new_node) await session.commit() logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})") return True @@ -204,7 +204,7 @@ class PermissionManager(IPermissionManager): granted=True, granted_at=datetime.utcnow(), ) - session.add(new_perm) + await session.add(new_perm) await session.commit() logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") @@ -257,7 +257,7 @@ class PermissionManager(IPermissionManager): granted=False, granted_at=datetime.utcnow(), ) - session.add(new_perm) + await session.add(new_perm) await session.commit() logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 69ec0956e..ca6dc52c3 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -186,8 +186,8 @@ class SchedulerService: story_content=content, send_success=success, ) - session.add(new_record) - session.commit() + await session.add(new_record) + await session.commit() logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}") except Exception as e: logger.error(f"更新日程处理状态时发生数据库错误: {e}") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/database.py b/src/plugins/built_in/napcat_adapter_plugin/src/database.py index 74842eed5..23b5d1f5d 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/database.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/database.py @@ -83,14 +83,14 @@ class DatabaseManager: continue # 更新现有记录的 lift_time existing_record.lift_time = ban_user.lift_time - session.add(existing_record) + await session.add(existing_record) logger.debug(f"更新禁言记录: {existing_record}") else: # 创建新记录 db_record = DB_BanUser( user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time ) - session.add(db_record) + await session.add(db_record) logger.debug(f"创建新禁言记录: {ban_user}") # 删除不在 ban_list 中的记录 for db_record in all_records: @@ -132,14 +132,14 @@ class DatabaseManager: if existing_record: # 如果记录已存在,更新 lift_time existing_record.lift_time = ban_record.lift_time - session.add(existing_record) + await session.add(existing_record) logger.debug(f"更新禁言记录: {ban_record}") else: # 如果记录不存在,创建新记录 db_record = DB_BanUser( user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time ) - session.add(db_record) + await session.add(db_record) logger.debug(f"创建新禁言记录: {ban_record}") def delete_ban_record(self, ban_record: BanUser): diff --git a/src/schedule/database.py b/src/schedule/database.py index 5025c1fa3..b420f0686 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -42,7 +42,7 @@ async def add_new_plans(plans: List[str], month: str): new_plan_objects = [ MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add ] - session.add_all(new_plan_objects) + await session.add_all(new_plan_objects) await session.commit() logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。") diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 115480381..4e66bf0c8 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -128,7 +128,7 @@ class ScheduleManager: existing_schedule.updated_at = datetime.now() else: new_schedule = Schedule(date=date_str, schedule_data=schedule_json) - session.add(new_schedule) + await session.add(new_schedule) await session.commit() @staticmethod