refactor(db,plugin): 优化数据库初始化和插件系统类型安全

- 重构数据库初始化逻辑,添加防重入保护和更好的错误处理
- 优化插件组件注册系统的类型注解和代码结构
- 简化统计模块异常处理逻辑
- 移除插件管理器中的重载功能以简化代码
- 更新Pyright配置排除内置插件目录
- 修复权限管理器异步方法调用
This commit is contained in:
雅诺狐
2025-10-07 16:29:17 +08:00
parent 4971d18f14
commit 6659c60799
8 changed files with 234 additions and 256 deletions

4
bot.py
View File

@@ -560,9 +560,9 @@ class MaiBotMain:
logger.info("正在初始化数据库表结构...") logger.info("正在初始化数据库表结构...")
try: try:
start_time = time.time() start_time = time.time()
from src.common.database.sqlalchemy_models import initialize_database as init_db from src.common.database.sqlalchemy_models import initialize_database
await init_db() await initialize_database()
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}") logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}")
except Exception as e: except Exception as e:

View File

@@ -2,8 +2,7 @@
"$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json", "$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json",
"include": [ "include": [
"src", "src",
"bot.py", "bot.py"
"__main__.py"
], ],
"exclude": [ "exclude": [
"**/__pycache__", "**/__pycache__",
@@ -11,7 +10,9 @@
"logs", "logs",
"tests", "tests",
"target", "target",
"*.egg-info" "*.egg-info",
"src/plugins/built_in/*",
"__main__.py"
], ],
"typeCheckingMode": "standard", "typeCheckingMode": "standard",
"reportMissingImports": false, "reportMissingImports": false,

View File

@@ -5,7 +5,7 @@
""" """
import datetime import datetime
from typing import Any, Optional, TypedDict, Literal, Union, Callable, TypeVar, cast from typing import Any, Optional, TypeVar, cast
from sqlalchemy import select, delete from sqlalchemy import select, delete
@@ -47,13 +47,12 @@ class AntiInjectionStatistics:
"""当前会话开始时间""" """当前会话开始时间"""
@staticmethod @staticmethod
async def get_or_create_stats() -> Optional[AntiInjectionStats]: # type: ignore[name-defined] async def get_or_create_stats() -> AntiInjectionStats:
"""获取或创建统计记录 """获取或创建统计记录
Returns: Returns:
AntiInjectionStats | None: 成功返回模型实例,否则 None AntiInjectionStats | None: 成功返回模型实例,否则 None
""" """
try:
async with get_db_session() as session: async with get_db_session() as session:
# 获取最新的统计记录,如果没有则创建 # 获取最新的统计记录,如果没有则创建
stats = ( stats = (
@@ -67,9 +66,7 @@ class AntiInjectionStatistics:
await session.commit() await session.commit()
await session.refresh(stats) await session.refresh(stats)
return stats return stats
except Exception as e:
logger.error(f"获取统计记录失败: {e}")
return None
@staticmethod @staticmethod
async def update_stats(**kwargs: Any) -> None: async def update_stats(**kwargs: Any) -> None:
@@ -97,7 +94,7 @@ class AntiInjectionStatistics:
if key == "processing_time_delta": if key == "processing_time_delta":
# 处理时间累加 - 确保不为 None # 处理时间累加 - 确保不为 None
delta = float(value) delta = float(value)
stats.processing_time_total = _add_optional(stats.processing_time_total, delta) # type: ignore[attr-defined] stats.processing_time_total = _add_optional(stats.processing_time_total, delta)
continue continue
elif key == "last_processing_time": elif key == "last_processing_time":
# 直接设置最后处理时间 # 直接设置最后处理时间
@@ -146,7 +143,7 @@ class AntiInjectionStatistics:
# 计算派生统计信息 - 处理 None 值 # 计算派生统计信息 - 处理 None 值
total_messages = stats.total_messages or 0 # type: ignore[attr-defined] total_messages = stats.total_messages or 0
detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined] detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined]
processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined] processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined]

View File

@@ -19,7 +19,7 @@ from contextlib import asynccontextmanager
from typing import Any from typing import Any
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@@ -31,6 +31,10 @@ logger = get_logger("sqlalchemy_models")
# 创建基类 # 创建基类
Base = declarative_base() Base = declarative_base()
# 全局异步引擎与会话工厂占位(延迟初始化)
_engine: AsyncEngine | None = None
_SessionLocal: async_sessionmaker[AsyncSession] | None = None
async def enable_sqlite_wal_mode(engine): async def enable_sqlite_wal_mode(engine):
"""为 SQLite 启用 WAL 模式以提高并发性能""" """为 SQLite 启用 WAL 模式以提高并发性能"""
@@ -649,23 +653,13 @@ class MonthlyPlan(Base):
last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True) last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
__table_args__ = ( __table_args__ = (
Index("idx_monthlyplan_target_month_status", "target_month", "status"), Index("idx_monthlyplan_target_month_status", "target_month", "status"),
Index("idx_monthlyplan_last_used_date", "last_used_date"), Index("idx_monthlyplan_last_used_date", "last_used_date"),
Index("idx_monthlyplan_usage_count", "usage_count"), Index("idx_monthlyplan_usage_count", "usage_count"),
# 保留旧索引以兼容
Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"),
) )
# 数据库引擎和会话管理
_engine = None
_SessionLocal = None
def get_database_url(): def get_database_url():
"""获取数据库连接URL""" """获取数据库连接URL"""
from src.config.config import global_config from src.config.config import global_config
@@ -709,13 +703,35 @@ def get_database_url():
return f"sqlite+aiosqlite:///{db_path}" return f"sqlite+aiosqlite:///{db_path}"
async def initialize_database(): _initializing: bool = False # 防止递归初始化
"""初始化异步数据库引擎和会话"""
global _engine, _SessionLocal
if _engine is not None: async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[AsyncSession]]:
"""初始化异步数据库引擎和会话
Returns:
tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: 创建好的异步引擎与会话工厂。
说明:
显式的返回类型标注有助于 Pyright/Pylance 正确推断调用处的对象,
避免后续对返回值再次 `await` 时出现 *"tuple[...] 并非 awaitable"* 的误用。
"""
global _engine, _SessionLocal, _initializing
# 已经初始化直接返回
if _engine is not None and _SessionLocal is not None:
return _engine, _SessionLocal return _engine, _SessionLocal
# 正在初始化的并发调用等待主初始化完成,避免递归
if _initializing:
import asyncio
for _ in range(1000): # 最多等待约10秒
await asyncio.sleep(0.01)
if _engine is not None and _SessionLocal is not None:
return _engine, _SessionLocal
raise RuntimeError("等待数据库初始化完成超时 (reentrancy guard)")
_initializing = True
try:
database_url = get_database_url() database_url = get_database_url()
from src.config.config import global_config from src.config.config import global_config
@@ -728,14 +744,13 @@ async def initialize_database():
} }
if config.database_type == "mysql": if config.database_type == "mysql":
# MySQL连接池配置 - 异步引擎使用默认连接池
engine_kwargs.update( engine_kwargs.update(
{ {
"pool_size": config.connection_pool_size, "pool_size": config.connection_pool_size,
"max_overflow": config.connection_pool_size * 2, "max_overflow": config.connection_pool_size * 2,
"pool_timeout": config.connection_timeout, "pool_timeout": config.connection_timeout,
"pool_recycle": 3600, # 1小时回收连接 "pool_recycle": 3600,
"pool_pre_ping": True, # 连接前ping检查 "pool_pre_ping": True,
"connect_args": { "connect_args": {
"autocommit": config.mysql_autocommit, "autocommit": config.mysql_autocommit,
"charset": config.mysql_charset, "charset": config.mysql_charset,
@@ -744,12 +759,11 @@ async def initialize_database():
} }
) )
else: else:
# SQLite配置 - aiosqlite不支持连接池参数
engine_kwargs.update( engine_kwargs.update(
{ {
"connect_args": { "connect_args": {
"check_same_thread": False, "check_same_thread": False,
"timeout": 60, # 增加超时时间 "timeout": 60,
}, },
} }
) )
@@ -757,17 +771,21 @@ async def initialize_database():
_engine = create_async_engine(database_url, **engine_kwargs) _engine = create_async_engine(database_url, **engine_kwargs)
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 调用新的迁移函数,它会处理表的创建和列的添加 # 迁移
try:
from src.common.database.db_migration import check_and_migrate_database from src.common.database.db_migration import check_and_migrate_database
await check_and_migrate_database(existing_engine=_engine)
except TypeError:
from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
await _legacy_migrate()
await check_and_migrate_database()
# 如果是 SQLite启用 WAL 模式以提高并发性能
if config.database_type == "sqlite": if config.database_type == "sqlite":
await enable_sqlite_wal_mode(_engine) await enable_sqlite_wal_mode(_engine)
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal return _engine, _SessionLocal
finally:
_initializing = False
@asynccontextmanager @asynccontextmanager

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import re import re
from pathlib import Path from pathlib import Path
from re import Pattern from re import Pattern
from typing import Any, Optional, Union from typing import Any, Optional, Union, cast
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
@@ -26,6 +28,23 @@ from src.plugin_system.base.plus_command import PlusCommand
logger = get_logger("component_registry") logger = get_logger("component_registry")
# 统一的组件类类型别名
ComponentClassType = (
type[BaseCommand]
| type[BaseAction]
| type[BaseTool]
| type[BaseEventHandler]
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
)
def _assign_plugin_attrs(cls: Any, plugin_name: str, plugin_config: dict) -> None:
"""为组件类动态赋予插件相关属性(避免在各注册函数中重复代码)。"""
setattr(cls, "plugin_name", plugin_name)
setattr(cls, "plugin_config", plugin_config)
class ComponentRegistry: class ComponentRegistry:
"""统一的组件注册中心 """统一的组件注册中心
@@ -41,9 +60,8 @@ class ComponentRegistry:
types: {} for types in ComponentType types: {} for types in ComponentType
} }
"""类型 -> 组件原名称 -> 组件信息""" """类型 -> 组件原名称 -> 组件信息"""
self._components_classes: dict[ # 组件类注册表(命名空间式组件名 -> 组件类)
str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"] self._components_classes: dict[str, ComponentClassType] = {}
] = {}
"""命名空间式组件名 -> 组件类""" """命名空间式组件名 -> 组件类"""
# 插件注册表 # 插件注册表
@@ -100,10 +118,8 @@ class ComponentRegistry:
return True return True
def register_component( def register_component(
self, self, self_component_info: ComponentInfo, component_class: ComponentClassType
component_info: ComponentInfo, ) -> bool: # noqa: C901 (保持原有结构, 以后可再拆)
component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
) -> bool:
"""注册组件 """注册组件
Args: Args:
@@ -113,9 +129,11 @@ class ComponentRegistry:
Returns: Returns:
bool: 是否注册成功 bool: 是否注册成功
""" """
component_info = self_component_info # 局部别名
component_name = component_info.name component_name = component_info.name
component_type = component_info.component_type component_type = component_info.component_type
plugin_name = getattr(component_info, "plugin_name", "unknown") plugin_name = getattr(component_info, "plugin_name", "unknown")
if "." in component_name: if "." in component_name:
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代") logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
return False return False
@@ -123,22 +141,19 @@ class ComponentRegistry:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False return False
namespaced_name = f"{component_type}.{component_name}" namespaced_name = f"{component_type.value}.{component_name}"
if namespaced_name in self._components: if namespaced_name in self._components:
existing_info = self._components[namespaced_name] existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, "plugin_name", "unknown") existing_plugin = getattr(existing_info, "plugin_name", "unknown")
logger.warning( logger.warning(
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册" f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
) )
return False return False
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称) self._components[namespaced_name] = component_info
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名 self._components_by_type[component_type][component_name] = component_info
self._components_classes[namespaced_name] = component_class self._components_classes[namespaced_name] = component_class
# 根据组件类型进行特定注册(使用原始名称)
match component_type: match component_type:
case ComponentType.ACTION: case ComponentType.ACTION:
assert isinstance(component_info, ActionInfo) assert isinstance(component_info, ActionInfo)
@@ -175,12 +190,11 @@ class ComponentRegistry:
if not ret: if not ret:
return False return False
logger.debug( logger.debug(
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' " f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' ({component_class.__name__}) [插件: {plugin_name}]"
f"({component_class.__name__}) [插件: {plugin_name}]"
) )
return True return True
def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool: def _register_action_component(self, action_info: ActionInfo, action_class: type[BaseAction]) -> bool:
"""注册Action组件到Action特定注册表""" """注册Action组件到Action特定注册表"""
if not (action_name := action_info.name): if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称") logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
@@ -188,19 +202,13 @@ class ComponentRegistry:
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction): if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action") logger.error(f"注册失败: {action_name} 不是有效的Action")
return False return False
_assign_plugin_attrs(action_class, action_info.plugin_name, self.get_plugin_config(action_info.plugin_name) or {})
action_class.plugin_name = action_info.plugin_name
# 设置插件配置
action_class.plugin_config = self.get_plugin_config(action_info.plugin_name) or {}
self._action_registry[action_name] = action_class self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
if action_info.enabled: if action_info.enabled:
self._default_actions[action_name] = action_info self._default_actions[action_name] = action_info
return True return True
def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool: def _register_command_component(self, command_info: CommandInfo, command_class: type[BaseCommand]) -> bool:
"""注册Command组件到Command特定注册表""" """注册Command组件到Command特定注册表"""
if not (command_name := command_info.name): if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称") logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
@@ -208,13 +216,10 @@ class ComponentRegistry:
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand): if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
logger.error(f"注册失败: {command_name} 不是有效的Command") logger.error(f"注册失败: {command_name} 不是有效的Command")
return False return False
_assign_plugin_attrs(
command_class.plugin_name = command_info.plugin_name command_class, command_info.plugin_name, self.get_plugin_config(command_info.plugin_name) or {}
# 设置插件配置 )
command_class.plugin_config = self.get_plugin_config(command_info.plugin_name) or {}
self._command_registry[command_name] = command_class self._command_registry[command_name] = command_class
# 如果启用了且有匹配模式
if command_info.enabled and command_info.command_pattern: if command_info.enabled and command_info.command_pattern:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL) pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
if pattern not in self._command_patterns: if pattern not in self._command_patterns:
@@ -223,11 +228,10 @@ class ComponentRegistry:
logger.warning( logger.warning(
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令" f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
) )
return True return True
def _register_plus_command_component( def _register_plus_command_component(
self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"] self, plus_command_info: PlusCommandInfo, plus_command_class: type[PlusCommand]
) -> bool: ) -> bool:
"""注册PlusCommand组件到特定注册表""" """注册PlusCommand组件到特定注册表"""
plus_command_name = plus_command_info.name plus_command_name = plus_command_info.name
@@ -241,33 +245,27 @@ class ComponentRegistry:
# 创建专门的PlusCommand注册表如果还没有 # 创建专门的PlusCommand注册表如果还没有
if not hasattr(self, "_plus_command_registry"): if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: dict[str, type["PlusCommand"]] = {} self._plus_command_registry: dict[str, type[PlusCommand]] = {}
_assign_plugin_attrs(
plus_command_class.plugin_name = plus_command_info.plugin_name plus_command_class,
# 设置插件配置 plus_command_info.plugin_name,
plus_command_class.plugin_config = self.get_plugin_config(plus_command_info.plugin_name) or {} self.get_plugin_config(plus_command_info.plugin_name) or {},
)
self._plus_command_registry[plus_command_name] = plus_command_class self._plus_command_registry[plus_command_name] = plus_command_class
logger.debug(f"已注册PlusCommand组件: {plus_command_name}") logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
return True return True
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool: def _register_tool_component(self, tool_info: ToolInfo, tool_class: type[BaseTool]) -> bool:
"""注册Tool组件到Tool特定注册表""" """注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name tool_name = tool_info.name
_assign_plugin_attrs(tool_class, tool_info.plugin_name, self.get_plugin_config(tool_info.plugin_name) or {})
tool_class.plugin_name = tool_info.plugin_name
# 设置插件配置
tool_class.plugin_config = self.get_plugin_config(tool_info.plugin_name) or {}
self._tool_registry[tool_name] = tool_class self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.enabled: if tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class self._llm_available_tools[tool_name] = tool_class
return True return True
def _register_event_handler_component( def _register_event_handler_component(
self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"] self, handler_info: EventHandlerInfo, handler_class: type[BaseEventHandler]
) -> bool: ) -> bool:
if not (handler_name := handler_info.name): if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
@@ -275,25 +273,19 @@ class ComponentRegistry:
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler): if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler") logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False return False
_assign_plugin_attrs(
handler_class.plugin_name = handler_info.plugin_name handler_class, handler_info.plugin_name, self.get_plugin_config(handler_info.plugin_name) or {}
# 设置插件配置 )
handler_class.plugin_config = self.get_plugin_config(handler_info.plugin_name) or {}
self._event_handler_registry[handler_name] = handler_class self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled: if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用") logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功 return True # 未启用,但是也是注册成功
handler_class.plugin_name = handler_info.plugin_name
# 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler( return event_manager.register_event_handler(
handler_class, self.get_plugin_config(handler_info.plugin_name) or {} handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
) )
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool: def _register_chatter_component(self, chatter_info: ChatterInfo, chatter_class: type[BaseChatter]) -> bool:
"""注册Chatter组件到Chatter特定注册表""" """注册Chatter组件到Chatter特定注册表"""
chatter_name = chatter_info.name chatter_name = chatter_info.name
@@ -303,18 +295,14 @@ class ComponentRegistry:
if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter): if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter):
logger.error(f"注册失败: {chatter_name} 不是有效的Chatter") logger.error(f"注册失败: {chatter_name} 不是有效的Chatter")
return False return False
_assign_plugin_attrs(
chatter_class.plugin_name = chatter_info.plugin_name chatter_class, chatter_info.plugin_name, self.get_plugin_config(chatter_info.plugin_name) or {}
# 设置插件配置 )
chatter_class.plugin_config = self.get_plugin_config(chatter_info.plugin_name) or {}
self._chatter_registry[chatter_name] = chatter_class self._chatter_registry[chatter_name] = chatter_class
if not chatter_info.enabled: if not chatter_info.enabled:
logger.warning(f"Chatter组件 {chatter_name} 未启用") logger.warning(f"Chatter组件 {chatter_name} 未启用")
return True # 未启用,但是也是注册成功 return True # 未启用,但是也是注册成功
self._enabled_chatter_registry[chatter_name] = chatter_class self._enabled_chatter_registry[chatter_name] = chatter_class
logger.debug(f"已注册Chatter组件: {chatter_name}") logger.debug(f"已注册Chatter组件: {chatter_name}")
return True return True
@@ -341,9 +329,13 @@ class ComponentRegistry:
if not hasattr(self, "_enabled_interest_calculator_registry"): if not hasattr(self, "_enabled_interest_calculator_registry"):
self._enabled_interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {} self._enabled_interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {}
interest_calculator_class.plugin_name = interest_calculator_info.plugin_name setattr(interest_calculator_class, "plugin_name", interest_calculator_info.plugin_name)
# 设置插件配置 # 设置插件配置
interest_calculator_class.plugin_config = self.get_plugin_config(interest_calculator_info.plugin_name) or {} setattr(
interest_calculator_class,
"plugin_config",
self.get_plugin_config(interest_calculator_info.plugin_name) or {},
)
self._interest_calculator_registry[calculator_name] = interest_calculator_class self._interest_calculator_registry[calculator_name] = interest_calculator_class
if not interest_calculator_info.enabled: if not interest_calculator_info.enabled:
@@ -356,7 +348,7 @@ class ComponentRegistry:
# === 组件移除相关 === # === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: "ComponentType", plugin_name: str) -> bool: async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
target_component_class = self.get_component_class(component_name, component_type) target_component_class = self.get_component_class(component_name, component_type)
if not target_component_class: if not target_component_class:
logger.warning(f"组件 {component_name} 未注册,无法移除") logger.warning(f"组件 {component_name} 未注册,无法移除")
@@ -398,8 +390,14 @@ class ComponentRegistry:
self._enabled_event_handlers.pop(component_name, None) self._enabled_event_handlers.pop(component_name, None)
try: try:
handler = event_manager.get_event_handler(component_name) handler = event_manager.get_event_handler(component_name)
for event in handler.subscribed_events: # 事件处理器可能未找到或未声明 subscribed_events,需判空
await event_manager.unsubscribe_handler_from_event(event, component_name) if handler and hasattr(handler, "subscribed_events"):
for event in getattr(handler, "subscribed_events"):
# 假设 unsubscribe_handler_from_event 是协程;若不是则移除 await
result = event_manager.unsubscribe_handler_from_event(event, component_name)
if hasattr(result, "__await__"):
await result # type: ignore[func-returns-value]
logger.debug(f"已移除EventHandler组件: {component_name}")
logger.debug(f"已移除EventHandler组件: {component_name}") logger.debug(f"已移除EventHandler组件: {component_name}")
except Exception as e: except Exception as e:
logger.warning(f"移除EventHandler事件订阅时出错: {e}") logger.warning(f"移除EventHandler事件订阅时出错: {e}")
@@ -415,7 +413,7 @@ class ComponentRegistry:
return False return False
# 移除通用注册信息 # 移除通用注册信息
namespaced_name = f"{component_type}.{component_name}" namespaced_name = f"{component_type.value}.{component_name}"
self._components.pop(namespaced_name, None) self._components.pop(namespaced_name, None)
self._components_by_type[component_type].pop(component_name, None) self._components_by_type[component_type].pop(component_name, None)
self._components_classes.pop(namespaced_name, None) self._components_classes.pop(namespaced_name, None)
@@ -477,9 +475,10 @@ class ComponentRegistry:
self._enabled_event_handlers[component_name] = target_component_class self._enabled_event_handlers[component_name] = target_component_class
from .event_manager import event_manager # 延迟导入防止循环导入问题 from .event_manager import event_manager # 延迟导入防止循环导入问题
event_manager.register_event_handler(component_name) # 重新注册事件处理器(启用)使用类而不是名称
cfg = self.get_plugin_config(target_component_info.plugin_name) or {}
namespaced_name = f"{component_type}.{component_name}" event_manager.register_event_handler(target_component_class, cfg) # type: ignore[arg-type]
namespaced_name = f"{component_type.value}.{component_name}"
self._components[namespaced_name].enabled = True self._components[namespaced_name].enabled = True
self._components_by_type[component_type][component_name].enabled = True self._components_by_type[component_type][component_name].enabled = True
logger.info(f"组件 {component_name} 已启用") logger.info(f"组件 {component_name} 已启用")
@@ -512,10 +511,16 @@ class ComponentRegistry:
from .event_manager import event_manager # 延迟导入防止循环导入问题 from .event_manager import event_manager # 延迟导入防止循环导入问题
handler = event_manager.get_event_handler(component_name) handler = event_manager.get_event_handler(component_name)
for event in handler.subscribed_events: if handler and hasattr(handler, "subscribed_events"):
await event_manager.unsubscribe_handler_from_event(event, component_name) for event in getattr(handler, "subscribed_events"):
result = event_manager.unsubscribe_handler_from_event(event, component_name)
if hasattr(result, "__await__"):
await result # type: ignore[func-returns-value]
self._components[component_name].enabled = False # 组件主注册表使用命名空间 key
namespaced_name = f"{component_type.value}.{component_name}"
if namespaced_name in self._components:
self._components[namespaced_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用") logger.info(f"组件 {component_name} 已禁用")
return True return True
@@ -528,8 +533,8 @@ class ComponentRegistry:
# === 组件查询方法 === # === 组件查询方法 ===
def get_component_info( def get_component_info(
self, component_name: str, component_type: Optional["ComponentType"] = None self, component_name: str, component_type: Optional[ComponentType] = None
) -> Optional["ComponentInfo"]: ) -> Optional[ComponentInfo]:
# sourcery skip: class-extract-method # sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析 """获取组件信息,支持自动命名空间解析
@@ -546,7 +551,7 @@ class ComponentRegistry:
# 2. 如果指定了组件类型,构造命名空间化的名称查找 # 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type: if component_type:
namespaced_name = f"{component_type}.{component_name}" namespaced_name = f"{component_type.value}.{component_name}"
return self._components.get(namespaced_name) return self._components.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找 # 3. 如果没有指定类型,尝试在所有命名空间中查找
@@ -573,8 +578,17 @@ class ComponentRegistry:
def get_component_class( def get_component_class(
self, self,
component_name: str, component_name: str,
component_type: Optional["ComponentType"] = None, component_type: Optional[ComponentType] = None,
) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None: ) -> (
type[BaseCommand]
| type[BaseAction]
| type[BaseEventHandler]
| type[BaseTool]
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
| None
):
"""获取组件类,支持自动命名空间解析 """获取组件类,支持自动命名空间解析
Args: Args:
@@ -591,7 +605,17 @@ class ComponentRegistry:
# 2. 如果指定了组件类型,构造命名空间化的名称查找 # 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type: if component_type:
namespaced_name = f"{component_type.value}.{component_name}" namespaced_name = f"{component_type.value}.{component_name}"
return self._components_classes.get(namespaced_name) # type: ignore[valid-type] return cast(
type[BaseCommand]
| type[BaseAction]
| type[BaseEventHandler]
| type[BaseTool]
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
| None,
self._components_classes.get(namespaced_name),
)
# 3. 如果没有指定类型,尝试在所有命名空间中查找 # 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = [] candidates = []
@@ -616,22 +640,22 @@ class ComponentRegistry:
# 4. 都没找到 # 4. 都没找到
return None return None
def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: def get_components_by_type(self, component_type: ComponentType) -> dict[str, ComponentInfo]:
"""获取指定类型的所有组件""" """获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy() return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: def get_enabled_components_by_type(self, component_type: ComponentType) -> dict[str, ComponentInfo]:
"""获取指定类型的所有启用组件""" """获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type) components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled} return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 === # === Action特定查询方法 ===
def get_action_registry(self) -> dict[str, type["BaseAction"]]: def get_action_registry(self) -> dict[str, type[BaseAction]]:
"""获取Action注册表""" """获取Action注册表"""
return self._action_registry.copy() return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional["ActionInfo"]: def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息""" """获取Action信息"""
info = self.get_component_info(action_name, ComponentType.ACTION) info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None return info if isinstance(info, ActionInfo) else None
@@ -642,11 +666,11 @@ class ComponentRegistry:
# === Command特定查询方法 === # === Command特定查询方法 ===
def get_command_registry(self) -> dict[str, type["BaseCommand"]]: def get_command_registry(self) -> dict[str, type[BaseCommand]]:
"""获取Command注册表""" """获取Command注册表"""
return self._command_registry.copy() return self._command_registry.copy()
def get_registered_command_info(self, command_name: str) -> Optional["CommandInfo"]: def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息""" """获取Command信息"""
info = self.get_component_info(command_name, ComponentType.COMMAND) info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None return info if isinstance(info, CommandInfo) else None
@@ -655,7 +679,7 @@ class ComponentRegistry:
"""获取Command模式注册表""" """获取Command模式注册表"""
return self._command_patterns.copy() return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None: def find_command_by_text(self, text: str) -> tuple[type[BaseCommand], dict, CommandInfo] | None:
# sourcery skip: use-named-expression, use-next # sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令 """根据文本查找匹配的命令
@@ -682,15 +706,15 @@ class ComponentRegistry:
return None return None
# === Tool 特定查询方法 === # === Tool 特定查询方法 ===
def get_tool_registry(self) -> dict[str, type["BaseTool"]]: def get_tool_registry(self) -> dict[str, type[BaseTool]]:
"""获取Tool注册表""" """获取Tool注册表"""
return self._tool_registry.copy() return self._tool_registry.copy()
def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]: def get_llm_available_tools(self) -> dict[str, type[BaseTool]]:
"""获取LLM可用的Tool列表""" """获取LLM可用的Tool列表"""
return self._llm_available_tools.copy() return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional["ToolInfo"]: def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息 """获取Tool信息
Args: Args:
@@ -703,13 +727,13 @@ class ComponentRegistry:
return info if isinstance(info, ToolInfo) else None return info if isinstance(info, ToolInfo) else None
# === PlusCommand 特定查询方法 === # === PlusCommand 特定查询方法 ===
def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]: def get_plus_command_registry(self) -> dict[str, type[PlusCommand]]:
"""获取PlusCommand注册表""" """获取PlusCommand注册表"""
if not hasattr(self, "_plus_command_registry"): if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: dict[str, type[PlusCommand]] = {} self._plus_command_registry: dict[str, type[PlusCommand]] = {}
return self._plus_command_registry.copy() return self._plus_command_registry.copy()
def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]: def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]:
"""获取PlusCommand信息 """获取PlusCommand信息
Args: Args:
@@ -723,44 +747,44 @@ class ComponentRegistry:
# === EventHandler 特定查询方法 === # === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]: def get_event_handler_registry(self) -> dict[str, type[BaseEventHandler]]:
"""获取事件处理器注册表""" """获取事件处理器注册表"""
return self._event_handler_registry.copy() return self._event_handler_registry.copy()
def get_registered_event_handler_info(self, handler_name: str) -> Optional["EventHandlerInfo"]: def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
"""获取事件处理器信息""" """获取事件处理器信息"""
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None return info if isinstance(info, EventHandlerInfo) else None
def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]: def get_enabled_event_handlers(self) -> dict[str, type[BaseEventHandler]]:
"""获取启用的事件处理器""" """获取启用的事件处理器"""
return self._enabled_event_handlers.copy() return self._enabled_event_handlers.copy()
# === Chatter 特定查询方法 === # === Chatter 特定查询方法 ===
def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]: def get_chatter_registry(self) -> dict[str, type[BaseChatter]]:
"""获取Chatter注册表""" """获取Chatter注册表"""
if not hasattr(self, "_chatter_registry"): if not hasattr(self, "_chatter_registry"):
self._chatter_registry: dict[str, type[BaseChatter]] = {} self._chatter_registry: dict[str, type[BaseChatter]] = {}
return self._chatter_registry.copy() return self._chatter_registry.copy()
def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]: def get_enabled_chatter_registry(self) -> dict[str, type[BaseChatter]]:
"""获取启用的Chatter注册表""" """获取启用的Chatter注册表"""
if not hasattr(self, "_enabled_chatter_registry"): if not hasattr(self, "_enabled_chatter_registry"):
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {} self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
return self._enabled_chatter_registry.copy() return self._enabled_chatter_registry.copy()
def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]: def get_registered_chatter_info(self, chatter_name: str) -> Optional[ChatterInfo]:
"""获取Chatter信息""" """获取Chatter信息"""
info = self.get_component_info(chatter_name, ComponentType.CHATTER) info = self.get_component_info(chatter_name, ComponentType.CHATTER)
return info if isinstance(info, ChatterInfo) else None return info if isinstance(info, ChatterInfo) else None
# === 插件查询方法 === # === 插件查询方法 ===
def get_plugin_info(self, plugin_name: str) -> Optional["PluginInfo"]: def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息""" """获取插件信息"""
return self._plugins.get(plugin_name) return self._plugins.get(plugin_name)
def get_all_plugins(self) -> dict[str, "PluginInfo"]: def get_all_plugins(self) -> dict[str, PluginInfo]:
"""获取所有插件""" """获取所有插件"""
return self._plugins.copy() return self._plugins.copy()

View File

@@ -22,8 +22,6 @@ class PermissionManager(IPermissionManager):
"""权限管理器实现类""" """权限管理器实现类"""
def __init__(self): def __init__(self):
self.engine = None
self.SessionLocal = None
self._master_users: set[tuple[str, str]] = set() self._master_users: set[tuple[str, str]] = set()
self._load_master_users() self._load_master_users()
@@ -52,7 +50,7 @@ class PermissionManager(IPermissionManager):
self._load_master_users() self._load_master_users()
logger.info("Master用户配置已重新加载") logger.info("Master用户配置已重新加载")
def is_master(self, user: UserInfo) -> bool: async def is_master(self, user: UserInfo) -> bool:
""" """
检查用户是否为Master用户 检查用户是否为Master用户
@@ -81,7 +79,7 @@ class PermissionManager(IPermissionManager):
""" """
try: try:
# Master用户拥有所有权限 # Master用户拥有所有权限
if self.is_master(user): if await self.is_master(user):
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}") logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
return True return True
@@ -288,10 +286,10 @@ class PermissionManager(IPermissionManager):
""" """
try: try:
# Master用户拥有所有权限 # Master用户拥有所有权限
if self.is_master(user): if await self.is_master(user):
async with self.SessionLocal() as session: async with self.SessionLocal() as session:
result = await session.execute(select(PermissionNodes.node_name)) result = await session.execute(select(PermissionNodes.node_name))
all_nodes = result.scalars().all() all_nodes = list(result.scalars().all())
return all_nodes return all_nodes
permissions = [] permissions = []

View File

@@ -95,6 +95,7 @@ class PluginManager:
if not plugin_class: if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在") logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1 return False, 1
init_module = None # 预先定义,避免后续条件加载导致未绑定
try: try:
# 使用记录的插件目录路径 # 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name) plugin_dir = self.plugin_paths.get(plugin_name)
@@ -314,6 +315,7 @@ class PluginManager:
module_name = ".".join(plugin_path.parent.parts) module_name = ".".join(plugin_path.parent.parts)
try: try:
init_module = None # 确保下方引用存在
# 首先加载 __init__.py 来获取元数据 # 首先加载 __init__.py 来获取元数据
init_file = os.path.join(plugin_dir, "__init__.py") init_file = os.path.join(plugin_dir, "__init__.py")
if os.path.exists(init_file): if os.path.exists(init_file):
@@ -524,13 +526,10 @@ class PluginManager:
fut.result(timeout=5) fut.result(timeout=5)
else: else:
asyncio.run(component_registry.unregister_plugin(plugin_name)) asyncio.run(component_registry.unregister_plugin(plugin_name))
except Exception: except Exception as e: # 捕获并记录卸载阶段协程调用错误
# 最后兜底:直接同步调用(如果 unregister_plugin 为非协程)或忽略错误 logger.debug(
try: f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}", exc_info=True
# 如果 unregister_plugin 是普通函数 )
component_registry.unregister_plugin(plugin_name)
except Exception as e:
logger.debug(f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}")
# 从已加载插件中移除 # 从已加载插件中移除
del self.loaded_plugins[plugin_name] del self.loaded_plugins[plugin_name]
@@ -550,65 +549,6 @@ class PluginManager:
logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True) logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True)
return False return False
def reload_plugin(self, plugin_name: str) -> bool:
"""重载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""
try:
logger.info(f"🔄 开始重载插件: {plugin_name}")
# 卸载插件
if plugin_name in self.loaded_plugins:
if not self.unload_plugin(plugin_name):
logger.warning(f"⚠️ 插件卸载失败,继续重载: {plugin_name}")
# 重新扫描插件目录
self.rescan_plugin_directory()
# 重新加载插件实例
if plugin_name in self.plugin_classes:
success, _ = self.load_registered_plugin_classes(plugin_name)
if success:
logger.info(f"✅ 插件重载成功: {plugin_name}")
return True
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件类未找到")
return False
except Exception as e:
logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True)
return False
def force_reload_plugin(self, plugin_name: str) -> bool:
"""强制重载插件(使用简化的方法)
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""
return self.reload_plugin(plugin_name)
@staticmethod
def clear_all_plugin_caches():
"""清理所有插件相关的模块缓存(简化版)"""
try:
logger.info("🧹 清理模块缓存...")
# 清理importlib缓存
importlib.invalidate_caches()
logger.info("🧹 模块缓存清理完成")
except Exception as e:
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
# 全局插件管理器实例 # 全局插件管理器实例
plugin_manager = PluginManager() plugin_manager = PluginManager()

View File

@@ -177,8 +177,8 @@ class ToolExecutor:
# 执行每个工具调用 # 执行每个工具调用
for tool_call in tool_calls: for tool_call in tool_calls:
tool_name = getattr(tool_call, "func_name", "unknown_tool")
try: try:
tool_name = tool_call.func_name
logger.debug(f"{self.log_prefix}执行工具: {tool_name}") logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具 # 执行工具