Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
4
bot.py
4
bot.py
@@ -560,9 +560,9 @@ class MaiBotMain:
|
||||
logger.info("正在初始化数据库表结构...")
|
||||
try:
|
||||
start_time = time.time()
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
||||
from src.common.database.sqlalchemy_models import initialize_database
|
||||
|
||||
await init_db()
|
||||
await initialize_database()
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}秒")
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
"$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json",
|
||||
"include": [
|
||||
"src",
|
||||
"bot.py",
|
||||
"__main__.py"
|
||||
"bot.py"
|
||||
],
|
||||
"exclude": [
|
||||
"**/__pycache__",
|
||||
@@ -11,7 +10,9 @@
|
||||
"logs",
|
||||
"tests",
|
||||
"target",
|
||||
"*.egg-info"
|
||||
"*.egg-info",
|
||||
"src/plugins/built_in/*",
|
||||
"__main__.py"
|
||||
],
|
||||
"typeCheckingMode": "standard",
|
||||
"reportMissingImports": false,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Any, Optional, TypedDict, Literal, Union, Callable, TypeVar, cast
|
||||
from typing import Any, Optional, TypeVar, cast
|
||||
|
||||
from sqlalchemy import select, delete
|
||||
|
||||
@@ -47,29 +47,26 @@ class AntiInjectionStatistics:
|
||||
"""当前会话开始时间"""
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_stats() -> Optional[AntiInjectionStats]: # type: ignore[name-defined]
|
||||
async def get_or_create_stats() -> AntiInjectionStats:
|
||||
"""获取或创建统计记录
|
||||
|
||||
Returns:
|
||||
AntiInjectionStats | None: 成功返回模型实例,否则 None
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 获取最新的统计记录,如果没有则创建
|
||||
stats = (
|
||||
stats = (
|
||||
(await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if not stats:
|
||||
stats = AntiInjectionStats()
|
||||
session.add(stats)
|
||||
await session.commit()
|
||||
await session.refresh(stats)
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计记录失败: {e}")
|
||||
return None
|
||||
if not stats:
|
||||
stats = AntiInjectionStats()
|
||||
session.add(stats)
|
||||
await session.commit()
|
||||
await session.refresh(stats)
|
||||
return stats
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def update_stats(**kwargs: Any) -> None:
|
||||
@@ -97,7 +94,7 @@ class AntiInjectionStatistics:
|
||||
if key == "processing_time_delta":
|
||||
# 处理时间累加 - 确保不为 None
|
||||
delta = float(value)
|
||||
stats.processing_time_total = _add_optional(stats.processing_time_total, delta) # type: ignore[attr-defined]
|
||||
stats.processing_time_total = _add_optional(stats.processing_time_total, delta)
|
||||
continue
|
||||
elif key == "last_processing_time":
|
||||
# 直接设置最后处理时间
|
||||
@@ -146,7 +143,7 @@ class AntiInjectionStatistics:
|
||||
|
||||
|
||||
# 计算派生统计信息 - 处理 None 值
|
||||
total_messages = stats.total_messages or 0 # type: ignore[attr-defined]
|
||||
total_messages = stats.total_messages or 0
|
||||
detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined]
|
||||
processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -31,6 +31,10 @@ logger = get_logger("sqlalchemy_models")
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
# 全局异步引擎与会话工厂占位(延迟初始化)
|
||||
_engine: AsyncEngine | None = None
|
||||
_SessionLocal: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
async def enable_sqlite_wal_mode(engine):
|
||||
"""为 SQLite 启用 WAL 模式以提高并发性能"""
|
||||
@@ -649,23 +653,13 @@ class MonthlyPlan(Base):
|
||||
last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
|
||||
Index("idx_monthlyplan_last_used_date", "last_used_date"),
|
||||
Index("idx_monthlyplan_usage_count", "usage_count"),
|
||||
# 保留旧索引以兼容
|
||||
Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"),
|
||||
)
|
||||
|
||||
|
||||
# 数据库引擎和会话管理
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def get_database_url():
|
||||
"""获取数据库连接URL"""
|
||||
from src.config.config import global_config
|
||||
@@ -709,65 +703,89 @@ def get_database_url():
|
||||
return f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
|
||||
async def initialize_database():
|
||||
"""初始化异步数据库引擎和会话"""
|
||||
global _engine, _SessionLocal
|
||||
_initializing: bool = False # 防止递归初始化
|
||||
|
||||
if _engine is not None:
|
||||
async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[AsyncSession]]:
|
||||
"""初始化异步数据库引擎和会话
|
||||
|
||||
Returns:
|
||||
tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: 创建好的异步引擎与会话工厂。
|
||||
|
||||
说明:
|
||||
显式的返回类型标注有助于 Pyright/Pylance 正确推断调用处的对象,
|
||||
避免后续对返回值再次 `await` 时出现 *"tuple[...] 并非 awaitable"* 的误用。
|
||||
"""
|
||||
global _engine, _SessionLocal, _initializing
|
||||
|
||||
# 已经初始化直接返回
|
||||
if _engine is not None and _SessionLocal is not None:
|
||||
return _engine, _SessionLocal
|
||||
|
||||
database_url = get_database_url()
|
||||
from src.config.config import global_config
|
||||
# 正在初始化的并发调用等待主初始化完成,避免递归
|
||||
if _initializing:
|
||||
import asyncio
|
||||
for _ in range(1000): # 最多等待约10秒
|
||||
await asyncio.sleep(0.01)
|
||||
if _engine is not None and _SessionLocal is not None:
|
||||
return _engine, _SessionLocal
|
||||
raise RuntimeError("等待数据库初始化完成超时 (reentrancy guard)")
|
||||
|
||||
config = global_config.database
|
||||
_initializing = True
|
||||
try:
|
||||
database_url = get_database_url()
|
||||
from src.config.config import global_config
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs: dict[str, Any] = {
|
||||
"echo": False, # 生产环境关闭SQL日志
|
||||
"future": True,
|
||||
}
|
||||
config = global_config.database
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# MySQL连接池配置 - 异步引擎使用默认连接池
|
||||
engine_kwargs.update(
|
||||
{
|
||||
"pool_size": config.connection_pool_size,
|
||||
"max_overflow": config.connection_pool_size * 2,
|
||||
"pool_timeout": config.connection_timeout,
|
||||
"pool_recycle": 3600, # 1小时回收连接
|
||||
"pool_pre_ping": True, # 连接前ping检查
|
||||
"connect_args": {
|
||||
"autocommit": config.mysql_autocommit,
|
||||
"charset": config.mysql_charset,
|
||||
"connect_timeout": config.connection_timeout,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# SQLite配置 - aiosqlite不支持连接池参数
|
||||
engine_kwargs.update(
|
||||
{
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 60, # 增加超时时间
|
||||
},
|
||||
}
|
||||
)
|
||||
# 配置引擎参数
|
||||
engine_kwargs: dict[str, Any] = {
|
||||
"echo": False, # 生产环境关闭SQL日志
|
||||
"future": True,
|
||||
}
|
||||
|
||||
_engine = create_async_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
if config.database_type == "mysql":
|
||||
engine_kwargs.update(
|
||||
{
|
||||
"pool_size": config.connection_pool_size,
|
||||
"max_overflow": config.connection_pool_size * 2,
|
||||
"pool_timeout": config.connection_timeout,
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
"connect_args": {
|
||||
"autocommit": config.mysql_autocommit,
|
||||
"charset": config.mysql_charset,
|
||||
"connect_timeout": config.connection_timeout,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
engine_kwargs.update(
|
||||
{
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 60,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 调用新的迁移函数,它会处理表的创建和列的添加
|
||||
from src.common.database.db_migration import check_and_migrate_database
|
||||
_engine = create_async_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
await check_and_migrate_database()
|
||||
# 迁移
|
||||
try:
|
||||
from src.common.database.db_migration import check_and_migrate_database
|
||||
await check_and_migrate_database(existing_engine=_engine)
|
||||
except TypeError:
|
||||
from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
|
||||
await _legacy_migrate()
|
||||
|
||||
# 如果是 SQLite,启用 WAL 模式以提高并发性能
|
||||
if config.database_type == "sqlite":
|
||||
await enable_sqlite_wal_mode(_engine)
|
||||
if config.database_type == "sqlite":
|
||||
await enable_sqlite_wal_mode(_engine)
|
||||
|
||||
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
|
||||
return _engine, _SessionLocal
|
||||
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
|
||||
return _engine, _SessionLocal
|
||||
finally:
|
||||
_initializing = False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
@@ -26,6 +28,23 @@ from src.plugin_system.base.plus_command import PlusCommand
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
# 统一的组件类类型别名
|
||||
ComponentClassType = (
|
||||
type[BaseCommand]
|
||||
| type[BaseAction]
|
||||
| type[BaseTool]
|
||||
| type[BaseEventHandler]
|
||||
| type[PlusCommand]
|
||||
| type[BaseChatter]
|
||||
| type[BaseInterestCalculator]
|
||||
)
|
||||
|
||||
|
||||
def _assign_plugin_attrs(cls: Any, plugin_name: str, plugin_config: dict) -> None:
|
||||
"""为组件类动态赋予插件相关属性(避免在各注册函数中重复代码)。"""
|
||||
setattr(cls, "plugin_name", plugin_name)
|
||||
setattr(cls, "plugin_config", plugin_config)
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""统一的组件注册中心
|
||||
@@ -41,9 +60,8 @@ class ComponentRegistry:
|
||||
types: {} for types in ComponentType
|
||||
}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: dict[
|
||||
str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"]
|
||||
] = {}
|
||||
# 组件类注册表(命名空间式组件名 -> 组件类)
|
||||
self._components_classes: dict[str, ComponentClassType] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
@@ -100,10 +118,8 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def register_component(
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
|
||||
) -> bool:
|
||||
self, self_component_info: ComponentInfo, component_class: ComponentClassType
|
||||
) -> bool: # noqa: C901 (保持原有结构, 以后可再拆)
|
||||
"""注册组件
|
||||
|
||||
Args:
|
||||
@@ -113,9 +129,11 @@ class ComponentRegistry:
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
component_info = self_component_info # 局部别名
|
||||
component_name = component_info.name
|
||||
component_type = component_info.component_type
|
||||
plugin_name = getattr(component_info, "plugin_name", "unknown")
|
||||
|
||||
if "." in component_name:
|
||||
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
@@ -123,22 +141,19 @@ class ComponentRegistry:
|
||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
if namespaced_name in self._components:
|
||||
existing_info = self._components[namespaced_name]
|
||||
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
|
||||
|
||||
logger.warning(
|
||||
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
|
||||
)
|
||||
return False
|
||||
|
||||
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
|
||||
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
|
||||
self._components[namespaced_name] = component_info
|
||||
self._components_by_type[component_type][component_name] = component_info
|
||||
self._components_classes[namespaced_name] = component_class
|
||||
|
||||
# 根据组件类型进行特定注册(使用原始名称)
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
assert isinstance(component_info, ActionInfo)
|
||||
@@ -175,12 +190,11 @@ class ComponentRegistry:
|
||||
if not ret:
|
||||
return False
|
||||
logger.debug(
|
||||
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
|
||||
f"({component_class.__name__}) [插件: {plugin_name}]"
|
||||
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' ({component_class.__name__}) [插件: {plugin_name}]"
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool:
|
||||
def _register_action_component(self, action_info: ActionInfo, action_class: type[BaseAction]) -> bool:
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
if not (action_name := action_info.name):
|
||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||
@@ -188,19 +202,13 @@ class ComponentRegistry:
|
||||
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
|
||||
logger.error(f"注册失败: {action_name} 不是有效的Action")
|
||||
return False
|
||||
|
||||
action_class.plugin_name = action_info.plugin_name
|
||||
# 设置插件配置
|
||||
action_class.plugin_config = self.get_plugin_config(action_info.plugin_name) or {}
|
||||
_assign_plugin_attrs(action_class, action_info.plugin_name, self.get_plugin_config(action_info.plugin_name) or {})
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
# 如果启用,添加到默认动作集
|
||||
if action_info.enabled:
|
||||
self._default_actions[action_name] = action_info
|
||||
|
||||
return True
|
||||
|
||||
def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool:
|
||||
def _register_command_component(self, command_info: CommandInfo, command_class: type[BaseCommand]) -> bool:
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
if not (command_name := command_info.name):
|
||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||
@@ -208,13 +216,10 @@ class ComponentRegistry:
|
||||
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
|
||||
logger.error(f"注册失败: {command_name} 不是有效的Command")
|
||||
return False
|
||||
|
||||
command_class.plugin_name = command_info.plugin_name
|
||||
# 设置插件配置
|
||||
command_class.plugin_config = self.get_plugin_config(command_info.plugin_name) or {}
|
||||
_assign_plugin_attrs(
|
||||
command_class, command_info.plugin_name, self.get_plugin_config(command_info.plugin_name) or {}
|
||||
)
|
||||
self._command_registry[command_name] = command_class
|
||||
|
||||
# 如果启用了且有匹配模式
|
||||
if command_info.enabled and command_info.command_pattern:
|
||||
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
if pattern not in self._command_patterns:
|
||||
@@ -223,11 +228,10 @@ class ComponentRegistry:
|
||||
logger.warning(
|
||||
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _register_plus_command_component(
|
||||
self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"]
|
||||
self, plus_command_info: PlusCommandInfo, plus_command_class: type[PlusCommand]
|
||||
) -> bool:
|
||||
"""注册PlusCommand组件到特定注册表"""
|
||||
plus_command_name = plus_command_info.name
|
||||
@@ -241,33 +245,27 @@ class ComponentRegistry:
|
||||
|
||||
# 创建专门的PlusCommand注册表(如果还没有)
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: dict[str, type["PlusCommand"]] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
# 设置插件配置
|
||||
plus_command_class.plugin_config = self.get_plugin_config(plus_command_info.plugin_name) or {}
|
||||
self._plus_command_registry: dict[str, type[PlusCommand]] = {}
|
||||
_assign_plugin_attrs(
|
||||
plus_command_class,
|
||||
plus_command_info.plugin_name,
|
||||
self.get_plugin_config(plus_command_info.plugin_name) or {},
|
||||
)
|
||||
self._plus_command_registry[plus_command_name] = plus_command_class
|
||||
|
||||
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool:
|
||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: type[BaseTool]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
tool_class.plugin_name = tool_info.plugin_name
|
||||
# 设置插件配置
|
||||
tool_class.plugin_config = self.get_plugin_config(tool_info.plugin_name) or {}
|
||||
_assign_plugin_attrs(tool_class, tool_info.plugin_name, self.get_plugin_config(tool_info.plugin_name) or {})
|
||||
self._tool_registry[tool_name] = tool_class
|
||||
|
||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||
if tool_info.enabled:
|
||||
self._llm_available_tools[tool_name] = tool_class
|
||||
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"]
|
||||
self, handler_info: EventHandlerInfo, handler_class: type[BaseEventHandler]
|
||||
) -> bool:
|
||||
if not (handler_name := handler_info.name):
|
||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||
@@ -275,25 +273,19 @@ class ComponentRegistry:
|
||||
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
|
||||
return False
|
||||
|
||||
handler_class.plugin_name = handler_info.plugin_name
|
||||
# 设置插件配置
|
||||
handler_class.plugin_config = self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
_assign_plugin_attrs(
|
||||
handler_class, handler_info.plugin_name, self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
)
|
||||
self._event_handler_registry[handler_name] = handler_class
|
||||
|
||||
if not handler_info.enabled:
|
||||
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
|
||||
handler_class.plugin_name = handler_info.plugin_name
|
||||
# 使用EventManager进行事件处理器注册
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
return event_manager.register_event_handler(
|
||||
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
)
|
||||
|
||||
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool:
|
||||
def _register_chatter_component(self, chatter_info: ChatterInfo, chatter_class: type[BaseChatter]) -> bool:
|
||||
"""注册Chatter组件到Chatter特定注册表"""
|
||||
chatter_name = chatter_info.name
|
||||
|
||||
@@ -303,18 +295,14 @@ class ComponentRegistry:
|
||||
if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter):
|
||||
logger.error(f"注册失败: {chatter_name} 不是有效的Chatter")
|
||||
return False
|
||||
|
||||
chatter_class.plugin_name = chatter_info.plugin_name
|
||||
# 设置插件配置
|
||||
chatter_class.plugin_config = self.get_plugin_config(chatter_info.plugin_name) or {}
|
||||
|
||||
_assign_plugin_attrs(
|
||||
chatter_class, chatter_info.plugin_name, self.get_plugin_config(chatter_info.plugin_name) or {}
|
||||
)
|
||||
self._chatter_registry[chatter_name] = chatter_class
|
||||
|
||||
if not chatter_info.enabled:
|
||||
logger.warning(f"Chatter组件 {chatter_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
self._enabled_chatter_registry[chatter_name] = chatter_class
|
||||
|
||||
logger.debug(f"已注册Chatter组件: {chatter_name}")
|
||||
return True
|
||||
|
||||
@@ -341,9 +329,13 @@ class ComponentRegistry:
|
||||
if not hasattr(self, "_enabled_interest_calculator_registry"):
|
||||
self._enabled_interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {}
|
||||
|
||||
interest_calculator_class.plugin_name = interest_calculator_info.plugin_name
|
||||
setattr(interest_calculator_class, "plugin_name", interest_calculator_info.plugin_name)
|
||||
# 设置插件配置
|
||||
interest_calculator_class.plugin_config = self.get_plugin_config(interest_calculator_info.plugin_name) or {}
|
||||
setattr(
|
||||
interest_calculator_class,
|
||||
"plugin_config",
|
||||
self.get_plugin_config(interest_calculator_info.plugin_name) or {},
|
||||
)
|
||||
self._interest_calculator_registry[calculator_name] = interest_calculator_class
|
||||
|
||||
if not interest_calculator_info.enabled:
|
||||
@@ -356,7 +348,7 @@ class ComponentRegistry:
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: "ComponentType", plugin_name: str) -> bool:
|
||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
if not target_component_class:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||
@@ -398,8 +390,14 @@ class ComponentRegistry:
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
try:
|
||||
handler = event_manager.get_event_handler(component_name)
|
||||
for event in handler.subscribed_events:
|
||||
await event_manager.unsubscribe_handler_from_event(event, component_name)
|
||||
# 事件处理器可能未找到或未声明 subscribed_events,需判空
|
||||
if handler and hasattr(handler, "subscribed_events"):
|
||||
for event in getattr(handler, "subscribed_events"):
|
||||
# 假设 unsubscribe_handler_from_event 是协程;若不是则移除 await
|
||||
result = event_manager.unsubscribe_handler_from_event(event, component_name)
|
||||
if hasattr(result, "__await__"):
|
||||
await result # type: ignore[func-returns-value]
|
||||
logger.debug(f"已移除EventHandler组件: {component_name}")
|
||||
logger.debug(f"已移除EventHandler组件: {component_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
|
||||
@@ -415,7 +413,7 @@ class ComponentRegistry:
|
||||
return False
|
||||
|
||||
# 移除通用注册信息
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
self._components.pop(namespaced_name, None)
|
||||
self._components_by_type[component_type].pop(component_name, None)
|
||||
self._components_classes.pop(namespaced_name, None)
|
||||
@@ -477,9 +475,10 @@ class ComponentRegistry:
|
||||
self._enabled_event_handlers[component_name] = target_component_class
|
||||
from .event_manager import event_manager # 延迟导入防止循环导入问题
|
||||
|
||||
event_manager.register_event_handler(component_name)
|
||||
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
# 重新注册事件处理器(启用)使用类而不是名称
|
||||
cfg = self.get_plugin_config(target_component_info.plugin_name) or {}
|
||||
event_manager.register_event_handler(target_component_class, cfg) # type: ignore[arg-type]
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
self._components[namespaced_name].enabled = True
|
||||
self._components_by_type[component_type][component_name].enabled = True
|
||||
logger.info(f"组件 {component_name} 已启用")
|
||||
@@ -512,10 +511,16 @@ class ComponentRegistry:
|
||||
from .event_manager import event_manager # 延迟导入防止循环导入问题
|
||||
|
||||
handler = event_manager.get_event_handler(component_name)
|
||||
for event in handler.subscribed_events:
|
||||
await event_manager.unsubscribe_handler_from_event(event, component_name)
|
||||
if handler and hasattr(handler, "subscribed_events"):
|
||||
for event in getattr(handler, "subscribed_events"):
|
||||
result = event_manager.unsubscribe_handler_from_event(event, component_name)
|
||||
if hasattr(result, "__await__"):
|
||||
await result # type: ignore[func-returns-value]
|
||||
|
||||
self._components[component_name].enabled = False
|
||||
# 组件主注册表使用命名空间 key
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
if namespaced_name in self._components:
|
||||
self._components[namespaced_name].enabled = False
|
||||
self._components_by_type[component_type][component_name].enabled = False
|
||||
logger.info(f"组件 {component_name} 已禁用")
|
||||
return True
|
||||
@@ -528,8 +533,8 @@ class ComponentRegistry:
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
self, component_name: str, component_type: Optional["ComponentType"] = None
|
||||
) -> Optional["ComponentInfo"]:
|
||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||
) -> Optional[ComponentInfo]:
|
||||
# sourcery skip: class-extract-method
|
||||
"""获取组件信息,支持自动命名空间解析
|
||||
|
||||
@@ -546,7 +551,7 @@ class ComponentRegistry:
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
return self._components.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
@@ -573,8 +578,17 @@ class ComponentRegistry:
|
||||
def get_component_class(
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional["ComponentType"] = None,
|
||||
) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None:
|
||||
component_type: Optional[ComponentType] = None,
|
||||
) -> (
|
||||
type[BaseCommand]
|
||||
| type[BaseAction]
|
||||
| type[BaseEventHandler]
|
||||
| type[BaseTool]
|
||||
| type[PlusCommand]
|
||||
| type[BaseChatter]
|
||||
| type[BaseInterestCalculator]
|
||||
| None
|
||||
):
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
@@ -591,7 +605,17 @@ class ComponentRegistry:
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
return self._components_classes.get(namespaced_name) # type: ignore[valid-type]
|
||||
return cast(
|
||||
type[BaseCommand]
|
||||
| type[BaseAction]
|
||||
| type[BaseEventHandler]
|
||||
| type[BaseTool]
|
||||
| type[PlusCommand]
|
||||
| type[BaseChatter]
|
||||
| type[BaseInterestCalculator]
|
||||
| None,
|
||||
self._components_classes.get(namespaced_name),
|
||||
)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
@@ -616,22 +640,22 @@ class ComponentRegistry:
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
|
||||
def get_components_by_type(self, component_type: ComponentType) -> dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有组件"""
|
||||
return self._components_by_type.get(component_type, {}).copy()
|
||||
|
||||
def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
|
||||
def get_enabled_components_by_type(self, component_type: ComponentType) -> dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有启用组件"""
|
||||
components = self.get_components_by_type(component_type)
|
||||
return {name: info for name, info in components.items() if info.enabled}
|
||||
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
def get_action_registry(self) -> dict[str, type["BaseAction"]]:
|
||||
def get_action_registry(self) -> dict[str, type[BaseAction]]:
|
||||
"""获取Action注册表"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
def get_registered_action_info(self, action_name: str) -> Optional["ActionInfo"]:
|
||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
"""获取Action信息"""
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
@@ -642,11 +666,11 @@ class ComponentRegistry:
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
def get_command_registry(self) -> dict[str, type["BaseCommand"]]:
|
||||
def get_command_registry(self) -> dict[str, type[BaseCommand]]:
|
||||
"""获取Command注册表"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
def get_registered_command_info(self, command_name: str) -> Optional["CommandInfo"]:
|
||||
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
||||
"""获取Command信息"""
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
@@ -655,7 +679,7 @@ class ComponentRegistry:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None:
|
||||
def find_command_by_text(self, text: str) -> tuple[type[BaseCommand], dict, CommandInfo] | None:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
@@ -682,15 +706,15 @@ class ComponentRegistry:
|
||||
return None
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
def get_tool_registry(self) -> dict[str, type["BaseTool"]]:
|
||||
def get_tool_registry(self) -> dict[str, type[BaseTool]]:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]:
|
||||
def get_llm_available_tools(self) -> dict[str, type[BaseTool]]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional["ToolInfo"]:
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
|
||||
"""获取Tool信息
|
||||
|
||||
Args:
|
||||
@@ -703,13 +727,13 @@ class ComponentRegistry:
|
||||
return info if isinstance(info, ToolInfo) else None
|
||||
|
||||
# === PlusCommand 特定查询方法 ===
|
||||
def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]:
|
||||
def get_plus_command_registry(self) -> dict[str, type[PlusCommand]]:
|
||||
"""获取PlusCommand注册表"""
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: dict[str, type[PlusCommand]] = {}
|
||||
return self._plus_command_registry.copy()
|
||||
|
||||
def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]:
|
||||
def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]:
|
||||
"""获取PlusCommand信息
|
||||
|
||||
Args:
|
||||
@@ -723,44 +747,44 @@ class ComponentRegistry:
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]:
|
||||
def get_event_handler_registry(self) -> dict[str, type[BaseEventHandler]]:
|
||||
"""获取事件处理器注册表"""
|
||||
return self._event_handler_registry.copy()
|
||||
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional["EventHandlerInfo"]:
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
|
||||
"""获取事件处理器信息"""
|
||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||
return info if isinstance(info, EventHandlerInfo) else None
|
||||
|
||||
def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]:
|
||||
def get_enabled_event_handlers(self) -> dict[str, type[BaseEventHandler]]:
|
||||
"""获取启用的事件处理器"""
|
||||
return self._enabled_event_handlers.copy()
|
||||
|
||||
# === Chatter 特定查询方法 ===
|
||||
def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
|
||||
def get_chatter_registry(self) -> dict[str, type[BaseChatter]]:
|
||||
"""获取Chatter注册表"""
|
||||
if not hasattr(self, "_chatter_registry"):
|
||||
self._chatter_registry: dict[str, type[BaseChatter]] = {}
|
||||
return self._chatter_registry.copy()
|
||||
|
||||
def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
|
||||
def get_enabled_chatter_registry(self) -> dict[str, type[BaseChatter]]:
|
||||
"""获取启用的Chatter注册表"""
|
||||
if not hasattr(self, "_enabled_chatter_registry"):
|
||||
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
|
||||
return self._enabled_chatter_registry.copy()
|
||||
|
||||
def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]:
|
||||
def get_registered_chatter_info(self, chatter_name: str) -> Optional[ChatterInfo]:
|
||||
"""获取Chatter信息"""
|
||||
info = self.get_component_info(chatter_name, ComponentType.CHATTER)
|
||||
return info if isinstance(info, ChatterInfo) else None
|
||||
|
||||
# === 插件查询方法 ===
|
||||
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional["PluginInfo"]:
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
|
||||
"""获取插件信息"""
|
||||
return self._plugins.get(plugin_name)
|
||||
|
||||
def get_all_plugins(self) -> dict[str, "PluginInfo"]:
|
||||
def get_all_plugins(self) -> dict[str, PluginInfo]:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
|
||||
@@ -22,8 +22,6 @@ class PermissionManager(IPermissionManager):
|
||||
"""权限管理器实现类"""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = None
|
||||
self.SessionLocal = None
|
||||
self._master_users: set[tuple[str, str]] = set()
|
||||
self._load_master_users()
|
||||
|
||||
@@ -52,7 +50,7 @@ class PermissionManager(IPermissionManager):
|
||||
self._load_master_users()
|
||||
logger.info("Master用户配置已重新加载")
|
||||
|
||||
def is_master(self, user: UserInfo) -> bool:
|
||||
async def is_master(self, user: UserInfo) -> bool:
|
||||
"""
|
||||
检查用户是否为Master用户
|
||||
|
||||
@@ -81,7 +79,7 @@ class PermissionManager(IPermissionManager):
|
||||
"""
|
||||
try:
|
||||
# Master用户拥有所有权限
|
||||
if self.is_master(user):
|
||||
if await self.is_master(user):
|
||||
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
|
||||
return True
|
||||
|
||||
@@ -288,10 +286,10 @@ class PermissionManager(IPermissionManager):
|
||||
"""
|
||||
try:
|
||||
# Master用户拥有所有权限
|
||||
if self.is_master(user):
|
||||
if await self.is_master(user):
|
||||
async with self.SessionLocal() as session:
|
||||
result = await session.execute(select(PermissionNodes.node_name))
|
||||
all_nodes = result.scalars().all()
|
||||
all_nodes = list(result.scalars().all())
|
||||
return all_nodes
|
||||
|
||||
permissions = []
|
||||
|
||||
@@ -95,6 +95,7 @@ class PluginManager:
|
||||
if not plugin_class:
|
||||
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
||||
return False, 1
|
||||
init_module = None # 预先定义,避免后续条件加载导致未绑定
|
||||
try:
|
||||
# 使用记录的插件目录路径
|
||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||
@@ -314,6 +315,7 @@ class PluginManager:
|
||||
module_name = ".".join(plugin_path.parent.parts)
|
||||
|
||||
try:
|
||||
init_module = None # 确保下方引用存在
|
||||
# 首先加载 __init__.py 来获取元数据
|
||||
init_file = os.path.join(plugin_dir, "__init__.py")
|
||||
if os.path.exists(init_file):
|
||||
@@ -524,13 +526,10 @@ class PluginManager:
|
||||
fut.result(timeout=5)
|
||||
else:
|
||||
asyncio.run(component_registry.unregister_plugin(plugin_name))
|
||||
except Exception:
|
||||
# 最后兜底:直接同步调用(如果 unregister_plugin 为非协程)或忽略错误
|
||||
try:
|
||||
# 如果 unregister_plugin 是普通函数
|
||||
component_registry.unregister_plugin(plugin_name)
|
||||
except Exception as e:
|
||||
logger.debug(f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}")
|
||||
except Exception as e: # 捕获并记录卸载阶段协程调用错误
|
||||
logger.debug(
|
||||
f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# 从已加载插件中移除
|
||||
del self.loaded_plugins[plugin_name]
|
||||
@@ -550,65 +549,6 @@ class PluginManager:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True)
|
||||
return False
|
||||
|
||||
def reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""重载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 重载是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔄 开始重载插件: {plugin_name}")
|
||||
|
||||
# 卸载插件
|
||||
if plugin_name in self.loaded_plugins:
|
||||
if not self.unload_plugin(plugin_name):
|
||||
logger.warning(f"⚠️ 插件卸载失败,继续重载: {plugin_name}")
|
||||
|
||||
# 重新扫描插件目录
|
||||
self.rescan_plugin_directory()
|
||||
|
||||
# 重新加载插件实例
|
||||
if plugin_name in self.plugin_classes:
|
||||
success, _ = self.load_registered_plugin_classes(plugin_name)
|
||||
if success:
|
||||
logger.info(f"✅ 插件重载成功: {plugin_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件类未找到")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True)
|
||||
return False
|
||||
|
||||
def force_reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""强制重载插件(使用简化的方法)
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 重载是否成功
|
||||
"""
|
||||
return self.reload_plugin(plugin_name)
|
||||
|
||||
@staticmethod
|
||||
def clear_all_plugin_caches():
|
||||
"""清理所有插件相关的模块缓存(简化版)"""
|
||||
try:
|
||||
logger.info("🧹 清理模块缓存...")
|
||||
# 清理importlib缓存
|
||||
importlib.invalidate_caches()
|
||||
logger.info("🧹 模块缓存清理完成")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 全局插件管理器实例
|
||||
plugin_manager = PluginManager()
|
||||
|
||||
@@ -177,8 +177,8 @@ class ToolExecutor:
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in tool_calls:
|
||||
tool_name = getattr(tool_call, "func_name", "unknown_tool")
|
||||
try:
|
||||
tool_name = tool_call.func_name
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
|
||||
Reference in New Issue
Block a user