diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py index fe993f484..b2a092958 100644 --- a/src/chat/chat_loop/cycle_processor.py +++ b/src/chat/chat_loop/cycle_processor.py @@ -86,7 +86,8 @@ class CycleProcessor: platform, action_message.get("chat_info_user_id", ""), ) - person_name = await person_info_manager.get_value(person_id, "person_name") + person_info = await person_info_manager.get_values(person_id, ["person_name"]) + person_name = person_info.get("person_name") action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" # 存储动作信息到数据库 diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 631aa7c09..22e57edf0 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -15,7 +15,6 @@ from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_stream import ChatStream -from src.chat.message_receive.chat_stream import ChatStream install(extra_lines=3) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index aa9c5eba0..bf3d4fe26 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -660,7 +660,7 @@ class DefaultReplyer: duration = end_time - start_time return name, result, duration - def build_s4u_chat_history_prompts( + async def build_s4u_chat_history_prompts( self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str ) -> Tuple[str, str]: """ @@ -692,7 +692,7 @@ class DefaultReplyer: all_dialogue_prompt = "" if message_list_before_now: latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :] - all_dialogue_prompt_str = build_readable_messages( + all_dialogue_prompt_str = await build_readable_messages( latest_25_msgs, replace_bot_name=True, timestamp_mode="normal", @@ -716,7 +716,7 @@ class DefaultReplyer: else: core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量 - core_dialogue_prompt_str = build_readable_messages( + core_dialogue_prompt_str = await build_readable_messages( core_dialogue_list, replace_bot_name=True, merge_messages=False, diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 9400032f8..51a0f4257 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -213,7 +213,7 @@ class BaseAction(ABC): # 检查新消息 current_time = time.time() - new_message_count = message_api.count_new_messages( + new_message_count = await message_api.count_new_messages( chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time ) diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index 9d996fd46..eb6083fc9 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -5,9 +5,10 @@ """ from typing import List, Set, Tuple -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.exc import IntegrityError, SQLAlchemyError from datetime import datetime +from sqlalchemy import select, delete from src.common.logger import get_logger from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions @@ -22,7 +23,7 @@ class PermissionManager(IPermissionManager): def __init__(self): self.engine = get_engine() - self.SessionLocal = sessionmaker(bind=self.engine) + self.SessionLocal = async_sessionmaker(bind=self.engine) self._master_users: Set[Tuple[str, str]] = set() self._load_master_users() logger.info("权限管理器初始化完成") @@ -62,7 +63,7 @@ class PermissionManager(IPermissionManager): logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户") return is_master - def check_permission(self, user: UserInfo, permission_node: str) -> bool: + async def check_permission(self, user: UserInfo, permission_node: str) -> bool: """ 检查用户是否拥有指定权限节点 @@ -79,34 +80,35 @@ class PermissionManager(IPermissionManager): logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}") return True - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查权限节点是否存在 - node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node)) + node = result.scalar_one_or_none() if not node: logger.warning(f"权限节点 {permission_node} 不存在") return False # 检查用户是否有明确的权限设置 - user_perm = ( - session.query(UserPermissions) + result = await session.execute( + select(UserPermissions) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) - .first() ) + user_perm = result.scalar_one_or_none() if user_perm: # 有明确设置,返回设置的值 - result = user_perm.granted + res = user_perm.granted logger.debug( - f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}" + f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}" ) - return result + return res else: # 没有明确设置,使用默认值 - result = node.default_granted + res = node.default_granted logger.debug( - f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}" + f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {res}" ) - return result + return res except SQLAlchemyError as e: logger.error(f"检查权限时数据库错误: {e}") @@ -115,7 +117,7 @@ class PermissionManager(IPermissionManager): logger.error(f"检查权限时发生未知错误: {e}") return False - def register_permission_node(self, node: PermissionNode) -> bool: + async def register_permission_node(self, node: PermissionNode) -> bool: """ 注册权限节点 @@ -126,15 +128,16 @@ class PermissionManager(IPermissionManager): bool: 注册是否成功 """ try: - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查节点是否已存在 - existing_node = session.query(PermissionNodes).filter_by(node_name=node.node_name).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=node.node_name)) + existing_node = result.scalar_one_or_none() if existing_node: # 更新现有节点的信息 existing_node.description = node.description existing_node.plugin_name = node.plugin_name existing_node.default_granted = node.default_granted - session.commit() + await session.commit() logger.debug(f"更新权限节点: {node.node_name}") return True @@ -147,7 +150,7 @@ class PermissionManager(IPermissionManager): created_at=datetime.utcnow(), ) session.add(new_node) - session.commit() + await session.commit() logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})") return True @@ -161,7 +164,7 @@ class PermissionManager(IPermissionManager): logger.error(f"注册权限节点时发生未知错误: {e}") return False - def grant_permission(self, user: UserInfo, permission_node: str) -> bool: + async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: """ 授权用户权限节点 @@ -173,19 +176,20 @@ class PermissionManager(IPermissionManager): bool: 授权是否成功 """ try: - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查权限节点是否存在 - node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node)) + node = result.scalar_one_or_none() if not node: logger.error(f"尝试授权不存在的权限节点: {permission_node}") return False # 检查是否已有权限记录 - existing_perm = ( - session.query(UserPermissions) + result = await session.execute( + select(UserPermissions) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) - .first() ) + existing_perm = result.scalar_one_or_none() if existing_perm: # 更新现有记录 @@ -202,7 +206,7 @@ class PermissionManager(IPermissionManager): ) session.add(new_perm) - session.commit() + await session.commit() logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") return True @@ -213,7 +217,7 @@ class PermissionManager(IPermissionManager): logger.error(f"授权权限时发生未知错误: {e}") return False - def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: + async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: """ 撤销用户权限节点 @@ -225,19 +229,20 @@ class PermissionManager(IPermissionManager): bool: 撤销是否成功 """ try: - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查权限节点是否存在 - node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node)) + node = result.scalar_one_or_none() if not node: logger.error(f"尝试撤销不存在的权限节点: {permission_node}") return False # 检查是否已有权限记录 - existing_perm = ( - session.query(UserPermissions) + result = await session.execute( + select(UserPermissions) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) - .first() ) + existing_perm = result.scalar_one_or_none() if existing_perm: # 更新现有记录 @@ -254,7 +259,7 @@ class PermissionManager(IPermissionManager): ) session.add(new_perm) - session.commit() + await session.commit() logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") return True @@ -265,7 +270,7 @@ class PermissionManager(IPermissionManager): logger.error(f"撤销权限时发生未知错误: {e}") return False - def get_user_permissions(self, user: UserInfo) -> List[str]: + async def get_user_permissions(self, user: UserInfo) -> List[str]: """ 获取用户拥有的所有权限节点 @@ -278,23 +283,25 @@ class PermissionManager(IPermissionManager): try: # Master用户拥有所有权限 if self.is_master(user): - with self.SessionLocal() as session: - all_nodes = session.query(PermissionNodes.node_name).all() - return [node.node_name for node in all_nodes] + async with self.SessionLocal() as session: + result = await session.execute(select(PermissionNodes.node_name)) + all_nodes = result.scalars().all() + return all_nodes permissions = [] - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 获取所有权限节点 - all_nodes = session.query(PermissionNodes).all() + result = await session.execute(select(PermissionNodes)) + all_nodes = result.scalars().all() for node in all_nodes: # 检查用户是否有明确的权限设置 - user_perm = ( - session.query(UserPermissions) + result = await session.execute( + select(UserPermissions) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name) - .first() ) + user_perm = result.scalar_one_or_none() if user_perm: # 有明确设置,使用设置的值 @@ -314,7 +321,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取用户权限时发生未知错误: {e}") return [] - def get_all_permission_nodes(self) -> List[PermissionNode]: + async def get_all_permission_nodes(self) -> List[PermissionNode]: """ 获取所有已注册的权限节点 @@ -322,8 +329,9 @@ class PermissionManager(IPermissionManager): List[PermissionNode]: 权限节点列表 """ try: - with self.SessionLocal() as session: - nodes = session.query(PermissionNodes).all() + async with self.SessionLocal() as session: + result = await session.execute(select(PermissionNodes)) + nodes = result.scalars().all() return [ PermissionNode( node_name=node.node_name, @@ -341,7 +349,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取所有权限节点时发生未知错误: {e}") return [] - def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: + async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: """ 获取指定插件的所有权限节点 @@ -352,8 +360,9 @@ class PermissionManager(IPermissionManager): List[PermissionNode]: 权限节点列表 """ try: - with self.SessionLocal() as session: - nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all() + async with self.SessionLocal() as session: + result = await session.execute(select(PermissionNodes).filter_by(plugin_name=plugin_name)) + nodes = result.scalars().all() return [ PermissionNode( node_name=node.node_name, @@ -371,7 +380,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取插件权限节点时发生未知错误: {e}") return [] - def delete_plugin_permissions(self, plugin_name: str) -> bool: + async def delete_plugin_permissions(self, plugin_name: str) -> bool: """ 删除指定插件的所有权限节点(用于插件卸载时清理) @@ -382,9 +391,10 @@ class PermissionManager(IPermissionManager): bool: 删除是否成功 """ try: - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 获取插件的所有权限节点 - plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all() + result = await session.execute(select(PermissionNodes).filter_by(plugin_name=plugin_name)) + plugin_nodes = result.scalars().all() node_names = [node.node_name for node in plugin_nodes] if not node_names: @@ -392,16 +402,17 @@ class PermissionManager(IPermissionManager): return True # 删除用户权限记录 - deleted_user_perms = ( - session.query(UserPermissions) - .filter(UserPermissions.permission_node.in_(node_names)) - .delete(synchronize_session=False) + result = await session.execute( + delete(UserPermissions) + .where(UserPermissions.permission_node.in_(node_names)) ) + deleted_user_perms = result.rowcount # 删除权限节点 - deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete() + result = await session.execute(delete(PermissionNodes).filter_by(plugin_name=plugin_name)) + deleted_nodes = result.rowcount - session.commit() + await session.commit() logger.info( f"已删除插件 {plugin_name} 的 {deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录" ) @@ -414,7 +425,7 @@ class PermissionManager(IPermissionManager): logger.error(f"删除插件权限时发生未知错误: {e}") return False - def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]: + async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]: """ 获取拥有指定权限的所有用户 @@ -427,17 +438,19 @@ class PermissionManager(IPermissionManager): try: users = [] - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查权限节点是否存在 - node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node)) + node = result.scalar_one_or_none() if not node: logger.warning(f"权限节点 {permission_node} 不存在") return users # 获取明确授权的用户 - granted_users = ( - session.query(UserPermissions).filter_by(permission_node=permission_node, granted=True).all() + result = await session.execute( + select(UserPermissions).filter_by(permission_node=permission_node, granted=True) ) + granted_users = result.scalars().all() for user_perm in granted_users: users.append((user_perm.platform, user_perm.user_id))