初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
parent ff7d1177fa
commit 2d4745cd58
257 changed files with 69069 additions and 0 deletions

View File

@@ -0,0 +1,104 @@
"""
MaiBot 插件系统
提供统一的插件开发和管理框架
"""
# 导出主要的公共接口
from .base import (
BasePlugin,
BaseAction,
BaseCommand,
BaseTool,
ConfigField,
ComponentType,
ActionActivationType,
ChatMode,
ComponentInfo,
ActionInfo,
CommandInfo,
PluginInfo,
ToolInfo,
PythonDependency,
BaseEventHandler,
EventHandlerInfo,
EventType,
MaiMessages,
ToolParamType,
)
# 导入工具模块
from .utils import (
ManifestValidator,
# ManifestGenerator,
# validate_plugin_manifest,
# generate_plugin_manifest,
)
from .apis import (
chat_api,
tool_api,
component_manage_api,
config_api,
database_api,
emoji_api,
generator_api,
llm_api,
message_api,
person_api,
plugin_manage_api,
send_api,
register_plugin,
get_logger,
)
__version__ = "2.0.0"
__all__ = [
# API 模块
"chat_api",
"tool_api",
"component_manage_api",
"config_api",
"database_api",
"emoji_api",
"generator_api",
"llm_api",
"message_api",
"person_api",
"plugin_manage_api",
"send_api",
"register_plugin",
"get_logger",
# 基础类
"BasePlugin",
"BaseAction",
"BaseCommand",
"BaseTool",
"BaseEventHandler",
# 类型定义
"ComponentType",
"ActionActivationType",
"ChatMode",
"ComponentInfo",
"ActionInfo",
"CommandInfo",
"PluginInfo",
"ToolInfo",
"PythonDependency",
"EventHandlerInfo",
"EventType",
"ToolParamType",
# 消息
"MaiMessages",
# 装饰器
"register_plugin",
"ConfigField",
# 工具函数
"ManifestValidator",
"get_logger",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
]

View File

@@ -0,0 +1,41 @@
"""
插件系统API模块
提供了插件开发所需的各种API
"""
# 导入所有API模块
from src.plugin_system.apis import (
chat_api,
component_manage_api,
config_api,
database_api,
emoji_api,
generator_api,
llm_api,
message_api,
person_api,
plugin_manage_api,
send_api,
tool_api,
)
from .logging_api import get_logger
from .plugin_register_api import register_plugin
# 导出所有API模块使它们可以通过 apis.xxx 方式访问
__all__ = [
"chat_api",
"component_manage_api",
"config_api",
"database_api",
"emoji_api",
"generator_api",
"llm_api",
"message_api",
"person_api",
"plugin_manage_api",
"send_api",
"get_logger",
"register_plugin",
"tool_api",
]

View File

@@ -0,0 +1,325 @@
"""
聊天API模块
专门负责聊天信息的查询和管理采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import chat_api
streams = chat_api.get_all_group_streams()
chat_type = chat_api.get_stream_type(stream)
或者:
from src.plugin_system.apis.chat_api import ChatManager as chat
streams = chat.get_all_group_streams()
"""
from typing import List, Dict, Any, Optional
from enum import Enum
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
logger = get_logger("chat_api")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有群聊聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 群聊聊天流列表
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有私聊聊天流
Args:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 私聊聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
return streams
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流
Args:
group_id: 群聊ID
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 group_id 为空字符串
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
"""
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in get_chat_manager().streams.items():
if (
stream.group_info
and str(stream.group_info.group_id) == str(group_id)
and stream.platform == platform
):
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatAPI] 查找群聊流失败: {e}")
return None
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流
Args:
user_id: 用户ID
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 user_id 为空字符串
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
"""
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in get_chat_manager().streams.items():
if (
not stream.group_info
and str(stream.user_info.user_id) == str(user_id)
and stream.platform == platform
):
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatAPI] 查找私聊流失败: {e}")
return None
@staticmethod
def get_stream_type(chat_stream: ChatStream) -> str:
"""获取聊天流类型
Args:
chat_stream: 聊天流对象
Returns:
str: 聊天类型 ("group", "private", "unknown")
Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型
ValueError: 如果 chat_stream 为空
"""
if not isinstance(chat_stream, ChatStream):
raise TypeError("chat_stream 必须是 ChatStream 类型")
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
@staticmethod
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
"""获取聊天流详细信息
Args:
chat_stream: 聊天流对象
Returns:
Dict ({str: Any}): 聊天流信息字典
Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型
ValueError: 如果 chat_stream 为空
"""
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, ChatStream):
raise TypeError("chat_stream 必须是 ChatStream 类型")
try:
info: Dict[str, Any] = {
"stream_id": chat_stream.stream_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.group_info:
info.update(
{
"group_id": chat_stream.group_info.group_id,
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"),
}
)
if chat_stream.user_info:
info.update(
{
"user_id": chat_stream.user_info.user_id,
"user_name": chat_stream.user_info.user_nickname,
}
)
return info
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}")
return {}
@staticmethod
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要
Returns:
Dict[str, int]: 包含各种统计信息的字典
"""
try:
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
summary = {
"total_streams": len(all_streams),
"group_streams": len(group_streams),
"private_streams": len(private_streams),
"qq_streams": len([s for s in all_streams if s.platform == "qq"]),
}
logger.debug(f"[ChatAPI] 聊天流统计: {summary}")
return summary
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
return {
"total_streams": 0,
"group_streams": 0,
"private_streams": 0,
"qq_streams": 0,
}
# =============================================================================
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
# =============================================================================
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: ChatStream) -> str:
"""获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()

View File

@@ -0,0 +1,268 @@
from typing import Optional, Union, Dict
from src.plugin_system.base.component_types import (
CommandInfo,
ActionInfo,
EventHandlerInfo,
PluginInfo,
ComponentType,
ToolInfo,
)
# === 插件信息查询 ===
def get_all_plugin_info() -> Dict[str, PluginInfo]:
"""
获取所有插件的信息。
Returns:
dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_all_plugins()
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
"""
获取指定插件的信息。
Args:
plugin_name (str): 插件名称。
Returns:
PluginInfo: 插件信息对象,如果插件不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_plugin_info(plugin_name)
# === 组件查询方法 ===
def get_component_info(
component_name: str, component_type: ComponentType
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定组件的信息。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_info(component_name, component_type) # type: ignore
def get_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定类型的所有组件信息。
Args:
component_type (ComponentType): 组件类型。
Returns:
dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_components_by_type(component_type) # type: ignore
def get_enabled_components_info_by_type(
component_type: ComponentType,
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
"""
获取指定类型的所有启用的组件信息。
Args:
component_type (ComponentType): 组件类型。
Returns:
dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_enabled_components_by_type(component_type) # type: ignore
# === Action 查询方法 ===
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
"""
获取指定 Action 的注册信息。
Args:
action_name (str): Action 名称。
Returns:
ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_action_info(action_name)
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
"""
获取指定 Command 的注册信息。
Args:
command_name (str): Command 名称。
Returns:
CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_command_info(command_name)
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
"""
获取指定 Tool 的注册信息。
Args:
tool_name (str): Tool 名称。
Returns:
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_tool_info(tool_name)
# === EventHandler 特定查询方法 ===
def get_registered_event_handler_info(
event_handler_name: str,
) -> Optional[EventHandlerInfo]:
"""
获取指定 EventHandler 的注册信息。
Args:
event_handler_name (str): EventHandler 名称。
Returns:
EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_event_handler_info(event_handler_name)
# === 组件管理方法 ===
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
"""
全局启用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
bool: 启用成功返回 True否则返回 False。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.enable_component(component_name, component_type)
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
"""
全局禁用指定组件。
**此函数是异步的,确保在异步环境中调用。**
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
Returns:
bool: 禁用成功返回 True否则返回 False。
"""
from src.plugin_system.core.component_registry import component_registry
return await component_registry.disable_component(component_name, component_type)
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
"""
局部启用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
stream_id (str): 消息流 ID。
Returns:
bool: 启用成功返回 True否则返回 False。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
"""
局部禁用指定组件。
Args:
component_name (str): 组件名称。
component_type (ComponentType): 组件类型。
stream_id (str): 消息流 ID。
Returns:
bool: 禁用成功返回 True否则返回 False。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
"""
获取指定消息流中禁用的组件列表。
Args:
stream_id (str): 消息流 ID。
component_type (ComponentType): 组件类型。
Returns:
list[str]: 禁用的组件名称列表。
"""
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
match component_type:
case ComponentType.ACTION:
return global_announcement_manager.get_disabled_chat_actions(stream_id)
case ComponentType.COMMAND:
return global_announcement_manager.get_disabled_chat_commands(stream_id)
case ComponentType.TOOL:
return global_announcement_manager.get_disabled_chat_tools(stream_id)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
case _:
raise ValueError(f"未知 component type: {component_type}")

View File

@@ -0,0 +1,77 @@
"""配置API模块
提供了配置读取和用户信息获取等功能
使用方式:
from src.plugin_system.apis import config_api
value = config_api.get_global_config("section.key")
platform, user_id = await config_api.get_user_id_by_person_name("用户名")
"""
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("config_api")
# =============================================================================
# 配置访问API函数
# =============================================================================
def get_global_config(key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值。
插件应使用此方法读取全局配置,以保证只读和隔离性。
Args:
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = global_config
try:
for k in keys:
if hasattr(current, k):
current = getattr(current, k)
else:
raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current
except Exception as e:
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
return default
def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any:
"""
从插件配置中获取值,支持嵌套键访问
Args:
plugin_config: 插件配置字典
key: 配置键名,支持嵌套访问如 "section.subsection.key",大小写敏感
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = plugin_config
try:
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
elif hasattr(current, k):
current = getattr(current, k)
else:
raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current
except Exception as e:
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
return default

View File

@@ -0,0 +1,29 @@
"""数据库API模块
提供数据库操作相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import database_api
records = await database_api.db_query(ActionRecords, query_type="get")
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
注意此模块现在使用SQLAlchemy实现提供更好的连接管理和错误处理
"""
from src.common.database.sqlalchemy_database_api import (
db_query,
db_save,
db_get,
store_action_info,
get_model_class,
MODEL_MAPPING
)
# 保持向后兼容性
__all__ = [
'db_query',
'db_save',
'db_get',
'store_action_info',
'get_model_class',
'MODEL_MAPPING'
]

View File

@@ -0,0 +1,268 @@
"""
表情API模块
提供表情包相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import emoji_api
result = await emoji_api.get_by_description("开心")
count = emoji_api.get_count()
"""
import random
from typing import Optional, Tuple, List
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.utils.utils_image import image_path_to_base64
logger = get_logger("emoji_api")
# =============================================================================
# 表情包获取API函数
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包
Args:
description: 表情包的描述文本,例如"开心""难过""愤怒"
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
Raises:
ValueError: 如果描述为空字符串
TypeError: 如果描述不是字符串类型
"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("描述必须是字符串类型")
try:
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
emoji_manager = get_emoji_manager()
emoji_result = await emoji_manager.get_emoji_for_text(description)
if not emoji_result:
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path, emoji_description, matched_emotion = emoji_result
emoji_base64 = image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
return None
logger.debug(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
return emoji_base64, emoji_description, matched_emotion
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包失败: {e}")
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包
Args:
count: 要获取的表情包数量默认为1
Returns:
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
Raises:
TypeError: 如果count不是整数类型
ValueError: 如果count为负数
"""
if not isinstance(count, int):
raise TypeError("count 必须是整数类型")
if count < 0:
raise ValueError("count 不能为负数")
if count == 0:
logger.warning("[EmojiAPI] count 为0返回空列表")
return []
try:
logger.info(f"[EmojiAPI] 随机获取 {count} 个表情包")
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
return []
# 过滤有效表情包
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiAPI] 没有有效的表情包")
return []
if len(valid_emojis) < count:
logger.warning(
f"[EmojiAPI] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
)
count = len(valid_emojis)
# 随机选择
selected_emojis = random.sample(valid_emojis, count)
results = []
for selected_emoji in selected_emojis:
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
continue
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
# 记录使用次数
emoji_manager.record_usage(selected_emoji.hash)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
return []
logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
"""根据情感标签获取表情包
Args:
emotion: 情感标签,如"happy""sad""angry"
Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
Raises:
ValueError: 如果情感标签为空字符串
TypeError: 如果情感标签不是字符串类型
"""
if not emotion:
raise ValueError("情感标签不能为空")
if not isinstance(emotion, str):
raise TypeError("情感标签必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
# 筛选匹配情感的表情包
matching_emojis = []
matching_emojis.extend(
emoji_obj
for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis:
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
return None
# 随机选择匹配的表情包
selected_emoji = random.choice(matching_emojis)
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
# 记录使用次数
emoji_manager.record_usage(selected_emoji.hash)
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
except Exception as e:
logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}")
return None
# =============================================================================
# 表情包信息查询API函数
# =============================================================================
def get_count() -> int:
"""获取表情包数量
Returns:
int: 当前可用的表情包数量
"""
try:
emoji_manager = get_emoji_manager()
return emoji_manager.emoji_num
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
return 0
def get_info():
"""获取表情包系统信息
Returns:
dict: 包含表情包数量、最大数量、可用数量信息
"""
try:
emoji_manager = get_emoji_manager()
return {
"current_count": emoji_manager.emoji_num,
"max_count": emoji_manager.emoji_num_max,
"available_emojis": len([e for e in emoji_manager.emoji_objects if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
"""获取所有可用的情感标签
Returns:
list: 所有表情包的情感标签列表(去重)
"""
try:
emoji_manager = get_emoji_manager()
emotions = set()
for emoji_obj in emoji_manager.emoji_objects:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
return sorted(list(emotions))
except Exception as e:
logger.error(f"[EmojiAPI] 获取情感标签失败: {e}")
return []
def get_descriptions() -> List[str]:
"""获取所有表情包描述
Returns:
list: 所有可用表情包的描述列表
"""
try:
emoji_manager = get_emoji_manager()
descriptions = []
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emoji_objects
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
return []

View File

@@ -0,0 +1,280 @@
"""
回复器API模块
提供回复器相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import generator_api
replyer = generator_api.get_replyer(chat_stream)
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
"""
import traceback
from typing import Tuple, Any, Dict, List, Optional
from rich.traceback import install
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo
install(extra_lines=3)
logger = get_logger("generator_api")
# =============================================================================
# 回复器获取API函数
# =============================================================================
def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
"""获取回复器对象
优先使用chat_stream如果没有则使用chat_id直接查找。
使用 ReplyerManager 来管理实例,避免重复创建。
Args:
chat_stream: 聊天流对象(优先)
chat_id: 聊天ID实际上就是stream_id
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
request_type: 请求类型
Returns:
Optional[DefaultReplyer]: 回复器对象如果获取失败则返回None
Raises:
ValueError: chat_stream 和 chat_id 均为空
"""
if not chat_id and not chat_stream:
raise ValueError("chat_stream 和 chat_id 不可均为空")
try:
logger.debug(f"[GeneratorAPI] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}")
return replyer_manager.get_replyer(
chat_stream=chat_stream,
chat_id=chat_id,
model_set_with_weight=model_set_with_weight,
request_type=request_type,
)
except Exception as e:
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
traceback.print_exc()
return None
# =============================================================================
# 回复生成API函数
# =============================================================================
async def generate_reply(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
return_prompt: bool = False,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "generator_api",
from_plugin: bool = True,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复
Args:
chat_stream: 聊天流对象(优先)
chat_id: 聊天ID备用
action_data: 动作数据向下兼容包含reply_to和extra_info
reply_to: 回复对象,格式为 "发送者:消息内容"
extra_info: 额外信息,用于补充上下文
available_actions: 可用动作
enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
request_type: 请求类型可选记录LLM使用
from_plugin: 是否来自插件
Returns:
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
try:
# 获取回复器
replyer = get_replyer(
chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type
)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
logger.debug("[GeneratorAPI] 开始生成回复")
if not reply_to and action_data:
reply_to = action_data.get("reply_to", "")
if not extra_info and action_data:
extra_info = action_data.get("extra_info", "")
# 调用回复器生成回复
success, llm_response_dict, prompt = await replyer.generate_reply_with_context(
reply_to=reply_to,
extra_info=extra_info,
available_actions=available_actions,
enable_tool=enable_tool,
from_plugin=from_plugin,
stream_id=chat_stream.stream_id if chat_stream else chat_id,
)
if not success:
logger.warning("[GeneratorAPI] 回复生成失败")
return False, [], None
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
if content := llm_response_dict.get("content", ""):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
else:
reply_set = []
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
if return_prompt:
return success, reply_set, prompt
else:
return success, reply_set, None
except ValueError as ve:
raise ve
except UserWarning as uw:
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
return False, [], None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
logger.error(traceback.format_exc())
return False, [], None
async def rewrite_reply(
chat_stream: Optional[ChatStream] = None,
reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
return_prompt: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""重写回复
Args:
chat_stream: 聊天流对象(优先)
reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取)
chat_id: 聊天ID备用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象
return_prompt: 是否返回提示词
Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
logger.info("[GeneratorAPI] 开始重写回复")
# 如果参数缺失从reply_data中获取
if reply_data:
raw_reply = raw_reply or reply_data.get("raw_reply", "")
reason = reason or reply_data.get("reason", "")
reply_to = reply_to or reply_data.get("reply_to", "")
# 调用回复器重写回复
success, content, prompt = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply,
reason=reason,
reply_to=reply_to,
return_prompt=return_prompt,
)
reply_set = []
if content:
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
if success:
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
else:
logger.warning("[GeneratorAPI] 重写回复失败")
return success, reply_set, prompt if return_prompt else None
except ValueError as ve:
raise ve
except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
return False, [], None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
"""将文本处理为更拟人化的文本
Args:
content: 文本内容
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
"""
if not isinstance(content, str):
raise ValueError("content 必须是字符串类型")
try:
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
reply_set = []
for text in processed_response:
reply_seg = ("text", text)
reply_set.append(reply_seg)
return reply_set
except Exception as e:
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
return []
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
prompt: str = "",
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return None
try:
logger.debug("[GeneratorAPI] 开始生成自定义回复")
response, _, _, _ = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorAPI] 自定义回复生成成功")
return response
else:
logger.warning("[GeneratorAPI] 自定义回复生成失败")
return None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
return None

View File

@@ -0,0 +1,122 @@
"""LLM API模块
提供了与LLM模型交互的功能
使用方式:
from src.plugin_system.apis import llm_api
models = llm_api.get_available_models()
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
"""
from typing import Tuple, Dict, List, Any, Optional
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.config.api_ada_configs import TaskConfig
logger = get_logger("llm_api")
# =============================================================================
# LLM模型API函数
# =============================================================================
def get_available_models() -> Dict[str, TaskConfig]:
"""获取所有可用的模型配置
Returns:
Dict[str, Any]: 模型配置字典key为模型名称value为模型配置
"""
try:
# 自动获取所有属性并转换为字典形式
models = model_config.model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
for attr in attrs:
if not attr.startswith("__"):
try:
value = getattr(models, attr)
if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value
except Exception as e:
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
continue
return rets
except Exception as e:
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
return {}
async def generate_with_model(
prompt: str,
model_config: TaskConfig,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
request_type: 请求类型标识
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
model_name_list = model_config.model_list
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
return True, response, reasoning_content, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", ""
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
"""使用指定模型和工具生成内容
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
tool_options: 工具选项列表
request_type: 请求类型标识
temperature: 温度参数
max_tokens: 最大token数
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
model_name_list = model_config.model_list
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens
)
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "", None

View File

@@ -0,0 +1,3 @@
from src.common.logger import get_logger
__all__ = ["get_logger"]

View File

@@ -0,0 +1,483 @@
"""
消息API模块
提供消息查询和构建成字符串的功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import message_api
messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time)
readable_text = message_api.build_readable_messages(messages)
"""
from typing import List, Dict, Any, Tuple, Optional
from src.config.config import global_config
import time
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_users,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
build_readable_messages,
build_readable_messages_with_list,
get_person_id_list,
)
# =============================================================================
# 消息查询API函数
# =============================================================================
def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
"""
获取指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
def get_messages_by_time_in_chat(
chat_id: str,
start_time: float,
end_time: float,
limit: int = 0,
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定时间范围内的消息
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
filter_command: 是否过滤命令消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command))
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
def get_messages_by_time_in_chat_inclusive(
chat_id: str,
start_time: float,
end_time: float,
limit: int = 0,
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定时间范围内的消息(包含边界)
Args:
chat_id: 聊天ID
start_time: 开始时间戳(包含)
end_time: 结束时间戳(包含)
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai:
return filter_mai_messages(
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
)
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定用户在指定时间范围内的消息
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
"""
随机选择一个聊天,返回该聊天在指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
Args:
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
"""
获取指定时间戳之前的消息
Args:
timestamp: 时间戳
limit: 限制返回的消息数量0为不限制
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
return get_raw_msg_before_timestamp(timestamp, limit)
def get_messages_before_time_in_chat(
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定时间戳之前的消息
Args:
chat_id: 聊天ID
timestamp: 时间戳
limit: 限制返回的消息数量0为不限制
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai:
return filter_mai_messages(get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit))
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]:
"""
获取指定用户在指定时间戳之前的消息
Args:
timestamp: 时间戳
person_ids: 用户ID列表
limit: 限制返回的消息数量0为不限制
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
"""
获取指定聊天中最近一段时间的消息
Args:
chat_id: 聊天ID
hours: 最近多少小时默认24小时
limit: 限制返回的消息数量默认100条
limit_mode: 当limit>0时生效'earliest'表示获取最早的记录,'latest'表示获取最新的记录
filter_mai: 是否过滤麦麦自身的消息默认为False
Returns:
List[Dict[str, Any]]: 消息列表
Raises:
ValueError: 如果参数不合法s
"""
if not isinstance(hours, (int, float)) or hours < 0:
raise ValueError("hours 不能是负数")
if not isinstance(limit, int) or limit < 0:
raise ValueError("limit 必须是非负整数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
now = time.time()
start_time = now - hours * 3600
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
# =============================================================================
# 消息计数API函数
# =============================================================================
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
"""
计算指定聊天中从开始时间到结束时间的新消息数量
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳如果为None则使用当前时间
Returns:
int: 新消息数量
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)):
raise ValueError("start_time 必须是数字类型")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since(chat_id, start_time, end_time)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
Args:
chat_id: 聊天ID
start_time: 开始时间戳
end_time: 结束时间戳
person_ids: 用户ID列表
Returns:
int: 新消息数量
Raises:
ValueError: 如果参数不合法
"""
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
# =============================================================================
# 消息格式化API函数
# =============================================================================
def build_readable_messages_to_str(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
"""
将消息列表构建成可读的字符串
Args:
messages: 消息列表
replace_bot_name: 是否将机器人的名称替换为""
merge_messages: 是否合并连续消息
timestamp_mode: 时间戳显示模式,'relative''absolute'
read_mark: 已读标记时间戳,用于分割已读和未读消息
truncate: 是否截断长消息
show_actions: 是否显示动作记录
Returns:
格式化后的可读字符串
"""
return build_readable_messages(
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
)
async def build_readable_messages_with_details(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
"""
将消息列表构建成可读的字符串,并返回详细信息
Args:
messages: 消息列表
replace_bot_name: 是否将机器人的名称替换为""
merge_messages: 是否合并连续消息
timestamp_mode: 时间戳显示模式,'relative''absolute'
truncate: 是否截断长消息
Returns:
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
"""
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的用户ID列表
Args:
messages: 消息列表
Returns:
用户ID列表
"""
return await get_person_id_list(messages)
# =============================================================================
# 消息过滤函数
# =============================================================================
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
从消息列表中移除麦麦的消息
Args:
messages: 消息列表,每个元素是消息字典
Returns:
过滤后的消息列表
"""
return [msg for msg in messages if msg.get("user_id") != str(global_config.bot.qq_account)]

View File

@@ -0,0 +1,154 @@
"""个人信息API模块
提供个人信息查询功能,用于插件获取用户相关信息
使用方式:
from src.plugin_system.apis import person_api
person_id = person_api.get_person_id("qq", 123456)
value = await person_api.get_person_value(person_id, "nickname")
"""
from typing import Any, Optional
from src.common.logger import get_logger
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
logger = get_logger("person_api")
# =============================================================================
# 个人信息API函数
# =============================================================================
def get_person_id(platform: str, user_id: int) -> str:
"""根据平台和用户ID获取person_id
Args:
platform: 平台名称,如 "qq", "telegram"
user_id: 用户ID
Returns:
str: 唯一的person_idMD5哈希值
示例:
person_id = person_api.get_person_id("qq", 123456)
"""
try:
return PersonInfoManager.get_person_id(platform, user_id)
except Exception as e:
logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
return ""
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
"""根据person_id和字段名获取某个值
Args:
person_id: 用户的唯一标识ID
field_name: 要获取的字段名,如 "nickname", "impression"
default: 当字段不存在或获取失败时返回的默认值
Returns:
Any: 字段值或默认值
示例:
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
impression = await person_api.get_person_value(person_id, "impression")
"""
try:
person_info_manager = get_person_info_manager()
value = await person_info_manager.get_value(person_id, field_name)
return value if value is not None else default
except Exception as e:
logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
return default
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
"""批量获取用户信息字段值
Args:
person_id: 用户的唯一标识ID
field_names: 要获取的字段名列表
default_dict: 默认值字典,键为字段名,值为默认值
Returns:
dict: 字段名到值的映射字典
示例:
values = await person_api.get_person_values(
person_id,
["nickname", "impression", "know_times"],
{"nickname": "未知用户", "know_times": 0}
)
"""
try:
person_info_manager = get_person_info_manager()
values = await person_info_manager.get_values(person_id, field_names)
# 如果获取成功,返回结果
if values:
return values
# 如果获取失败,构建默认值字典
result = {}
if default_dict:
for field in field_names:
result[field] = default_dict.get(field, None)
else:
for field in field_names:
result[field] = None
return result
except Exception as e:
logger.error(f"[PersonAPI] 批量获取用户信息失败: person_id={person_id}, fields={field_names}, error={e}")
# 返回默认值字典
result = {}
if default_dict:
for field in field_names:
result[field] = default_dict.get(field, None)
else:
for field in field_names:
result[field] = None
return result
async def is_person_known(platform: str, user_id: int) -> bool:
"""判断是否认识某个用户
Args:
platform: 平台名称
user_id: 用户ID
Returns:
bool: 是否认识该用户
示例:
known = await person_api.is_person_known("qq", 123456)
"""
try:
person_info_manager = get_person_info_manager()
return await person_info_manager.is_person_known(platform, user_id)
except Exception as e:
logger.error(f"[PersonAPI] 检查用户是否已知失败: platform={platform}, user_id={user_id}, error={e}")
return False
def get_person_id_by_name(person_name: str) -> str:
"""根据用户名获取person_id
Args:
person_name: 用户名
Returns:
str: person_id如果未找到返回空字符串
示例:
person_id = person_api.get_person_id_by_name("张三")
"""
try:
person_info_manager = get_person_info_manager()
return person_info_manager.get_person_id_by_person_name(person_name)
except Exception as e:
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return ""

View File

@@ -0,0 +1,120 @@
from typing import Tuple, List
def list_loaded_plugins() -> List[str]:
"""
列出所有当前加载的插件。
Returns:
List[str]: 当前加载的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_loaded_plugins()
def list_registered_plugins() -> List[str]:
"""
列出所有已注册的插件。
Returns:
List[str]: 已注册的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_registered_plugins()
def get_plugin_path(plugin_name: str) -> str:
"""
获取指定插件的路径。
Args:
plugin_name (str): 插件名称。
Returns:
str: 插件目录的绝对路径。
Raises:
ValueError: 如果插件不存在。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
return plugin_path
else:
raise ValueError(f"插件 '{plugin_name}' 不存在。")
async def remove_plugin(plugin_name: str) -> bool:
"""
卸载指定的插件。
**此函数是异步的,确保在异步环境中调用。**
Args:
plugin_name (str): 要卸载的插件名称。
Returns:
bool: 卸载是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return await plugin_manager.remove_registered_plugin(plugin_name)
async def reload_plugin(plugin_name: str) -> bool:
"""
重新加载指定的插件。
**此函数是异步的,确保在异步环境中调用。**
Args:
plugin_name (str): 要重新加载的插件名称。
Returns:
bool: 重新加载是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return await plugin_manager.reload_registered_plugin(plugin_name)
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
"""
加载指定的插件。
Args:
plugin_name (str): 要加载的插件名称。
Returns:
Tuple[bool, int]: 加载是否成功,成功或失败个数。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.load_registered_plugin_classes(plugin_name)
def add_plugin_directory(plugin_directory: str) -> bool:
"""
添加插件目录。
Args:
plugin_directory (str): 要添加的插件目录路径。
Returns:
bool: 添加是否成功。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.add_plugin_directory(plugin_directory)
def rescan_plugin_directory() -> Tuple[int, int]:
"""
重新扫描插件目录,加载新插件。
Returns:
Tuple[int, int]: 成功加载的插件数量和失败的插件数量。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.rescan_plugin_directory()

View File

@@ -0,0 +1,46 @@
from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls):
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.base.base_plugin import BasePlugin
"""插件注册装饰器
用法:
@register_plugin
class MyPlugin(BasePlugin):
plugin_name = "my_plugin"
plugin_description = "我的插件"
...
"""
if not issubclass(cls, BasePlugin):
logger.error(f"{cls.__name__} 不是 BasePlugin 的子类")
return cls
# 只是注册插件类,不立即实例化
# 插件管理器会负责实例化和注册
plugin_name: str = cls.plugin_name # type: ignore
if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
splitted_name = cls.__module__.split(".")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
if not (root_path / "pyproject.toml").exists():
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
return cls
plugin_manager.plugin_classes[plugin_name] = cls
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
return cls

View File

@@ -0,0 +1,369 @@
"""
发送API模块
专门负责发送各种类型的消息采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import send_api
# 方式1直接使用stream_id推荐
await send_api.text_to_stream("hello", stream_id)
await send_api.emoji_to_stream(emoji_base64, stream_id)
await send_api.custom_to_stream("video", video_data, stream_id)
# 方式2使用群聊/私聊指定函数
await send_api.text_to_group("hello", "123456")
await send_api.text_to_user("hello", "987654")
# 方式3使用通用custom_message函数
await send_api.custom_message("video", video_data, "123456", True)
"""
import traceback
import time
import difflib
from typing import Optional, Union
from src.common.logger import get_logger
# 导入依赖
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, replace_user_references_async
from src.person_info.person_info import get_person_info_manager
from maim_message import Seg, UserInfo
from src.config.config import global_config
logger = get_logger("send_api")
# =============================================================================
# 内部实现函数(不暴露给外部)
# =============================================================================
async def _send_to_target(
message_type: str,
content: Union[str, dict],
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_to: str = "",
reply_to_platform_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
"""向指定目标发送消息的内部实现
Args:
message_type: 消息类型,如"text""image""emoji"
content: 消息内容
stream_id: 目标流ID
display_message: 显示消息
typing: 是否模拟打字等待。
reply_to: 回复消息,格式为"发送者:消息内容"
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
storage_message: 是否存储消息到数据库
show_log: 发送是否显示日志
Returns:
bool: 是否发送成功
"""
try:
if show_log:
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
# 查找目标聊天流
target_stream = get_chat_manager().get_stream(stream_id)
if not target_stream:
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
return False
# 创建发送器
heart_fc_sender = HeartFCSender()
# 生成消息ID
current_time = time.time()
message_id = f"send_api_{int(current_time * 1000)}"
# 构建机器人用户信息
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=target_stream.platform,
)
# 创建消息段
message_segment = Seg(type=message_type, data=content) # type: ignore
# 处理回复消息
anchor_message = None
if reply_to:
anchor_message = await _find_reply_message(target_stream, reply_to)
if anchor_message and anchor_message.message_info.user_info and not reply_to_platform_id:
reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
)
# 构建发送消息对象
bot_message = MessageSending(
message_id=message_id,
chat_stream=target_stream,
bot_user_info=bot_user_info,
sender_info=target_stream.user_info,
message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
is_head=True,
is_emoji=(message_type == "emoji"),
thinking_start_time=current_time,
reply_to=reply_to_platform_id,
)
# 发送消息
sent_msg = await heart_fc_sender.send_message(
bot_message,
typing=typing,
set_reply=(anchor_message is not None),
storage_message=storage_message,
show_log=show_log,
)
if sent_msg:
logger.debug(f"[SendAPI] 成功发送消息到 {stream_id}")
return True
else:
logger.error("[SendAPI] 发送消息失败")
return False
except Exception as e:
logger.error(f"[SendAPI] 发送消息时出错: {e}")
traceback.print_exc()
return False
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
# sourcery skip: inline-variable, use-named-expression
"""查找要回复的消息
Args:
target_stream: 目标聊天流
reply_to: 回复格式,如"发送者:消息内容""发送者:消息内容"
Returns:
Optional[MessageRecv]: 找到的消息如果没找到则返回None
"""
try:
# 解析reply_to参数
if ":" in reply_to:
parts = reply_to.split(":", 1)
elif "" in reply_to:
parts = reply_to.split("", 1)
else:
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
return None
if len(parts) != 2:
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
return None
sender = parts[0].strip()
text = parts[1].strip()
# 获取聊天流的最新20条消息
reverse_talking_message = get_raw_msg_before_timestamp_with_chat(
target_stream.stream_id,
time.time(), # 当前时间之前的消息
20, # 最新的20条消息
)
# 反转列表,使最新的消息在前面
reverse_talking_message = list(reversed(reverse_talking_message))
find_msg = None
for message in reverse_talking_message:
user_id = message["user_id"]
platform = message["chat_info_platform"]
person_id = get_person_info_manager().get_person_id(platform, user_id)
person_name = await get_person_info_manager().get_value(person_id, "person_name")
if person_name == sender:
translate_text = message["processed_plain_text"]
# 使用独立函数处理用户引用格式
translate_text = await replace_user_references_async(translate_text, platform)
similarity = difflib.SequenceMatcher(None, text, translate_text).ratio()
if similarity >= 0.9:
find_msg = message
break
if not find_msg:
logger.info("[SendAPI] 未找到匹配的回复消息")
return None
# 构建MessageRecv对象
user_info = {
"platform": find_msg.get("user_platform", ""),
"user_id": find_msg.get("user_id", ""),
"user_nickname": find_msg.get("user_nickname", ""),
"user_cardname": find_msg.get("user_cardname", ""),
}
group_info = {}
if find_msg.get("chat_info_group_id"):
group_info = {
"platform": find_msg.get("chat_info_group_platform", ""),
"group_id": find_msg.get("chat_info_group_id", ""),
"group_name": find_msg.get("chat_info_group_name", ""),
}
format_info = {"content_format": "", "accept_format": ""}
template_info = {"template_items": {}}
message_info = {
"platform": target_stream.platform,
"message_id": find_msg.get("message_id"),
"time": find_msg.get("time"),
"group_info": group_info,
"user_info": user_info,
"additional_config": find_msg.get("additional_config"),
"format_info": format_info,
"template_info": template_info,
}
message_dict = {
"message_info": message_info,
"raw_message": find_msg.get("processed_plain_text"),
"processed_plain_text": find_msg.get("processed_plain_text"),
}
find_rec_msg = MessageRecv(message_dict)
find_rec_msg.update_chat_stream(target_stream)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {sender}")
return find_rec_msg
except Exception as e:
logger.error(f"[SendAPI] 查找回复消息时出错: {e}")
traceback.print_exc()
return None
# =============================================================================
# 公共API函数 - 预定义类型的发送函数
# =============================================================================
async def text_to_stream(
text: str,
stream_id: str,
typing: bool = False,
reply_to: str = "",
reply_to_platform_id: str = "",
storage_message: bool = True,
) -> bool:
"""向指定流发送文本消息
Args:
text: 要发送的文本内容
stream_id: 聊天流ID
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
return await _send_to_target(
"text",
text,
stream_id,
"",
typing,
reply_to,
reply_to_platform_id=reply_to_platform_id,
storage_message=storage_message,
)
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
"""向指定流发送表情包
Args:
emoji_base64: 表情包的base64编码
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True) -> bool:
"""向指定流发送图片
Args:
image_base64: 图片的base64编码
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
async def command_to_stream(
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = ""
) -> bool:
"""向指定流发送命令
Args:
command: 命令
stream_id: 聊天流ID
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
return await _send_to_target(
"command", command, stream_id, display_message, typing=False, storage_message=storage_message
)
async def custom_to_stream(
message_type: str,
content: str | dict,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
show_log: bool = True,
) -> bool:
"""向指定流发送自定义类型消息
Args:
message_type: 消息类型,如"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
stream_id: 聊天流ID
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
storage_message: 是否存储消息到数据库
show_log: 是否显示日志
Returns:
bool: 是否发送成功
"""
return await _send_to_target(
message_type=message_type,
content=content,
stream_id=stream_id,
display_message=display_message,
typing=typing,
reply_to=reply_to,
storage_message=storage_message,
show_log=show_log,
)

View File

@@ -0,0 +1,34 @@
from typing import Optional, Type
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取公开工具实例"""
from src.plugin_system.core import component_registry
# 获取插件配置
tool_info = component_registry.get_component_info(tool_name, ComponentType.TOOL)
if tool_info:
plugin_config = component_registry.get_plugin_config(tool_info.plugin_name)
else:
plugin_config = None
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class(plugin_config) if tool_class else None
def get_llm_available_tool_definitions():
"""获取LLM可用的工具定义列表
Returns:
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
"""
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]

View File

@@ -0,0 +1,49 @@
"""
插件基础类模块
提供插件开发的基础类和类型定义
"""
from .base_plugin import BasePlugin
from .base_action import BaseAction
from .base_tool import BaseTool
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .component_types import (
ComponentType,
ActionActivationType,
ChatMode,
ComponentInfo,
ActionInfo,
CommandInfo,
ToolInfo,
PluginInfo,
PythonDependency,
EventHandlerInfo,
EventType,
MaiMessages,
ToolParamType,
)
from .config_types import ConfigField
__all__ = [
"BasePlugin",
"BaseAction",
"BaseCommand",
"BaseTool",
"ComponentType",
"ActionActivationType",
"ChatMode",
"ComponentInfo",
"ActionInfo",
"CommandInfo",
"ToolInfo",
"PluginInfo",
"PythonDependency",
"ConfigField",
"EventHandlerInfo",
"EventType",
"BaseEventHandler",
"MaiMessages",
"ToolParamType",
]

View File

@@ -0,0 +1,437 @@
import time
import asyncio
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api
logger = get_logger("base_action")
class BaseAction(ABC):
"""Action组件基类
Action是插件的一种组件类型用于处理聊天中的动作逻辑
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
- focus_activation_type: 专注模式激活类型
- normal_activation_type: 普通模式激活类型
- activation_keywords: 激活关键词列表
- keyword_case_sensitive: 关键词是否区分大小写
- mode_enable: 启用的聊天模式
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
- llm_judge_prompt: LLM判断提示词
"""
def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: ChatStream,
log_prefix: str = "",
plugin_config: Optional[dict] = None,
action_message: Optional[dict] = None,
**kwargs,
):
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
"""初始化Action组件
Args:
action_data: 动作数据
reasoning: 执行该动作的理由
cycle_timers: 计时器字典
thinking_id: 思考ID
chat_stream: 聊天流对象
log_prefix: 日志前缀
plugin_config: 插件配置字典
action_message: 消息数据
**kwargs: 其他参数
"""
if plugin_config is None:
plugin_config = {}
self.action_data = action_data
self.reasoning = reasoning
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
self.log_prefix = log_prefix
self.plugin_config = plugin_config or {}
"""对应的插件配置"""
# 设置动作基本信息实例属性
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
"""Action的名字"""
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
"""Action的描述"""
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
# 设置激活类型实例属性(从类属性复制,提供默认值)
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
"""FOCUS模式下的激活类型"""
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
"""NORMAL模式下的激活类型"""
self.activation_type = getattr(self.__class__, "activation_type", self.focus_activation_type)
"""激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率"""
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
"""协助LLM进行判断的Prompt"""
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
"""激活类型为KEYWORD时的KEYWORDS列表"""
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.stream_id
self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解)
self.action_message = action_message
self.group_id = None
self.group_name = None
self.user_id = None
self.user_nickname = None
self.is_group = False
self.target_id = None
self.has_action_message = False
if self.action_message:
self.has_action_message = True
else:
self.action_message = {}
if self.has_action_message:
if self.action_name != "no_reply":
self.group_id = str(self.action_message.get("chat_info_group_id", None))
self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None)
if self.group_id:
self.is_group = True
self.target_id = self.group_id
else:
self.is_group = False
self.target_id = self.user_id
else:
if self.chat_stream.group_info:
self.group_id = self.chat_stream.group_info.group_id
self.group_name = self.chat_stream.group_info.group_name
self.is_group = True
self.target_id = self.group_id
else:
self.user_id = self.chat_stream.user_info.user_id
self.user_nickname = self.chat_stream.user_info.user_nickname
self.is_group = False
self.target_id = self.user_id
logger.debug(f"{self.log_prefix} Action组件初始化完成")
logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
)
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
"""等待新消息或超时
在loop_start_time之后等待新消息如果没有新消息且没有超时就一直等待。
使用message_api检查self.chat_id对应的聊天中是否有新消息。
Args:
timeout: 超时时间默认1200秒
Returns:
Tuple[bool, str]: (是否收到新消息, 空字符串)
"""
try:
# 获取循环开始时间,如果没有则使用当前时间
loop_start_time = self.action_data.get("loop_start_time", time.time())
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
# 确保有有效的chat_id
if not self.chat_id:
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
return False, "没有有效的chat_id"
wait_start_time = asyncio.get_event_loop().time()
while True:
# 检查关闭标志
# shutting_down = self.get_action_context("shutting_down", False)
# if shutting_down:
# logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
# return False, ""
# 检查新消息
current_time = time.time()
new_message_count = message_api.count_new_messages(
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
)
if new_message_count > 0:
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息聊天ID: {self.chat_id}")
return True, ""
# 检查超时
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
if elapsed_time > timeout:
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)聊天ID: {self.chat_id}")
return False, ""
# 每30秒记录一次等待状态
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
# 短暂休眠
await asyncio.sleep(0.5)
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
return False, ""
except Exception as e:
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
async def send_text(
self, content: str, reply_to: str = "", typing: bool = False
) -> bool:
"""发送文本消息
Args:
content: 文本内容
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.text_to_stream(
text=content,
stream_id=self.chat_id,
reply_to=reply_to,
typing=typing,
)
async def send_emoji(self, emoji_base64: str) -> bool:
"""发送表情包
Args:
emoji_base64: 表情包的base64编码
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.emoji_to_stream(emoji_base64, self.chat_id)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
Args:
image_base64: 图片的base64编码
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.image_to_stream(image_base64, self.chat_id)
async def send_custom(self, message_type: str, content: str, typing: bool = False, reply_to: str = "") -> bool:
"""发送自定义类型消息
Args:
message_type: 消息类型,如"video""file""audio"
content: 消息内容
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=self.chat_id,
typing=typing,
reply_to=reply_to,
)
async def store_action_info(
self,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
) -> None:
"""存储动作信息到数据库
Args:
action_build_into_prompt: 是否构建到提示中
action_prompt_display: 显示的action提示信息
action_done: action是否完成
"""
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=action_build_into_prompt,
action_prompt_display=action_prompt_display,
action_done=action_done,
thinking_id=self.thinking_id,
action_data=self.action_data,
action_name=self.action_name,
)
async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息
使用stream API发送命令
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
try:
if not self.chat_id:
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
success = await send_api.command_to_stream(
command=command_data,
stream_id=self.chat_id,
storage_message=storage_message,
display_message=display_message,
)
if success:
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
else:
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
@classmethod
def get_action_info(cls) -> "ActionInfo":
"""从类属性生成ActionInfo
所有信息都从类属性中读取,确保一致性和完整性。
Action类必须定义所有必要的类属性。
Returns:
ActionInfo: 生成的Action信息对象
"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
if "." in name:
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
# 获取focus_activation_type和normal_activation_type
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
# 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type
activation_type = getattr(cls, "activation_type", focus_activation_type)
return ActionInfo(
name=name,
component_type=ComponentType.ACTION,
description=getattr(cls, "action_description", "Action动作"),
focus_activation_type=focus_activation_type,
normal_activation_type=normal_activation_type,
activation_type=activation_type,
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
mode_enable=getattr(cls, "mode_enable", ChatMode.ALL),
parallel_action=getattr(cls, "parallel_action", True),
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
# 使用正确的字段名
action_parameters=getattr(cls, "action_parameters", {}).copy(),
action_require=getattr(cls, "action_require", []).copy(),
associated_types=getattr(cls, "associated_types", []).copy(),
)
@abstractmethod
async def execute(self) -> Tuple[bool, str]:
"""执行Action的抽象方法子类必须实现
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def handle_action(self) -> Tuple[bool, str]:
"""兼容旧系统的handle_action接口委托给execute方法
为了保持向后兼容性旧系统的代码可能会调用handle_action方法。
此方法将调用委托给新的execute方法。
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
return await self.execute()
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -0,0 +1,228 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional
from src.common.logger import get_logger
from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import MessageRecv
from src.plugin_system.apis import send_api
logger = get_logger("base_command")
class BaseCommand(ABC):
"""Command组件基类
Command是插件的一种组件类型用于处理命令请求
子类可以通过类属性定义命令模式:
- command_pattern: 命令匹配的正则表达式
- command_help: 命令帮助信息
- command_examples: 命令使用示例列表
"""
command_name: str = ""
"""Command组件的名称"""
command_description: str = ""
"""Command组件的描述"""
# 默认命令设置
command_pattern: str = r""
"""命令匹配的正则表达式"""
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
"""初始化Command组件
Args:
message: 接收到的消息对象
plugin_config: 插件配置字典
"""
self.message = message
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]"
logger.debug(f"{self.log_prefix} Command组件初始化完成")
def set_matched_groups(self, groups: Dict[str, str]) -> None:
"""设置正则表达式匹配的命名组
Args:
groups: 正则表达式匹配的命名组
"""
self.matched_groups = groups
@abstractmethod
async def execute(self) -> Tuple[bool, Optional[str], bool]:
"""执行Command的抽象方法子类必须实现
Returns:
Tuple[bool, Optional[str], bool]: (是否执行成功, 可选的回复消息, 是否拦截消息 不进行 后续处理)
"""
pass
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
async def send_text(self, content: str, reply_to: str = "") -> bool:
"""发送回复消息
Args:
content: 回复内容
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
) -> bool:
"""发送指定类型的回复消息到当前聊天环境
Args:
message_type: 消息类型,如"text""image""emoji"
content: 消息内容
display_message: 显示消息(可选)
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=chat_stream.stream_id,
display_message=display_message,
typing=typing,
reply_to=reply_to,
)
async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息
Args:
command_name: 命令名称
args: 命令参数
display_message: 显示消息
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
try:
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
# 构造命令数据
command_data = {"name": command_name, "args": args or {}}
success = await send_api.command_to_stream(
command=command_data,
stream_id=chat_stream.stream_id,
storage_message=storage_message,
display_message=display_message,
)
if success:
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
else:
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
async def send_emoji(self, emoji_base64: str) -> bool:
"""发送表情包
Args:
emoji_base64: 表情包的base64编码
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
Args:
image_base64: 图片的base64编码
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
@classmethod
def get_command_info(cls) -> "CommandInfo":
"""从类属性生成CommandInfo
Args:
name: Command名称如果不提供则使用类名
description: Command描述如果不提供则使用类文档字符串
Returns:
CommandInfo: 生成的Command信息对象
"""
if "." in cls.command_name:
logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
return CommandInfo(
name=cls.command_name,
component_type=ComponentType.COMMAND,
description=cls.command_description,
command_pattern=cls.command_pattern,
)

View File

@@ -0,0 +1,101 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict
from src.common.logger import get_logger
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
logger = get_logger("base_event_handler")
class BaseEventHandler(ABC):
"""事件处理器基类
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
"""
event_type: EventType = EventType.UNKNOWN
"""事件类型,默认为未知"""
handler_name: str = ""
"""处理器名称"""
handler_description: str = ""
"""处理器描述"""
weight: int = 0
"""处理器权重,越大权重越高"""
intercept_message: bool = False
"""是否拦截消息,默认为否"""
def __init__(self):
self.log_prefix = "[EventHandler]"
self.plugin_name = ""
"""对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""
if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type")
@abstractmethod
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, Optional[str]]:
"""执行事件处理的抽象方法,子类必须实现
Returns:
Tuple[bool, bool, Optional[str]]: (是否执行成功, 是否需要继续处理, 可选的返回消息)
"""
raise NotImplementedError("子类必须实现 execute 方法")
@classmethod
def get_handler_info(cls) -> "EventHandlerInfo":
"""获取事件处理器的信息"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
if "." in name:
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
return EventHandlerInfo(
name=name,
component_type=ComponentType.EVENT_HANDLER,
description=getattr(cls, "handler_description", "events处理器"),
event_type=cls.event_type,
weight=cls.weight,
intercept_message=cls.intercept_message,
)
def set_plugin_config(self, plugin_config: Dict) -> None:
"""设置插件配置
Args:
plugin_config (dict): 插件配置字典
"""
self.plugin_config = plugin_config
def set_plugin_name(self, plugin_name: str) -> None:
"""设置插件名称
Args:
plugin_name (str): 插件名称
"""
self.plugin_name = plugin_name
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -0,0 +1,76 @@
from abc import abstractmethod
from typing import List, Type, Tuple, Union
from .plugin_base import PluginBase
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
from .base_action import BaseAction
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .base_tool import BaseTool
logger = get_logger("base_plugin")
class BasePlugin(PluginBase):
"""基于Action和Command的插件基类
所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件:
- Action组件处理聊天中的动作
- Command组件处理命令请求
- 未来可扩展Scheduler、Listener等
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@abstractmethod
def get_plugin_components(
self,
) -> List[
Union[
Tuple[ActionInfo, Type[BaseAction]],
Tuple[CommandInfo, Type[BaseCommand]],
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
Tuple[ToolInfo, Type[BaseTool]],
]
]:
"""获取插件包含的组件列表
子类必须实现此方法,返回组件信息和组件类的列表
Returns:
List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
"""
raise NotImplementedError("Subclasses must implement this method")
def register_plugin(self) -> bool:
"""注册插件及其所有组件"""
from src.plugin_system.core.component_registry import component_registry
components = self.get_plugin_components()
# 检查依赖
if not self._check_dependencies():
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
return False
# 注册所有组件
registered_components = []
for component_info, component_class in components:
component_info.plugin_name = self.plugin_name
if component_registry.register_component(component_info, component_class):
registered_components.append(component_info)
else:
logger.warning(f"{self.log_prefix} 组件 {component_info.name} 注册失败")
# 更新插件信息中的组件列表
self.plugin_info.components = registered_components
# 注册插件
if component_registry.register_plugin(self.plugin_info):
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
return True
else:
logger.error(f"{self.log_prefix} 插件注册失败")
return False

View File

@@ -0,0 +1,119 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from rich.traceback import install
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
install(extra_lines=3)
logger = get_logger("base_tool")
class BaseTool(ABC):
"""所有工具的基类"""
name: str = ""
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
param_name: 参数名称
param_type: 参数类型
description: 参数描述
required: 是否必填
enum_values: 枚举值列表
例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])]
"""
available_for_llm: bool = False
"""是否可供LLM使用"""
def __init__(self, plugin_config: Optional[dict] = None):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
@classmethod
def get_tool_info(cls) -> ToolInfo:
"""获取工具信息"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return ToolInfo(
name=cls.name,
tool_description=cls.description,
enabled=cls.available_for_llm,
tool_parameters=cls.parameters,
component_type=ComponentType.TOOL,
)
@abstractmethod
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数(供llm调用)
通过该方法maicore会通过llm的tool call来调用工具
传入的是json格式的参数符合parameters定义的格式
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")
async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
"""直接执行工具函数(供插件调用)
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
插件可以直接调用此方法,用更加明了的方式传入参数
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
Args:
**function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
for param_name in parameter_required:
if param_name not in function_args:
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
return await self.execute(function_args)
def get_config(self, key: str, default=None):
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
if not self.plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current

View File

@@ -0,0 +1,283 @@
from enum import Enum
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
# 组件类型枚举
class ComponentType(Enum):
"""组件类型枚举"""
ACTION = "action" # 动作组件
COMMAND = "command" # 命令组件
TOOL = "tool" # 服务组件(预留)
SCHEDULER = "scheduler" # 定时任务组件(预留)
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
def __str__(self) -> str:
return self.value
# 动作激活类型枚举
class ActionActivationType(Enum):
"""动作激活类型枚举"""
NEVER = "never" # 从不激活(默认关闭)
ALWAYS = "always" # 默认参与到planner
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
RANDOM = "random" # 随机启用action到planner
KEYWORD = "keyword" # 关键词触发启用action到planner
def __str__(self):
return self.value
# 聊天模式枚举
class ChatMode(Enum):
"""聊天模式枚举"""
FOCUS = "focus" # Focus聊天模式
NORMAL = "normal" # Normal聊天模式
PRIORITY = "priority" # 优先级聊天模式
ALL = "all" # 所有聊天模式
def __str__(self):
return self.value
# 事件类型枚举
class EventType(Enum):
"""
事件类型枚举类
"""
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_MESSAGE = "on_message"
ON_PLAN = "on_plan"
POST_LLM = "post_llm"
AFTER_LLM = "after_llm"
POST_SEND = "post_send"
AFTER_SEND = "after_send"
UNKNOWN = "unknown" # 未知事件类型
def __str__(self) -> str:
return self.value
@dataclass
class PythonDependency:
"""Python包依赖信息"""
package_name: str # 包名称
version: str = "" # 版本要求,例如: ">=1.0.0", "==2.1.3", ""表示任意版本
optional: bool = False # 是否为可选依赖
description: str = "" # 依赖描述
install_name: str = "" # 安装时的包名如果与import名不同
def __post_init__(self):
if not self.install_name:
self.install_name = self.package_name
def get_pip_requirement(self) -> str:
"""获取pip安装格式的依赖字符串"""
if self.version:
return f"{self.install_name}{self.version}"
return self.install_name
@dataclass
class ComponentInfo:
"""组件信息"""
name: str # 组件名称
component_type: ComponentType # 组件类型
description: str = "" # 组件描述
enabled: bool = True # 是否启用
plugin_name: str = "" # 所属插件名称
is_built_in: bool = False # 是否为内置组件
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@dataclass
class ActionInfo(ComponentInfo):
"""动作组件信息"""
action_parameters: Dict[str, str] = field(
default_factory=dict
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
action_require: List[str] = field(default_factory=list) # 动作需求说明
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
# 激活类型相关
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS
activation_type: ActionActivationType = ActionActivationType.ALWAYS
random_activation_probability: float = 0.0
llm_judge_prompt: str = ""
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
keyword_case_sensitive: bool = False
# 模式和并行设置
mode_enable: ChatMode = ChatMode.ALL
parallel_action: bool = False
def __post_init__(self):
super().__post_init__()
if self.activation_keywords is None:
self.activation_keywords = []
if self.action_parameters is None:
self.action_parameters = {}
if self.action_require is None:
self.action_require = []
if self.associated_types is None:
self.associated_types = []
self.component_type = ComponentType.ACTION
@dataclass
class CommandInfo(ComponentInfo):
"""命令组件信息"""
command_pattern: str = "" # 命令匹配模式(正则表达式)
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.COMMAND
@dataclass
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
tool_description: str = "" # 工具描述
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.TOOL
@dataclass
class EventHandlerInfo(ComponentInfo):
"""事件处理器组件信息"""
event_type: EventType = EventType.ON_MESSAGE # 监听事件类型
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
weight: int = 0 # 事件处理器权重,决定执行顺序
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.EVENT_HANDLER
@dataclass
class PluginInfo:
"""插件信息"""
display_name: str # 插件显示名称
name: str # 插件名称
description: str # 插件描述
version: str = "1.0.0" # 插件版本
author: str = "" # 插件作者
enabled: bool = True # 是否启用
is_built_in: bool = False # 是否为内置插件
components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表
dependencies: List[str] = field(default_factory=list) # 依赖的其他插件
python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖
config_file: str = "" # 配置文件路径
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
# 新增manifest相关信息
manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据
license: str = "" # 插件许可证
homepage_url: str = "" # 插件主页
repository_url: str = "" # 插件仓库地址
keywords: List[str] = field(default_factory=list) # 插件关键词
categories: List[str] = field(default_factory=list) # 插件分类
min_host_version: str = "" # 最低主机版本要求
max_host_version: str = "" # 最高主机版本要求
def __post_init__(self):
if self.components is None:
self.components = []
if self.dependencies is None:
self.dependencies = []
if self.python_dependencies is None:
self.python_dependencies = []
if self.metadata is None:
self.metadata = {}
if self.manifest_data is None:
self.manifest_data = {}
if self.keywords is None:
self.keywords = []
if self.categories is None:
self.categories = []
def get_missing_packages(self) -> List[PythonDependency]:
"""检查缺失的Python包"""
missing = []
for dep in self.python_dependencies:
try:
__import__(dep.package_name)
except ImportError:
if not dep.optional:
missing.append(dep)
return missing
def get_pip_requirements(self) -> List[str]:
"""获取所有pip安装格式的依赖"""
return [dep.get_pip_requirement() for dep in self.python_dependencies]
@dataclass
class MaiMessages:
"""MaiM插件消息"""
message_segments: List[Seg] = field(default_factory=list)
"""消息段列表,支持多段消息"""
message_base_info: Dict[str, Any] = field(default_factory=dict)
"""消息基本信息,包含平台,用户信息等数据"""
plain_text: str = ""
"""纯文本消息内容"""
raw_message: Optional[str] = None
"""原始消息内容"""
is_group_message: bool = False
"""是否为群组消息"""
is_private_message: bool = False
"""是否为私聊消息"""
stream_id: Optional[str] = None
"""流ID用于标识消息流"""
llm_prompt: Optional[str] = None
"""LLM提示词"""
llm_response_content: Optional[str] = None
"""LLM响应内容"""
llm_response_reasoning: Optional[str] = None
"""LLM响应推理内容"""
llm_response_model: Optional[str] = None
"""LLM响应模型名称"""
llm_response_tool_call: Optional[List[ToolCall]] = None
"""LLM使用的工具调用"""
action_usage: Optional[List[str]] = None
"""使用的Action"""
additional_data: Dict[Any, Any] = field(default_factory=dict)
"""附加数据,可以存储额外信息"""
def __post_init__(self):
if self.message_segments is None:
self.message_segments = []

View File

@@ -0,0 +1,18 @@
"""
插件系统配置类型定义
"""
from typing import Any, Optional, List
from dataclasses import dataclass, field
@dataclass
class ConfigField:
"""配置字段定义"""
type: type # 字段类型
default: Any # 默认值
description: str # 字段描述
example: Optional[str] = None # 示例值
required: bool = False # 是否必需
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表

View File

@@ -0,0 +1,577 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Union
import os
import inspect
import toml
import json
import shutil
import datetime
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
PluginInfo,
PythonDependency,
)
from src.plugin_system.base.config_types import ConfigField
from src.plugin_system.utils.manifest_utils import ManifestValidator
logger = get_logger("plugin_base")
class PluginBase(ABC):
"""插件总基类
所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。
"""
# 插件基本信息(子类必须定义)
@property
@abstractmethod
def plugin_name(self) -> str:
return "" # 插件内部标识符(如 "hello_world_plugin"
@property
@abstractmethod
def enable_plugin(self) -> bool:
return True # 是否启用插件
@property
@abstractmethod
def dependencies(self) -> List[str]:
return [] # 依赖的其他插件
@property
@abstractmethod
def python_dependencies(self) -> List[PythonDependency]:
return [] # Python包依赖
@property
@abstractmethod
def config_file_name(self) -> str:
return "" # 配置文件名
# manifest文件相关
manifest_file_name: str = "_manifest.json" # manifest文件名
manifest_data: Dict[str, Any] = {} # manifest数据
# 配置定义
@property
@abstractmethod
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
return {}
config_section_descriptions: Dict[str, str] = {}
def __init__(self, plugin_dir: str):
"""初始化插件
Args:
plugin_dir: 插件目录路径,由插件管理器传递
"""
self.config: Dict[str, Any] = {} # 插件配置
self.plugin_dir = plugin_dir # 插件目录路径
self.log_prefix = f"[Plugin:{self.plugin_name}]"
# 加载manifest文件
self._load_manifest()
# 验证插件信息
self._validate_plugin_info()
# 加载插件配置
self._load_plugin_config()
# 从manifest获取显示信息
self.display_name = self.get_manifest_info("name", self.plugin_name)
self.plugin_version = self.get_manifest_info("version", "1.0.0")
self.plugin_description = self.get_manifest_info("description", "")
self.plugin_author = self._get_author_name()
# 创建插件信息对象
self.plugin_info = PluginInfo(
name=self.plugin_name,
display_name=self.display_name,
description=self.plugin_description,
version=self.plugin_version,
author=self.plugin_author,
enabled=self.enable_plugin,
is_built_in=False,
config_file=self.config_file_name or "",
dependencies=self.dependencies.copy(),
python_dependencies=self.python_dependencies.copy(),
# manifest相关信息
manifest_data=self.manifest_data.copy(),
license=self.get_manifest_info("license", ""),
homepage_url=self.get_manifest_info("homepage_url", ""),
repository_url=self.get_manifest_info("repository_url", ""),
keywords=self.get_manifest_info("keywords", []).copy() if self.get_manifest_info("keywords") else [],
categories=self.get_manifest_info("categories", []).copy() if self.get_manifest_info("categories") else [],
min_host_version=self.get_manifest_info("host_application.min_version", ""),
max_host_version=self.get_manifest_info("host_application.max_version", ""),
)
logger.debug(f"{self.log_prefix} 插件基类初始化完成")
def _validate_plugin_info(self):
"""验证插件基本信息"""
if not self.plugin_name:
raise ValueError(f"插件类 {self.__class__.__name__} 必须定义 plugin_name")
# 验证manifest中的必需信息
if not self.get_manifest_info("name"):
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少name字段")
if not self.get_manifest_info("description"):
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段")
def _load_manifest(self): # sourcery skip: raise-from-previous-error
"""加载manifest文件强制要求"""
if not self.plugin_dir:
raise ValueError(f"{self.log_prefix} 没有插件目录路径无法加载manifest")
manifest_path = os.path.join(self.plugin_dir, self.manifest_file_name)
if not os.path.exists(manifest_path):
error_msg = f"{self.log_prefix} 缺少必需的manifest文件: {manifest_path}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
try:
with open(manifest_path, "r", encoding="utf-8") as f:
self.manifest_data = json.load(f)
logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}")
# 验证manifest格式
self._validate_manifest()
except json.JSONDecodeError as e:
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
logger.error(error_msg)
raise ValueError(error_msg) # noqa
except IOError as e:
error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}"
logger.error(error_msg)
raise IOError(error_msg) # noqa
def _get_author_name(self) -> str:
"""从manifest获取作者名称"""
author_info = self.get_manifest_info("author", {})
if isinstance(author_info, dict):
return author_info.get("name", "")
else:
return str(author_info) if author_info else ""
def _validate_manifest(self):
"""验证manifest文件格式使用强化的验证器"""
if not self.manifest_data:
raise ValueError(f"{self.log_prefix} manifest数据为空验证失败")
validator = ManifestValidator()
is_valid = validator.validate_manifest(self.manifest_data)
# 记录验证结果
if validator.validation_errors or validator.validation_warnings:
report = validator.get_validation_report()
logger.info(f"{self.log_prefix} Manifest验证结果:\n{report}")
# 如果有验证错误,抛出异常
if not is_valid:
error_msg = f"{self.log_prefix} Manifest文件验证失败"
if validator.validation_errors:
error_msg += f": {'; '.join(validator.validation_errors)}"
raise ValueError(error_msg)
def get_manifest_info(self, key: str, default: Any = None) -> Any:
"""获取manifest信息
Args:
key: 信息键,支持点分割的嵌套键(如 "author.name"
default: 默认值
Returns:
Any: 对应的值
"""
if not self.manifest_data:
return default
keys = key.split(".")
value = self.manifest_data
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def _generate_and_save_default_config(self, config_file_path: str):
"""根据插件的Schema生成并保存默认配置文件"""
if not self.config_schema:
logger.debug(f"{self.log_prefix} 插件未定义config_schema不生成配置文件")
return
toml_str = f"# {self.plugin_name} - 自动生成的配置文件\n"
plugin_description = self.get_manifest_info("description", "插件配置文件")
toml_str += f"# {plugin_description}\n\n"
# 遍历每个配置节
for section, fields in self.config_schema.items():
# 添加节描述
if section in self.config_section_descriptions:
toml_str += f"# {self.config_section_descriptions[section]}\n"
toml_str += f"[{section}]\n\n"
# 遍历节内的字段
if isinstance(fields, dict):
for field_name, field in fields.items():
if isinstance(field, ConfigField):
# 添加字段描述
toml_str += f"# {field.description}"
if field.required:
toml_str += " (必需)"
toml_str += "\n"
# 如果有示例值,添加示例
if field.example:
toml_str += f"# 示例: {field.example}\n"
# 如果有可选值,添加说明
if field.choices:
choices_str = ", ".join(map(str, field.choices))
toml_str += f"# 可选值: {choices_str}\n"
# 添加字段值
value = field.default
if isinstance(value, str):
toml_str += f'{field_name} = "{value}"\n'
elif isinstance(value, bool):
toml_str += f"{field_name} = {str(value).lower()}\n"
else:
toml_str += f"{field_name} = {value}\n"
toml_str += "\n"
toml_str += "\n"
try:
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(toml_str)
logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}")
except IOError as e:
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
def _get_expected_config_version(self) -> str:
"""获取插件期望的配置版本号"""
# 从config_schema的plugin.config_version字段获取
if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict):
config_version_field = self.config_schema["plugin"].get("config_version")
if isinstance(config_version_field, ConfigField):
return config_version_field.default
return "1.0.0"
def _get_current_config_version(self, config: Dict[str, Any]) -> str:
"""从配置文件中获取当前版本号"""
if "plugin" in config and "config_version" in config["plugin"]:
return str(config["plugin"]["config_version"])
# 如果没有config_version字段视为最早的版本
return "0.0.0"
def _backup_config_file(self, config_file_path: str) -> str:
"""备份配置文件"""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{config_file_path}.backup_{timestamp}"
try:
shutil.copy2(config_file_path, backup_path)
logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}")
return backup_path
except Exception as e:
logger.error(f"{self.log_prefix} 备份配置文件失败: {e}")
return ""
def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
"""将旧配置值迁移到新配置结构中
Args:
old_config: 旧配置数据
new_config: 基于新schema生成的默认配置
Returns:
Dict[str, Any]: 迁移后的配置
"""
def migrate_section(
old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str
) -> Dict[str, Any]:
"""迁移单个配置节"""
result = new_section.copy()
for key, value in old_section.items():
if key in new_section:
# 特殊处理config_version字段总是使用新版本
if section_name == "plugin" and key == "config_version":
# 保持新的版本号,不迁移旧值
logger.debug(
f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})"
)
continue
# 键存在于新配置中,复制值
if isinstance(value, dict) and isinstance(new_section[key], dict):
# 递归处理嵌套字典
result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}")
else:
result[key] = value
logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}")
else:
# 键在新配置中不存在,记录警告
logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除")
return result
migrated_config = {}
# 迁移每个配置节
for section_name, new_section_data in new_config.items():
if (
section_name in old_config
and isinstance(old_config[section_name], dict)
and isinstance(new_section_data, dict)
):
migrated_config[section_name] = migrate_section(
old_config[section_name], new_section_data, section_name
)
else:
# 新增的节或类型不匹配,使用默认值
migrated_config[section_name] = new_section_data
if section_name in old_config:
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
# 检查旧配置中是否有新配置没有的节
for section_name in old_config:
if section_name not in migrated_config:
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
return migrated_config
def _generate_config_from_schema(self) -> Dict[str, Any]:
# sourcery skip: dict-comprehension
"""根据schema生成配置数据结构不写入文件"""
if not self.config_schema:
return {}
config_data = {}
# 遍历每个配置节
for section, fields in self.config_schema.items():
if isinstance(fields, dict):
section_data = {}
# 遍历节内的字段
for field_name, field in fields.items():
if isinstance(field, ConfigField):
section_data[field_name] = field.default
config_data[section] = section_data
return config_data
def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str):
"""将配置数据保存为TOML文件包含注释"""
if not self.config_schema:
logger.debug(f"{self.log_prefix} 插件未定义config_schema不生成配置文件")
return
toml_str = f"# {self.plugin_name} - 配置文件\n"
plugin_description = self.get_manifest_info("description", "插件配置文件")
toml_str += f"# {plugin_description}\n"
# 获取当前期望的配置版本
expected_version = self._get_expected_config_version()
toml_str += f"# 配置版本: {expected_version}\n\n"
# 遍历每个配置节
for section, fields in self.config_schema.items():
# 添加节描述
if section in self.config_section_descriptions:
toml_str += f"# {self.config_section_descriptions[section]}\n"
toml_str += f"[{section}]\n\n"
# 遍历节内的字段
if isinstance(fields, dict) and section in config_data:
section_data = config_data[section]
for field_name, field in fields.items():
if isinstance(field, ConfigField):
# 添加字段描述
toml_str += f"# {field.description}"
if field.required:
toml_str += " (必需)"
toml_str += "\n"
# 如果有示例值,添加示例
if field.example:
toml_str += f"# 示例: {field.example}\n"
# 如果有可选值,添加说明
if field.choices:
choices_str = ", ".join(map(str, field.choices))
toml_str += f"# 可选值: {choices_str}\n"
# 添加字段值(使用迁移后的值)
value = section_data.get(field_name, field.default)
if isinstance(value, str):
toml_str += f'{field_name} = "{value}"\n'
elif isinstance(value, bool):
toml_str += f"{field_name} = {str(value).lower()}\n"
elif isinstance(value, list):
# 格式化列表
if all(isinstance(item, str) for item in value):
formatted_list = "[" + ", ".join(f'"{item}"' for item in value) + "]"
else:
formatted_list = str(value)
toml_str += f"{field_name} = {formatted_list}\n"
else:
toml_str += f"{field_name} = {value}\n"
toml_str += "\n"
toml_str += "\n"
try:
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(toml_str)
logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}")
except IOError as e:
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
def _load_plugin_config(self): # sourcery skip: extract-method
"""加载插件配置文件,支持版本检查和自动迁移"""
if not self.config_file_name:
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
return
# 优先使用传入的插件目录路径
if self.plugin_dir:
plugin_dir = self.plugin_dir
else:
# fallback尝试从类的模块信息获取路径
try:
plugin_module_path = inspect.getfile(self.__class__)
plugin_dir = os.path.dirname(plugin_module_path)
except (TypeError, OSError):
# 最后的fallback从模块的__file__属性获取
module = inspect.getmodule(self.__class__)
if module and hasattr(module, "__file__") and module.__file__:
plugin_dir = os.path.dirname(module.__file__)
else:
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
return
config_file_path = os.path.join(plugin_dir, self.config_file_name)
# 如果配置文件不存在,生成默认配置
if not os.path.exists(config_file_path):
logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。")
self._generate_and_save_default_config(config_file_path)
if not os.path.exists(config_file_path):
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。")
return
file_ext = os.path.splitext(self.config_file_name)[1].lower()
if file_ext == ".toml":
# 加载现有配置
with open(config_file_path, "r", encoding="utf-8") as f:
existing_config = toml.load(f) or {}
# 检查配置版本
current_version = self._get_current_config_version(existing_config)
# 如果配置文件没有版本信息,跳过版本检查
if current_version == "0.0.0":
logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查")
self.config = existing_config
else:
expected_version = self._get_expected_config_version()
if current_version != expected_version:
logger.info(
f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
)
# 生成新的默认配置结构
new_config_structure = self._generate_config_from_schema()
# 迁移旧配置值到新结构
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
# 保存迁移后的配置
self._save_config_to_file(migrated_config, config_file_path)
logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}")
self.config = migrated_config
else:
logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载")
self.config = existing_config
logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
# 从配置中更新 enable_plugin
if "plugin" in self.config and "enabled" in self.config["plugin"]:
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
else:
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
self.config = {}
def _check_dependencies(self) -> bool:
"""检查插件依赖"""
from src.plugin_system.core.component_registry import component_registry
if not self.dependencies:
return True
for dep in self.dependencies:
if not component_registry.get_plugin_info(dep):
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep}")
return False
return True
def get_config(self, key: str, default: Any = None) -> Any:
"""获取插件配置值,支持嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
default: 默认值
Returns:
Any: 配置值或默认值
"""
# 支持嵌套键访问
keys = key.split(".")
current = self.config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
@abstractmethod
def register_plugin(self) -> bool:
"""
注册插件到插件管理器
子类必须实现此方法,返回注册是否成功
Returns:
bool: 是否成功注册插件
"""
raise NotImplementedError("Subclasses must implement this method")

View File

@@ -0,0 +1,19 @@
"""
插件核心管理模块
提供插件的加载、注册和管理功能
"""
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
__all__ = [
"plugin_manager",
"component_registry",
"events_manager",
"global_announcement_manager",
"hot_reload_manager",
]

View File

@@ -0,0 +1,688 @@
import re
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo,
ActionInfo,
ToolInfo,
CommandInfo,
EventHandlerInfo,
PluginInfo,
ComponentType,
)
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("component_registry")
class ComponentRegistry:
"""统一的组件注册中心
负责管理所有插件组件的注册、查询和生命周期管理
"""
def __init__(self):
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
self._components: Dict[str, ComponentInfo] = {}
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
self._plugins: Dict[str, PluginInfo] = {}
"""插件名 -> 插件信息"""
# Action特定注册表
self._action_registry: Dict[str, Type[BaseAction]] = {}
"""Action注册表 action名 -> action类"""
self._default_actions: Dict[str, ActionInfo] = {}
"""默认动作集即启用的Action集用于重置ActionManager状态"""
# Command特定注册表
self._command_registry: Dict[str, Type[BaseCommand]] = {}
"""Command类注册表 command名 -> command类"""
self._command_patterns: Dict[Pattern, str] = {}
"""编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
"""event_handler名 -> event_handler类"""
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
"""启用的事件处理器 event_handler名 -> event_handler类"""
logger.info("组件注册中心初始化完成")
# == 注册方法 ==
def register_plugin(self, plugin_info: PluginInfo) -> bool:
"""注册插件
Args:
plugin_info: 插件信息
Returns:
bool: 是否注册成功
"""
plugin_name = plugin_info.name
if plugin_name in self._plugins:
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
return False
self._plugins[plugin_name] = plugin_info
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
return True
def register_component(
self,
component_info: ComponentInfo,
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
) -> bool:
"""注册组件
Args:
component_info (ComponentInfo): 组件信息
component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
Returns:
bool: 是否注册成功
"""
component_name = component_info.name
component_type = component_info.component_type
plugin_name = getattr(component_info, "plugin_name", "unknown")
if "." in component_name:
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
namespaced_name = f"{component_type}.{component_name}"
if namespaced_name in self._components:
existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
logger.warning(
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
)
return False
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
self._components_classes[namespaced_name] = component_class
# 根据组件类型进行特定注册(使用原始名称)
match component_type:
case ComponentType.ACTION:
assert isinstance(component_info, ActionInfo)
assert issubclass(component_class, BaseAction)
ret = self._register_action_component(component_info, component_class)
case ComponentType.COMMAND:
assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class)
case ComponentType.TOOL:
assert isinstance(component_info, ToolInfo)
assert issubclass(component_class, BaseTool)
ret = self._register_tool_component(component_info, component_class)
case ComponentType.EVENT_HANDLER:
assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler)
ret = self._register_event_handler_component(component_info, component_class)
case _:
logger.warning(f"未知组件类型: {component_type}")
if not ret:
return False
logger.debug(
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
f"({component_class.__name__}) [插件: {plugin_name}]"
)
return True
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
"""注册Action组件到Action特定注册表"""
if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
return False
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action")
return False
self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
if action_info.enabled:
self._default_actions[action_name] = action_info
return True
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
"""注册Command组件到Command特定注册表"""
if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
return False
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
logger.error(f"注册失败: {command_name} 不是有效的Command")
return False
self._command_registry[command_name] = command_class
# 如果启用了且有匹配模式
if command_info.enabled and command_info.command_pattern:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
if pattern not in self._command_patterns:
self._command_patterns[pattern] = command_name
else:
logger.warning(
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
)
return True
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class
return True
def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool:
if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
return False
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False
self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
from .events_manager import events_manager # 延迟导入防止循环导入问题
if events_manager.register_event_subscriber(handler_info, handler_class):
self._enabled_event_handlers[handler_name] = handler_class
return True
else:
logger.error(f"注册事件处理器 {handler_name} 失败")
return False
# === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
target_component_class = self.get_component_class(component_name, component_type)
if not target_component_class:
logger.warning(f"组件 {component_name} 未注册,无法移除")
return False
try:
# 根据组件类型进行特定的清理操作
match component_type:
case ComponentType.ACTION:
# 移除Action注册
self._action_registry.pop(component_name, None)
self._default_actions.pop(component_name, None)
logger.debug(f"已移除Action组件: {component_name}")
case ComponentType.COMMAND:
# 移除Command注册和模式
self._command_registry.pop(component_name, None)
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key, None)
logger.debug(f"已移除Command组件: {component_name} (清理了 {len(keys_to_remove)} 个模式)")
case ComponentType.TOOL:
# 移除Tool注册
self._tool_registry.pop(component_name, None)
self._llm_available_tools.pop(component_name, None)
logger.debug(f"已移除Tool组件: {component_name}")
case ComponentType.EVENT_HANDLER:
# 移除EventHandler注册和事件订阅
from .events_manager import events_manager # 延迟导入防止循环导入问题
self._event_handler_registry.pop(component_name, None)
self._enabled_event_handlers.pop(component_name, None)
try:
await events_manager.unregister_event_subscriber(component_name)
logger.debug(f"已移除EventHandler组件: {component_name}")
except Exception as e:
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
case _:
logger.warning(f"未知的组件类型: {component_type}")
return False
# 移除通用注册信息
namespaced_name = f"{component_type}.{component_name}"
self._components.pop(namespaced_name, None)
self._components_by_type[component_type].pop(component_name, None)
self._components_classes.pop(namespaced_name, None)
logger.info(f"组件 {component_name} ({component_type}) 已完全移除")
return True
except Exception as e:
logger.error(f"移除组件 {component_name} ({component_type}) 时发生错误: {e}")
return False
def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息
Args:
plugin_name: 插件名称
Returns:
bool: 是否成功移除
"""
if plugin_name not in self._plugins:
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
return False
del self._plugins[plugin_name]
logger.info(f"插件 {plugin_name} 已移除")
return True
# === 组件全局启用/禁用方法 ===
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的启用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 启用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法启用")
return False
target_component_info.enabled = True
match component_type:
case ComponentType.ACTION:
assert isinstance(target_component_info, ActionInfo)
self._default_actions[component_name] = target_component_info
case ComponentType.COMMAND:
assert isinstance(target_component_info, CommandInfo)
pattern = target_component_info.command_pattern
self._command_patterns[re.compile(pattern)] = component_name
case ComponentType.TOOL:
assert isinstance(target_component_info, ToolInfo)
assert issubclass(target_component_class, BaseTool)
self._llm_available_tools[component_name] = target_component_class
case ComponentType.EVENT_HANDLER:
assert isinstance(target_component_info, EventHandlerInfo)
assert issubclass(target_component_class, BaseEventHandler)
self._enabled_event_handlers[component_name] = target_component_class
from .events_manager import events_manager # 延迟导入防止循环导入问题
events_manager.register_event_subscriber(target_component_info, target_component_class)
namespaced_name = f"{component_type}.{component_name}"
self._components[namespaced_name].enabled = True
self._components_by_type[component_type][component_name].enabled = True
logger.info(f"组件 {component_name} 已启用")
return True
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的禁用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 禁用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法禁用")
return False
target_component_info.enabled = False
try:
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.TOOL:
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name)
from .events_manager import events_manager # 延迟导入防止循环导入问题
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
except KeyError as e:
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
return False
# === 组件查询方法 ===
def get_component_info(
self, component_name: str, component_type: Optional[ComponentType] = None
) -> Optional[ComponentInfo]:
# sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[ComponentInfo]: 组件信息或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name:
return self._components.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
namespaced_name = f"{component_type}.{component_name}"
return self._components.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}"
if component_info := self._components.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_info))
if len(candidates) == 1:
# 只有一个匹配,直接返回
return candidates[0][2]
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_component_class(
self,
component_name: str,
component_type: Optional[ComponentType] = None,
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
"""获取组件类,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[Union[BaseCommand, BaseAction]]: 组件类或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name:
return self._components_classes.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
namespaced_name = f"{component_type.value}.{component_name}"
return self._components_classes.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}"
if component_class := self._components_classes.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_class))
if len(candidates) == 1:
# 只有一个匹配,直接返回
_, full_name, cls = candidates[0]
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
return cls
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
"""获取Action注册表"""
return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息"""
info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
# === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
"""获取Command注册表"""
return self._command_registry.copy()
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息"""
info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None
def get_command_patterns(self) -> Dict[Pattern, str]:
"""获取Command模式注册表"""
return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
# sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令
Args:
text: 输入文本
Returns:
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
"""
candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
if not candidates:
return None
if len(candidates) > 1:
logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
command_name = self._command_patterns[candidates[0]]
command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
return (
self._command_registry[command_name],
candidates[0].match(text).groupdict(), # type: ignore
command_info,
)
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息
Args:
tool_name: 工具名称
Returns:
ToolInfo: 工具信息对象,如果工具不存在则返回 None
"""
info = self.get_component_info(tool_name, ComponentType.TOOL)
return info if isinstance(info, ToolInfo) else None
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取事件处理器注册表"""
return self._event_handler_registry.copy()
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
"""获取事件处理器信息"""
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取启用的事件处理器"""
return self._enabled_event_handlers.copy()
# === 插件查询方法 ===
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息"""
return self._plugins.get(plugin_name)
def get_all_plugins(self) -> Dict[str, PluginInfo]:
"""获取所有插件"""
return self._plugins.copy()
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
# """获取所有启用的插件"""
# return {name: info for name, info in self._plugins.items() if info.enabled}
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
"""获取插件的所有组件"""
plugin_info = self.get_plugin_info(plugin_name)
return plugin_info.components if plugin_info else []
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
"""获取插件配置
Args:
plugin_name: 插件名称
Returns:
Optional[dict]: 插件配置字典或None
"""
# 从插件管理器获取插件实例的配置
from src.plugin_system.core.plugin_manager import plugin_manager
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
return plugin_instance.config if plugin_instance else None
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""
action_components: int = 0
command_components: int = 0
tool_components: int = 0
events_handlers: int = 0
for component in self._components.values():
if component.component_type == ComponentType.ACTION:
action_components += 1
elif component.component_type == ComponentType.COMMAND:
command_components += 1
elif component.component_type == ComponentType.TOOL:
tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1
return {
"action_components": action_components,
"command_components": command_components,
"tool_components": tool_components,
"event_handlers": events_handlers,
"total_components": len(self._components),
"total_plugins": len(self._plugins),
"components_by_type": {
component_type.value: len(components) for component_type, components in self._components_by_type.items()
},
"enabled_components": len([c for c in self._components.values() if c.enabled]),
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
}
# === 组件移除相关 ===
async def unregister_plugin(self, plugin_name: str) -> bool:
"""卸载插件及其所有组件
Args:
plugin_name: 插件名称
Returns:
bool: 是否成功卸载
"""
plugin_info = self.get_plugin_info(plugin_name)
if not plugin_info:
logger.warning(f"插件 {plugin_name} 未注册,无法卸载")
return False
logger.info(f"开始卸载插件: {plugin_name}")
# 记录卸载失败的组件
failed_components = []
# 逐个移除插件的所有组件
for component_info in plugin_info.components:
try:
success = await self.remove_component(
component_info.name,
component_info.component_type,
plugin_name,
)
if not success:
failed_components.append(f"{component_info.component_type}.{component_info.name}")
except Exception as e:
logger.error(f"移除组件 {component_info.name} 时发生异常: {e}")
failed_components.append(f"{component_info.component_type}.{component_info.name}")
# 移除插件注册信息
plugin_removed = self.remove_plugin_registry(plugin_name)
if failed_components:
logger.warning(f"插件 {plugin_name} 部分组件卸载失败: {failed_components}")
return False
elif not plugin_removed:
logger.error(f"插件 {plugin_name} 注册信息移除失败")
return False
else:
logger.info(f"插件 {plugin_name} 卸载成功")
return True
# 创建全局组件注册中心实例
component_registry = ComponentRegistry()

View File

@@ -0,0 +1,262 @@
import asyncio
import contextlib
from typing import List, Dict, Optional, Type, Tuple, Any
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
from src.plugin_system.base.base_events_handler import BaseEventHandler
from .global_announcement_manager import global_announcement_manager
logger = get_logger("events_manager")
class EventsManager:
def __init__(self):
# 有权重的 events 订阅者注册表
self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_info (EventHandlerInfo): 事件处理器信息
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 是否注册成功
"""
handler_name = handler_info.name
if handler_name in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return True
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False
self._handler_mapping[handler_name] = handler_class
return self._insert_event_handler(handler_class, handler_info)
async def handle_mai_events(
self,
event_type: EventType,
message: Optional[MessageRecv] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional[Dict[str, Any]] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> bool:
"""处理 events"""
from src.plugin_system.core import component_registry
continue_flag = True
transformed_message: Optional[MaiMessages] = None
if not message:
assert stream_id, "如果没有消息必须提供流ID"
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
transformed_message = self._build_message_from_stream(stream_id, llm_prompt, llm_response)
else:
transformed_message = self._transform_event_without_message(
stream_id, llm_prompt, llm_response, action_usage
)
else:
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
for handler in self._events_subscribers.get(event_type, []):
if transformed_message.stream_id:
stream_id = transformed_message.stream_id
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
continue
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
if handler.intercept_message:
try:
success, continue_processing, result = await handler.execute(transformed_message)
if not success:
logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}")
else:
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}")
continue_flag = continue_flag and continue_processing
except Exception as e:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}")
continue
else:
try:
handler_task = asyncio.create_task(handler.execute(transformed_message))
handler_task.add_done_callback(self._task_done_callback)
handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
if handler.handler_name not in self._handler_tasks:
self._handler_tasks[handler.handler_name] = []
self._handler_tasks[handler.handler_name].append(handler_task)
except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
continue
return continue_flag
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
if handler_class.event_type == EventType.UNKNOWN:
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
return False
handler_instance = handler_class()
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
self._events_subscribers[handler_class.event_type].append(handler_instance)
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
return True
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器"""
display_handler_name = handler_class.handler_name or handler_class.__name__
if handler_class.event_type == EventType.UNKNOWN:
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
return False
handlers = self._events_subscribers[handler_class.event_type]
for i, handler in enumerate(handlers):
if isinstance(handler, handler_class):
del handlers[i]
logger.debug(f"事件处理器 {display_handler_name} 已移除")
return True
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
return False
def _transform_event_message(
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
) -> MaiMessages:
"""转换事件消息格式"""
# 直接赋值部分内容
transformed_message = MaiMessages(
llm_prompt=llm_prompt,
llm_response_content=llm_response.get("content") if llm_response else None,
llm_response_reasoning=llm_response.get("reasoning") if llm_response else None,
llm_response_model=llm_response.get("model") if llm_response else None,
llm_response_tool_call=llm_response.get("tool_calls") if llm_response else None,
raw_message=message.raw_message,
additional_data=message.message_info.additional_config or {},
)
# 消息段处理
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
# stream_id 处理
if hasattr(message, "chat_stream") and message.chat_stream:
transformed_message.stream_id = message.chat_stream.stream_id
# 处理后文本
transformed_message.plain_text = message.processed_plain_text
# 基本信息
if hasattr(message, "message_info") and message.message_info:
if message.message_info.platform:
transformed_message.message_base_info["platform"] = message.message_info.platform
if message.message_info.group_info:
transformed_message.is_group_message = True
transformed_message.message_base_info.update(
{
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
}
)
if message.message_info.user_info:
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
transformed_message.message_base_info.update(
{
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
}
)
return transformed_message
def _build_message_from_stream(
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
) -> MaiMessages:
"""从流ID构建消息"""
chat_stream = get_chat_manager().get_stream(stream_id)
assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流"
message = chat_stream.context.get_last_message()
return self._transform_event_message(message, llm_prompt, llm_response)
def _transform_event_without_message(
self,
stream_id: str,
llm_prompt: Optional[str] = None,
llm_response: Optional[Dict[str, Any]] = None,
action_usage: Optional[List[str]] = None,
) -> MaiMessages:
"""没有message对象时进行转换"""
chat_stream = get_chat_manager().get_stream(stream_id)
assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流"
return MaiMessages(
stream_id=stream_id,
llm_prompt=llm_prompt,
llm_response_content=(llm_response.get("content") if llm_response else None),
llm_response_reasoning=(llm_response.get("reasoning") if llm_response else None),
llm_response_model=llm_response.get("model") if llm_response else None,
llm_response_tool_call=(llm_response.get("tool_calls") if llm_response else None),
is_group_message=(not (not chat_stream.group_info)),
is_private_message=(not chat_stream.group_info),
action_usage=action_usage,
additional_data={"response_is_processed": True},
)
def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]):
"""任务完成回调"""
task_name = task.get_name() or "Unknown Task"
try:
success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
if success:
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
else:
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
finally:
with contextlib.suppress(ValueError, KeyError):
self._handler_tasks[task_name].remove(task)
async def cancel_handler_tasks(self, handler_name: str) -> None:
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
for task in remaining_tasks:
task.cancel()
try:
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
except asyncio.TimeoutError:
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
except Exception as e:
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
if handler_name in self._handler_tasks:
del self._handler_tasks[handler_name]
async def unregister_event_subscriber(self, handler_name: str) -> bool:
"""取消注册事件处理器"""
if handler_name not in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
return False
await self.cancel_handler_tasks(handler_name)
handler_class = self._handler_mapping.pop(handler_name)
if not self._remove_event_handler_instance(handler_class):
return False
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
return True
events_manager = EventsManager()

View File

@@ -0,0 +1,120 @@
from typing import List, Dict
from src.common.logger import get_logger
logger = get_logger("global_announcement_manager")
class GlobalAnnouncementManager:
def __init__(self) -> None:
# 用户禁用的动作chat_id -> [action_name]
self._user_disabled_actions: Dict[str, List[str]] = {}
# 用户禁用的命令chat_id -> [command_name]
self._user_disabled_commands: Dict[str, List[str]] = {}
# 用户禁用的事件处理器chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
# 用户禁用的工具chat_id -> [tool_name]
self._user_disabled_tools: Dict[str, List[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
if chat_id not in self._user_disabled_actions:
self._user_disabled_actions[chat_id] = []
if action_name in self._user_disabled_actions[chat_id]:
logger.warning(f"动作 {action_name} 已经被禁用")
return False
self._user_disabled_actions[chat_id].append(action_name)
return True
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""启用特定聊天的某个动作"""
if chat_id in self._user_disabled_actions:
try:
self._user_disabled_actions[chat_id].remove(action_name)
return True
except ValueError:
logger.warning(f"动作 {action_name} 不在禁用列表中")
return False
return False
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
"""禁用特定聊天的某个命令"""
if chat_id not in self._user_disabled_commands:
self._user_disabled_commands[chat_id] = []
if command_name in self._user_disabled_commands[chat_id]:
logger.warning(f"命令 {command_name} 已经被禁用")
return False
self._user_disabled_commands[chat_id].append(command_name)
return True
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
"""启用特定聊天的某个命令"""
if chat_id in self._user_disabled_commands:
try:
self._user_disabled_commands[chat_id].remove(command_name)
return True
except ValueError:
logger.warning(f"命令 {command_name} 不在禁用列表中")
return False
return False
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
"""禁用特定聊天的某个事件处理器"""
if chat_id not in self._user_disabled_event_handlers:
self._user_disabled_event_handlers[chat_id] = []
if handler_name in self._user_disabled_event_handlers[chat_id]:
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
return False
self._user_disabled_event_handlers[chat_id].append(handler_name)
return True
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
"""启用特定聊天的某个事件处理器"""
if chat_id in self._user_disabled_event_handlers:
try:
self._user_disabled_event_handlers[chat_id].remove(handler_name)
return True
except ValueError:
logger.warning(f"事件处理器 {handler_name} 不在禁用列表中")
return False
return False
def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""禁用特定聊天的某个工具"""
if chat_id not in self._user_disabled_tools:
self._user_disabled_tools[chat_id] = []
if tool_name in self._user_disabled_tools[chat_id]:
logger.warning(f"工具 {tool_name} 已经被禁用")
return False
self._user_disabled_tools[chat_id].append(tool_name)
return True
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""启用特定聊天的某个工具"""
if chat_id in self._user_disabled_tools:
try:
self._user_disabled_tools[chat_id].remove(tool_name)
return True
except ValueError:
logger.warning(f"工具 {tool_name} 不在禁用列表中")
return False
return False
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有命令"""
return self._user_disabled_commands.get(chat_id, []).copy()
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()
global_announcement_manager = GlobalAnnouncementManager()

View File

@@ -0,0 +1,242 @@
"""
插件热重载模块
使用 Watchdog 监听插件目录变化,自动重载插件
"""
import os
import time
from pathlib import Path
from threading import Thread
from typing import Dict, Set
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from src.common.logger import get_logger
from .plugin_manager import plugin_manager
logger = get_logger("plugin_hot_reload")
class PluginFileHandler(FileSystemEventHandler):
"""插件文件变化处理器"""
def __init__(self, hot_reload_manager):
super().__init__()
self.hot_reload_manager = hot_reload_manager
self.pending_reloads: Set[str] = set() # 待重载的插件名称
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
self.debounce_delay = 1.0 # 防抖延迟(秒)
def on_modified(self, event):
"""文件修改事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "modified")
def on_created(self, event):
"""文件创建事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "created")
def on_deleted(self, event):
"""文件删除事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "deleted")
def _handle_file_change(self, file_path: str, change_type: str):
"""处理文件变化"""
try:
# 获取插件名称
plugin_name = self._get_plugin_name_from_path(file_path)
if not plugin_name:
return
current_time = time.time()
last_time = self.last_reload_time.get(plugin_name, 0)
# 防抖处理,避免频繁重载
if current_time - last_time < self.debounce_delay:
return
file_name = Path(file_path).name
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type})")
# 如果是删除事件,处理关键文件删除
if change_type == "deleted":
if file_name == "plugin.py":
if plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name}")
self.hot_reload_manager._unload_plugin(plugin_name)
return
elif file_name == "manifest.toml":
if plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name}")
self.hot_reload_manager._unload_plugin(plugin_name)
return
# 对于修改和创建事件,都进行重载
# 添加到待重载列表
self.pending_reloads.add(plugin_name)
self.last_reload_time[plugin_name] = current_time
# 延迟重载,避免文件正在写入时重载
reload_thread = Thread(
target=self._delayed_reload,
args=(plugin_name,),
daemon=True
)
reload_thread.start()
except Exception as e:
logger.error(f"❌ 处理文件变化时发生错误: {e}")
def _delayed_reload(self, plugin_name: str):
"""延迟重载插件"""
try:
time.sleep(self.debounce_delay)
if plugin_name in self.pending_reloads:
self.pending_reloads.remove(plugin_name)
self.hot_reload_manager._reload_plugin(plugin_name)
except Exception as e:
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}")
def _get_plugin_name_from_path(self, file_path: str) -> str:
"""从文件路径获取插件名称"""
try:
path = Path(file_path)
# 检查是否在监听的插件目录中
plugin_root = Path(self.hot_reload_manager.watch_directory)
if not path.is_relative_to(plugin_root):
return ""
# 获取插件目录名(插件名)
relative_path = path.relative_to(plugin_root)
plugin_name = relative_path.parts[0]
# 确认这是一个有效的插件目录(检查是否有 plugin.py 或 manifest.toml
plugin_dir = plugin_root / plugin_name
if plugin_dir.is_dir() and ((plugin_dir / "plugin.py").exists() or (plugin_dir / "manifest.toml").exists()):
return plugin_name
return ""
except Exception:
return ""
class PluginHotReloadManager:
"""插件热重载管理器"""
def __init__(self, watch_directory: str = None):
print("fuck")
print(os.getcwd())
self.watch_directory = os.path.join(os.getcwd(), "plugins")
self.observer = None
self.file_handler = None
self.is_running = False
# 确保监听目录存在
if not os.path.exists(self.watch_directory):
os.makedirs(self.watch_directory, exist_ok=True)
logger.info(f"创建插件监听目录: {self.watch_directory}")
def start(self):
"""启动热重载监听"""
if self.is_running:
logger.warning("插件热重载已经在运行中")
return
try:
self.observer = Observer()
self.file_handler = PluginFileHandler(self)
self.observer.schedule(
self.file_handler,
self.watch_directory,
recursive=True
)
self.observer.start()
self.is_running = True
logger.info("🚀 插件热重载已启动,监听目录: plugins")
except Exception as e:
logger.error(f"❌ 启动插件热重载失败: {e}")
self.is_running = False
def stop(self):
"""停止热重载监听"""
if not self.is_running:
return
if self.observer:
self.observer.stop()
self.observer.join()
self.is_running = False
def _reload_plugin(self, plugin_name: str):
"""重载指定插件"""
try:
logger.info(f"🔄 开始重载插件: {plugin_name}")
if plugin_manager.reload_plugin(plugin_name):
logger.info(f"✅ 插件重载成功: {plugin_name}")
else:
logger.error(f"❌ 插件重载失败: {plugin_name}")
except Exception as e:
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}")
def _unload_plugin(self, plugin_name: str):
"""卸载指定插件"""
try:
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
if plugin_manager.unload_plugin(plugin_name):
logger.info(f"✅ 插件卸载成功: {plugin_name}")
else:
logger.error(f"❌ 插件卸载失败: {plugin_name}")
except Exception as e:
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}")
def reload_all_plugins(self):
"""重载所有插件"""
try:
logger.info("🔄 开始重载所有插件...")
# 获取当前已加载的插件列表
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
success_count = 0
fail_count = 0
for plugin_name in loaded_plugins:
if plugin_manager.reload_plugin(plugin_name):
success_count += 1
else:
fail_count += 1
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count}")
except Exception as e:
logger.error(f"❌ 重载所有插件时发生错误: {e}")
def get_status(self) -> dict:
"""获取热重载状态"""
return {
"is_running": self.is_running,
"watch_directory": self.watch_directory,
"loaded_plugins": len(plugin_manager.loaded_plugins),
"failed_plugins": len(plugin_manager.failed_plugins),
}
# 全局热重载管理器实例
hot_reload_manager = PluginHotReloadManager()

View File

@@ -0,0 +1,593 @@
import os
import traceback
import sys
from typing import Dict, List, Optional, Tuple, Type, Any
from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
from src.common.logger import get_logger
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
logger = get_logger("plugin_manager")
class PluginManager:
"""
插件管理器类
负责加载,重载和卸载插件,同时管理插件的所有组件
"""
def __init__(self):
self.plugin_directories: List[str] = [] # 插件根目录列表
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
# 确保插件目录存在
self._ensure_plugin_directories()
logger.info("插件管理器初始化完成")
# === 插件目录管理 ===
def add_plugin_directory(self, directory: str) -> bool:
"""添加插件目录"""
if os.path.exists(directory):
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件目录: {directory}")
return True
else:
logger.warning(f"插件不可重复加载: {directory}")
else:
logger.warning(f"插件目录不存在: {directory}")
return False
# === 插件加载管理 ===
def load_all_plugins(self) -> Tuple[int, int]:
"""加载所有插件
Returns:
tuple[int, int]: (插件数量, 组件数量)
"""
logger.debug("开始加载所有插件...")
# 第一阶段:加载所有插件模块(注册插件类)
total_loaded_modules = 0
total_failed_modules = 0
for directory in self.plugin_directories:
loaded, failed = self._load_plugin_modules_from_directory(directory)
total_loaded_modules += loaded
total_failed_modules += failed
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
total_registered = 0
total_failed_registration = 0
for plugin_name in self.plugin_classes.keys():
load_status, count = self.load_registered_plugin_classes(plugin_name)
if load_status:
total_registered += 1
else:
total_failed_registration += count
self._show_stats(total_registered, total_failed_registration)
return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
"""
加载已经注册的插件类
"""
plugin_class = self.plugin_classes.get(plugin_name)
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
try:
# 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name)
# 如果没有记录,直接返回失败
if not plugin_dir:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
return False, 0
# 检查版本兼容性
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
plugin_name, plugin_instance.manifest_data
)
if not is_compatible:
self.failed_plugins[plugin_name] = compatibility_error
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
return False, 1
if plugin_instance.register_plugin():
self.loaded_plugins[plugin_name] = plugin_instance
self._show_plugin_components(plugin_name)
return True, 1
else:
self.failed_plugins[plugin_name] = "插件注册失败"
logger.error(f"❌ 插件注册失败: {plugin_name}")
return False, 1
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
return False, 1
async def remove_registered_plugin(self, plugin_name: str) -> bool:
"""
禁用插件模块
"""
if not plugin_name:
raise ValueError("插件名称不能为空")
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载")
return False
plugin_instance = self.loaded_plugins[plugin_name]
plugin_info = plugin_instance.plugin_info
success = True
for component in plugin_info.components:
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
success &= component_registry.remove_plugin_registry(plugin_name)
del self.loaded_plugins[plugin_name]
return success
async def reload_registered_plugin(self, plugin_name: str) -> bool:
"""
重载插件模块
"""
if not await self.remove_registered_plugin(plugin_name):
return False
if not self.load_registered_plugin_classes(plugin_name)[0]:
return False
logger.debug(f"插件 {plugin_name} 重载成功")
return True
def rescan_plugin_directory(self) -> Tuple[int, int]:
"""
重新扫描插件根目录
"""
total_success = 0
total_fail = 0
for directory in self.plugin_directories:
if os.path.exists(directory):
logger.debug(f"重新扫描插件根目录: {directory}")
success, fail = self._load_plugin_modules_from_directory(directory)
total_success += success
total_fail += fail
else:
logger.warning(f"插件根目录不存在: {directory}")
return total_success, total_fail
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例
Args:
plugin_name: 插件名称
Returns:
Optional[BasePlugin]: 插件实例或None
"""
return self.loaded_plugins.get(plugin_name)
# === 查询方法 ===
def list_loaded_plugins(self) -> List[str]:
"""
列出所有当前加载的插件。
Returns:
list: 当前加载的插件名称列表。
"""
return list(self.loaded_plugins.keys())
def list_registered_plugins(self) -> List[str]:
"""
列出所有已注册的插件类。
Returns:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
"""
获取指定插件的路径。
Args:
plugin_name: 插件名称
Returns:
Optional[str]: 插件目录的绝对路径如果插件不存在则返回None。
"""
return self.plugin_paths.get(plugin_name)
# === 私有方法 ===
# == 目录管理 ==
def _ensure_plugin_directories(self) -> None:
"""确保所有插件根目录存在,如果不存在则创建"""
default_directories = ["src/plugins/built_in", "plugins"]
for directory in default_directories:
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"创建插件根目录: {directory}")
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件根目录: {directory}")
else:
logger.warning(f"根目录不可重复加载: {directory}")
# == 插件加载 ==
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
"""从指定目录加载插件模块"""
loaded_count = 0
failed_count = 0
if not os.path.exists(directory):
logger.warning(f"插件根目录不存在: {directory}")
return 0, 1
logger.debug(f"正在扫描插件根目录: {directory}")
# 遍历目录中的所有包
for item in os.listdir(directory):
item_path = os.path.join(directory, item)
if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
plugin_file = os.path.join(item_path, "plugin.py")
if os.path.exists(plugin_file):
if self._load_plugin_module_file(plugin_file):
loaded_count += 1
else:
failed_count += 1
return loaded_count, failed_count
def _load_plugin_module_file(self, plugin_file: str) -> bool:
# sourcery skip: extract-method
"""加载单个插件模块文件
Args:
plugin_file: 插件文件路径
plugin_name: 插件名称
plugin_dir: 插件目录路径
"""
# 生成模块名
plugin_path = Path(plugin_file)
module_name = ".".join(plugin_path.parent.parts)
try:
# 动态导入插件模块
spec = spec_from_file_location(module_name, plugin_file)
if spec is None or spec.loader is None:
logger.error(f"无法创建模块规范: {plugin_file}")
return False
module = module_from_spec(spec)
module.__package__ = module_name # 设置模块包名
spec.loader.exec_module(module)
logger.debug(f"插件模块加载成功: {plugin_file}")
return True
except Exception as e:
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
logger.error(error_msg)
self.failed_plugins[module_name] = error_msg
return False
# == 兼容性检查 ==
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
"""检查插件版本兼容性
Args:
plugin_name: 插件名称
manifest_data: manifest数据
Returns:
Tuple[bool, str]: (是否兼容, 错误信息)
"""
if "host_application" not in manifest_data:
return True, "" # 没有版本要求,默认兼容
host_app = manifest_data["host_application"]
if not isinstance(host_app, dict):
return True, ""
min_version = host_app.get("min_version", "")
max_version = host_app.get("max_version", "")
if not min_version and not max_version:
return True, "" # 没有版本要求,默认兼容
try:
current_version = VersionComparator.get_current_host_version()
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
if not is_compatible:
return False, f"版本不兼容: {error_msg}"
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
return True, ""
except Exception as e:
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
# == 显示统计与插件信息 ==
def _show_stats(self, total_registered: int, total_failed_registration: int):
# sourcery skip: low-code-quality
# 获取组件统计信息
stats = component_registry.get_registry_stats()
action_count = stats.get("action_components", 0)
command_count = stats.get("command_components", 0)
tool_count = stats.get("tool_components", 0)
event_handler_count = stats.get("event_handlers", 0)
total_components = stats.get("total_components", 0)
# 📋 显示插件加载总览
if total_registered > 0:
logger.info("🎉 插件系统加载完成!")
logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
)
# 显示详细的插件列表
logger.info("📋 已加载插件详情:")
for plugin_name in self.loaded_plugins.keys():
if plugin_info := component_registry.get_plugin_info(plugin_name):
# 插件基本信息
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
info_parts = [part for part in [version_info, author_info, license_info] if part]
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
logger.info(f" 📦 {plugin_info.display_name}{extra_info}")
# Manifest信息
if plugin_info.manifest_data:
"""
if plugin_info.keywords:
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
if plugin_info.categories:
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
"""
if plugin_info.homepage_url:
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
# 组件列表
if plugin_info.components:
action_components = [
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
]
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
tool_components = [
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
if action_components:
action_names = [c.name for c in action_components]
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
if command_components:
command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
if tool_components:
tool_names = [c.name for c in tool_components]
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
if event_handler_components:
event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
# 依赖信息
if plugin_info.dependencies:
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
# 配置文件信息
if plugin_info.config_file:
config_status = "" if self.plugin_paths.get(plugin_name) else ""
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
# 显示目录统计
logger.info("📂 加载目录统计:")
for directory in self.plugin_directories:
if os.path.exists(directory):
plugins_in_dir = []
for plugin_name in self.loaded_plugins.keys():
plugin_path = self.plugin_paths.get(plugin_name, "")
if (
Path(plugin_path)
.resolve()
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
):
plugins_in_dir.append(plugin_name)
if plugins_in_dir:
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
else:
logger.info(f" 📁 {directory}: 0个插件")
# 失败信息
if total_failed_registration > 0:
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
for failed_plugin, error in self.failed_plugins.items():
logger.info(f"{failed_plugin}: {error}")
else:
logger.warning("😕 没有成功加载任何插件")
def _show_plugin_components(self, plugin_name: str) -> None:
if plugin_info := component_registry.get_plugin_info(plugin_name):
component_types = {}
for comp in plugin_info.components:
comp_type = comp.component_type.name
component_types[comp_type] = component_types.get(comp_type, 0) + 1
components_str = ", ".join([f"{count}{ctype}" for ctype, count in component_types.items()])
# 显示manifest信息
manifest_info = ""
if plugin_info.license:
manifest_info += f" [{plugin_info.license}]"
if plugin_info.keywords:
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
if len(plugin_info.keywords) > 3:
manifest_info += "..."
logger.info(
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
)
else:
logger.info(f"✅ 插件加载成功: {plugin_name}")
# === 插件卸载和重载管理 ===
def unload_plugin(self, plugin_name: str) -> bool:
"""卸载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 卸载是否成功
"""
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载,无需卸载")
return False
try:
# 获取插件实例
plugin_instance = self.loaded_plugins[plugin_name]
# 调用插件的清理方法(如果有的话)
if hasattr(plugin_instance, 'on_unload'):
plugin_instance.on_unload()
# 从组件注册表中移除插件的所有组件
component_registry.unregister_plugin(plugin_name)
# 从已加载插件中移除
del self.loaded_plugins[plugin_name]
# 从失败列表中移除(如果存在)
if plugin_name in self.failed_plugins:
del self.failed_plugins[plugin_name]
logger.info(f"✅ 插件卸载成功: {plugin_name}")
return True
except Exception as e:
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}")
return False
def reload_plugin(self, plugin_name: str) -> bool:
"""重载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""
try:
# 先卸载插件
if plugin_name in self.loaded_plugins:
self.unload_plugin(plugin_name)
# 清除Python模块缓存
plugin_path = self.plugin_paths.get(plugin_name)
if plugin_path:
plugin_file = os.path.join(plugin_path, "plugin.py")
if os.path.exists(plugin_file):
# 从sys.modules中移除相关模块
modules_to_remove = []
plugin_module_prefix = ".".join(Path(plugin_file).parent.parts)
for module_name in sys.modules:
if module_name.startswith(plugin_module_prefix):
modules_to_remove.append(module_name)
for module_name in modules_to_remove:
del sys.modules[module_name]
# 从插件类注册表中移除
if plugin_name in self.plugin_classes:
del self.plugin_classes[plugin_name]
# 重新加载插件模块
if self._load_plugin_module_file(plugin_file):
# 重新加载插件实例
success, _ = self.load_registered_plugin_classes(plugin_name)
if success:
logger.info(f"🔄 插件重载成功: {plugin_name}")
return True
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 模块加载失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件文件不存在")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件路径未知")
return False
except Exception as e:
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}")
logger.debug("详细错误信息: ", exc_info=True)
return False
# 全局插件管理器实例
plugin_manager = PluginManager()

View File

@@ -0,0 +1,421 @@
import time
from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
logger = get_logger("tool_use")
def init_tool_executor_prompt():
"""初始化工具执行器的提示词"""
tool_executor_prompt = """
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}
群里正在进行的聊天内容:
{chat_history}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 内容中是否包含需要查询信息的问题
2. 是否有明确的工具使用指令
If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
"""
Prompt(tool_executor_prompt, "tool_executor_prompt")
# 初始化提示词
init_tool_executor_prompt()
class ToolExecutor:
"""独立的工具执行器组件
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
"""
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
"""初始化工具执行器
Args:
executor_id: 执行器标识符,用于日志记录
enable_cache: 是否启用缓存机制
cache_ttl: 缓存生存时间(周期数)
"""
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 缓存配置
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}")
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict[str, Any]], List[str], str]:
"""从聊天消息执行工具
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
return_details: 是否返回详细信息(使用的工具列表和提示词)
Returns:
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
# 首先检查缓存
cache_key = self._generate_cache_key(target_message, chat_history, sender)
if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if not return_details:
return cached_result, [], ""
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, ""
# 缓存未命中,执行工具调用
# 获取可用工具
tools = self._get_tool_definitions()
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
# 构建工具调用提示词
prompt = await global_prompt_manager.format_prompt(
"tool_executor_prompt",
target_message=target_message,
chat_history=chat_history,
sender=sender,
bot_name=bot_name,
time_now=time_now,
)
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
# 调用LLM进行工具决策
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools, raise_when_empty=False
)
# 执行工具调用
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
# 缓存结果
if tool_results:
self._set_cache(cache_key, tool_results)
if used_tools:
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
if return_details:
return tool_results, used_tools, prompt
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [definition for name, definition in all_tools if name not in user_disabled_tools]
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用
Args:
tool_calls: LLM返回的工具调用列表
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
tool_results: List[Dict[str, Any]] = []
used_tools = []
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return [], []
# 提取tool_calls中的函数名称
func_names = [call.func_name for call in tool_calls if call.func_name]
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
# 执行每个工具调用
for tool_call in tool_calls:
try:
tool_name = tool_call.func_name
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
result = await self.execute_tool_call(tool_call)
if result:
tool_info = {
"type": result.get("type", "unknown_type"),
"id": result.get("id", f"tool_exec_{time.time()}"),
"content": result.get("content", ""),
"tool_name": tool_name,
"timestamp": time.time(),
}
content = tool_info["content"]
if not isinstance(content, (str, list, tuple)):
tool_info["content"] = str(content)
tool_results.append(tool_info)
used_tools.append(tool_name)
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
preview = content[:200]
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
# 添加错误信息到结果中
error_info = {
"type": "tool_error",
"id": f"tool_error_{time.time()}",
"content": f"工具{tool_name}执行失败: {str(e)}",
"tool_name": tool_name,
"timestamp": time.time(),
}
tool_results.append(error_info)
return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
Args:
tool_call: 工具调用对象
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
try:
function_name = tool_call.func_name
function_args = tool_call.args or {}
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
if result:
return {
"tool_call_id": tool_call.call_id,
"role": "tool",
"name": function_name,
"type": "function",
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
raise e
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
Returns:
str: 缓存键
"""
import hashlib
# 使用消息内容和群聊状态生成唯一缓存键
content = f"{target_message}_{chat_history}_{sender}"
return hashlib.md5(content.encode()).hexdigest()
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
"""从缓存获取结果
Args:
cache_key: 缓存键
Returns:
Optional[List[Dict]]: 缓存的结果如果不存在或过期则返回None
"""
if not self.enable_cache or cache_key not in self.tool_cache:
return None
cache_item = self.tool_cache[cache_key]
if cache_item["ttl"] <= 0:
# 缓存过期,删除
del self.tool_cache[cache_key]
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
return None
# 减少TTL
cache_item["ttl"] -= 1
logger.debug(f"{self.log_prefix}使用缓存结果剩余TTL: {cache_item['ttl']}")
return cache_item["result"]
def _set_cache(self, cache_key: str, result: List[Dict]):
"""设置缓存
Args:
cache_key: 缓存键
result: 要缓存的结果
"""
if not self.enable_cache:
return
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
logger.debug(f"{self.log_prefix}设置缓存TTL: {self.cache_ttl}")
def _cleanup_expired_cache(self):
"""清理过期的缓存"""
if not self.enable_cache:
return
expired_keys = []
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
for key in expired_keys:
del self.tool_cache[key]
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
"""直接执行指定工具
Args:
tool_name: 工具名称
tool_args: 工具参数
validate_args: 是否验证参数
Returns:
Optional[Dict]: 工具执行结果失败时返回None
"""
try:
tool_call = ToolCall(
call_id=f"direct_tool_{time.time()}",
func_name=tool_name,
args=tool_args,
)
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
result = await self.execute_tool_call(tool_call)
if result:
tool_info = {
"type": result.get("type", "unknown_type"),
"id": result.get("id", f"direct_tool_{time.time()}"),
"content": result.get("content", ""),
"tool_name": tool_name,
"timestamp": time.time(),
}
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
return tool_info
except Exception as e:
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
return None
def clear_cache(self):
"""清空所有缓存"""
if self.enable_cache:
cache_count = len(self.tool_cache)
self.tool_cache.clear()
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
def get_cache_status(self) -> Dict:
"""获取缓存状态信息
Returns:
Dict: 包含缓存统计信息的字典
"""
if not self.enable_cache:
return {"enabled": False, "cache_count": 0}
# 清理过期缓存
self._cleanup_expired_cache()
total_count = len(self.tool_cache)
ttl_distribution = {}
for cache_item in self.tool_cache.values():
ttl = cache_item["ttl"]
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
return {
"enabled": True,
"cache_count": total_count,
"cache_ttl": self.cache_ttl,
"ttl_distribution": ttl_distribution,
}
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
"""动态修改缓存配置
Args:
enable_cache: 是否启用缓存
cache_ttl: 缓存TTL
"""
if enable_cache is not None:
self.enable_cache = enable_cache
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
if cache_ttl > 0:
self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
"""
ToolExecutor使用示例
# 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3
executor = ToolExecutor(executor_id="my_executor")
results, _, _ = await executor.execute_from_chat_message(
talking_message_str="今天天气怎么样?现在几点了?",
is_group_chat=False
)
# 2. 禁用缓存的执行器
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
# 3. 自定义缓存TTL
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
# 4. 获取详细信息
results, used_tools, prompt = await executor.execute_from_chat_message(
talking_message_str="帮我查询Python相关知识",
is_group_chat=False,
return_details=True
)
# 5. 直接执行特定工具
result = await executor.execute_specific_tool_simple(
tool_name="get_knowledge",
tool_args={"query": "机器学习"}
)
# 6. 缓存管理
cache_status = executor.get_cache_status() # 查看缓存状态
executor.clear_cache() # 清空缓存
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
"""

View File

@@ -0,0 +1,19 @@
"""
插件系统工具模块
提供插件开发和管理的实用工具
"""
from .manifest_utils import (
ManifestValidator,
# ManifestGenerator,
# validate_plugin_manifest,
# generate_plugin_manifest,
)
__all__ = [
"ManifestValidator",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
]

View File

@@ -0,0 +1,515 @@
"""
插件Manifest工具模块
提供manifest文件的验证、生成和管理功能
"""
import re
from typing import Dict, Any, Tuple
from src.common.logger import get_logger
from src.config.config import MMC_VERSION
# if TYPE_CHECKING:
# from src.plugin_system.base.base_plugin import BasePlugin
logger = get_logger("manifest_utils")
class VersionComparator:
"""版本号比较器
支持语义化版本号比较自动处理snapshot版本并支持向前兼容性检查
"""
# 版本兼容性映射表(硬编码)
# 格式: {插件最大支持版本: [实际兼容的版本列表]}
COMPATIBILITY_MAP = {
# 0.8.x 系列向前兼容规则
"0.8.0": ["0.8.1", "0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.1": ["0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.2": ["0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.3": ["0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.4": ["0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.5": ["0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.6": ["0.8.7", "0.8.8", "0.8.9", "0.8.10"],
"0.8.7": ["0.8.8", "0.8.9", "0.8.10"],
"0.8.8": ["0.8.9", "0.8.10"],
"0.8.9": ["0.8.10"],
# 可以根据需要添加更多兼容映射
# "0.9.0": ["0.9.1", "0.9.2", "0.9.3"], # 示例0.9.x系列兼容
}
@staticmethod
def normalize_version(version: str) -> str:
"""标准化版本号移除snapshot标识
Args:
version: 原始版本号,如 "0.8.0-snapshot.1"
Returns:
str: 标准化后的版本号,如 "0.8.0"
"""
if not version:
return "0.0.0"
# 移除snapshot部分
normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
# 确保版本号格式正确
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
# 如果不是有效的版本号格式,返回默认版本
return "0.0.0"
# 尝试补全版本号
parts = normalized.split(".")
while len(parts) < 3:
parts.append("0")
normalized = ".".join(parts[:3])
return normalized
@staticmethod
def parse_version(version: str) -> Tuple[int, int, int]:
"""解析版本号为元组
Args:
version: 版本号字符串
Returns:
Tuple[int, int, int]: (major, minor, patch)
"""
normalized = VersionComparator.normalize_version(version)
try:
parts = normalized.split(".")
return (int(parts[0]), int(parts[1]), int(parts[2]))
except (ValueError, IndexError):
logger.warning(f"无法解析版本号: {version},使用默认版本 0.0.0")
return (0, 0, 0)
@staticmethod
def compare_versions(version1: str, version2: str) -> int:
"""比较两个版本号
Args:
version1: 第一个版本号
version2: 第二个版本号
Returns:
int: -1 if version1 < version2, 0 if equal, 1 if version1 > version2
"""
v1_tuple = VersionComparator.parse_version(version1)
v2_tuple = VersionComparator.parse_version(version2)
if v1_tuple < v2_tuple:
return -1
elif v1_tuple > v2_tuple:
return 1
else:
return 0
@staticmethod
def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]:
"""检查向前兼容性(仅使用兼容性映射表)
Args:
current_version: 当前版本
max_version: 插件声明的最大支持版本
Returns:
Tuple[bool, str]: (是否兼容, 兼容信息)
"""
current_normalized = VersionComparator.normalize_version(current_version)
max_normalized = VersionComparator.normalize_version(max_version)
# 检查兼容性映射表
if max_normalized in VersionComparator.COMPATIBILITY_MAP:
compatible_versions = VersionComparator.COMPATIBILITY_MAP[max_normalized]
if current_normalized in compatible_versions:
return True, f"根据兼容性映射表,版本 {current_normalized}{max_normalized} 兼容"
return False, ""
@staticmethod
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
"""检查版本是否在指定范围内,支持兼容性检查
Args:
version: 要检查的版本号
min_version: 最小版本号(可选)
max_version: 最大版本号(可选)
Returns:
Tuple[bool, str]: (是否兼容, 错误信息或兼容信息)
"""
if not min_version and not max_version:
return True, ""
version_normalized = VersionComparator.normalize_version(version)
# 检查最小版本
if min_version:
min_normalized = VersionComparator.normalize_version(min_version)
if VersionComparator.compare_versions(version_normalized, min_normalized) < 0:
return False, f"版本 {version_normalized} 低于最小要求版本 {min_normalized}"
# 检查最大版本
if max_version:
max_normalized = VersionComparator.normalize_version(max_version)
comparison = VersionComparator.compare_versions(version_normalized, max_normalized)
if comparison > 0:
# 严格版本检查失败,尝试兼容性检查
is_compatible, compat_msg = VersionComparator.check_forward_compatibility(
version_normalized, max_normalized
)
if not is_compatible:
return False, f"版本 {version_normalized} 高于最大支持版本 {max_normalized},且无兼容性映射"
logger.info(f"版本兼容性检查:{compat_msg}")
return True, compat_msg
return True, ""
@staticmethod
def get_current_host_version() -> str:
"""获取当前主机应用版本
Returns:
str: 当前版本号
"""
return VersionComparator.normalize_version(MMC_VERSION)
@staticmethod
def add_compatibility_mapping(base_version: str, compatible_versions: list) -> None:
"""动态添加兼容性映射
Args:
base_version: 基础版本(插件声明的最大支持版本)
compatible_versions: 兼容的版本列表
"""
base_normalized = VersionComparator.normalize_version(base_version)
VersionComparator.COMPATIBILITY_MAP[base_normalized] = [
VersionComparator.normalize_version(v) for v in compatible_versions
]
logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}")
@staticmethod
def get_compatibility_info() -> Dict[str, list]:
"""获取当前的兼容性映射表
Returns:
Dict[str, list]: 兼容性映射表的副本
"""
return VersionComparator.COMPATIBILITY_MAP.copy()
class ManifestValidator:
"""Manifest文件验证器"""
# 必需字段(必须存在且不能为空)
REQUIRED_FIELDS = ["manifest_version", "name", "version", "description", "author"]
# 可选字段(可以不存在或为空)
OPTIONAL_FIELDS = [
"license",
"host_application",
"homepage_url",
"repository_url",
"keywords",
"categories",
"default_locale",
"locales_path",
"plugin_info",
]
# 建议填写的字段(会给出警告但不会导致验证失败)
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
SUPPORTED_MANIFEST_VERSIONS = [1]
def __init__(self):
self.validation_errors = []
self.validation_warnings = []
def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool:
"""验证manifest数据
Args:
manifest_data: manifest数据字典
Returns:
bool: 是否验证通过(只有错误会导致验证失败,警告不会)
"""
self.validation_errors.clear()
self.validation_warnings.clear()
# 检查必需字段
for field in self.REQUIRED_FIELDS:
if field not in manifest_data:
self.validation_errors.append(f"缺少必需字段: {field}")
elif not manifest_data[field]:
self.validation_errors.append(f"必需字段不能为空: {field}")
# 检查manifest版本
if "manifest_version" in manifest_data:
version = manifest_data["manifest_version"]
if version not in self.SUPPORTED_MANIFEST_VERSIONS:
self.validation_errors.append(
f"不支持的manifest版本: {version},支持的版本: {self.SUPPORTED_MANIFEST_VERSIONS}"
)
# 检查作者信息格式
if "author" in manifest_data:
author = manifest_data["author"]
if isinstance(author, dict):
if "name" not in author or not author["name"]:
self.validation_errors.append("作者信息缺少name字段或为空")
# url字段是可选的
if "url" in author and author["url"]:
url = author["url"]
if not (url.startswith("http://") or url.startswith("https://")):
self.validation_warnings.append("作者URL建议使用完整的URL格式")
elif isinstance(author, str):
if not author.strip():
self.validation_errors.append("作者信息不能为空")
else:
self.validation_errors.append("作者信息格式错误应为字符串或包含name字段的对象")
# 检查主机应用版本要求(可选)
if "host_application" in manifest_data:
host_app = manifest_data["host_application"]
if isinstance(host_app, dict):
min_version = host_app.get("min_version", "")
max_version = host_app.get("max_version", "")
# 验证版本字段格式
for version_field in ["min_version", "max_version"]:
if version_field in host_app and not host_app[version_field]:
self.validation_warnings.append(f"host_application.{version_field}为空")
# 检查当前主机版本兼容性
if min_version or max_version:
current_version = VersionComparator.get_current_host_version()
is_compatible, error_msg = VersionComparator.is_version_in_range(
current_version, min_version, max_version
)
if not is_compatible:
self.validation_errors.append(f"版本兼容性检查失败: {error_msg} (当前版本: {current_version})")
else:
logger.debug(
f"版本兼容性检查通过: 当前版本 {current_version} 符合要求 [{min_version}, {max_version}]"
)
else:
self.validation_errors.append("host_application格式错误应为对象")
# 检查URL格式可选字段
for url_field in ["homepage_url", "repository_url"]:
if url_field in manifest_data and manifest_data[url_field]:
url: str = manifest_data[url_field]
if not (url.startswith("http://") or url.startswith("https://")):
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")
# 检查数组字段格式(可选字段)
for list_field in ["keywords", "categories"]:
if list_field in manifest_data:
field_value = manifest_data[list_field]
if field_value is not None and not isinstance(field_value, list):
self.validation_errors.append(f"{list_field}应为数组格式")
elif isinstance(field_value, list):
# 检查数组元素是否为字符串
for i, item in enumerate(field_value):
if not isinstance(item, str):
self.validation_warnings.append(f"{list_field}[{i}]应为字符串")
# 检查建议字段(给出警告)
for field in self.RECOMMENDED_FIELDS:
if field not in manifest_data or not manifest_data[field]:
self.validation_warnings.append(f"建议填写字段: {field}")
# 检查plugin_info结构可选
if "plugin_info" in manifest_data:
plugin_info = manifest_data["plugin_info"]
if isinstance(plugin_info, dict):
# 检查components数组
if "components" in plugin_info:
components = plugin_info["components"]
if not isinstance(components, list):
self.validation_errors.append("plugin_info.components应为数组格式")
else:
for i, component in enumerate(components):
if not isinstance(component, dict):
self.validation_errors.append(f"plugin_info.components[{i}]应为对象")
else:
# 检查组件必需字段
for comp_field in ["type", "name", "description"]:
if comp_field not in component or not component[comp_field]:
self.validation_errors.append(
f"plugin_info.components[{i}]缺少必需字段: {comp_field}"
)
else:
self.validation_errors.append("plugin_info应为对象格式")
return len(self.validation_errors) == 0
def get_validation_report(self) -> str:
"""获取验证报告"""
report = []
if self.validation_errors:
report.append("❌ 验证错误:")
report.extend(f" - {error}" for error in self.validation_errors)
if self.validation_warnings:
report.append("⚠️ 验证警告:")
report.extend(f" - {warning}" for warning in self.validation_warnings)
if not self.validation_errors and not self.validation_warnings:
report.append("✅ Manifest文件验证通过")
return "\n".join(report)
# class ManifestGenerator:
# """Manifest文件生成器"""
# def __init__(self):
# self.template = {
# "manifest_version": 1,
# "name": "",
# "version": "1.0.0",
# "description": "",
# "author": {"name": "", "url": ""},
# "license": "MIT",
# "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
# "homepage_url": "",
# "repository_url": "",
# "keywords": [],
# "categories": [],
# "default_locale": "zh-CN",
# "locales_path": "_locales",
# }
# def generate_from_plugin(self, plugin_instance: BasePlugin) -> Dict[str, Any]:
# """从插件实例生成manifest
# Args:
# plugin_instance: BasePlugin实例
# Returns:
# Dict[str, Any]: 生成的manifest数据
# """
# manifest = self.template.copy()
# # 基本信息
# manifest["name"] = plugin_instance.plugin_name
# manifest["version"] = plugin_instance.plugin_version
# manifest["description"] = plugin_instance.plugin_description
# # 作者信息
# if plugin_instance.plugin_author:
# manifest["author"]["name"] = plugin_instance.plugin_author
# # 组件信息
# components = []
# plugin_components = plugin_instance.get_plugin_components()
# for component_info, component_class in plugin_components:
# component_data: Dict[str, Any] = {
# "type": component_info.component_type.value,
# "name": component_info.name,
# "description": component_info.description,
# }
# # 添加激活模式信息对于Action组件
# if hasattr(component_class, "focus_activation_type"):
# activation_modes = []
# if hasattr(component_class, "focus_activation_type"):
# activation_modes.append(component_class.focus_activation_type.value)
# if hasattr(component_class, "normal_activation_type"):
# activation_modes.append(component_class.normal_activation_type.value)
# component_data["activation_modes"] = list(set(activation_modes))
# # 添加关键词信息
# if hasattr(component_class, "activation_keywords"):
# keywords = getattr(component_class, "activation_keywords", [])
# if keywords:
# component_data["keywords"] = keywords
# components.append(component_data)
# manifest["plugin_info"] = {"is_built_in": True, "plugin_type": "general", "components": components}
# return manifest
# def save_manifest(self, manifest_data: Dict[str, Any], plugin_dir: str) -> bool:
# """保存manifest文件
# Args:
# manifest_data: manifest数据
# plugin_dir: 插件目录
# Returns:
# bool: 是否保存成功
# """
# try:
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
# with open(manifest_path, "w", encoding="utf-8") as f:
# json.dump(manifest_data, f, ensure_ascii=False, indent=2)
# logger.info(f"Manifest文件已保存: {manifest_path}")
# return True
# except Exception as e:
# logger.error(f"保存manifest文件失败: {e}")
# return False
# def validate_plugin_manifest(plugin_dir: str) -> bool:
# """验证插件目录中的manifest文件
# Args:
# plugin_dir: 插件目录路径
# Returns:
# bool: 是否验证通过
# """
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
# if not os.path.exists(manifest_path):
# logger.warning(f"未找到manifest文件: {manifest_path}")
# return False
# try:
# with open(manifest_path, "r", encoding="utf-8") as f:
# manifest_data = json.load(f)
# validator = ManifestValidator()
# is_valid = validator.validate_manifest(manifest_data)
# logger.info(f"Manifest验证结果:\n{validator.get_validation_report()}")
# return is_valid
# except Exception as e:
# logger.error(f"读取或验证manifest文件失败: {e}")
# return False
# def generate_plugin_manifest(plugin_instance: BasePlugin, save_to_file: bool = True) -> Optional[Dict[str, Any]]:
# """为插件生成manifest文件
# Args:
# plugin_instance: BasePlugin实例
# save_to_file: 是否保存到文件
# Returns:
# Optional[Dict[str, Any]]: 生成的manifest数据
# """
# try:
# generator = ManifestGenerator()
# manifest_data = generator.generate_from_plugin(plugin_instance)
# if save_to_file and plugin_instance.plugin_dir:
# generator.save_manifest(manifest_data, plugin_instance.plugin_dir)
# return manifest_data
# except Exception as e:
# logger.error(f"生成manifest文件失败: {e}")
# return None