From 6659c607992ebdb178a4e25e7599125f550c205b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Tue, 7 Oct 2025 16:29:17 +0800 Subject: [PATCH] =?UTF-8?q?refactor(db,plugin):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=88=9D=E5=A7=8B=E5=8C=96=E5=92=8C?= =?UTF-8?q?=E6=8F=92=E4=BB=B6=E7=B3=BB=E7=BB=9F=E7=B1=BB=E5=9E=8B=E5=AE=89?= =?UTF-8?q?=E5=85=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重构数据库初始化逻辑,添加防重入保护和更好的错误处理 - 优化插件组件注册系统的类型注解和代码结构 - 简化统计模块异常处理逻辑 - 移除插件管理器中的重载功能以简化代码 - 更新Pyright配置排除内置插件目录 - 修复权限管理器异步方法调用 --- bot.py | 4 +- pyrightconfig.json | 7 +- .../management/statistics.py | 29 +-- src/common/database/sqlalchemy_models.py | 136 ++++++----- src/plugin_system/core/component_registry.py | 230 ++++++++++-------- src/plugin_system/core/permission_manager.py | 10 +- src/plugin_system/core/plugin_manager.py | 72 +----- src/plugin_system/core/tool_use.py | 2 +- 8 files changed, 234 insertions(+), 256 deletions(-) diff --git a/bot.py b/bot.py index debeaac5f..fa53b98bb 100644 --- a/bot.py +++ b/bot.py @@ -560,9 +560,9 @@ class MaiBotMain: logger.info("正在初始化数据库表结构...") try: start_time = time.time() - from src.common.database.sqlalchemy_models import initialize_database as init_db + from src.common.database.sqlalchemy_models import initialize_database - await init_db() + await initialize_database() elapsed_time = time.time() - start_time logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}秒") except Exception as e: diff --git a/pyrightconfig.json b/pyrightconfig.json index 0dff0f212..3cffac58c 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -2,8 +2,7 @@ "$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json", "include": [ "src", - "bot.py", - "__main__.py" + "bot.py" ], "exclude": [ "**/__pycache__", @@ -11,7 +10,9 @@ "logs", "tests", "target", - "*.egg-info" + "*.egg-info", + "src/plugins/built_in/*", + "__main__.py" ], "typeCheckingMode": "standard", "reportMissingImports": false, diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 9820ea525..3f690413a 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -5,7 +5,7 @@ """ import datetime -from typing import Any, Optional, TypedDict, Literal, Union, Callable, TypeVar, cast +from typing import Any, Optional, TypeVar, cast from sqlalchemy import select, delete @@ -47,29 +47,26 @@ class AntiInjectionStatistics: """当前会话开始时间""" @staticmethod - async def get_or_create_stats() -> Optional[AntiInjectionStats]: # type: ignore[name-defined] + async def get_or_create_stats() -> AntiInjectionStats: """获取或创建统计记录 Returns: AntiInjectionStats | None: 成功返回模型实例,否则 None """ - try: - async with get_db_session() as session: + async with get_db_session() as session: # 获取最新的统计记录,如果没有则创建 - stats = ( + stats = ( (await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()))) .scalars() .first() ) - if not stats: - stats = AntiInjectionStats() - session.add(stats) - await session.commit() - await session.refresh(stats) - return stats - except Exception as e: - logger.error(f"获取统计记录失败: {e}") - return None + if not stats: + stats = AntiInjectionStats() + session.add(stats) + await session.commit() + await session.refresh(stats) + return stats + @staticmethod async def update_stats(**kwargs: Any) -> None: @@ -97,7 +94,7 @@ class AntiInjectionStatistics: if key == "processing_time_delta": # 处理时间累加 - 确保不为 None delta = float(value) - stats.processing_time_total = _add_optional(stats.processing_time_total, delta) # type: ignore[attr-defined] + stats.processing_time_total = _add_optional(stats.processing_time_total, delta) continue elif key == "last_processing_time": # 直接设置最后处理时间 @@ -146,7 +143,7 @@ class AntiInjectionStatistics: # 计算派生统计信息 - 处理 None 值 - total_messages = stats.total_messages or 0 # type: ignore[attr-defined] + total_messages = stats.total_messages or 0 detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined] processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined] diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 9319e11f4..d182291d9 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -19,7 +19,7 @@ from contextlib import asynccontextmanager from typing import Any from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Mapped, mapped_column @@ -31,6 +31,10 @@ logger = get_logger("sqlalchemy_models") # 创建基类 Base = declarative_base() +# 全局异步引擎与会话工厂占位(延迟初始化) +_engine: AsyncEngine | None = None +_SessionLocal: async_sessionmaker[AsyncSession] | None = None + async def enable_sqlite_wal_mode(engine): """为 SQLite 启用 WAL 模式以提高并发性能""" @@ -649,23 +653,13 @@ class MonthlyPlan(Base): last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True) created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - # 保留 is_deleted 字段以兼容现有数据,但标记为已弃用 - is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - __table_args__ = ( Index("idx_monthlyplan_target_month_status", "target_month", "status"), Index("idx_monthlyplan_last_used_date", "last_used_date"), Index("idx_monthlyplan_usage_count", "usage_count"), - # 保留旧索引以兼容 - Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"), ) -# 数据库引擎和会话管理 -_engine = None -_SessionLocal = None - - def get_database_url(): """获取数据库连接URL""" from src.config.config import global_config @@ -709,65 +703,89 @@ def get_database_url(): return f"sqlite+aiosqlite:///{db_path}" -async def initialize_database(): - """初始化异步数据库引擎和会话""" - global _engine, _SessionLocal +_initializing: bool = False # 防止递归初始化 - if _engine is not None: +async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[AsyncSession]]: + """初始化异步数据库引擎和会话 + + Returns: + tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: 创建好的异步引擎与会话工厂。 + + 说明: + 显式的返回类型标注有助于 Pyright/Pylance 正确推断调用处的对象, + 避免后续对返回值再次 `await` 时出现 *"tuple[...] 并非 awaitable"* 的误用。 + """ + global _engine, _SessionLocal, _initializing + + # 已经初始化直接返回 + if _engine is not None and _SessionLocal is not None: return _engine, _SessionLocal - database_url = get_database_url() - from src.config.config import global_config + # 正在初始化的并发调用等待主初始化完成,避免递归 + if _initializing: + import asyncio + for _ in range(1000): # 最多等待约10秒 + await asyncio.sleep(0.01) + if _engine is not None and _SessionLocal is not None: + return _engine, _SessionLocal + raise RuntimeError("等待数据库初始化完成超时 (reentrancy guard)") - config = global_config.database + _initializing = True + try: + database_url = get_database_url() + from src.config.config import global_config - # 配置引擎参数 - engine_kwargs: dict[str, Any] = { - "echo": False, # 生产环境关闭SQL日志 - "future": True, - } + config = global_config.database - if config.database_type == "mysql": - # MySQL连接池配置 - 异步引擎使用默认连接池 - engine_kwargs.update( - { - "pool_size": config.connection_pool_size, - "max_overflow": config.connection_pool_size * 2, - "pool_timeout": config.connection_timeout, - "pool_recycle": 3600, # 1小时回收连接 - "pool_pre_ping": True, # 连接前ping检查 - "connect_args": { - "autocommit": config.mysql_autocommit, - "charset": config.mysql_charset, - "connect_timeout": config.connection_timeout, - }, - } - ) - else: - # SQLite配置 - aiosqlite不支持连接池参数 - engine_kwargs.update( - { - "connect_args": { - "check_same_thread": False, - "timeout": 60, # 增加超时时间 - }, - } - ) + # 配置引擎参数 + engine_kwargs: dict[str, Any] = { + "echo": False, # 生产环境关闭SQL日志 + "future": True, + } - _engine = create_async_engine(database_url, **engine_kwargs) - _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) + if config.database_type == "mysql": + engine_kwargs.update( + { + "pool_size": config.connection_pool_size, + "max_overflow": config.connection_pool_size * 2, + "pool_timeout": config.connection_timeout, + "pool_recycle": 3600, + "pool_pre_ping": True, + "connect_args": { + "autocommit": config.mysql_autocommit, + "charset": config.mysql_charset, + "connect_timeout": config.connection_timeout, + }, + } + ) + else: + engine_kwargs.update( + { + "connect_args": { + "check_same_thread": False, + "timeout": 60, + }, + } + ) - # 调用新的迁移函数,它会处理表的创建和列的添加 - from src.common.database.db_migration import check_and_migrate_database + _engine = create_async_engine(database_url, **engine_kwargs) + _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) - await check_and_migrate_database() + # 迁移 + try: + from src.common.database.db_migration import check_and_migrate_database + await check_and_migrate_database(existing_engine=_engine) + except TypeError: + from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate + await _legacy_migrate() - # 如果是 SQLite,启用 WAL 模式以提高并发性能 - if config.database_type == "sqlite": - await enable_sqlite_wal_mode(_engine) + if config.database_type == "sqlite": + await enable_sqlite_wal_mode(_engine) - logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") - return _engine, _SessionLocal + logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") + return _engine, _SessionLocal + finally: + _initializing = False @asynccontextmanager diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 91b3001da..3194272c8 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import re from pathlib import Path from re import Pattern -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from src.common.logger import get_logger from src.plugin_system.base.base_action import BaseAction @@ -26,6 +28,23 @@ from src.plugin_system.base.plus_command import PlusCommand logger = get_logger("component_registry") +# 统一的组件类类型别名 +ComponentClassType = ( + type[BaseCommand] + | type[BaseAction] + | type[BaseTool] + | type[BaseEventHandler] + | type[PlusCommand] + | type[BaseChatter] + | type[BaseInterestCalculator] +) + + +def _assign_plugin_attrs(cls: Any, plugin_name: str, plugin_config: dict) -> None: + """为组件类动态赋予插件相关属性(避免在各注册函数中重复代码)。""" + setattr(cls, "plugin_name", plugin_name) + setattr(cls, "plugin_config", plugin_config) + class ComponentRegistry: """统一的组件注册中心 @@ -41,9 +60,8 @@ class ComponentRegistry: types: {} for types in ComponentType } """类型 -> 组件原名称 -> 组件信息""" - self._components_classes: dict[ - str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"] - ] = {} + # 组件类注册表(命名空间式组件名 -> 组件类) + self._components_classes: dict[str, ComponentClassType] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 @@ -100,10 +118,8 @@ class ComponentRegistry: return True def register_component( - self, - component_info: ComponentInfo, - component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]], - ) -> bool: + self, self_component_info: ComponentInfo, component_class: ComponentClassType + ) -> bool: # noqa: C901 (保持原有结构, 以后可再拆) """注册组件 Args: @@ -113,9 +129,11 @@ class ComponentRegistry: Returns: bool: 是否注册成功 """ + component_info = self_component_info # 局部别名 component_name = component_info.name component_type = component_info.component_type plugin_name = getattr(component_info, "plugin_name", "unknown") + if "." in component_name: logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代") return False @@ -123,22 +141,19 @@ class ComponentRegistry: logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") return False - namespaced_name = f"{component_type}.{component_name}" - + namespaced_name = f"{component_type.value}.{component_name}" if namespaced_name in self._components: existing_info = self._components[namespaced_name] existing_plugin = getattr(existing_info, "plugin_name", "unknown") - logger.warning( f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册" ) return False - self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称) - self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名 + self._components[namespaced_name] = component_info + self._components_by_type[component_type][component_name] = component_info self._components_classes[namespaced_name] = component_class - # 根据组件类型进行特定注册(使用原始名称) match component_type: case ComponentType.ACTION: assert isinstance(component_info, ActionInfo) @@ -175,12 +190,11 @@ class ComponentRegistry: if not ret: return False logger.debug( - f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' " - f"({component_class.__name__}) [插件: {plugin_name}]" + f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' ({component_class.__name__}) [插件: {plugin_name}]" ) return True - def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool: + def _register_action_component(self, action_info: ActionInfo, action_class: type[BaseAction]) -> bool: """注册Action组件到Action特定注册表""" if not (action_name := action_info.name): logger.error(f"Action组件 {action_class.__name__} 必须指定名称") @@ -188,19 +202,13 @@ class ComponentRegistry: if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction): logger.error(f"注册失败: {action_name} 不是有效的Action") return False - - action_class.plugin_name = action_info.plugin_name - # 设置插件配置 - action_class.plugin_config = self.get_plugin_config(action_info.plugin_name) or {} + _assign_plugin_attrs(action_class, action_info.plugin_name, self.get_plugin_config(action_info.plugin_name) or {}) self._action_registry[action_name] = action_class - - # 如果启用,添加到默认动作集 if action_info.enabled: self._default_actions[action_name] = action_info - return True - def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool: + def _register_command_component(self, command_info: CommandInfo, command_class: type[BaseCommand]) -> bool: """注册Command组件到Command特定注册表""" if not (command_name := command_info.name): logger.error(f"Command组件 {command_class.__name__} 必须指定名称") @@ -208,13 +216,10 @@ class ComponentRegistry: if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand): logger.error(f"注册失败: {command_name} 不是有效的Command") return False - - command_class.plugin_name = command_info.plugin_name - # 设置插件配置 - command_class.plugin_config = self.get_plugin_config(command_info.plugin_name) or {} + _assign_plugin_attrs( + command_class, command_info.plugin_name, self.get_plugin_config(command_info.plugin_name) or {} + ) self._command_registry[command_name] = command_class - - # 如果启用了且有匹配模式 if command_info.enabled and command_info.command_pattern: pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL) if pattern not in self._command_patterns: @@ -223,11 +228,10 @@ class ComponentRegistry: logger.warning( f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令" ) - return True def _register_plus_command_component( - self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"] + self, plus_command_info: PlusCommandInfo, plus_command_class: type[PlusCommand] ) -> bool: """注册PlusCommand组件到特定注册表""" plus_command_name = plus_command_info.name @@ -241,33 +245,27 @@ class ComponentRegistry: # 创建专门的PlusCommand注册表(如果还没有) if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: dict[str, type["PlusCommand"]] = {} - - plus_command_class.plugin_name = plus_command_info.plugin_name - # 设置插件配置 - plus_command_class.plugin_config = self.get_plugin_config(plus_command_info.plugin_name) or {} + self._plus_command_registry: dict[str, type[PlusCommand]] = {} + _assign_plugin_attrs( + plus_command_class, + plus_command_info.plugin_name, + self.get_plugin_config(plus_command_info.plugin_name) or {}, + ) self._plus_command_registry[plus_command_name] = plus_command_class - logger.debug(f"已注册PlusCommand组件: {plus_command_name}") return True - def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool: + def _register_tool_component(self, tool_info: ToolInfo, tool_class: type[BaseTool]) -> bool: """注册Tool组件到Tool特定注册表""" tool_name = tool_info.name - - tool_class.plugin_name = tool_info.plugin_name - # 设置插件配置 - tool_class.plugin_config = self.get_plugin_config(tool_info.plugin_name) or {} + _assign_plugin_attrs(tool_class, tool_info.plugin_name, self.get_plugin_config(tool_info.plugin_name) or {}) self._tool_registry[tool_name] = tool_class - - # 如果是llm可用的且启用的工具,添加到 llm可用工具列表 if tool_info.enabled: self._llm_available_tools[tool_name] = tool_class - return True def _register_event_handler_component( - self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"] + self, handler_info: EventHandlerInfo, handler_class: type[BaseEventHandler] ) -> bool: if not (handler_name := handler_info.name): logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") @@ -275,25 +273,19 @@ class ComponentRegistry: if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler): logger.error(f"注册失败: {handler_name} 不是有效的EventHandler") return False - - handler_class.plugin_name = handler_info.plugin_name - # 设置插件配置 - handler_class.plugin_config = self.get_plugin_config(handler_info.plugin_name) or {} + _assign_plugin_attrs( + handler_class, handler_info.plugin_name, self.get_plugin_config(handler_info.plugin_name) or {} + ) self._event_handler_registry[handler_name] = handler_class - if not handler_info.enabled: logger.warning(f"EventHandler组件 {handler_name} 未启用") return True # 未启用,但是也是注册成功 - - handler_class.plugin_name = handler_info.plugin_name - # 使用EventManager进行事件处理器注册 from src.plugin_system.core.event_manager import event_manager - return event_manager.register_event_handler( handler_class, self.get_plugin_config(handler_info.plugin_name) or {} ) - def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool: + def _register_chatter_component(self, chatter_info: ChatterInfo, chatter_class: type[BaseChatter]) -> bool: """注册Chatter组件到Chatter特定注册表""" chatter_name = chatter_info.name @@ -303,18 +295,14 @@ class ComponentRegistry: if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter): logger.error(f"注册失败: {chatter_name} 不是有效的Chatter") return False - - chatter_class.plugin_name = chatter_info.plugin_name - # 设置插件配置 - chatter_class.plugin_config = self.get_plugin_config(chatter_info.plugin_name) or {} - + _assign_plugin_attrs( + chatter_class, chatter_info.plugin_name, self.get_plugin_config(chatter_info.plugin_name) or {} + ) self._chatter_registry[chatter_name] = chatter_class - if not chatter_info.enabled: logger.warning(f"Chatter组件 {chatter_name} 未启用") return True # 未启用,但是也是注册成功 self._enabled_chatter_registry[chatter_name] = chatter_class - logger.debug(f"已注册Chatter组件: {chatter_name}") return True @@ -341,9 +329,13 @@ class ComponentRegistry: if not hasattr(self, "_enabled_interest_calculator_registry"): self._enabled_interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {} - interest_calculator_class.plugin_name = interest_calculator_info.plugin_name + setattr(interest_calculator_class, "plugin_name", interest_calculator_info.plugin_name) # 设置插件配置 - interest_calculator_class.plugin_config = self.get_plugin_config(interest_calculator_info.plugin_name) or {} + setattr( + interest_calculator_class, + "plugin_config", + self.get_plugin_config(interest_calculator_info.plugin_name) or {}, + ) self._interest_calculator_registry[calculator_name] = interest_calculator_class if not interest_calculator_info.enabled: @@ -356,7 +348,7 @@ class ComponentRegistry: # === 组件移除相关 === - async def remove_component(self, component_name: str, component_type: "ComponentType", plugin_name: str) -> bool: + async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool: target_component_class = self.get_component_class(component_name, component_type) if not target_component_class: logger.warning(f"组件 {component_name} 未注册,无法移除") @@ -398,8 +390,14 @@ class ComponentRegistry: self._enabled_event_handlers.pop(component_name, None) try: handler = event_manager.get_event_handler(component_name) - for event in handler.subscribed_events: - await event_manager.unsubscribe_handler_from_event(event, component_name) + # 事件处理器可能未找到或未声明 subscribed_events,需判空 + if handler and hasattr(handler, "subscribed_events"): + for event in getattr(handler, "subscribed_events"): + # 假设 unsubscribe_handler_from_event 是协程;若不是则移除 await + result = event_manager.unsubscribe_handler_from_event(event, component_name) + if hasattr(result, "__await__"): + await result # type: ignore[func-returns-value] + logger.debug(f"已移除EventHandler组件: {component_name}") logger.debug(f"已移除EventHandler组件: {component_name}") except Exception as e: logger.warning(f"移除EventHandler事件订阅时出错: {e}") @@ -415,7 +413,7 @@ class ComponentRegistry: return False # 移除通用注册信息 - namespaced_name = f"{component_type}.{component_name}" + namespaced_name = f"{component_type.value}.{component_name}" self._components.pop(namespaced_name, None) self._components_by_type[component_type].pop(component_name, None) self._components_classes.pop(namespaced_name, None) @@ -477,9 +475,10 @@ class ComponentRegistry: self._enabled_event_handlers[component_name] = target_component_class from .event_manager import event_manager # 延迟导入防止循环导入问题 - event_manager.register_event_handler(component_name) - - namespaced_name = f"{component_type}.{component_name}" + # 重新注册事件处理器(启用)使用类而不是名称 + cfg = self.get_plugin_config(target_component_info.plugin_name) or {} + event_manager.register_event_handler(target_component_class, cfg) # type: ignore[arg-type] + namespaced_name = f"{component_type.value}.{component_name}" self._components[namespaced_name].enabled = True self._components_by_type[component_type][component_name].enabled = True logger.info(f"组件 {component_name} 已启用") @@ -512,10 +511,16 @@ class ComponentRegistry: from .event_manager import event_manager # 延迟导入防止循环导入问题 handler = event_manager.get_event_handler(component_name) - for event in handler.subscribed_events: - await event_manager.unsubscribe_handler_from_event(event, component_name) + if handler and hasattr(handler, "subscribed_events"): + for event in getattr(handler, "subscribed_events"): + result = event_manager.unsubscribe_handler_from_event(event, component_name) + if hasattr(result, "__await__"): + await result # type: ignore[func-returns-value] - self._components[component_name].enabled = False + # 组件主注册表使用命名空间 key + namespaced_name = f"{component_type.value}.{component_name}" + if namespaced_name in self._components: + self._components[namespaced_name].enabled = False self._components_by_type[component_type][component_name].enabled = False logger.info(f"组件 {component_name} 已禁用") return True @@ -528,8 +533,8 @@ class ComponentRegistry: # === 组件查询方法 === def get_component_info( - self, component_name: str, component_type: Optional["ComponentType"] = None - ) -> Optional["ComponentInfo"]: + self, component_name: str, component_type: Optional[ComponentType] = None + ) -> Optional[ComponentInfo]: # sourcery skip: class-extract-method """获取组件信息,支持自动命名空间解析 @@ -546,7 +551,7 @@ class ComponentRegistry: # 2. 如果指定了组件类型,构造命名空间化的名称查找 if component_type: - namespaced_name = f"{component_type}.{component_name}" + namespaced_name = f"{component_type.value}.{component_name}" return self._components.get(namespaced_name) # 3. 如果没有指定类型,尝试在所有命名空间中查找 @@ -573,8 +578,17 @@ class ComponentRegistry: def get_component_class( self, component_name: str, - component_type: Optional["ComponentType"] = None, - ) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None: + component_type: Optional[ComponentType] = None, + ) -> ( + type[BaseCommand] + | type[BaseAction] + | type[BaseEventHandler] + | type[BaseTool] + | type[PlusCommand] + | type[BaseChatter] + | type[BaseInterestCalculator] + | None + ): """获取组件类,支持自动命名空间解析 Args: @@ -591,7 +605,17 @@ class ComponentRegistry: # 2. 如果指定了组件类型,构造命名空间化的名称查找 if component_type: namespaced_name = f"{component_type.value}.{component_name}" - return self._components_classes.get(namespaced_name) # type: ignore[valid-type] + return cast( + type[BaseCommand] + | type[BaseAction] + | type[BaseEventHandler] + | type[BaseTool] + | type[PlusCommand] + | type[BaseChatter] + | type[BaseInterestCalculator] + | None, + self._components_classes.get(namespaced_name), + ) # 3. 如果没有指定类型,尝试在所有命名空间中查找 candidates = [] @@ -616,22 +640,22 @@ class ComponentRegistry: # 4. 都没找到 return None - def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: + def get_components_by_type(self, component_type: ComponentType) -> dict[str, ComponentInfo]: """获取指定类型的所有组件""" return self._components_by_type.get(component_type, {}).copy() - def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: + def get_enabled_components_by_type(self, component_type: ComponentType) -> dict[str, ComponentInfo]: """获取指定类型的所有启用组件""" components = self.get_components_by_type(component_type) return {name: info for name, info in components.items() if info.enabled} # === Action特定查询方法 === - def get_action_registry(self) -> dict[str, type["BaseAction"]]: + def get_action_registry(self) -> dict[str, type[BaseAction]]: """获取Action注册表""" return self._action_registry.copy() - def get_registered_action_info(self, action_name: str) -> Optional["ActionInfo"]: + def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]: """获取Action信息""" info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None @@ -642,11 +666,11 @@ class ComponentRegistry: # === Command特定查询方法 === - def get_command_registry(self) -> dict[str, type["BaseCommand"]]: + def get_command_registry(self) -> dict[str, type[BaseCommand]]: """获取Command注册表""" return self._command_registry.copy() - def get_registered_command_info(self, command_name: str) -> Optional["CommandInfo"]: + def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]: """获取Command信息""" info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None @@ -655,7 +679,7 @@ class ComponentRegistry: """获取Command模式注册表""" return self._command_patterns.copy() - def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None: + def find_command_by_text(self, text: str) -> tuple[type[BaseCommand], dict, CommandInfo] | None: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -682,15 +706,15 @@ class ComponentRegistry: return None # === Tool 特定查询方法 === - def get_tool_registry(self) -> dict[str, type["BaseTool"]]: + def get_tool_registry(self) -> dict[str, type[BaseTool]]: """获取Tool注册表""" return self._tool_registry.copy() - def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]: + def get_llm_available_tools(self) -> dict[str, type[BaseTool]]: """获取LLM可用的Tool列表""" return self._llm_available_tools.copy() - def get_registered_tool_info(self, tool_name: str) -> Optional["ToolInfo"]: + def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]: """获取Tool信息 Args: @@ -703,13 +727,13 @@ class ComponentRegistry: return info if isinstance(info, ToolInfo) else None # === PlusCommand 特定查询方法 === - def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]: + def get_plus_command_registry(self) -> dict[str, type[PlusCommand]]: """获取PlusCommand注册表""" if not hasattr(self, "_plus_command_registry"): self._plus_command_registry: dict[str, type[PlusCommand]] = {} return self._plus_command_registry.copy() - def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]: + def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]: """获取PlusCommand信息 Args: @@ -723,44 +747,44 @@ class ComponentRegistry: # === EventHandler 特定查询方法 === - def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]: + def get_event_handler_registry(self) -> dict[str, type[BaseEventHandler]]: """获取事件处理器注册表""" return self._event_handler_registry.copy() - def get_registered_event_handler_info(self, handler_name: str) -> Optional["EventHandlerInfo"]: + def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]: """获取事件处理器信息""" info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) return info if isinstance(info, EventHandlerInfo) else None - def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]: + def get_enabled_event_handlers(self) -> dict[str, type[BaseEventHandler]]: """获取启用的事件处理器""" return self._enabled_event_handlers.copy() # === Chatter 特定查询方法 === - def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]: + def get_chatter_registry(self) -> dict[str, type[BaseChatter]]: """获取Chatter注册表""" if not hasattr(self, "_chatter_registry"): self._chatter_registry: dict[str, type[BaseChatter]] = {} return self._chatter_registry.copy() - def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]: + def get_enabled_chatter_registry(self) -> dict[str, type[BaseChatter]]: """获取启用的Chatter注册表""" if not hasattr(self, "_enabled_chatter_registry"): self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {} return self._enabled_chatter_registry.copy() - def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]: + def get_registered_chatter_info(self, chatter_name: str) -> Optional[ChatterInfo]: """获取Chatter信息""" info = self.get_component_info(chatter_name, ComponentType.CHATTER) return info if isinstance(info, ChatterInfo) else None # === 插件查询方法 === - def get_plugin_info(self, plugin_name: str) -> Optional["PluginInfo"]: + def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]: """获取插件信息""" return self._plugins.get(plugin_name) - def get_all_plugins(self) -> dict[str, "PluginInfo"]: + def get_all_plugins(self) -> dict[str, PluginInfo]: """获取所有插件""" return self._plugins.copy() diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index 99f00340c..544564e56 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -22,8 +22,6 @@ class PermissionManager(IPermissionManager): """权限管理器实现类""" def __init__(self): - self.engine = None - self.SessionLocal = None self._master_users: set[tuple[str, str]] = set() self._load_master_users() @@ -52,7 +50,7 @@ class PermissionManager(IPermissionManager): self._load_master_users() logger.info("Master用户配置已重新加载") - def is_master(self, user: UserInfo) -> bool: + async def is_master(self, user: UserInfo) -> bool: """ 检查用户是否为Master用户 @@ -81,7 +79,7 @@ class PermissionManager(IPermissionManager): """ try: # Master用户拥有所有权限 - if self.is_master(user): + if await self.is_master(user): logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}") return True @@ -288,10 +286,10 @@ class PermissionManager(IPermissionManager): """ try: # Master用户拥有所有权限 - if self.is_master(user): + if await self.is_master(user): async with self.SessionLocal() as session: result = await session.execute(select(PermissionNodes.node_name)) - all_nodes = result.scalars().all() + all_nodes = list(result.scalars().all()) return all_nodes permissions = [] diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 6542365b7..a5244e0e7 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -95,6 +95,7 @@ class PluginManager: if not plugin_class: logger.error(f"插件 {plugin_name} 的插件类未注册或不存在") return False, 1 + init_module = None # 预先定义,避免后续条件加载导致未绑定 try: # 使用记录的插件目录路径 plugin_dir = self.plugin_paths.get(plugin_name) @@ -314,6 +315,7 @@ class PluginManager: module_name = ".".join(plugin_path.parent.parts) try: + init_module = None # 确保下方引用存在 # 首先加载 __init__.py 来获取元数据 init_file = os.path.join(plugin_dir, "__init__.py") if os.path.exists(init_file): @@ -524,13 +526,10 @@ class PluginManager: fut.result(timeout=5) else: asyncio.run(component_registry.unregister_plugin(plugin_name)) - except Exception: - # 最后兜底:直接同步调用(如果 unregister_plugin 为非协程)或忽略错误 - try: - # 如果 unregister_plugin 是普通函数 - component_registry.unregister_plugin(plugin_name) - except Exception as e: - logger.debug(f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}") + except Exception as e: # 捕获并记录卸载阶段协程调用错误 + logger.debug( + f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}", exc_info=True + ) # 从已加载插件中移除 del self.loaded_plugins[plugin_name] @@ -550,65 +549,6 @@ class PluginManager: logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True) return False - def reload_plugin(self, plugin_name: str) -> bool: - """重载指定插件 - - Args: - plugin_name: 插件名称 - - Returns: - bool: 重载是否成功 - """ - try: - logger.info(f"🔄 开始重载插件: {plugin_name}") - - # 卸载插件 - if plugin_name in self.loaded_plugins: - if not self.unload_plugin(plugin_name): - logger.warning(f"⚠️ 插件卸载失败,继续重载: {plugin_name}") - - # 重新扫描插件目录 - self.rescan_plugin_directory() - - # 重新加载插件实例 - if plugin_name in self.plugin_classes: - success, _ = self.load_registered_plugin_classes(plugin_name) - if success: - logger.info(f"✅ 插件重载成功: {plugin_name}") - return True - else: - logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败") - return False - else: - logger.error(f"❌ 插件重载失败: {plugin_name} - 插件类未找到") - return False - - except Exception as e: - logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True) - return False - - def force_reload_plugin(self, plugin_name: str) -> bool: - """强制重载插件(使用简化的方法) - - Args: - plugin_name: 插件名称 - - Returns: - bool: 重载是否成功 - """ - return self.reload_plugin(plugin_name) - - @staticmethod - def clear_all_plugin_caches(): - """清理所有插件相关的模块缓存(简化版)""" - try: - logger.info("🧹 清理模块缓存...") - # 清理importlib缓存 - importlib.invalidate_caches() - logger.info("🧹 模块缓存清理完成") - except Exception as e: - logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True) - # 全局插件管理器实例 plugin_manager = PluginManager() diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 7dd09a894..82a9bf721 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -177,8 +177,8 @@ class ToolExecutor: # 执行每个工具调用 for tool_call in tool_calls: + tool_name = getattr(tool_call, "func_name", "unknown_tool") try: - tool_name = tool_call.func_name logger.debug(f"{self.log_prefix}执行工具: {tool_name}") # 执行工具