refactor(db): 重构数据库交互为异步模式

为了提升性能并与项目整体的异步架构保持一致,对核心数据库交互模块进行了异步化重构。

主要修改内容包括:
- 将 `PermissionManager` 中的所有数据库操作从同步改为异步,以避免阻塞事件循环。
- 使用 `async_sessionmaker` 和 `async with session` 替代原有的同步会话管理。
- 将 SQLAlchemy 查询语法更新为异步兼容的 `await session.execute(select(...))` 模式。
- 相应地,调用链中依赖数据库操作的多个方法也已更新为 `async` 函数。
This commit is contained in:
tt-P607
2025-09-20 13:07:06 +08:00
parent 57b2e32ba0
commit 8d9aa4fb9e
5 changed files with 81 additions and 68 deletions

View File

@@ -86,7 +86,8 @@ class CycleProcessor:
platform, platform,
action_message.get("chat_info_user_id", ""), action_message.get("chat_info_user_id", ""),
) )
person_name = await person_info_manager.get_value(person_id, "person_name") person_info = await person_info_manager.get_values(person_id, ["person_name"])
person_name = person_info.get("person_name")
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
# 存储动作信息到数据库 # 存储动作信息到数据库

View File

@@ -15,7 +15,6 @@ from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import ChatStream
install(extra_lines=3) install(extra_lines=3)

View File

@@ -660,7 +660,7 @@ class DefaultReplyer:
duration = end_time - start_time duration = end_time - start_time
return name, result, duration return name, result, duration
def build_s4u_chat_history_prompts( async def build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
) -> Tuple[str, str]: ) -> Tuple[str, str]:
""" """
@@ -692,7 +692,7 @@ class DefaultReplyer:
all_dialogue_prompt = "" all_dialogue_prompt = ""
if message_list_before_now: if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :] latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages( all_dialogue_prompt_str = await build_readable_messages(
latest_25_msgs, latest_25_msgs,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal", timestamp_mode="normal",
@@ -716,7 +716,7 @@ class DefaultReplyer:
else: else:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量 core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
core_dialogue_prompt_str = build_readable_messages( core_dialogue_prompt_str = await build_readable_messages(
core_dialogue_list, core_dialogue_list,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,

View File

@@ -213,7 +213,7 @@ class BaseAction(ABC):
# 检查新消息 # 检查新消息
current_time = time.time() current_time = time.time()
new_message_count = message_api.count_new_messages( new_message_count = await message_api.count_new_messages(
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
) )

View File

@@ -5,9 +5,10 @@
""" """
from typing import List, Set, Tuple from typing import List, Set, Tuple
from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from datetime import datetime from datetime import datetime
from sqlalchemy import select, delete
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions
@@ -22,7 +23,7 @@ class PermissionManager(IPermissionManager):
def __init__(self): def __init__(self):
self.engine = get_engine() self.engine = get_engine()
self.SessionLocal = sessionmaker(bind=self.engine) self.SessionLocal = async_sessionmaker(bind=self.engine)
self._master_users: Set[Tuple[str, str]] = set() self._master_users: Set[Tuple[str, str]] = set()
self._load_master_users() self._load_master_users()
logger.info("权限管理器初始化完成") logger.info("权限管理器初始化完成")
@@ -62,7 +63,7 @@ class PermissionManager(IPermissionManager):
logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户") logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户")
return is_master return is_master
def check_permission(self, user: UserInfo, permission_node: str) -> bool: async def check_permission(self, user: UserInfo, permission_node: str) -> bool:
""" """
检查用户是否拥有指定权限节点 检查用户是否拥有指定权限节点
@@ -79,34 +80,35 @@ class PermissionManager(IPermissionManager):
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}") logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
return True return True
with self.SessionLocal() as session: async with self.SessionLocal() as session:
# 检查权限节点是否存在 # 检查权限节点是否存在
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node))
node = result.scalar_one_or_none()
if not node: if not node:
logger.warning(f"权限节点 {permission_node} 不存在") logger.warning(f"权限节点 {permission_node} 不存在")
return False return False
# 检查用户是否有明确的权限设置 # 检查用户是否有明确的权限设置
user_perm = ( result = await session.execute(
session.query(UserPermissions) select(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
) )
user_perm = result.scalar_one_or_none()
if user_perm: if user_perm:
# 有明确设置,返回设置的值 # 有明确设置,返回设置的值
result = user_perm.granted res = user_perm.granted
logger.debug( logger.debug(
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}" f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}"
) )
return result return res
else: else:
# 没有明确设置,使用默认值 # 没有明确设置,使用默认值
result = node.default_granted res = node.default_granted
logger.debug( logger.debug(
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}" f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {res}"
) )
return result return res
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"检查权限时数据库错误: {e}") logger.error(f"检查权限时数据库错误: {e}")
@@ -115,7 +117,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"检查权限时发生未知错误: {e}") logger.error(f"检查权限时发生未知错误: {e}")
return False return False
def register_permission_node(self, node: PermissionNode) -> bool: async def register_permission_node(self, node: PermissionNode) -> bool:
""" """
注册权限节点 注册权限节点
@@ -126,15 +128,16 @@ class PermissionManager(IPermissionManager):
bool: 注册是否成功 bool: 注册是否成功
""" """
try: try:
with self.SessionLocal() as session: async with self.SessionLocal() as session:
# 检查节点是否已存在 # 检查节点是否已存在
existing_node = session.query(PermissionNodes).filter_by(node_name=node.node_name).first() result = await session.execute(select(PermissionNodes).filter_by(node_name=node.node_name))
existing_node = result.scalar_one_or_none()
if existing_node: if existing_node:
# 更新现有节点的信息 # 更新现有节点的信息
existing_node.description = node.description existing_node.description = node.description
existing_node.plugin_name = node.plugin_name existing_node.plugin_name = node.plugin_name
existing_node.default_granted = node.default_granted existing_node.default_granted = node.default_granted
session.commit() await session.commit()
logger.debug(f"更新权限节点: {node.node_name}") logger.debug(f"更新权限节点: {node.node_name}")
return True return True
@@ -147,7 +150,7 @@ class PermissionManager(IPermissionManager):
created_at=datetime.utcnow(), created_at=datetime.utcnow(),
) )
session.add(new_node) session.add(new_node)
session.commit() await session.commit()
logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})") logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})")
return True return True
@@ -161,7 +164,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"注册权限节点时发生未知错误: {e}") logger.error(f"注册权限节点时发生未知错误: {e}")
return False return False
def grant_permission(self, user: UserInfo, permission_node: str) -> bool: async def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
""" """
授权用户权限节点 授权用户权限节点
@@ -173,19 +176,20 @@ class PermissionManager(IPermissionManager):
bool: 授权是否成功 bool: 授权是否成功
""" """
try: try:
with self.SessionLocal() as session: async with self.SessionLocal() as session:
# 检查权限节点是否存在 # 检查权限节点是否存在
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node))
node = result.scalar_one_or_none()
if not node: if not node:
logger.error(f"尝试授权不存在的权限节点: {permission_node}") logger.error(f"尝试授权不存在的权限节点: {permission_node}")
return False return False
# 检查是否已有权限记录 # 检查是否已有权限记录
existing_perm = ( result = await session.execute(
session.query(UserPermissions) select(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
) )
existing_perm = result.scalar_one_or_none()
if existing_perm: if existing_perm:
# 更新现有记录 # 更新现有记录
@@ -202,7 +206,7 @@ class PermissionManager(IPermissionManager):
) )
session.add(new_perm) session.add(new_perm)
session.commit() await session.commit()
logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
return True return True
@@ -213,7 +217,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"授权权限时发生未知错误: {e}") logger.error(f"授权权限时发生未知错误: {e}")
return False return False
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
""" """
撤销用户权限节点 撤销用户权限节点
@@ -225,19 +229,20 @@ class PermissionManager(IPermissionManager):
bool: 撤销是否成功 bool: 撤销是否成功
""" """
try: try:
with self.SessionLocal() as session: async with self.SessionLocal() as session:
# 检查权限节点是否存在 # 检查权限节点是否存在
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node))
node = result.scalar_one_or_none()
if not node: if not node:
logger.error(f"尝试撤销不存在的权限节点: {permission_node}") logger.error(f"尝试撤销不存在的权限节点: {permission_node}")
return False return False
# 检查是否已有权限记录 # 检查是否已有权限记录
existing_perm = ( result = await session.execute(
session.query(UserPermissions) select(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
) )
existing_perm = result.scalar_one_or_none()
if existing_perm: if existing_perm:
# 更新现有记录 # 更新现有记录
@@ -254,7 +259,7 @@ class PermissionManager(IPermissionManager):
) )
session.add(new_perm) session.add(new_perm)
session.commit() await session.commit()
logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
return True return True
@@ -265,7 +270,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"撤销权限时发生未知错误: {e}") logger.error(f"撤销权限时发生未知错误: {e}")
return False return False
def get_user_permissions(self, user: UserInfo) -> List[str]: async def get_user_permissions(self, user: UserInfo) -> List[str]:
""" """
获取用户拥有的所有权限节点 获取用户拥有的所有权限节点
@@ -278,23 +283,25 @@ class PermissionManager(IPermissionManager):
try: try:
# Master用户拥有所有权限 # Master用户拥有所有权限
if self.is_master(user): if self.is_master(user):
with self.SessionLocal() as session: async with self.SessionLocal() as session:
all_nodes = session.query(PermissionNodes.node_name).all() result = await session.execute(select(PermissionNodes.node_name))
return [node.node_name for node in all_nodes] all_nodes = result.scalars().all()
return all_nodes
permissions = [] permissions = []
with self.SessionLocal() as session: async with self.SessionLocal() as session:
# 获取所有权限节点 # 获取所有权限节点
all_nodes = session.query(PermissionNodes).all() result = await session.execute(select(PermissionNodes))
all_nodes = result.scalars().all()
for node in all_nodes: for node in all_nodes:
# 检查用户是否有明确的权限设置 # 检查用户是否有明确的权限设置
user_perm = ( result = await session.execute(
session.query(UserPermissions) select(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name)
.first()
) )
user_perm = result.scalar_one_or_none()
if user_perm: if user_perm:
# 有明确设置,使用设置的值 # 有明确设置,使用设置的值
@@ -314,7 +321,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"获取用户权限时发生未知错误: {e}") logger.error(f"获取用户权限时发生未知错误: {e}")
return [] return []
def get_all_permission_nodes(self) -> List[PermissionNode]: async def get_all_permission_nodes(self) -> List[PermissionNode]:
""" """
获取所有已注册的权限节点 获取所有已注册的权限节点
@@ -322,8 +329,9 @@ class PermissionManager(IPermissionManager):
List[PermissionNode]: 权限节点列表 List[PermissionNode]: 权限节点列表
""" """
try: try:
with self.SessionLocal() as session: async with self.SessionLocal() as session:
nodes = session.query(PermissionNodes).all() result = await session.execute(select(PermissionNodes))
nodes = result.scalars().all()
return [ return [
PermissionNode( PermissionNode(
node_name=node.node_name, node_name=node.node_name,
@@ -341,7 +349,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"获取所有权限节点时发生未知错误: {e}") logger.error(f"获取所有权限节点时发生未知错误: {e}")
return [] return []
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
""" """
获取指定插件的所有权限节点 获取指定插件的所有权限节点
@@ -352,8 +360,9 @@ class PermissionManager(IPermissionManager):
List[PermissionNode]: 权限节点列表 List[PermissionNode]: 权限节点列表
""" """
try: try:
with self.SessionLocal() as session: async with self.SessionLocal() as session:
nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all() result = await session.execute(select(PermissionNodes).filter_by(plugin_name=plugin_name))
nodes = result.scalars().all()
return [ return [
PermissionNode( PermissionNode(
node_name=node.node_name, node_name=node.node_name,
@@ -371,7 +380,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"获取插件权限节点时发生未知错误: {e}") logger.error(f"获取插件权限节点时发生未知错误: {e}")
return [] return []
def delete_plugin_permissions(self, plugin_name: str) -> bool: async def delete_plugin_permissions(self, plugin_name: str) -> bool:
""" """
删除指定插件的所有权限节点(用于插件卸载时清理) 删除指定插件的所有权限节点(用于插件卸载时清理)
@@ -382,9 +391,10 @@ class PermissionManager(IPermissionManager):
bool: 删除是否成功 bool: 删除是否成功
""" """
try: try:
with self.SessionLocal() as session: async with self.SessionLocal() as session:
# 获取插件的所有权限节点 # 获取插件的所有权限节点
plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all() result = await session.execute(select(PermissionNodes).filter_by(plugin_name=plugin_name))
plugin_nodes = result.scalars().all()
node_names = [node.node_name for node in plugin_nodes] node_names = [node.node_name for node in plugin_nodes]
if not node_names: if not node_names:
@@ -392,16 +402,17 @@ class PermissionManager(IPermissionManager):
return True return True
# 删除用户权限记录 # 删除用户权限记录
deleted_user_perms = ( result = await session.execute(
session.query(UserPermissions) delete(UserPermissions)
.filter(UserPermissions.permission_node.in_(node_names)) .where(UserPermissions.permission_node.in_(node_names))
.delete(synchronize_session=False)
) )
deleted_user_perms = result.rowcount
# 删除权限节点 # 删除权限节点
deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete() result = await session.execute(delete(PermissionNodes).filter_by(plugin_name=plugin_name))
deleted_nodes = result.rowcount
session.commit() await session.commit()
logger.info( logger.info(
f"已删除插件 {plugin_name}{deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录" f"已删除插件 {plugin_name}{deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录"
) )
@@ -414,7 +425,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"删除插件权限时发生未知错误: {e}") logger.error(f"删除插件权限时发生未知错误: {e}")
return False return False
def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]: async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
""" """
获取拥有指定权限的所有用户 获取拥有指定权限的所有用户
@@ -427,17 +438,19 @@ class PermissionManager(IPermissionManager):
try: try:
users = [] users = []
with self.SessionLocal() as session: async with self.SessionLocal() as session:
# 检查权限节点是否存在 # 检查权限节点是否存在
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node))
node = result.scalar_one_or_none()
if not node: if not node:
logger.warning(f"权限节点 {permission_node} 不存在") logger.warning(f"权限节点 {permission_node} 不存在")
return users return users
# 获取明确授权的用户 # 获取明确授权的用户
granted_users = ( result = await session.execute(
session.query(UserPermissions).filter_by(permission_node=permission_node, granted=True).all() select(UserPermissions).filter_by(permission_node=permission_node, granted=True)
) )
granted_users = result.scalars().all()
for user_perm in granted_users: for user_perm in granted_users:
users.append((user_perm.platform, user_perm.user_id)) users.append((user_perm.platform, user_perm.user_id))