This commit is contained in:
tt-P607
2025-10-07 16:47:50 +08:00
12 changed files with 490 additions and 267 deletions

4
bot.py
View File

@@ -560,9 +560,9 @@ class MaiBotMain:
logger.info("正在初始化数据库表结构...") logger.info("正在初始化数据库表结构...")
try: try:
start_time = time.time() start_time = time.time()
from src.common.database.sqlalchemy_models import initialize_database as init_db from src.common.database.sqlalchemy_models import initialize_database
await init_db() await initialize_database()
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}") logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}")
except Exception as e: except Exception as e:

View File

@@ -2,8 +2,7 @@
"$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json", "$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json",
"include": [ "include": [
"src", "src",
"bot.py", "bot.py"
"__main__.py"
], ],
"exclude": [ "exclude": [
"**/__pycache__", "**/__pycache__",
@@ -11,7 +10,9 @@
"logs", "logs",
"tests", "tests",
"target", "target",
"*.egg-info" "*.egg-info",
"src/plugins/built_in/*",
"__main__.py"
], ],
"typeCheckingMode": "standard", "typeCheckingMode": "standard",
"reportMissingImports": false, "reportMissingImports": false,

View File

@@ -5,7 +5,7 @@
""" """
import datetime import datetime
from typing import Any, Optional, TypedDict, Literal, Union, Callable, TypeVar, cast from typing import Any, Optional, TypeVar, cast
from sqlalchemy import select, delete from sqlalchemy import select, delete
@@ -47,13 +47,12 @@ class AntiInjectionStatistics:
"""当前会话开始时间""" """当前会话开始时间"""
@staticmethod @staticmethod
async def get_or_create_stats() -> Optional[AntiInjectionStats]: # type: ignore[name-defined] async def get_or_create_stats() -> AntiInjectionStats:
"""获取或创建统计记录 """获取或创建统计记录
Returns: Returns:
AntiInjectionStats | None: 成功返回模型实例,否则 None AntiInjectionStats | None: 成功返回模型实例,否则 None
""" """
try:
async with get_db_session() as session: async with get_db_session() as session:
# 获取最新的统计记录,如果没有则创建 # 获取最新的统计记录,如果没有则创建
stats = ( stats = (
@@ -67,9 +66,7 @@ class AntiInjectionStatistics:
await session.commit() await session.commit()
await session.refresh(stats) await session.refresh(stats)
return stats return stats
except Exception as e:
logger.error(f"获取统计记录失败: {e}")
return None
@staticmethod @staticmethod
async def update_stats(**kwargs: Any) -> None: async def update_stats(**kwargs: Any) -> None:
@@ -97,7 +94,7 @@ class AntiInjectionStatistics:
if key == "processing_time_delta": if key == "processing_time_delta":
# 处理时间累加 - 确保不为 None # 处理时间累加 - 确保不为 None
delta = float(value) delta = float(value)
stats.processing_time_total = _add_optional(stats.processing_time_total, delta) # type: ignore[attr-defined] stats.processing_time_total = _add_optional(stats.processing_time_total, delta)
continue continue
elif key == "last_processing_time": elif key == "last_processing_time":
# 直接设置最后处理时间 # 直接设置最后处理时间
@@ -146,7 +143,7 @@ class AntiInjectionStatistics:
# 计算派生统计信息 - 处理 None 值 # 计算派生统计信息 - 处理 None 值
total_messages = stats.total_messages or 0 # type: ignore[attr-defined] total_messages = stats.total_messages or 0
detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined] detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined]
processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined] processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined]

View File

@@ -110,8 +110,24 @@ class ChatterManager:
self.stats["streams_processed"] += 1 self.stats["streams_processed"] += 1
try: try:
result = await self.instances[stream_id].execute(context) result = await self.instances[stream_id].execute(context)
# 检查执行结果是否真正成功
success = result.get("success", False)
if success:
self.stats["successful_executions"] += 1 self.stats["successful_executions"] += 1
# 只有真正成功时才清空未读消息
try:
from src.chat.message_manager.message_manager import message_manager
await message_manager.clear_stream_unread_messages(stream_id)
logger.debug(f"{stream_id} 处理成功,已清空未读消息")
except Exception as clear_e:
logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}")
else:
self.stats["failed_executions"] += 1
logger.warning(f"{stream_id} 处理失败,不清空未读消息")
# 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext # 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext
try: try:
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
@@ -124,19 +140,14 @@ class ChatterManager:
logger.error(f"同步 chat_stream 回 StreamContext 失败: {sync_e}") logger.error(f"同步 chat_stream 回 StreamContext 失败: {sync_e}")
# 记录处理结果 # 记录处理结果
success = result.get("success", False)
actions_count = result.get("actions_count", 0) actions_count = result.get("actions_count", 0)
logger.debug(f"{stream_id} 处理完成: 成功={success}, 动作数={actions_count}") logger.debug(f"{stream_id} 处理完成: 成功={success}, 动作数={actions_count}")
# 在处理完成后,清除该流的未读消息
try:
from src.chat.message_manager.message_manager import message_manager
await message_manager.clear_stream_unread_messages(stream_id)
except Exception as clear_e:
logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}")
return result return result
except asyncio.CancelledError:
self.stats["failed_executions"] += 1
logger.info(f"{stream_id} 处理被取消,不清空未读消息")
raise
except Exception as e: except Exception as e:
self.stats["failed_executions"] += 1 self.stats["failed_executions"] += 1
logger.error(f"处理流 {stream_id} 时发生错误: {e}") logger.error(f"处理流 {stream_id} 时发生错误: {e}")

View File

@@ -55,7 +55,51 @@ class SingleStreamContextManager:
bool: 是否成功添加 bool: 是否成功添加
""" """
try: try:
# 直接操作上下文的消息列表 # 使用MessageManager的内置缓存系统
try:
from .message_manager import message_manager
# 如果MessageManager正在运行使用缓存系统
if message_manager.is_running:
# 先计算兴趣值(需要在缓存前计算)
await self._calculate_message_interest(message)
message.is_read = False
# 添加到缓存而不是直接添加到未读消息
cache_success = message_manager.add_message_to_cache(self.stream_id, message)
if cache_success:
# 自动检测和更新chat type
self._detect_chat_type(message)
self.total_messages += 1
self.last_access_time = time.time()
# 检查当前是否正在处理消息
is_processing = message_manager.get_stream_processing_status(self.stream_id)
if not is_processing:
# 如果当前没有在处理,立即刷新缓存到未读消息
cached_messages = message_manager.flush_cached_messages(self.stream_id)
for cached_msg in cached_messages:
self.context.unread_messages.append(cached_msg)
logger.debug(f"立即刷新缓存到未读消息: stream={self.stream_id}, 数量={len(cached_messages)}")
else:
logger.debug(f"消息已缓存,等待当前处理完成: stream={self.stream_id}")
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
logger.debug(f"添加消息到缓存系统: {self.stream_id}")
return True
else:
logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}")
except ImportError:
logger.debug("MessageManager不可用使用直接添加模式")
except Exception as e:
logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}")
# 回退方案:直接添加到未读消息
message.is_read = False message.is_read = False
self.context.unread_messages.append(message) self.context.unread_messages.append(message)

View File

@@ -364,9 +364,17 @@ class StreamLoopManager:
logger.warning(f"Chatter管理器未设置: {stream_id}") logger.warning(f"Chatter管理器未设置: {stream_id}")
return False return False
# 设置处理状态为正在处理
self._set_stream_processing_status(stream_id, True)
try: try:
start_time = time.time() start_time = time.time()
# 在处理开始前,先刷新缓存到未读消息
cached_messages = await self._flush_cached_messages_to_unread(stream_id)
if cached_messages:
logger.info(f"处理开始前刷新缓存消息: stream={stream_id}, 数量={len(cached_messages)}")
# 直接调用chatter_manager处理流上下文 # 直接调用chatter_manager处理流上下文
task = asyncio.create_task(self.chatter_manager.process_stream_context(stream_id, context)) task = asyncio.create_task(self.chatter_manager.process_stream_context(stream_id, context))
self.chatter_manager.set_processing_task(stream_id, task) self.chatter_manager.set_processing_task(stream_id, task)
@@ -374,6 +382,11 @@ class StreamLoopManager:
success = results.get("success", False) success = results.get("success", False)
if success: if success:
# 处理成功后,再次刷新缓存中可能的新消息
additional_messages = await self._flush_cached_messages_to_unread(stream_id)
if additional_messages:
logger.info(f"处理完成后刷新新消息: stream={stream_id}, 数量={len(additional_messages)}")
asyncio.create_task(self._refresh_focus_energy(stream_id)) asyncio.create_task(self._refresh_focus_energy(stream_id))
process_time = time.time() - start_time process_time = time.time() - start_time
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)") logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
@@ -385,6 +398,57 @@ class StreamLoopManager:
except Exception as e: except Exception as e:
logger.error(f"流处理异常: {stream_id} - {e}", exc_info=True) logger.error(f"流处理异常: {stream_id} - {e}", exc_info=True)
return False return False
finally:
# 无论成功或失败,都要设置处理状态为未处理
self._set_stream_processing_status(stream_id, False)
def _set_stream_processing_status(self, stream_id: str, is_processing: bool) -> None:
"""设置流的处理状态"""
try:
from .message_manager import message_manager
if message_manager.is_running:
message_manager.set_stream_processing_status(stream_id, is_processing)
logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}")
except ImportError:
logger.debug("MessageManager不可用跳过状态设置")
except Exception as e:
logger.warning(f"设置流处理状态失败: stream={stream_id}, error={e}")
async def _flush_cached_messages_to_unread(self, stream_id: str) -> list:
"""将缓存消息刷新到未读消息列表"""
try:
from .message_manager import message_manager
if message_manager.is_running and message_manager.has_cached_messages(stream_id):
# 获取缓存消息
cached_messages = message_manager.flush_cached_messages(stream_id)
if cached_messages:
# 获取聊天流并添加到未读消息
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
for message in cached_messages:
chat_stream.context_manager.context.unread_messages.append(message)
logger.debug(f"刷新缓存消息到未读列表: stream={stream_id}, 数量={len(cached_messages)}")
else:
logger.warning(f"无法找到聊天流: {stream_id}")
return cached_messages
return []
except ImportError:
logger.debug("MessageManager不可用跳过缓存刷新")
return []
except Exception as e:
logger.warning(f"刷新缓存消息失败: stream={stream_id}, error={e}")
return []
async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float: async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float:
"""计算下次检查间隔 """计算下次检查间隔

View File

@@ -6,6 +6,7 @@
import asyncio import asyncio
import random import random
import time import time
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager from src.chat.chatter_manager import ChatterManager
@@ -46,6 +47,14 @@ class MessageManager:
self.sleep_manager = SleepManager() self.sleep_manager = SleepManager()
self.wakeup_manager = WakeUpManager(self.sleep_manager) self.wakeup_manager = WakeUpManager(self.sleep_manager)
# 消息缓存系统 - 直接集成到消息管理器
self.message_caches: Dict[str, deque] = defaultdict(deque) # 每个流的消息缓存
self.stream_processing_status: Dict[str, bool] = defaultdict(bool) # 流的处理状态
self.cache_stats = {
"total_cached_messages": 0,
"total_flushed_messages": 0,
}
# 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager # 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager
async def start(self): async def start(self):
@@ -72,6 +81,9 @@ class MessageManager:
except Exception as e: except Exception as e:
logger.error(f"启动流缓存管理器失败: {e}") logger.error(f"启动流缓存管理器失败: {e}")
# 启动消息缓存系统(内置)
logger.info("📦 消息缓存系统已启动")
# 启动自适应流管理器 # 启动自适应流管理器
try: try:
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
@@ -115,6 +127,11 @@ class MessageManager:
except Exception as e: except Exception as e:
logger.error(f"停止流缓存管理器失败: {e}") logger.error(f"停止流缓存管理器失败: {e}")
# 停止消息缓存系统(内置)
self.message_caches.clear()
self.stream_processing_status.clear()
logger.info("📦 消息缓存系统已停止")
# 停止自适应流管理器 # 停止自适应流管理器
try: try:
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
@@ -429,6 +446,115 @@ class MessageManager:
except Exception as e: except Exception as e:
logger.error(f"清除流 {stream_id} 的未读消息时发生错误: {e}") logger.error(f"清除流 {stream_id} 的未读消息时发生错误: {e}")
# ===== 消息缓存系统方法 =====
def add_message_to_cache(self, stream_id: str, message: DatabaseMessages) -> bool:
"""添加消息到缓存
Args:
stream_id: 流ID
message: 消息对象
Returns:
bool: 是否成功添加到缓存
"""
try:
if not self.is_running:
return False
self.message_caches[stream_id].append(message)
self.cache_stats["total_cached_messages"] += 1
logger.debug(f"消息已添加到缓存: stream={stream_id}, content={message.processed_plain_text[:50]}...")
return True
except Exception as e:
logger.error(f"添加消息到缓存失败: stream={stream_id}, error={e}")
return False
def flush_cached_messages(self, stream_id: str) -> list[DatabaseMessages]:
"""刷新缓存消息到未读消息列表
Args:
stream_id: 流ID
Returns:
List[DatabaseMessages]: 缓存的消息列表
"""
try:
if stream_id not in self.message_caches:
return []
cached_messages = list(self.message_caches[stream_id])
self.message_caches[stream_id].clear()
self.cache_stats["total_flushed_messages"] += len(cached_messages)
logger.debug(f"刷新缓存消息: stream={stream_id}, 数量={len(cached_messages)}")
return cached_messages
except Exception as e:
logger.error(f"刷新缓存消息失败: stream={stream_id}, error={e}")
return []
def set_stream_processing_status(self, stream_id: str, is_processing: bool):
"""设置流的处理状态
Args:
stream_id: 流ID
is_processing: 是否正在处理
"""
try:
self.stream_processing_status[stream_id] = is_processing
logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}")
except Exception as e:
logger.error(f"设置流处理状态失败: stream={stream_id}, error={e}")
def get_stream_processing_status(self, stream_id: str) -> bool:
"""获取流的处理状态
Args:
stream_id: 流ID
Returns:
bool: 是否正在处理
"""
return self.stream_processing_status.get(stream_id, False)
def has_cached_messages(self, stream_id: str) -> bool:
"""检查流是否有缓存消息
Args:
stream_id: 流ID
Returns:
bool: 是否有缓存消息
"""
return stream_id in self.message_caches and len(self.message_caches[stream_id]) > 0
def get_cached_messages_count(self, stream_id: str) -> int:
"""获取流的缓存消息数量
Args:
stream_id: 流ID
Returns:
int: 缓存消息数量
"""
return len(self.message_caches.get(stream_id, []))
def get_cache_stats(self) -> dict[str, Any]:
"""获取缓存统计信息
Returns:
Dict[str, Any]: 缓存统计信息
"""
return {
"total_cached_messages": self.cache_stats["total_cached_messages"],
"total_flushed_messages": self.cache_stats["total_flushed_messages"],
"active_caches": len(self.message_caches),
"cached_streams": len([s for s in self.message_caches.keys() if self.message_caches[s]]),
"processing_streams": len([s for s in self.stream_processing_status.keys() if self.stream_processing_status[s]]),
}
# 创建全局消息管理器实例 # 创建全局消息管理器实例
message_manager = MessageManager() message_manager = MessageManager()

View File

@@ -19,7 +19,7 @@ from contextlib import asynccontextmanager
from typing import Any from typing import Any
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@@ -31,6 +31,10 @@ logger = get_logger("sqlalchemy_models")
# 创建基类 # 创建基类
Base = declarative_base() Base = declarative_base()
# 全局异步引擎与会话工厂占位(延迟初始化)
_engine: AsyncEngine | None = None
_SessionLocal: async_sessionmaker[AsyncSession] | None = None
async def enable_sqlite_wal_mode(engine): async def enable_sqlite_wal_mode(engine):
"""为 SQLite 启用 WAL 模式以提高并发性能""" """为 SQLite 启用 WAL 模式以提高并发性能"""
@@ -649,23 +653,13 @@ class MonthlyPlan(Base):
last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True) last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
__table_args__ = ( __table_args__ = (
Index("idx_monthlyplan_target_month_status", "target_month", "status"), Index("idx_monthlyplan_target_month_status", "target_month", "status"),
Index("idx_monthlyplan_last_used_date", "last_used_date"), Index("idx_monthlyplan_last_used_date", "last_used_date"),
Index("idx_monthlyplan_usage_count", "usage_count"), Index("idx_monthlyplan_usage_count", "usage_count"),
# 保留旧索引以兼容
Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"),
) )
# 数据库引擎和会话管理
_engine = None
_SessionLocal = None
def get_database_url(): def get_database_url():
"""获取数据库连接URL""" """获取数据库连接URL"""
from src.config.config import global_config from src.config.config import global_config
@@ -709,13 +703,35 @@ def get_database_url():
return f"sqlite+aiosqlite:///{db_path}" return f"sqlite+aiosqlite:///{db_path}"
async def initialize_database(): _initializing: bool = False # 防止递归初始化
"""初始化异步数据库引擎和会话"""
global _engine, _SessionLocal
if _engine is not None: async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[AsyncSession]]:
"""初始化异步数据库引擎和会话
Returns:
tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: 创建好的异步引擎与会话工厂。
说明:
显式的返回类型标注有助于 Pyright/Pylance 正确推断调用处的对象,
避免后续对返回值再次 `await` 时出现 *"tuple[...] 并非 awaitable"* 的误用。
"""
global _engine, _SessionLocal, _initializing
# 已经初始化直接返回
if _engine is not None and _SessionLocal is not None:
return _engine, _SessionLocal return _engine, _SessionLocal
# 正在初始化的并发调用等待主初始化完成,避免递归
if _initializing:
import asyncio
for _ in range(1000): # 最多等待约10秒
await asyncio.sleep(0.01)
if _engine is not None and _SessionLocal is not None:
return _engine, _SessionLocal
raise RuntimeError("等待数据库初始化完成超时 (reentrancy guard)")
_initializing = True
try:
database_url = get_database_url() database_url = get_database_url()
from src.config.config import global_config from src.config.config import global_config
@@ -728,14 +744,13 @@ async def initialize_database():
} }
if config.database_type == "mysql": if config.database_type == "mysql":
# MySQL连接池配置 - 异步引擎使用默认连接池
engine_kwargs.update( engine_kwargs.update(
{ {
"pool_size": config.connection_pool_size, "pool_size": config.connection_pool_size,
"max_overflow": config.connection_pool_size * 2, "max_overflow": config.connection_pool_size * 2,
"pool_timeout": config.connection_timeout, "pool_timeout": config.connection_timeout,
"pool_recycle": 3600, # 1小时回收连接 "pool_recycle": 3600,
"pool_pre_ping": True, # 连接前ping检查 "pool_pre_ping": True,
"connect_args": { "connect_args": {
"autocommit": config.mysql_autocommit, "autocommit": config.mysql_autocommit,
"charset": config.mysql_charset, "charset": config.mysql_charset,
@@ -744,12 +759,11 @@ async def initialize_database():
} }
) )
else: else:
# SQLite配置 - aiosqlite不支持连接池参数
engine_kwargs.update( engine_kwargs.update(
{ {
"connect_args": { "connect_args": {
"check_same_thread": False, "check_same_thread": False,
"timeout": 60, # 增加超时时间 "timeout": 60,
}, },
} }
) )
@@ -757,17 +771,21 @@ async def initialize_database():
_engine = create_async_engine(database_url, **engine_kwargs) _engine = create_async_engine(database_url, **engine_kwargs)
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 调用新的迁移函数,它会处理表的创建和列的添加 # 迁移
try:
from src.common.database.db_migration import check_and_migrate_database from src.common.database.db_migration import check_and_migrate_database
await check_and_migrate_database(existing_engine=_engine)
except TypeError:
from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
await _legacy_migrate()
await check_and_migrate_database()
# 如果是 SQLite启用 WAL 模式以提高并发性能
if config.database_type == "sqlite": if config.database_type == "sqlite":
await enable_sqlite_wal_mode(_engine) await enable_sqlite_wal_mode(_engine)
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal return _engine, _SessionLocal
finally:
_initializing = False
@asynccontextmanager @asynccontextmanager

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import re import re
from pathlib import Path from pathlib import Path
from re import Pattern from re import Pattern
from typing import Any, Optional, Union from typing import Any, Optional, Union, cast
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
@@ -26,6 +28,23 @@ from src.plugin_system.base.plus_command import PlusCommand
logger = get_logger("component_registry") logger = get_logger("component_registry")
# 统一的组件类类型别名
ComponentClassType = (
type[BaseCommand]
| type[BaseAction]
| type[BaseTool]
| type[BaseEventHandler]
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
)
def _assign_plugin_attrs(cls: Any, plugin_name: str, plugin_config: dict) -> None:
"""为组件类动态赋予插件相关属性(避免在各注册函数中重复代码)。"""
setattr(cls, "plugin_name", plugin_name)
setattr(cls, "plugin_config", plugin_config)
class ComponentRegistry: class ComponentRegistry:
"""统一的组件注册中心 """统一的组件注册中心
@@ -41,9 +60,8 @@ class ComponentRegistry:
types: {} for types in ComponentType types: {} for types in ComponentType
} }
"""类型 -> 组件原名称 -> 组件信息""" """类型 -> 组件原名称 -> 组件信息"""
self._components_classes: dict[ # 组件类注册表(命名空间式组件名 -> 组件类)
str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"] self._components_classes: dict[str, ComponentClassType] = {}
] = {}
"""命名空间式组件名 -> 组件类""" """命名空间式组件名 -> 组件类"""
# 插件注册表 # 插件注册表
@@ -100,10 +118,8 @@ class ComponentRegistry:
return True return True
def register_component( def register_component(
self, self, self_component_info: ComponentInfo, component_class: ComponentClassType
component_info: ComponentInfo, ) -> bool: # noqa: C901 (保持原有结构, 以后可再拆)
component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
) -> bool:
"""注册组件 """注册组件
Args: Args:
@@ -113,9 +129,11 @@ class ComponentRegistry:
Returns: Returns:
bool: 是否注册成功 bool: 是否注册成功
""" """
component_info = self_component_info # 局部别名
component_name = component_info.name component_name = component_info.name
component_type = component_info.component_type component_type = component_info.component_type
plugin_name = getattr(component_info, "plugin_name", "unknown") plugin_name = getattr(component_info, "plugin_name", "unknown")
if "." in component_name: if "." in component_name:
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代") logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
return False return False
@@ -123,22 +141,19 @@ class ComponentRegistry:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False return False
namespaced_name = f"{component_type}.{component_name}" namespaced_name = f"{component_type.value}.{component_name}"
if namespaced_name in self._components: if namespaced_name in self._components:
existing_info = self._components[namespaced_name] existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, "plugin_name", "unknown") existing_plugin = getattr(existing_info, "plugin_name", "unknown")
logger.warning( logger.warning(
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册" f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
) )
return False return False
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称) self._components[namespaced_name] = component_info
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名 self._components_by_type[component_type][component_name] = component_info
self._components_classes[namespaced_name] = component_class self._components_classes[namespaced_name] = component_class
# 根据组件类型进行特定注册(使用原始名称)
match component_type: match component_type:
case ComponentType.ACTION: case ComponentType.ACTION:
assert isinstance(component_info, ActionInfo) assert isinstance(component_info, ActionInfo)
@@ -175,12 +190,11 @@ class ComponentRegistry:
if not ret: if not ret:
return False return False
logger.debug( logger.debug(
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' " f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' ({component_class.__name__}) [插件: {plugin_name}]"
f"({component_class.__name__}) [插件: {plugin_name}]"
) )
return True return True
def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool: def _register_action_component(self, action_info: ActionInfo, action_class: type[BaseAction]) -> bool:
"""注册Action组件到Action特定注册表""" """注册Action组件到Action特定注册表"""
if not (action_name := action_info.name): if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称") logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
@@ -188,19 +202,13 @@ class ComponentRegistry:
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction): if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action") logger.error(f"注册失败: {action_name} 不是有效的Action")
return False return False
_assign_plugin_attrs(action_class, action_info.plugin_name, self.get_plugin_config(action_info.plugin_name) or {})
action_class.plugin_name = action_info.plugin_name
# 设置插件配置
action_class.plugin_config = self.get_plugin_config(action_info.plugin_name) or {}
self._action_registry[action_name] = action_class self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
if action_info.enabled: if action_info.enabled:
self._default_actions[action_name] = action_info self._default_actions[action_name] = action_info
return True return True
def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool: def _register_command_component(self, command_info: CommandInfo, command_class: type[BaseCommand]) -> bool:
"""注册Command组件到Command特定注册表""" """注册Command组件到Command特定注册表"""
if not (command_name := command_info.name): if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称") logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
@@ -208,13 +216,10 @@ class ComponentRegistry:
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand): if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
logger.error(f"注册失败: {command_name} 不是有效的Command") logger.error(f"注册失败: {command_name} 不是有效的Command")
return False return False
_assign_plugin_attrs(
command_class.plugin_name = command_info.plugin_name command_class, command_info.plugin_name, self.get_plugin_config(command_info.plugin_name) or {}
# 设置插件配置 )
command_class.plugin_config = self.get_plugin_config(command_info.plugin_name) or {}
self._command_registry[command_name] = command_class self._command_registry[command_name] = command_class
# 如果启用了且有匹配模式
if command_info.enabled and command_info.command_pattern: if command_info.enabled and command_info.command_pattern:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL) pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
if pattern not in self._command_patterns: if pattern not in self._command_patterns:
@@ -223,11 +228,10 @@ class ComponentRegistry:
logger.warning( logger.warning(
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令" f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
) )
return True return True
def _register_plus_command_component( def _register_plus_command_component(
self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"] self, plus_command_info: PlusCommandInfo, plus_command_class: type[PlusCommand]
) -> bool: ) -> bool:
"""注册PlusCommand组件到特定注册表""" """注册PlusCommand组件到特定注册表"""
plus_command_name = plus_command_info.name plus_command_name = plus_command_info.name
@@ -241,33 +245,27 @@ class ComponentRegistry:
# 创建专门的PlusCommand注册表如果还没有 # 创建专门的PlusCommand注册表如果还没有
if not hasattr(self, "_plus_command_registry"): if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: dict[str, type["PlusCommand"]] = {} self._plus_command_registry: dict[str, type[PlusCommand]] = {}
_assign_plugin_attrs(
plus_command_class.plugin_name = plus_command_info.plugin_name plus_command_class,
# 设置插件配置 plus_command_info.plugin_name,
plus_command_class.plugin_config = self.get_plugin_config(plus_command_info.plugin_name) or {} self.get_plugin_config(plus_command_info.plugin_name) or {},
)
self._plus_command_registry[plus_command_name] = plus_command_class self._plus_command_registry[plus_command_name] = plus_command_class
logger.debug(f"已注册PlusCommand组件: {plus_command_name}") logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
return True return True
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool: def _register_tool_component(self, tool_info: ToolInfo, tool_class: type[BaseTool]) -> bool:
"""注册Tool组件到Tool特定注册表""" """注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name tool_name = tool_info.name
_assign_plugin_attrs(tool_class, tool_info.plugin_name, self.get_plugin_config(tool_info.plugin_name) or {})
tool_class.plugin_name = tool_info.plugin_name
# 设置插件配置
tool_class.plugin_config = self.get_plugin_config(tool_info.plugin_name) or {}
self._tool_registry[tool_name] = tool_class self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.enabled: if tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class self._llm_available_tools[tool_name] = tool_class
return True return True
def _register_event_handler_component( def _register_event_handler_component(
self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"] self, handler_info: EventHandlerInfo, handler_class: type[BaseEventHandler]
) -> bool: ) -> bool:
if not (handler_name := handler_info.name): if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
@@ -275,25 +273,19 @@ class ComponentRegistry:
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler): if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler") logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False return False
_assign_plugin_attrs(
handler_class.plugin_name = handler_info.plugin_name handler_class, handler_info.plugin_name, self.get_plugin_config(handler_info.plugin_name) or {}
# 设置插件配置 )
handler_class.plugin_config = self.get_plugin_config(handler_info.plugin_name) or {}
self._event_handler_registry[handler_name] = handler_class self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled: if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用") logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功 return True # 未启用,但是也是注册成功
handler_class.plugin_name = handler_info.plugin_name
# 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler( return event_manager.register_event_handler(
handler_class, self.get_plugin_config(handler_info.plugin_name) or {} handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
) )
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool: def _register_chatter_component(self, chatter_info: ChatterInfo, chatter_class: type[BaseChatter]) -> bool:
"""注册Chatter组件到Chatter特定注册表""" """注册Chatter组件到Chatter特定注册表"""
chatter_name = chatter_info.name chatter_name = chatter_info.name
@@ -303,18 +295,14 @@ class ComponentRegistry:
if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter): if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter):
logger.error(f"注册失败: {chatter_name} 不是有效的Chatter") logger.error(f"注册失败: {chatter_name} 不是有效的Chatter")
return False return False
_assign_plugin_attrs(
chatter_class.plugin_name = chatter_info.plugin_name chatter_class, chatter_info.plugin_name, self.get_plugin_config(chatter_info.plugin_name) or {}
# 设置插件配置 )
chatter_class.plugin_config = self.get_plugin_config(chatter_info.plugin_name) or {}
self._chatter_registry[chatter_name] = chatter_class self._chatter_registry[chatter_name] = chatter_class
if not chatter_info.enabled: if not chatter_info.enabled:
logger.warning(f"Chatter组件 {chatter_name} 未启用") logger.warning(f"Chatter组件 {chatter_name} 未启用")
return True # 未启用,但是也是注册成功 return True # 未启用,但是也是注册成功
self._enabled_chatter_registry[chatter_name] = chatter_class self._enabled_chatter_registry[chatter_name] = chatter_class
logger.debug(f"已注册Chatter组件: {chatter_name}") logger.debug(f"已注册Chatter组件: {chatter_name}")
return True return True
@@ -341,9 +329,13 @@ class ComponentRegistry:
if not hasattr(self, "_enabled_interest_calculator_registry"): if not hasattr(self, "_enabled_interest_calculator_registry"):
self._enabled_interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {} self._enabled_interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {}
interest_calculator_class.plugin_name = interest_calculator_info.plugin_name setattr(interest_calculator_class, "plugin_name", interest_calculator_info.plugin_name)
# 设置插件配置 # 设置插件配置
interest_calculator_class.plugin_config = self.get_plugin_config(interest_calculator_info.plugin_name) or {} setattr(
interest_calculator_class,
"plugin_config",
self.get_plugin_config(interest_calculator_info.plugin_name) or {},
)
self._interest_calculator_registry[calculator_name] = interest_calculator_class self._interest_calculator_registry[calculator_name] = interest_calculator_class
if not interest_calculator_info.enabled: if not interest_calculator_info.enabled:
@@ -356,7 +348,7 @@ class ComponentRegistry:
# === 组件移除相关 === # === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: "ComponentType", plugin_name: str) -> bool: async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
target_component_class = self.get_component_class(component_name, component_type) target_component_class = self.get_component_class(component_name, component_type)
if not target_component_class: if not target_component_class:
logger.warning(f"组件 {component_name} 未注册,无法移除") logger.warning(f"组件 {component_name} 未注册,无法移除")
@@ -398,8 +390,14 @@ class ComponentRegistry:
self._enabled_event_handlers.pop(component_name, None) self._enabled_event_handlers.pop(component_name, None)
try: try:
handler = event_manager.get_event_handler(component_name) handler = event_manager.get_event_handler(component_name)
for event in handler.subscribed_events: # 事件处理器可能未找到或未声明 subscribed_events,需判空
await event_manager.unsubscribe_handler_from_event(event, component_name) if handler and hasattr(handler, "subscribed_events"):
for event in getattr(handler, "subscribed_events"):
# 假设 unsubscribe_handler_from_event 是协程;若不是则移除 await
result = event_manager.unsubscribe_handler_from_event(event, component_name)
if hasattr(result, "__await__"):
await result # type: ignore[func-returns-value]
logger.debug(f"已移除EventHandler组件: {component_name}")
logger.debug(f"已移除EventHandler组件: {component_name}") logger.debug(f"已移除EventHandler组件: {component_name}")
except Exception as e: except Exception as e:
logger.warning(f"移除EventHandler事件订阅时出错: {e}") logger.warning(f"移除EventHandler事件订阅时出错: {e}")
@@ -415,7 +413,7 @@ class ComponentRegistry:
return False return False
# 移除通用注册信息 # 移除通用注册信息
namespaced_name = f"{component_type}.{component_name}" namespaced_name = f"{component_type.value}.{component_name}"
self._components.pop(namespaced_name, None) self._components.pop(namespaced_name, None)
self._components_by_type[component_type].pop(component_name, None) self._components_by_type[component_type].pop(component_name, None)
self._components_classes.pop(namespaced_name, None) self._components_classes.pop(namespaced_name, None)
@@ -477,9 +475,10 @@ class ComponentRegistry:
self._enabled_event_handlers[component_name] = target_component_class self._enabled_event_handlers[component_name] = target_component_class
from .event_manager import event_manager # 延迟导入防止循环导入问题 from .event_manager import event_manager # 延迟导入防止循环导入问题
event_manager.register_event_handler(component_name) # 重新注册事件处理器(启用)使用类而不是名称
cfg = self.get_plugin_config(target_component_info.plugin_name) or {}
namespaced_name = f"{component_type}.{component_name}" event_manager.register_event_handler(target_component_class, cfg) # type: ignore[arg-type]
namespaced_name = f"{component_type.value}.{component_name}"
self._components[namespaced_name].enabled = True self._components[namespaced_name].enabled = True
self._components_by_type[component_type][component_name].enabled = True self._components_by_type[component_type][component_name].enabled = True
logger.info(f"组件 {component_name} 已启用") logger.info(f"组件 {component_name} 已启用")
@@ -512,10 +511,16 @@ class ComponentRegistry:
from .event_manager import event_manager # 延迟导入防止循环导入问题 from .event_manager import event_manager # 延迟导入防止循环导入问题
handler = event_manager.get_event_handler(component_name) handler = event_manager.get_event_handler(component_name)
for event in handler.subscribed_events: if handler and hasattr(handler, "subscribed_events"):
await event_manager.unsubscribe_handler_from_event(event, component_name) for event in getattr(handler, "subscribed_events"):
result = event_manager.unsubscribe_handler_from_event(event, component_name)
if hasattr(result, "__await__"):
await result # type: ignore[func-returns-value]
self._components[component_name].enabled = False # 组件主注册表使用命名空间 key
namespaced_name = f"{component_type.value}.{component_name}"
if namespaced_name in self._components:
self._components[namespaced_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用") logger.info(f"组件 {component_name} 已禁用")
return True return True
@@ -528,8 +533,8 @@ class ComponentRegistry:
# === 组件查询方法 === # === 组件查询方法 ===
def get_component_info( def get_component_info(
self, component_name: str, component_type: Optional["ComponentType"] = None self, component_name: str, component_type: Optional[ComponentType] = None
) -> Optional["ComponentInfo"]: ) -> Optional[ComponentInfo]:
# sourcery skip: class-extract-method # sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析 """获取组件信息,支持自动命名空间解析
@@ -546,7 +551,7 @@ class ComponentRegistry:
# 2. 如果指定了组件类型,构造命名空间化的名称查找 # 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type: if component_type:
namespaced_name = f"{component_type}.{component_name}" namespaced_name = f"{component_type.value}.{component_name}"
return self._components.get(namespaced_name) return self._components.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找 # 3. 如果没有指定类型,尝试在所有命名空间中查找
@@ -573,8 +578,17 @@ class ComponentRegistry:
def get_component_class( def get_component_class(
self, self,
component_name: str, component_name: str,
component_type: Optional["ComponentType"] = None, component_type: Optional[ComponentType] = None,
) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None: ) -> (
type[BaseCommand]
| type[BaseAction]
| type[BaseEventHandler]
| type[BaseTool]
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
| None
):
"""获取组件类,支持自动命名空间解析 """获取组件类,支持自动命名空间解析
Args: Args:
@@ -591,7 +605,17 @@ class ComponentRegistry:
# 2. 如果指定了组件类型,构造命名空间化的名称查找 # 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type: if component_type:
namespaced_name = f"{component_type.value}.{component_name}" namespaced_name = f"{component_type.value}.{component_name}"
return self._components_classes.get(namespaced_name) # type: ignore[valid-type] return cast(
type[BaseCommand]
| type[BaseAction]
| type[BaseEventHandler]
| type[BaseTool]
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
| None,
self._components_classes.get(namespaced_name),
)
# 3. 如果没有指定类型,尝试在所有命名空间中查找 # 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = [] candidates = []
@@ -616,22 +640,22 @@ class ComponentRegistry:
# 4. 都没找到 # 4. 都没找到
return None return None
def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: def get_components_by_type(self, component_type: ComponentType) -> dict[str, ComponentInfo]:
"""获取指定类型的所有组件""" """获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy() return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]: def get_enabled_components_by_type(self, component_type: ComponentType) -> dict[str, ComponentInfo]:
"""获取指定类型的所有启用组件""" """获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type) components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled} return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 === # === Action特定查询方法 ===
def get_action_registry(self) -> dict[str, type["BaseAction"]]: def get_action_registry(self) -> dict[str, type[BaseAction]]:
"""获取Action注册表""" """获取Action注册表"""
return self._action_registry.copy() return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional["ActionInfo"]: def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息""" """获取Action信息"""
info = self.get_component_info(action_name, ComponentType.ACTION) info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None return info if isinstance(info, ActionInfo) else None
@@ -642,11 +666,11 @@ class ComponentRegistry:
# === Command特定查询方法 === # === Command特定查询方法 ===
def get_command_registry(self) -> dict[str, type["BaseCommand"]]: def get_command_registry(self) -> dict[str, type[BaseCommand]]:
"""获取Command注册表""" """获取Command注册表"""
return self._command_registry.copy() return self._command_registry.copy()
def get_registered_command_info(self, command_name: str) -> Optional["CommandInfo"]: def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息""" """获取Command信息"""
info = self.get_component_info(command_name, ComponentType.COMMAND) info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None return info if isinstance(info, CommandInfo) else None
@@ -655,7 +679,7 @@ class ComponentRegistry:
"""获取Command模式注册表""" """获取Command模式注册表"""
return self._command_patterns.copy() return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None: def find_command_by_text(self, text: str) -> tuple[type[BaseCommand], dict, CommandInfo] | None:
# sourcery skip: use-named-expression, use-next # sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令 """根据文本查找匹配的命令
@@ -682,15 +706,15 @@ class ComponentRegistry:
return None return None
# === Tool 特定查询方法 === # === Tool 特定查询方法 ===
def get_tool_registry(self) -> dict[str, type["BaseTool"]]: def get_tool_registry(self) -> dict[str, type[BaseTool]]:
"""获取Tool注册表""" """获取Tool注册表"""
return self._tool_registry.copy() return self._tool_registry.copy()
def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]: def get_llm_available_tools(self) -> dict[str, type[BaseTool]]:
"""获取LLM可用的Tool列表""" """获取LLM可用的Tool列表"""
return self._llm_available_tools.copy() return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional["ToolInfo"]: def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息 """获取Tool信息
Args: Args:
@@ -703,13 +727,13 @@ class ComponentRegistry:
return info if isinstance(info, ToolInfo) else None return info if isinstance(info, ToolInfo) else None
# === PlusCommand 特定查询方法 === # === PlusCommand 特定查询方法 ===
def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]: def get_plus_command_registry(self) -> dict[str, type[PlusCommand]]:
"""获取PlusCommand注册表""" """获取PlusCommand注册表"""
if not hasattr(self, "_plus_command_registry"): if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: dict[str, type[PlusCommand]] = {} self._plus_command_registry: dict[str, type[PlusCommand]] = {}
return self._plus_command_registry.copy() return self._plus_command_registry.copy()
def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]: def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]:
"""获取PlusCommand信息 """获取PlusCommand信息
Args: Args:
@@ -723,44 +747,44 @@ class ComponentRegistry:
# === EventHandler 特定查询方法 === # === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]: def get_event_handler_registry(self) -> dict[str, type[BaseEventHandler]]:
"""获取事件处理器注册表""" """获取事件处理器注册表"""
return self._event_handler_registry.copy() return self._event_handler_registry.copy()
def get_registered_event_handler_info(self, handler_name: str) -> Optional["EventHandlerInfo"]: def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
"""获取事件处理器信息""" """获取事件处理器信息"""
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None return info if isinstance(info, EventHandlerInfo) else None
def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]: def get_enabled_event_handlers(self) -> dict[str, type[BaseEventHandler]]:
"""获取启用的事件处理器""" """获取启用的事件处理器"""
return self._enabled_event_handlers.copy() return self._enabled_event_handlers.copy()
# === Chatter 特定查询方法 === # === Chatter 特定查询方法 ===
def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]: def get_chatter_registry(self) -> dict[str, type[BaseChatter]]:
"""获取Chatter注册表""" """获取Chatter注册表"""
if not hasattr(self, "_chatter_registry"): if not hasattr(self, "_chatter_registry"):
self._chatter_registry: dict[str, type[BaseChatter]] = {} self._chatter_registry: dict[str, type[BaseChatter]] = {}
return self._chatter_registry.copy() return self._chatter_registry.copy()
def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]: def get_enabled_chatter_registry(self) -> dict[str, type[BaseChatter]]:
"""获取启用的Chatter注册表""" """获取启用的Chatter注册表"""
if not hasattr(self, "_enabled_chatter_registry"): if not hasattr(self, "_enabled_chatter_registry"):
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {} self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
return self._enabled_chatter_registry.copy() return self._enabled_chatter_registry.copy()
def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]: def get_registered_chatter_info(self, chatter_name: str) -> Optional[ChatterInfo]:
"""获取Chatter信息""" """获取Chatter信息"""
info = self.get_component_info(chatter_name, ComponentType.CHATTER) info = self.get_component_info(chatter_name, ComponentType.CHATTER)
return info if isinstance(info, ChatterInfo) else None return info if isinstance(info, ChatterInfo) else None
# === 插件查询方法 === # === 插件查询方法 ===
def get_plugin_info(self, plugin_name: str) -> Optional["PluginInfo"]: def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息""" """获取插件信息"""
return self._plugins.get(plugin_name) return self._plugins.get(plugin_name)
def get_all_plugins(self) -> dict[str, "PluginInfo"]: def get_all_plugins(self) -> dict[str, PluginInfo]:
"""获取所有插件""" """获取所有插件"""
return self._plugins.copy() return self._plugins.copy()

View File

@@ -22,8 +22,6 @@ class PermissionManager(IPermissionManager):
"""权限管理器实现类""" """权限管理器实现类"""
def __init__(self): def __init__(self):
self.engine = None
self.SessionLocal = None
self._master_users: set[tuple[str, str]] = set() self._master_users: set[tuple[str, str]] = set()
self._load_master_users() self._load_master_users()
@@ -52,7 +50,7 @@ class PermissionManager(IPermissionManager):
self._load_master_users() self._load_master_users()
logger.info("Master用户配置已重新加载") logger.info("Master用户配置已重新加载")
def is_master(self, user: UserInfo) -> bool: async def is_master(self, user: UserInfo) -> bool:
""" """
检查用户是否为Master用户 检查用户是否为Master用户
@@ -81,7 +79,7 @@ class PermissionManager(IPermissionManager):
""" """
try: try:
# Master用户拥有所有权限 # Master用户拥有所有权限
if self.is_master(user): if await self.is_master(user):
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}") logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
return True return True
@@ -288,10 +286,10 @@ class PermissionManager(IPermissionManager):
""" """
try: try:
# Master用户拥有所有权限 # Master用户拥有所有权限
if self.is_master(user): if await self.is_master(user):
async with self.SessionLocal() as session: async with self.SessionLocal() as session:
result = await session.execute(select(PermissionNodes.node_name)) result = await session.execute(select(PermissionNodes.node_name))
all_nodes = result.scalars().all() all_nodes = list(result.scalars().all())
return all_nodes return all_nodes
permissions = [] permissions = []

View File

@@ -95,6 +95,7 @@ class PluginManager:
if not plugin_class: if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在") logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1 return False, 1
init_module = None # 预先定义,避免后续条件加载导致未绑定
try: try:
# 使用记录的插件目录路径 # 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name) plugin_dir = self.plugin_paths.get(plugin_name)
@@ -314,6 +315,7 @@ class PluginManager:
module_name = ".".join(plugin_path.parent.parts) module_name = ".".join(plugin_path.parent.parts)
try: try:
init_module = None # 确保下方引用存在
# 首先加载 __init__.py 来获取元数据 # 首先加载 __init__.py 来获取元数据
init_file = os.path.join(plugin_dir, "__init__.py") init_file = os.path.join(plugin_dir, "__init__.py")
if os.path.exists(init_file): if os.path.exists(init_file):
@@ -524,13 +526,10 @@ class PluginManager:
fut.result(timeout=5) fut.result(timeout=5)
else: else:
asyncio.run(component_registry.unregister_plugin(plugin_name)) asyncio.run(component_registry.unregister_plugin(plugin_name))
except Exception: except Exception as e: # 捕获并记录卸载阶段协程调用错误
# 最后兜底:直接同步调用(如果 unregister_plugin 为非协程)或忽略错误 logger.debug(
try: f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}", exc_info=True
# 如果 unregister_plugin 是普通函数 )
component_registry.unregister_plugin(plugin_name)
except Exception as e:
logger.debug(f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}")
# 从已加载插件中移除 # 从已加载插件中移除
del self.loaded_plugins[plugin_name] del self.loaded_plugins[plugin_name]
@@ -550,65 +549,6 @@ class PluginManager:
logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True) logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True)
return False return False
def reload_plugin(self, plugin_name: str) -> bool:
"""重载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""
try:
logger.info(f"🔄 开始重载插件: {plugin_name}")
# 卸载插件
if plugin_name in self.loaded_plugins:
if not self.unload_plugin(plugin_name):
logger.warning(f"⚠️ 插件卸载失败,继续重载: {plugin_name}")
# 重新扫描插件目录
self.rescan_plugin_directory()
# 重新加载插件实例
if plugin_name in self.plugin_classes:
success, _ = self.load_registered_plugin_classes(plugin_name)
if success:
logger.info(f"✅ 插件重载成功: {plugin_name}")
return True
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件类未找到")
return False
except Exception as e:
logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True)
return False
def force_reload_plugin(self, plugin_name: str) -> bool:
"""强制重载插件(使用简化的方法)
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""
return self.reload_plugin(plugin_name)
@staticmethod
def clear_all_plugin_caches():
"""清理所有插件相关的模块缓存(简化版)"""
try:
logger.info("🧹 清理模块缓存...")
# 清理importlib缓存
importlib.invalidate_caches()
logger.info("🧹 模块缓存清理完成")
except Exception as e:
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
# 全局插件管理器实例 # 全局插件管理器实例
plugin_manager = PluginManager() plugin_manager = PluginManager()

View File

@@ -177,8 +177,8 @@ class ToolExecutor:
# 执行每个工具调用 # 执行每个工具调用
for tool_call in tool_calls: for tool_call in tool_calls:
tool_name = getattr(tool_call, "func_name", "unknown_tool")
try: try:
tool_name = tool_call.func_name
logger.debug(f"{self.log_prefix}执行工具: {tool_name}") logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具 # 执行工具